package middleware import ( "context" "fmt" "net/http" "os" "strconv" "sync" "time" "veza-backend-api/internal/metrics" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/redis/go-redis/v9" "golang.org/x/time/rate" ) // uploadRateLimitFallback stores per-user in-memory limiters when Redis fails (fail-secure) var ( uploadRateLimitFallback sync.Map // map[string]*rate.Limiter uploadRateLimitFallbackMu sync.Mutex ) // RateLimiterConfig configuration pour le rate limiter // TASK-SEC-003: hourly limits (100 non-auth, 1000 auth) via WindowSeconds type RateLimiterConfig struct { // Limites par IP (non authentifié) et par utilisateur (auth) dans la fenêtre IPLimit int UserLimit int // Fenêtre en secondes (3600 = 1 heure pour TASK-SEC-003) WindowSeconds int // Configuration Redis RedisClient *redis.Client KeyPrefix string } // RateLimiter middleware pour limiter le taux de requêtes type RateLimiter struct { config *RateLimiterConfig ipLimiter *rate.Limiter userLimiter *rate.Limiter } // NewRateLimiter crée un nouveau rate limiter func NewRateLimiter(config *RateLimiterConfig) *RateLimiter { window := time.Duration(config.WindowSeconds) * time.Second if window <= 0 { window = time.Hour } ipLimit := config.IPLimit if ipLimit <= 0 { ipLimit = 100 } userLimit := config.UserLimit if userLimit <= 0 { userLimit = 1000 } return &RateLimiter{ config: config, ipLimiter: rate.NewLimiter( rate.Every(window/time.Duration(ipLimit)), ipLimit, ), userLimiter: rate.NewLimiter( rate.Every(window/time.Duration(userLimit)), userLimit, ), } } // Routes exclues du rate limiting (routes critiques) var excludedRateLimitPathsRedis = []string{ "/health", "/healthz", "/readyz", "/api/v1/health", "/api/v1/healthz", "/api/v1/readyz", "/api/v1/csrf-token", "/api/v1/auth/register", "/api/v1/auth/login", "/api/v1/auth/refresh", "/api/v1/auth/verify-email", "/api/v1/auth/resend-verification", "/api/v1/auth/check-username", "/swagger", "/docs", } // isExcludedPathRedis vérifie si un chemin est exclu du rate limiting (version Redis) func isExcludedPathRedis(path string) bool { for _, excluded := range excludedRateLimitPathsRedis { if path == excluded || (len(path) > len(excluded) && path[:len(excluded)] == excluded) { return true } } return false } // DDoS rate limit constants (SEC1-04): global 1000 req/s, per-IP 100 req/s const ( ddosGlobalLimit = 1000 ddosPerIPLimit = 100 ddosWindowSeconds = 1 ) // ddosRateLimitFallback stores in-memory limiters when Redis fails (fail-secure) var ( ddosGlobalFallback *rate.Limiter ddosFallbackMu sync.Mutex ddosPerIPFallbackMap sync.Map ) // DDoSRateLimitMiddleware applies SEC1-04 DDoS protection: global 1000 req/s, per-IP 100 req/s. // Uses 1-second sliding window. Excludes health, swagger, auth critical paths. // Must run before main RateLimitMiddleware. When Redis is nil, uses in-memory fallback. func DDoSRateLimitMiddleware(redisClient *redis.Client) gin.HandlerFunc { return func(c *gin.Context) { if isExcludedPathRedis(c.Request.URL.Path) { c.Next() return } if os.Getenv("APP_ENV") == "test" || os.Getenv("DISABLE_RATE_LIMIT_FOR_TESTS") == "true" { c.Next() return } ctx := c.Request.Context() ip := c.ClientIP() // Check global limit globalKey := "rate:ddos:global" globalAllowed, _, err := checkRedisLimit1s(ctx, redisClient, globalKey, ddosGlobalLimit) if err != nil { globalAllowed = getDDoSFallbackLimiter(globalKey, ddosGlobalLimit).Allow() } if !globalAllowed { c.JSON(http.StatusTooManyRequests, gin.H{ "error": "Global rate limit exceeded (DDoS protection)", "retry_after": ddosWindowSeconds, }) c.Abort() return } // Check per-IP limit ipKey := "rate:ddos:ip:" + ip ipAllowed, remaining, err := checkRedisLimit1s(ctx, redisClient, ipKey, ddosPerIPLimit) if err != nil { ipAllowed = getDDoSPerIPFallbackLimiter(ipKey).Allow() remaining = 0 } c.Header("X-RateLimit-Limit", strconv.Itoa(ddosPerIPLimit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) if !ipAllowed { c.Header("Retry-After", strconv.Itoa(ddosWindowSeconds)) c.JSON(http.StatusTooManyRequests, gin.H{ "error": "Rate limit exceeded. Please try again in a moment.", "retry_after": ddosWindowSeconds, }) c.Abort() return } c.Next() } } // checkRedisLimit1s uses Redis INCR with 1-second window (returns allowed, remaining, error) func checkRedisLimit1s(ctx context.Context, redisClient *redis.Client, key string, limit int) (bool, int, error) { if redisClient == nil { return false, 0, fmt.Errorf("redis not configured") } script := ` local key = KEYS[1] local limit = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = redis.call('GET', key) if current == false then redis.call('SET', key, 1, 'EX', window) return {1, limit - 1} end local count = tonumber(current) if count < limit then redis.call('INCR', key) return {1, limit - count - 1} else return {0, 0} end ` result, err := redisClient.Eval(ctx, script, []string{key}, limit, ddosWindowSeconds).Result() if err != nil { // Redis unreachable: caller falls back to in-memory limiter. // "miss" here = Redis didn't deliver a verdict. metrics.RecordCacheMiss("rate_limiter") return false, 0, err } metrics.RecordCacheHit("rate_limiter") results := result.([]interface{}) allowed := results[0].(int64) == 1 remaining := int(results[1].(int64)) return allowed, remaining, nil } func getDDoSFallbackLimiter(key string, limit int) *rate.Limiter { ddosFallbackMu.Lock() defer ddosFallbackMu.Unlock() if ddosGlobalFallback == nil { ddosGlobalFallback = rate.NewLimiter(rate.Every(time.Second/time.Duration(limit)), limit) } return ddosGlobalFallback } func getDDoSPerIPFallbackLimiter(key string) *rate.Limiter { if v, ok := ddosPerIPFallbackMap.Load(key); ok { return v.(*rate.Limiter) } limiter := rate.NewLimiter(rate.Every(time.Second/time.Duration(ddosPerIPLimit)), ddosPerIPLimit) if v, loaded := ddosPerIPFallbackMap.LoadOrStore(key, limiter); loaded { return v.(*rate.Limiter) } return limiter } // RateLimitMiddleware middleware principal de rate limiting func (rl *RateLimiter) RateLimitMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // Exclure les routes critiques du rate limiting if isExcludedPathRedis(c.Request.URL.Path) { c.Next() return } // P1.6: Bypass rate limiting in test environments. Never bypass in production. if os.Getenv("APP_ENV") == "test" || os.Getenv("DISABLE_RATE_LIMIT_FOR_TESTS") == "true" { c.Next() return } // Déterminer si l'utilisateur est authentifié userIDInterface, isAuthenticated := c.Get("user_id") var limiter *rate.Limiter var key string var limit int if isAuthenticated { // Utilisateur authentifié - limite plus élevée // BE-SVC-002: Support UUID for user rate limiting limiter = rl.userLimiter // Convertir userID en string pour la clé Redis var userIDStr string switch v := userIDInterface.(type) { case uuid.UUID: userIDStr = v.String() case string: userIDStr = v default: userIDStr = fmt.Sprintf("%v", v) } key = fmt.Sprintf("%s:user:%s", rl.config.KeyPrefix, userIDStr) limit = rl.config.UserLimit } else { // IP non authentifiée - limite plus stricte limiter = rl.ipLimiter key = fmt.Sprintf("%s:ip:%s", rl.config.KeyPrefix, c.ClientIP()) limit = rl.config.IPLimit } windowSec := rl.config.WindowSeconds if windowSec <= 0 { windowSec = 3600 } // Vérifier la limite avec Redis pour persistance allowed, remaining, err := rl.checkRedisLimit(c.Request.Context(), key, limit, windowSec) if err != nil { // En cas d'erreur Redis, utiliser le limiter local allowed = limiter.Allow() remaining = int(limiter.Tokens()) } // Ajouter les headers de rate limiting c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Duration(windowSec)*time.Second).Unix(), 10)) if !allowed { c.JSON(http.StatusTooManyRequests, gin.H{ "error": "Rate limit exceeded", "retry_after": windowSec, }) c.Abort() return } c.Next() } } // checkRedisLimit vérifie la limite dans Redis func (rl *RateLimiter) checkRedisLimit(ctx context.Context, key string, limit int, windowSec int) (bool, int, error) { // Use in-memory fallback when Redis is not configured (e.g. integration tests) if rl.config == nil || rl.config.RedisClient == nil { return false, 0, fmt.Errorf("redis not configured") } if windowSec <= 0 { windowSec = 3600 } // Utiliser un script Lua pour l'atomicité (TASK-SEC-003: fenêtre 1h) script := ` local key = KEYS[1] local limit = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = redis.call('GET', key) if current == false then redis.call('SET', key, 1, 'EX', window) return {1, limit - 1} end local count = tonumber(current) if count < limit then redis.call('INCR', key) return {1, limit - count - 1} else return {0, 0} end ` result, err := rl.config.RedisClient.Eval( ctx, script, []string{key}, limit, windowSec, ).Result() if err != nil { return false, 0, err } results := result.([]interface{}) allowed := results[0].(int64) == 1 remaining := int(results[1].(int64)) return allowed, remaining, nil } // RateLimitByIP middleware pour limiter par IP uniquement func (rl *RateLimiter) RateLimitByIP() gin.HandlerFunc { return func(c *gin.Context) { key := fmt.Sprintf("%s:ip:%s", rl.config.KeyPrefix, c.ClientIP()) windowSec := rl.config.WindowSeconds if windowSec <= 0 { windowSec = 3600 } allowed, remaining, err := rl.checkRedisLimit(c.Request.Context(), key, rl.config.IPLimit, windowSec) if err != nil { allowed = rl.ipLimiter.Allow() remaining = int(rl.ipLimiter.Tokens()) } // INT-013: Standardize rate limit response format resetTime := time.Now().Add(time.Duration(windowSec) * time.Second).Unix() retryAfter := windowSec c.Header("X-RateLimit-Limit", strconv.Itoa(rl.config.IPLimit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) if !allowed { c.Header("Retry-After", strconv.Itoa(retryAfter)) c.JSON(http.StatusTooManyRequests, gin.H{ "success": false, "error": gin.H{ "code": 429, "message": "Rate limit exceeded. Please try again later.", "details": []gin.H{ { "field": "rate_limit", "message": fmt.Sprintf("You have exceeded the rate limit of %d requests per hour", rl.config.IPLimit), }, }, "retry_after": retryAfter, "limit": rl.config.IPLimit, "remaining": 0, "reset": resetTime, }, }) c.Abort() return } c.Next() } } // frontendLogRateLimitFallback stores per-IP in-memory limiters for logs endpoint when Redis fails var ( frontendLogRateLimitFallback sync.Map // map[string]*rate.Limiter frontendLogRateLimitFallbackMu sync.Mutex ) // getFrontendLogFallbackLimiter returns or creates an in-memory rate.Limiter for the given key. // Limit: 60 requests per minute per IP. func getFrontendLogFallbackLimiter(key string) *rate.Limiter { if v, ok := frontendLogRateLimitFallback.Load(key); ok { return v.(*rate.Limiter) } frontendLogRateLimitFallbackMu.Lock() defer frontendLogRateLimitFallbackMu.Unlock() if v, ok := frontendLogRateLimitFallback.Load(key); ok { return v.(*rate.Limiter) } limiter := rate.NewLimiter(rate.Every(time.Minute/60), 60) frontendLogRateLimitFallback.Store(key, limiter) return limiter } // FrontendLogRateLimit middleware limits POST /api/v1/logs/frontend by IP. // Limit: 60 requests per minute per IP. Prevents log flooding (A01, A04). func FrontendLogRateLimit(redisClient *redis.Client) gin.HandlerFunc { const limit = 60 const windowSec = 60 return func(c *gin.Context) { key := fmt.Sprintf("logs_frontend:ip:%s", c.ClientIP()) if redisClient != nil { script := ` local key = KEYS[1] local limit = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = redis.call('GET', key) if current == false then redis.call('SET', key, 1, 'EX', window) return {1, limit - 1} end local count = tonumber(current) if count < limit then redis.call('INCR', key) return {1, limit - count - 1} else return {0, 0} end ` result, err := redisClient.Eval( c.Request.Context(), script, []string{key}, limit, windowSec, ).Result() if err == nil { metrics.RecordCacheHit("rate_limiter") results := result.([]interface{}) allowed := results[0].(int64) == 1 remaining := int(results[1].(int64)) c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) if !allowed { c.JSON(http.StatusTooManyRequests, gin.H{ "error": "Rate limit exceeded for frontend logging", "retry_after": windowSec, }) c.Abort() return } c.Next() return } metrics.RecordCacheMiss("rate_limiter") } // Fail-secure: Redis error or nil — use in-memory fallback limiter := getFrontendLogFallbackLimiter(key) allowed := limiter.Allow() remaining := int(limiter.Tokens()) if remaining < 0 { remaining = 0 } c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) if !allowed { c.JSON(http.StatusTooManyRequests, gin.H{ "error": "Rate limit exceeded for frontend logging", "retry_after": windowSec, }) c.Abort() return } c.Next() } } // getUploadFallbackLimiter returns or creates an in-memory rate.Limiter for the given key. // Limit: 10 requests per hour (aligned with Redis config). func getUploadFallbackLimiter(key string) *rate.Limiter { if v, ok := uploadRateLimitFallback.Load(key); ok { return v.(*rate.Limiter) } uploadRateLimitFallbackMu.Lock() defer uploadRateLimitFallbackMu.Unlock() // Double-check after acquiring lock if v, ok := uploadRateLimitFallback.Load(key); ok { return v.(*rate.Limiter) } // 10 per hour = 1 every 6 minutes limiter := rate.NewLimiter(rate.Every(time.Hour/10), 10) uploadRateLimitFallback.Store(key, limiter) return limiter } // UploadRateLimit middleware pour limiter les uploads de tracks par utilisateur // Limite: 10 uploads par heure par utilisateur // Fail-secure: uses in-memory fallback when Redis fails (rejects if over limit) func UploadRateLimit(redisClient *redis.Client) gin.HandlerFunc { limit := 10 window := time.Hour return func(c *gin.Context) { userIDInterface, exists := c.Get("user_id") if !exists { c.Next() return } var userIDStr string switch v := userIDInterface.(type) { case uuid.UUID: userIDStr = v.String() case string: userIDStr = v default: userIDStr = fmt.Sprintf("%v", v) } key := fmt.Sprintf("upload_rate_limit:%s", userIDStr) if redisClient != nil { script := ` local key = KEYS[1] local limit = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = redis.call('GET', key) if current == false then redis.call('SET', key, 1, 'EX', window) return {1, limit - 1} end local count = tonumber(current) if count < limit then redis.call('INCR', key) return {1, limit - count - 1} else return {0, 0} end ` result, err := redisClient.Eval( c.Request.Context(), script, []string{key}, limit, int(window.Seconds()), ).Result() if err == nil { metrics.RecordCacheHit("rate_limiter") results := result.([]interface{}) allowed := results[0].(int64) == 1 remaining := int(results[1].(int64)) c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(window).Unix(), 10)) if !allowed { c.JSON(http.StatusTooManyRequests, gin.H{ "error": "upload rate limit exceeded", "retry_after": int(window.Seconds()), }) c.Abort() return } c.Next() return } metrics.RecordCacheMiss("rate_limiter") } // Fail-secure: Redis error or nil — use in-memory fallback limiter := getUploadFallbackLimiter(key) allowed := limiter.Allow() remaining := int(limiter.Tokens()) if remaining < 0 { remaining = 0 } c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(window).Unix(), 10)) if !allowed { c.JSON(http.StatusTooManyRequests, gin.H{ "error": "upload rate limit exceeded", "retry_after": int(window.Seconds()), }) c.Abort() return } c.Next() } }