veza/veza-backend-api/internal/testutils/db.go

319 lines
7.5 KiB
Go

package testutils
import (
"context"
"database/sql"
"fmt"
"testing"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// SetupTestDB creates a connection to the test container database.
// It ensures the container is running and the schema is migrated.
// The container is shared across tests (singleton in setup.go), so be mindful of data state.
func SetupTestDB() *gorm.DB {
dsn, err := GetTestContainerDB(context.Background())
if err != nil {
panic(fmt.Sprintf("failed to setup test db container: %v", err))
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
panic(fmt.Sprintf("failed to connect to test db: %v", err))
}
return db
}
// CleanupTestDB closes the SQL connection.
// Note: It does NOT stop the container.
func CleanupTestDB(db *gorm.DB) error {
if db == nil {
return nil
}
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// ResetTestDB deletes all data from the database to ensure a clean state.
// It respects foreign key constraints by deleting in the correct order.
func ResetTestDB(db *gorm.DB) error {
if db == nil {
return nil
}
// Supprimer toutes les données dans l'ordre pour respecter les contraintes de clés étrangères
// L'ordre inverse de création (ou celui qui respecte les FK)
tables := []string{
"messages",
"playlist_tracks",
"role_permissions",
"user_roles",
"permissions",
"roles",
"room_members",
"rooms",
"tracks",
"playlists",
"refresh_tokens",
"sessions",
"users",
"user_profiles",
"audit_logs",
"mfa_configs",
"recovery_codes",
}
for _, table := range tables {
// Use TRUNCATE CASCADE for Postgres which is faster and handles FKs better
// But TRUNCATE cannot be used easily if tables are referenced by others unless CASCADE is used.
// Also, we need to check if table exists to avoid errors?
// With the container setup, tables should always exist.
// For simplicity and safety, we try DELETE or TRUNCATE CASCADE.
// TRUNCATE table_name CASCADE;
if err := db.Exec(fmt.Sprintf("TRUNCATE TABLE %s CASCADE", table)).Error; err != nil {
// If TRUNCATE fails (e.g. permissions?), fallback to DELETE
// Also ignore if table doesn't exist (though it should)
_ = db.Exec(fmt.Sprintf("DELETE FROM %s", table))
}
}
return nil
}
// GetDBStats retourne les statistiques de la base de données de test
func GetDBStats(db *gorm.DB) (*sql.DBStats, error) {
if db == nil {
return nil, nil
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
stats := sqlDB.Stats()
return &stats, nil
}
// CleanupOptions configure le comportement du cleanup (T0049)
type CleanupOptions struct {
Cascade bool
UseTransaction bool
SkipForeignKeys bool
Tables []string // Si spécifié, nettoie uniquement ces tables
}
// CleanupDatabaseWithOptions nettoie avec options (T0049)
func CleanupDatabaseWithOptions(t *testing.T, db *gorm.DB, opts CleanupOptions) error {
var dbInstance *gorm.DB
if opts.UseTransaction {
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
dbInstance = tx
defer tx.Rollback()
} else {
dbInstance = db
}
return cleanupTables(t, dbInstance, opts)
}
func cleanupTables(t *testing.T, db *gorm.DB, opts CleanupOptions) error {
sqlDB, err := db.DB()
if err != nil {
return fmt.Errorf("failed to get sql.DB: %w", err)
}
driverName := sqlDB.Driver()
driverType := fmt.Sprintf("%T", driverName)
isPostgreSQL := !contains(driverType, "sqlite")
if !opts.SkipForeignKeys {
if isPostgreSQL {
if err := db.Exec("SET session_replication_role = 'replica'").Error; err != nil {
t.Logf("Warning: Failed to disable foreign keys: %v", err)
}
defer func() {
if err := db.Exec("SET session_replication_role = 'origin'").Error; err != nil {
t.Logf("Warning: Failed to re-enable foreign keys: %v", err)
}
}()
} else {
// SQLite
if err := db.Exec("PRAGMA foreign_keys = OFF").Error; err != nil {
t.Logf("Warning: Failed to disable foreign keys: %v", err)
}
defer func() {
if err := db.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
t.Logf("Warning: Failed to re-enable foreign keys: %v", err)
}
}()
}
}
tables := opts.Tables
if len(tables) == 0 {
tables = getAllTables(t, db, isPostgreSQL)
}
for _, table := range tables {
var query string
if opts.Cascade && isPostgreSQL {
// CASCADE est supporté par PostgreSQL
query = fmt.Sprintf("TRUNCATE TABLE %s CASCADE", table)
} else {
// Pour SQLite ou sans cascade, utiliser DELETE FROM
query = fmt.Sprintf("DELETE FROM %s", table)
}
if err := db.Exec(query).Error; err != nil {
t.Logf("Warning: Failed to cleanup table %s: %v", table, err)
// Continue avec les autres tables
}
}
return nil
}
// contains vérifie si une chaîne contient une sous-chaîne (utilitaire pour détection DB)
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// getAllTables récupère la liste de toutes les tables (T0049)
func getAllTables(t *testing.T, db *gorm.DB, isPostgreSQL bool) []string {
var tables []string
if isPostgreSQL {
query := `
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
ORDER BY tablename
`
rows, err := db.Raw(query).Rows()
if err != nil {
t.Logf("Warning: Failed to get table list: %v", err)
return getDefaultTables()
}
defer rows.Close()
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
t.Logf("Warning: Failed to scan table name: %v", err)
continue
}
tables = append(tables, tableName)
}
} else {
// SQLite
query := `
SELECT name
FROM sqlite_master
WHERE type='table' AND name NOT LIKE 'sqlite_%'
ORDER BY name
`
rows, err := db.Raw(query).Rows()
if err != nil {
t.Logf("Warning: Failed to get table list: %v", err)
return getDefaultTables()
}
defer rows.Close()
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
t.Logf("Warning: Failed to scan table name: %v", err)
continue
}
tables = append(tables, tableName)
}
}
if len(tables) == 0 {
return getDefaultTables()
}
return tables
}
// getDefaultTables retourne la liste par défaut des tables (T0049)
func getDefaultTables() []string {
return []string{
"messages",
"playlist_tracks",
"role_permissions",
"user_roles",
"permissions",
"roles",
"playlists",
"tracks",
"refresh_tokens",
"room_members",
"rooms",
"users",
"oauth_accounts",
"user_profiles",
"sessions",
"audit_logs",
"mfa_configs",
"recovery_codes",
}
}
// RegisterCleanupHook enregistre un hook de cleanup (T0049)
func RegisterCleanupHook(t *testing.T, hook func()) {
t.Cleanup(hook)
}
// CleanupWithTransaction nettoie avec une transaction (T0049)
func CleanupWithTransaction(t *testing.T, db *gorm.DB, cleanupFunc func(*gorm.DB)) error {
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
cleanupFunc(tx)
return tx.Rollback().Error
}
// CleanupSpecificTables nettoie uniquement les tables spécifiées (T0049)
func CleanupSpecificTables(t *testing.T, db *gorm.DB, tables []string) error {
opts := CleanupOptions{
Cascade: true,
UseTransaction: false,
SkipForeignKeys: false,
Tables: tables,
}
return CleanupDatabaseWithOptions(t, db, opts)
}