package testutils import ( "context" "database/sql" "fmt" "strings" "testing" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" ) // allowedTestTables is a whitelist of table names that are safe to use // in DELETE FROM / TRUNCATE statements in test utilities. // This prevents SQL injection via table name interpolation in fmt.Sprintf. var allowedTestTables = map[string]bool{ "users": true, "user_sessions": true, "user_profiles": true, "user_settings": true, "user_roles": true, "user_blocks": true, "roles": true, "role_permissions": true, "permissions": true, "tracks": true, "track_likes": true, "track_plays": true, "track_comments": true, "track_shares": true, "track_versions": true, "track_history": true, "playlists": true, "playlist_tracks": true, "playlist_collaborators": true, "playlist_follows": true, "playlist_share_links": true, "messages": true, "rooms": true, "room_members": true, "notifications": true, "follows": true, "likes": true, "comments": true, "posts": true, "sessions": true, "jobs": true, "audit_logs": true, "refresh_tokens": true, "password_reset_tokens": true, "email_verification_tokens": true, "federated_identities": true, "files": true, "file_uploads": true, "file_metadata": true, "file_conversions": true, "analytics_events": true, "admin_settings": true, "webhooks": true, "webhook_failures": true, "playback_history": true, "queues": true, "queue_items": true, "hls_streams": true, "hls_transcode_queue": true, "bitrate_adaptation_logs": true, } // validateTableName checks that a table name is in the allowed whitelist. // Panics if the name is not allowed, preventing potential SQL injection. func validateTableName(table string) { // Strip "public." prefix if present clean := strings.TrimPrefix(table, "public.") if !allowedTestTables[clean] { panic(fmt.Sprintf("SECURITY: table name %q is not in the allowed whitelist for test cleanup", table)) } } // SetupTestDB creates a connection to the test container database. // It ensures the container is running and the schema is migrated. // The container is shared across tests (singleton in setup.go), so be mindful of data state. func SetupTestDB() *gorm.DB { dsn, err := GetTestContainerDB(context.Background()) if err != nil { panic(fmt.Sprintf("failed to setup test db container: %v", err)) } db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { panic(fmt.Sprintf("failed to connect to test db: %v", err)) } return db } // CleanupTestDB closes the SQL connection. // Note: It does NOT stop the container. func CleanupTestDB(db *gorm.DB) error { if db == nil { return nil } sqlDB, err := db.DB() if err != nil { return err } return sqlDB.Close() } // ResetTestDB deletes all data from the database to ensure a clean state. // It respects foreign key constraints by deleting in the correct order. func ResetTestDB(db *gorm.DB) error { if db == nil { return nil } // Supprimer toutes les données dans l'ordre pour respecter les contraintes de clés étrangères // L'ordre inverse de création (ou celui qui respecte les FK) tables := []string{ "messages", "playlist_tracks", "role_permissions", "user_roles", "permissions", "roles", "room_members", "rooms", "tracks", "playlists", "refresh_tokens", "sessions", "users", "user_profiles", "audit_logs", "mfa_configs", "recovery_codes", } for _, table := range tables { // Use TRUNCATE CASCADE for Postgres which is faster and handles FKs better // But TRUNCATE cannot be used easily if tables are referenced by others unless CASCADE is used. // Also, we need to check if table exists to avoid errors? // With the container setup, tables should always exist. // Validate table name against whitelist before interpolation validateTableName(table) // For simplicity and safety, we try DELETE or TRUNCATE CASCADE. // TRUNCATE table_name CASCADE; if err := db.Exec(fmt.Sprintf("TRUNCATE TABLE %s CASCADE", table)).Error; err != nil { // If TRUNCATE fails (e.g. permissions?), fallback to DELETE // Also ignore if table doesn't exist (though it should) _ = db.Exec(fmt.Sprintf("DELETE FROM %s", table)) } } return nil } // GetDBStats retourne les statistiques de la base de données de test func GetDBStats(db *gorm.DB) (*sql.DBStats, error) { if db == nil { return nil, nil } sqlDB, err := db.DB() if err != nil { return nil, err } stats := sqlDB.Stats() return &stats, nil } // CleanupOptions configure le comportement du cleanup (T0049) type CleanupOptions struct { Cascade bool UseTransaction bool SkipForeignKeys bool Tables []string // Si spécifié, nettoie uniquement ces tables } // CleanupDatabaseWithOptions nettoie avec options (T0049) func CleanupDatabaseWithOptions(t *testing.T, db *gorm.DB, opts CleanupOptions) error { var dbInstance *gorm.DB if opts.UseTransaction { tx := db.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() panic(r) } }() dbInstance = tx defer tx.Rollback() } else { dbInstance = db } return cleanupTables(t, dbInstance, opts) } func cleanupTables(t *testing.T, db *gorm.DB, opts CleanupOptions) error { sqlDB, err := db.DB() if err != nil { return fmt.Errorf("failed to get sql.DB: %w", err) } driverName := sqlDB.Driver() driverType := fmt.Sprintf("%T", driverName) isPostgreSQL := !contains(driverType, "sqlite") if !opts.SkipForeignKeys { if isPostgreSQL { if err := db.Exec("SET session_replication_role = 'replica'").Error; err != nil { t.Logf("Warning: Failed to disable foreign keys: %v", err) } defer func() { if err := db.Exec("SET session_replication_role = 'origin'").Error; err != nil { t.Logf("Warning: Failed to re-enable foreign keys: %v", err) } }() } else { // SQLite if err := db.Exec("PRAGMA foreign_keys = OFF").Error; err != nil { t.Logf("Warning: Failed to disable foreign keys: %v", err) } defer func() { if err := db.Exec("PRAGMA foreign_keys = ON").Error; err != nil { t.Logf("Warning: Failed to re-enable foreign keys: %v", err) } }() } } tables := opts.Tables if len(tables) == 0 { tables = getAllTables(t, db, isPostgreSQL) } for _, table := range tables { // Validate table name against whitelist before interpolation validateTableName(table) var query string if opts.Cascade && isPostgreSQL { // CASCADE est supporté par PostgreSQL query = fmt.Sprintf("TRUNCATE TABLE %s CASCADE", table) } else { // Pour SQLite ou sans cascade, utiliser DELETE FROM query = fmt.Sprintf("DELETE FROM %s", table) } if err := db.Exec(query).Error; err != nil { t.Logf("Warning: Failed to cleanup table %s: %v", table, err) // Continue avec les autres tables } } return nil } // contains vérifie si une chaîne contient une sous-chaîne (utilitaire pour détection DB) func contains(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false } // getAllTables récupère la liste de toutes les tables (T0049) func getAllTables(t *testing.T, db *gorm.DB, isPostgreSQL bool) []string { var tables []string if isPostgreSQL { query := ` SELECT tablename FROM pg_tables WHERE schemaname = 'public' ORDER BY tablename ` rows, err := db.Raw(query).Rows() if err != nil { t.Logf("Warning: Failed to get table list: %v", err) return getDefaultTables() } defer rows.Close() for rows.Next() { var tableName string if err := rows.Scan(&tableName); err != nil { t.Logf("Warning: Failed to scan table name: %v", err) continue } tables = append(tables, tableName) } } else { // SQLite query := ` SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name ` rows, err := db.Raw(query).Rows() if err != nil { t.Logf("Warning: Failed to get table list: %v", err) return getDefaultTables() } defer rows.Close() for rows.Next() { var tableName string if err := rows.Scan(&tableName); err != nil { t.Logf("Warning: Failed to scan table name: %v", err) continue } tables = append(tables, tableName) } } if len(tables) == 0 { return getDefaultTables() } return tables } // getDefaultTables retourne la liste par défaut des tables (T0049) func getDefaultTables() []string { return []string{ "messages", "playlist_tracks", "role_permissions", "user_roles", "permissions", "roles", "playlists", "tracks", "refresh_tokens", "room_members", "rooms", "users", "oauth_accounts", "user_profiles", "sessions", "audit_logs", "mfa_configs", "recovery_codes", } } // RegisterCleanupHook enregistre un hook de cleanup (T0049) func RegisterCleanupHook(t *testing.T, hook func()) { t.Cleanup(hook) } // CleanupWithTransaction nettoie avec une transaction (T0049) func CleanupWithTransaction(t *testing.T, db *gorm.DB, cleanupFunc func(*gorm.DB)) error { tx := db.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() panic(r) } }() cleanupFunc(tx) return tx.Rollback().Error } // CleanupSpecificTables nettoie uniquement les tables spécifiées (T0049) func CleanupSpecificTables(t *testing.T, db *gorm.DB, tables []string) error { opts := CleanupOptions{ Cascade: true, UseTransaction: false, SkipForeignKeys: false, Tables: tables, } return CleanupDatabaseWithOptions(t, db, opts) }