veza/veza-backend-api/internal/testutils/db.go
2026-03-05 23:03:43 +01:00

392 lines
10 KiB
Go

package testutils
import (
"context"
"database/sql"
"fmt"
"strings"
"testing"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// allowedTestTables is a whitelist of table names that are safe to use
// in DELETE FROM / TRUNCATE statements in test utilities.
// This prevents SQL injection via table name interpolation in fmt.Sprintf.
var allowedTestTables = map[string]bool{
"users": true,
"user_sessions": true,
"user_profiles": true,
"user_settings": true,
"user_roles": true,
"user_blocks": true,
"roles": true,
"role_permissions": true,
"permissions": true,
"tracks": true,
"track_likes": true,
"track_plays": true,
"track_comments": true,
"track_shares": true,
"track_versions": true,
"track_history": true,
"playlists": true,
"playlist_tracks": true,
"playlist_collaborators": true,
"playlist_follows": true,
"playlist_share_links": true,
"messages": true,
"rooms": true,
"room_members": true,
"notifications": true,
"follows": true,
"likes": true,
"comments": true,
"posts": true,
"sessions": true,
"jobs": true,
"audit_logs": true,
"refresh_tokens": true,
"password_reset_tokens": true,
"email_verification_tokens": true,
"federated_identities": true,
"files": true,
"file_uploads": true,
"file_metadata": true,
"file_conversions": true,
"analytics_events": true,
"admin_settings": true,
"webhooks": true,
"webhook_failures": true,
"playback_history": true,
"queues": true,
"queue_items": true,
"hls_streams": true,
"hls_transcode_queue": true,
"bitrate_adaptation_logs": true,
}
// validateTableName checks that a table name is in the allowed whitelist.
// Panics if the name is not allowed, preventing potential SQL injection.
func validateTableName(table string) {
// Strip "public." prefix if present
clean := strings.TrimPrefix(table, "public.")
if !allowedTestTables[clean] {
panic(fmt.Sprintf("SECURITY: table name %q is not in the allowed whitelist for test cleanup", table))
}
}
// SetupTestDB creates a connection to the test container database.
// It ensures the container is running and the schema is migrated.
// The container is shared across tests (singleton in setup.go), so be mindful of data state.
func SetupTestDB() *gorm.DB {
dsn, err := GetTestContainerDB(context.Background())
if err != nil {
panic(fmt.Sprintf("failed to setup test db container: %v", err))
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
panic(fmt.Sprintf("failed to connect to test db: %v", err))
}
return db
}
// CleanupTestDB closes the SQL connection.
// Note: It does NOT stop the container.
func CleanupTestDB(db *gorm.DB) error {
if db == nil {
return nil
}
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// ResetTestDB deletes all data from the database to ensure a clean state.
// It respects foreign key constraints by deleting in the correct order.
func ResetTestDB(db *gorm.DB) error {
if db == nil {
return nil
}
// Supprimer toutes les données dans l'ordre pour respecter les contraintes de clés étrangères
// L'ordre inverse de création (ou celui qui respecte les FK)
tables := []string{
"messages",
"playlist_tracks",
"role_permissions",
"user_roles",
"permissions",
"roles",
"room_members",
"rooms",
"tracks",
"playlists",
"refresh_tokens",
"sessions",
"users",
"user_profiles",
"audit_logs",
"mfa_configs",
"recovery_codes",
}
for _, table := range tables {
// Use TRUNCATE CASCADE for Postgres which is faster and handles FKs better
// But TRUNCATE cannot be used easily if tables are referenced by others unless CASCADE is used.
// Also, we need to check if table exists to avoid errors?
// With the container setup, tables should always exist.
// Validate table name against whitelist before interpolation
validateTableName(table)
// For simplicity and safety, we try DELETE or TRUNCATE CASCADE.
// TRUNCATE table_name CASCADE;
if err := db.Exec(fmt.Sprintf("TRUNCATE TABLE %s CASCADE", table)).Error; err != nil {
// If TRUNCATE fails (e.g. permissions?), fallback to DELETE
// Also ignore if table doesn't exist (though it should)
_ = db.Exec(fmt.Sprintf("DELETE FROM %s", table))
}
}
return nil
}
// GetDBStats retourne les statistiques de la base de données de test
func GetDBStats(db *gorm.DB) (*sql.DBStats, error) {
if db == nil {
return nil, nil
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
stats := sqlDB.Stats()
return &stats, nil
}
// CleanupOptions configure le comportement du cleanup (T0049)
type CleanupOptions struct {
Cascade bool
UseTransaction bool
SkipForeignKeys bool
Tables []string // Si spécifié, nettoie uniquement ces tables
}
// CleanupDatabaseWithOptions nettoie avec options (T0049)
func CleanupDatabaseWithOptions(t *testing.T, db *gorm.DB, opts CleanupOptions) error {
var dbInstance *gorm.DB
if opts.UseTransaction {
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
dbInstance = tx
defer tx.Rollback()
} else {
dbInstance = db
}
return cleanupTables(t, dbInstance, opts)
}
func cleanupTables(t *testing.T, db *gorm.DB, opts CleanupOptions) error {
sqlDB, err := db.DB()
if err != nil {
return fmt.Errorf("failed to get sql.DB: %w", err)
}
driverName := sqlDB.Driver()
driverType := fmt.Sprintf("%T", driverName)
isPostgreSQL := !contains(driverType, "sqlite")
if !opts.SkipForeignKeys {
if isPostgreSQL {
if err := db.Exec("SET session_replication_role = 'replica'").Error; err != nil {
t.Logf("Warning: Failed to disable foreign keys: %v", err)
}
defer func() {
if err := db.Exec("SET session_replication_role = 'origin'").Error; err != nil {
t.Logf("Warning: Failed to re-enable foreign keys: %v", err)
}
}()
} else {
// SQLite
if err := db.Exec("PRAGMA foreign_keys = OFF").Error; err != nil {
t.Logf("Warning: Failed to disable foreign keys: %v", err)
}
defer func() {
if err := db.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
t.Logf("Warning: Failed to re-enable foreign keys: %v", err)
}
}()
}
}
tables := opts.Tables
if len(tables) == 0 {
tables = getAllTables(t, db, isPostgreSQL)
}
for _, table := range tables {
// Validate table name against whitelist before interpolation
validateTableName(table)
var query string
if opts.Cascade && isPostgreSQL {
// CASCADE est supporté par PostgreSQL
query = fmt.Sprintf("TRUNCATE TABLE %s CASCADE", table)
} else {
// Pour SQLite ou sans cascade, utiliser DELETE FROM
query = fmt.Sprintf("DELETE FROM %s", table)
}
if err := db.Exec(query).Error; err != nil {
t.Logf("Warning: Failed to cleanup table %s: %v", table, err)
// Continue avec les autres tables
}
}
return nil
}
// contains vérifie si une chaîne contient une sous-chaîne (utilitaire pour détection DB)
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// getAllTables récupère la liste de toutes les tables (T0049)
func getAllTables(t *testing.T, db *gorm.DB, isPostgreSQL bool) []string {
var tables []string
if isPostgreSQL {
query := `
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
ORDER BY tablename
`
rows, err := db.Raw(query).Rows()
if err != nil {
t.Logf("Warning: Failed to get table list: %v", err)
return getDefaultTables()
}
defer rows.Close()
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
t.Logf("Warning: Failed to scan table name: %v", err)
continue
}
tables = append(tables, tableName)
}
} else {
// SQLite
query := `
SELECT name
FROM sqlite_master
WHERE type='table' AND name NOT LIKE 'sqlite_%'
ORDER BY name
`
rows, err := db.Raw(query).Rows()
if err != nil {
t.Logf("Warning: Failed to get table list: %v", err)
return getDefaultTables()
}
defer rows.Close()
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
t.Logf("Warning: Failed to scan table name: %v", err)
continue
}
tables = append(tables, tableName)
}
}
if len(tables) == 0 {
return getDefaultTables()
}
return tables
}
// getDefaultTables retourne la liste par défaut des tables (T0049)
func getDefaultTables() []string {
return []string{
"messages",
"playlist_tracks",
"role_permissions",
"user_roles",
"permissions",
"roles",
"playlists",
"tracks",
"refresh_tokens",
"room_members",
"rooms",
"users",
"oauth_accounts",
"user_profiles",
"sessions",
"audit_logs",
"mfa_configs",
"recovery_codes",
}
}
// RegisterCleanupHook enregistre un hook de cleanup (T0049)
func RegisterCleanupHook(t *testing.T, hook func()) {
t.Cleanup(hook)
}
// CleanupWithTransaction nettoie avec une transaction (T0049)
func CleanupWithTransaction(t *testing.T, db *gorm.DB, cleanupFunc func(*gorm.DB)) error {
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
cleanupFunc(tx)
return tx.Rollback().Error
}
// CleanupSpecificTables nettoie uniquement les tables spécifiées (T0049)
func CleanupSpecificTables(t *testing.T, db *gorm.DB, tables []string) error {
opts := CleanupOptions{
Cascade: true,
UseTransaction: false,
SkipForeignKeys: false,
Tables: tables,
}
return CleanupDatabaseWithOptions(t, db, opts)
}