package middleware import ( "context" "net/http" "os" "strings" "sync" "time" "github.com/gin-gonic/gin" "go.uber.org/zap" "gorm.io/gorm" ) // maintenanceState carries the latest cached view of the platform-wide // maintenance flag. It is refreshed lazily from `platform_settings` when a // request comes in after the TTL has expired, so operators flipping the flag // on one pod propagate to every other pod within a bounded window (10s). type maintenanceState struct { mu sync.RWMutex enabled bool lastCheck time.Time db *gorm.DB logger *zap.Logger ttl time.Duration } const defaultMaintenanceCacheTTL = 10 * time.Second var ( state = &maintenanceState{ttl: defaultMaintenanceCacheTTL} maintenanceInitMu sync.Mutex ) func init() { v := os.Getenv("MAINTENANCE_MODE") state.mu.Lock() state.enabled = v == "true" || v == "1" state.mu.Unlock() } // InitMaintenanceMode wires the DB pool so subsequent MaintenanceModeEnabled() // calls refresh from `platform_settings.maintenance_mode` with a TTL cache. // Safe to call more than once (last write wins). If db is nil the middleware // falls back to the in-memory state seeded from MAINTENANCE_MODE. func InitMaintenanceMode(db *gorm.DB, logger *zap.Logger) { maintenanceInitMu.Lock() defer maintenanceInitMu.Unlock() if logger == nil { logger = zap.NewNop() } state.mu.Lock() state.db = db state.logger = logger state.lastCheck = time.Time{} // force refresh on first call state.mu.Unlock() // Prime the cache so the very first request doesn't see a stale value. refreshFromDB(context.Background()) } // refreshFromDB reads the current value from the DB and updates the cache. // Never propagates errors to callers — a broken DB should not silently // enable maintenance mode, so the previous cached value wins. func refreshFromDB(ctx context.Context) { state.mu.RLock() db := state.db logger := state.logger state.mu.RUnlock() if db == nil { return } var row struct { ValueBool *bool `gorm:"column:value_bool"` } err := db.WithContext(ctx). Table("platform_settings"). Select("value_bool"). Where("key = ?", "maintenance_mode"). Take(&row).Error state.mu.Lock() state.lastCheck = time.Now() state.mu.Unlock() if err != nil { if err != gorm.ErrRecordNotFound && logger != nil { logger.Warn("Failed to refresh maintenance flag from DB — keeping cached value", zap.Error(err), ) } return } enabled := row.ValueBool != nil && *row.ValueBool state.mu.Lock() state.enabled = enabled state.mu.Unlock() } // MaintenanceModeEnabled returns the cached maintenance flag, refreshing from // the DB if the TTL has expired and a DB pool has been wired. func MaintenanceModeEnabled() bool { state.mu.RLock() enabled := state.enabled lastCheck := state.lastCheck hasDB := state.db != nil ttl := state.ttl state.mu.RUnlock() if hasDB && time.Since(lastCheck) > ttl { refreshFromDB(context.Background()) state.mu.RLock() enabled = state.enabled state.mu.RUnlock() } return enabled } // SetMaintenanceMode sets the in-memory flag without touching the DB. It is // kept for tests and for cases where a caller already owns the DB write — it // does not persist the value across pods. Use PlatformSettings to change // state across a deployment. func SetMaintenanceMode(enabled bool) { state.mu.Lock() state.enabled = enabled state.lastCheck = time.Now().Add(state.ttl) // suppress the next DB refresh state.mu.Unlock() } // MaintenanceGin returns a Gin middleware for maintenance mode. // Exempt paths: /health, /healthz, /readyz, /api/v1/health, /api/v1/admin, /swagger, /docs func MaintenanceGin() gin.HandlerFunc { return func(c *gin.Context) { if !MaintenanceModeEnabled() { c.Next() return } path := c.Request.URL.Path if isMaintenanceExempt(path) { c.Next() return } c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "Platform is under maintenance"}) } } func isMaintenanceExempt(path string) bool { path = strings.TrimSuffix(path, "/") exempts := []string{"/health", "/healthz", "/readyz", "/health/deep", "/metrics", "/swagger", "/docs", "/api/versions"} for _, exempt := range exempts { if path == exempt || strings.HasPrefix(path, exempt+"/") { return true } } if strings.Contains(path, "/api/v1/health") { return true } if strings.Contains(path, "/api/v1/admin") { return true } return false }