package middleware import ( "context" "fmt" "net/http" "strconv" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/redis/go-redis/v9" "go.uber.org/zap" ) // UserRateLimiterConfig configuration pour le rate limiter par utilisateur // BE-SVC-002: Implement rate limiting per user type UserRateLimiterConfig struct { // Limites par utilisateur RequestsPerMinute int Burst int // Fenêtre de temps pour le rate limiting Window time.Duration // Configuration Redis RedisClient *redis.Client KeyPrefix string // Logger Logger *zap.Logger } // UserRateLimiter middleware pour limiter le taux de requêtes par utilisateur type UserRateLimiter struct { config *UserRateLimiterConfig } // NewUserRateLimiter crée un nouveau rate limiter par utilisateur func NewUserRateLimiter(config *UserRateLimiterConfig) *UserRateLimiter { if config.Window == 0 { config.Window = time.Minute } if config.KeyPrefix == "" { config.KeyPrefix = "rate_limit" } return &UserRateLimiter{ config: config, } } // Middleware retourne le middleware Gin pour le rate limiting par utilisateur func (url *UserRateLimiter) Middleware() gin.HandlerFunc { return func(c *gin.Context) { // Récupérer l'ID utilisateur depuis le contexte userIDInterface, exists := c.Get("user_id") if !exists { // Si pas d'utilisateur authentifié, passer au suivant // (ce middleware est pour les utilisateurs authentifiés uniquement) c.JSON(http.StatusUnauthorized, gin.H{ "error": "Authentication required for rate limiting", }) c.Abort() return } // Convertir l'ID utilisateur en UUID var userID uuid.UUID switch v := userIDInterface.(type) { case uuid.UUID: userID = v case string: var err error userID, err = uuid.Parse(v) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": "Invalid user ID format", }) c.Abort() return } default: // Essayer de convertir en string puis en UUID userIDStr := fmt.Sprintf("%v", v) var err error userID, err = uuid.Parse(userIDStr) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": "Invalid user ID format", }) c.Abort() return } } // Construire la clé Redis pour cet utilisateur key := fmt.Sprintf("%s:user:%s", url.config.KeyPrefix, userID.String()) limit := url.config.RequestsPerMinute windowSeconds := int(url.config.Window.Seconds()) // Vérifier la limite avec Redis allowed, remaining, resetTime, err := url.checkRedisLimit(c.Request.Context(), key, limit, windowSeconds) if err != nil { // En cas d'erreur Redis, logger l'erreur mais autoriser la requête (fail-open) if url.config.Logger != nil { url.config.Logger.Warn("Redis rate limit check failed, allowing request", zap.Error(err), zap.String("user_id", userID.String())) } c.Next() return } // Ajouter les headers de rate limiting c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) if !allowed { retryAfter := resetTime - time.Now().Unix() if retryAfter < 0 { retryAfter = 0 } c.JSON(http.StatusTooManyRequests, gin.H{ "error": "Rate limit exceeded", "retry_after": retryAfter, "limit": limit, "window": url.config.Window.String(), }) c.Abort() return } c.Next() } } // checkRedisLimit vérifie la limite dans Redis avec un script Lua atomique func (url *UserRateLimiter) checkRedisLimit(ctx context.Context, key string, limit, windowSeconds int) (bool, int, int64, error) { // Script Lua pour l'atomicité (sliding window) script := ` local key = KEYS[1] local limit = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local now = tonumber(ARGV[3]) -- Nettoyer les anciennes entrées (sliding window) redis.call('ZREMRANGEBYSCORE', key, 0, now - window) -- Compter les requêtes dans la fenêtre local count = redis.call('ZCARD', key) if count < limit then -- Ajouter la requête actuelle redis.call('ZADD', key, now, now) redis.call('EXPIRE', key, window) return {1, limit - count - 1, now + window} else -- Limite dépassée, retourner le temps de reset local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') local resetTime = now + window if #oldest > 0 then resetTime = tonumber(oldest[2]) + window end return {0, 0, resetTime} end ` now := time.Now().Unix() result, err := url.config.RedisClient.Eval( ctx, script, []string{key}, limit, windowSeconds, now, ).Result() if err != nil { return false, 0, 0, err } results := result.([]interface{}) allowed := results[0].(int64) == 1 remaining := int(results[1].(int64)) resetTime := results[2].(int64) return allowed, remaining, resetTime, nil } // GetUserRateLimitInfo récupère les informations de rate limit pour un utilisateur func (url *UserRateLimiter) GetUserRateLimitInfo(ctx context.Context, userID uuid.UUID) (remaining int, resetTime int64, err error) { key := fmt.Sprintf("%s:user:%s", url.config.KeyPrefix, userID.String()) limit := url.config.RequestsPerMinute windowSeconds := int(url.config.Window.Seconds()) _, remaining, resetTime, err = url.checkRedisLimit(ctx, key, limit, windowSeconds) return remaining, resetTime, err }