veza/veza-backend-api/internal/testutils/setup.go

88 lines
2.3 KiB
Go

package testutils
import (
"context"
"fmt"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
"time"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
)
var (
pgContainer *postgres.PostgresContainer
pgDSN string
containerOnce sync.Once
pgErr error
)
// GetTestContainerDB ensures the postgres container is running and returns the DSN.
// It uses a singleton pattern to start the container only once per test run.
func GetTestContainerDB(ctx context.Context) (string, error) {
containerOnce.Do(func() {
pgErr = setupPostgresContainer(ctx)
})
return pgDSN, pgErr
}
func setupPostgresContainer(ctx context.Context) error {
// Find project root relative to this file
// This file is in internal/testutils/setup.go
_, filename, _, _ := runtime.Caller(0)
projectRoot := filepath.Join(filepath.Dir(filename), "../..")
migrationsDir := filepath.Join(projectRoot, "migrations")
// Collect migration files
files, err := os.ReadDir(migrationsDir)
if err != nil {
return fmt.Errorf("failed to read migrations dir: %w", err)
}
var migrationFiles []string
for _, f := range files {
if strings.HasSuffix(f.Name(), ".sql") {
migrationFiles = append(migrationFiles, filepath.Join(migrationsDir, f.Name()))
}
}
sort.Strings(migrationFiles) // Ensure alphabetical order (001_, 002_, ...)
// Start Postgres container
var containerErr error
pgContainer, containerErr = postgres.Run(ctx,
"postgres:15-alpine",
postgres.WithDatabase("veza_test"),
postgres.WithUsername("veza"),
postgres.WithPassword("veza"),
postgres.WithInitScripts(migrationFiles...),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(60*time.Second)),
)
if containerErr != nil {
return fmt.Errorf("failed to start postgres container: %w", containerErr)
}
var dsnErr error
pgDSN, dsnErr = pgContainer.ConnectionString(ctx, "sslmode=disable")
if dsnErr != nil {
return fmt.Errorf("failed to get connection string: %w", dsnErr)
}
return nil
}
// TerminateContainer allows manual termination if needed (mostly for cleanup)
func TerminateContainer(ctx context.Context) error {
if pgContainer != nil {
return pgContainer.Terminate(ctx)
}
return nil
}