package middleware import ( "context" "fmt" "net/http" "os" "strconv" "sync" "time" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" ) // EndpointLimiterConfig configuration pour les limites par endpoint type EndpointLimiterConfig struct { RedisClient *redis.Client KeyPrefix string } // EndpointLimits définit les limites pour chaque endpoint type EndpointLimits struct { // Login: 5 tentatives/15min par IP LoginAttempts int LoginWindow time.Duration // Register: 3 comptes/heure par IP RegisterAttempts int RegisterWindow time.Duration // Password reset: 3 tentatives/heure PasswordResetAttempts int PasswordResetWindow time.Duration // Upload: 10 fichiers/heure par user UploadAttempts int UploadWindow time.Duration } // DefaultEndpointLimits retourne les limites par défaut func DefaultEndpointLimits() *EndpointLimits { return &EndpointLimits{ LoginAttempts: 5, LoginWindow: 15 * time.Minute, RegisterAttempts: 3, RegisterWindow: time.Hour, PasswordResetAttempts: 3, PasswordResetWindow: time.Hour, UploadAttempts: 10, UploadWindow: time.Hour, } } // endpointInMemoryEntry holds in-memory rate limit state for fallback type endpointInMemoryEntry struct { count int windowStart time.Time } // EndpointLimiter gère les limites par endpoint type EndpointLimiter struct { config *EndpointLimiterConfig limits *EndpointLimits inMemoryStore map[string]*endpointInMemoryEntry inMemoryMu sync.RWMutex } // NewEndpointLimiter crée un nouveau endpoint limiter func NewEndpointLimiter(config *EndpointLimiterConfig, limits *EndpointLimits) *EndpointLimiter { return &EndpointLimiter{ config: config, limits: limits, inMemoryStore: make(map[string]*endpointInMemoryEntry), } } // LoginRateLimit middleware pour limiter les tentatives de login func (el *EndpointLimiter) LoginRateLimit() gin.HandlerFunc { return el.createEndpointLimit( "login", el.limits.LoginAttempts, el.limits.LoginWindow, "Too many login attempts", ) } // RegisterRateLimit middleware pour limiter les inscriptions func (el *EndpointLimiter) RegisterRateLimit() gin.HandlerFunc { return el.createEndpointLimit( "register", el.limits.RegisterAttempts, el.limits.RegisterWindow, "Too many registration attempts", ) } // PasswordResetRateLimit middleware pour limiter les reset de mot de passe func (el *EndpointLimiter) PasswordResetRateLimit() gin.HandlerFunc { return el.createEndpointLimit( "password_reset", el.limits.PasswordResetAttempts, el.limits.PasswordResetWindow, "Too many password reset attempts", ) } // VerifyEmailRateLimit middleware pour limiter les tentatives de vérification d'email // BE-SEC-005: Implement rate limiting for authentication endpoints func (el *EndpointLimiter) VerifyEmailRateLimit() gin.HandlerFunc { return el.createEndpointLimit( "verify_email", 5, // 5 tentatives par heure time.Hour, // Fenêtre de 1 heure "Too many email verification attempts", ) } // ResendVerificationRateLimit middleware pour limiter les renvois de vérification // BE-SEC-005: Implement rate limiting for authentication endpoints func (el *EndpointLimiter) ResendVerificationRateLimit() gin.HandlerFunc { return el.createEndpointLimit( "resend_verification", 3, // 3 tentatives par heure time.Hour, // Fenêtre de 1 heure "Too many verification resend attempts", ) } // CheckUsernameRateLimit middleware pour limiter l'énumération de noms d'utilisateur (SEC-009) func (el *EndpointLimiter) CheckUsernameRateLimit() gin.HandlerFunc { return el.createEndpointLimit( "check_username", 30, // 30 requêtes par minute time.Minute, // Fenêtre de 1 minute "Too many username check attempts", ) } // RefreshRateLimit middleware pour limiter le token grinding sur /auth/refresh (SEC-010) func (el *EndpointLimiter) RefreshRateLimit() gin.HandlerFunc { return el.createEndpointLimit( "refresh", 10, // 10 requêtes par minute time.Minute, // Fenêtre de 1 minute "Too many token refresh attempts", ) } // ValidateRateLimit middleware pour limiter les appels à POST /validate (A01) func (el *EndpointLimiter) ValidateRateLimit() gin.HandlerFunc { return el.createEndpointLimit( "validate", 10, // 10 requêtes par minute par IP time.Minute, // Fenêtre de 1 minute "Too many validation requests", ) } // UploadRateLimit middleware pour limiter les uploads par utilisateur func (el *EndpointLimiter) UploadRateLimit() gin.HandlerFunc { return func(c *gin.Context) { // Récupérer l'ID utilisateur depuis le contexte userID, exists := c.Get("user_id") if !exists { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"}) c.Abort() return } key := fmt.Sprintf("%s:upload:user:%v", el.config.KeyPrefix, userID) allowed, remaining, err := el.checkLimit(c.Request.Context(), key, el.limits.UploadAttempts, el.limits.UploadWindow) if err != nil { // Fallback in-memory when Redis fails (fail-secure) allowed, remaining = el.checkLimitInMemory(key, el.limits.UploadAttempts, el.limits.UploadWindow) } c.Header("X-UploadLimit-Limit", strconv.Itoa(el.limits.UploadAttempts)) c.Header("X-UploadLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-UploadLimit-Reset", strconv.FormatInt(time.Now().Add(el.limits.UploadWindow).Unix(), 10)) if !allowed { c.JSON(http.StatusTooManyRequests, gin.H{ "error": "Upload limit exceeded", "retry_after": int(el.limits.UploadWindow.Seconds()), }) c.Abort() return } c.Next() } } // createEndpointLimit crée un middleware de limitation pour un endpoint func (el *EndpointLimiter) createEndpointLimit( endpoint string, attempts int, window time.Duration, errorMessage string, ) gin.HandlerFunc { return func(c *gin.Context) { // SEC-011: Never bypass rate limiting in production. // E2E: Completely disable in test environment (APP_ENV=test) to prevent flaky tests. if os.Getenv("APP_ENV") == "production" { // Continue to rate limit — NEVER bypass in production } else if os.Getenv("APP_ENV") == "test" || os.Getenv("DISABLE_RATE_LIMIT_FOR_TESTS") == "true" { c.Next() return } key := fmt.Sprintf("%s:%s:ip:%s", el.config.KeyPrefix, endpoint, c.ClientIP()) allowed, remaining, err := el.checkLimit(c.Request.Context(), key, attempts, window) if err != nil { // Fallback in-memory when Redis fails (fail-secure) allowed, remaining = el.checkLimitInMemory(key, attempts, window) } headerPrefix := fmt.Sprintf("X-%sLimit", capitalize(endpoint)) c.Header(headerPrefix+"-Limit", strconv.Itoa(attempts)) c.Header(headerPrefix+"-Remaining", strconv.Itoa(remaining)) c.Header(headerPrefix+"-Reset", strconv.FormatInt(time.Now().Add(window).Unix(), 10)) if !allowed { c.JSON(http.StatusTooManyRequests, gin.H{ "error": errorMessage, "retry_after": int(window.Seconds()), }) c.Abort() return } c.Next() } } // checkLimit vérifie si une limite est respectée func (el *EndpointLimiter) checkLimit(ctx context.Context, key string, attempts int, window time.Duration) (bool, int, error) { // Use in-memory fallback when Redis is not configured (e.g. integration tests) if el.config == nil || el.config.RedisClient == nil { allowed, remaining := el.checkLimitInMemory(key, attempts, window) return allowed, remaining, nil } // Script Lua pour l'atomicité script := ` local key = KEYS[1] local attempts = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = redis.call('GET', key) if current == false then redis.call('SET', key, 1, 'EX', window) return {1, attempts - 1} end local count = tonumber(current) if count < attempts then redis.call('INCR', key) return {1, attempts - count - 1} else return {0, 0} end ` result, err := el.config.RedisClient.Eval( ctx, script, []string{key}, attempts, int(window.Seconds()), ).Result() if err != nil { return false, 0, err } results := result.([]interface{}) allowed := results[0].(int64) == 1 remaining := int(results[1].(int64)) return allowed, remaining, nil } // checkLimitInMemory implements rate limiting in-memory when Redis fails (fail-secure) func (el *EndpointLimiter) checkLimitInMemory(key string, attempts int, window time.Duration) (bool, int) { el.inMemoryMu.Lock() defer el.inMemoryMu.Unlock() now := time.Now() entry := el.inMemoryStore[key] if entry == nil { entry = &endpointInMemoryEntry{count: 0, windowStart: now} el.inMemoryStore[key] = entry } if now.Sub(entry.windowStart) >= window { entry.count = 0 entry.windowStart = now } entry.count++ remaining := attempts - entry.count if remaining < 0 { remaining = 0 } return entry.count <= attempts, remaining } // capitalize met en majuscule la première lettre func capitalize(s string) string { if len(s) == 0 { return s } return string(s[0]-32) + s[1:] } // RateLimitByUser middleware pour limiter par utilisateur (pour endpoints génériques) func (el *EndpointLimiter) RateLimitByUser(attempts int, window time.Duration, errorMessage string) gin.HandlerFunc { return func(c *gin.Context) { userID, exists := c.Get("user_id") if !exists { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"}) c.Abort() return } key := fmt.Sprintf("%s:user:%v", el.config.KeyPrefix, userID) allowed, remaining, err := el.checkLimit(c.Request.Context(), key, attempts, window) if err != nil { // Fallback in-memory when Redis fails (fail-secure) allowed, remaining = el.checkLimitInMemory(key, attempts, window) } c.Header("X-UserLimit-Limit", strconv.Itoa(attempts)) c.Header("X-UserLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-UserLimit-Reset", strconv.FormatInt(time.Now().Add(window).Unix(), 10)) if !allowed { c.JSON(http.StatusTooManyRequests, gin.H{ "error": errorMessage, "retry_after": int(window.Seconds()), }) c.Abort() return } c.Next() } }