package middleware import ( "context" "fmt" "strconv" "sync" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/redis/go-redis/v9" "go.uber.org/zap" "golang.org/x/time/rate" apperrors "veza-backend-api/internal/errors" "veza-backend-api/internal/response" ) // UserRateLimiterConfig configuration pour le rate limiter par utilisateur // BE-SVC-002: Implement rate limiting per user type UserRateLimiterConfig struct { // Limites par utilisateur RequestsPerMinute int Burst int // Fenêtre de temps pour le rate limiting Window time.Duration // Configuration Redis RedisClient *redis.Client KeyPrefix string // Logger Logger *zap.Logger } // UserRateLimiter middleware pour limiter le taux de requêtes par utilisateur type UserRateLimiter struct { config *UserRateLimiterConfig fallback sync.Map // map[string]*rate.Limiter for fail-secure when Redis fails fallbackMu sync.Mutex } // NewUserRateLimiter crée un nouveau rate limiter par utilisateur func NewUserRateLimiter(config *UserRateLimiterConfig) *UserRateLimiter { if config.Window == 0 { config.Window = time.Minute } if config.KeyPrefix == "" { config.KeyPrefix = "rate_limit" } return &UserRateLimiter{ config: config, } } // Middleware retourne le middleware Gin pour le rate limiting par utilisateur func (url *UserRateLimiter) Middleware() gin.HandlerFunc { return func(c *gin.Context) { // Récupérer l'ID utilisateur depuis le contexte userIDInterface, exists := c.Get("user_id") if !exists { response.RespondWithAppError(c, apperrors.NewUnauthorizedError("Authentication required for rate limiting")) c.Abort() return } // Convertir l'ID utilisateur en UUID var userID uuid.UUID switch v := userIDInterface.(type) { case uuid.UUID: userID = v case string: var err error userID, err = uuid.Parse(v) if err != nil { response.RespondWithAppError(c, apperrors.NewValidationError("Invalid user ID format")) c.Abort() return } default: // Essayer de convertir en string puis en UUID userIDStr := fmt.Sprintf("%v", v) var err error userID, err = uuid.Parse(userIDStr) if err != nil { response.RespondWithAppError(c, apperrors.NewValidationError("Invalid user ID format")) c.Abort() return } } // Construire la clé Redis pour cet utilisateur key := fmt.Sprintf("%s:user:%s", url.config.KeyPrefix, userID.String()) limit := url.config.RequestsPerMinute windowSeconds := int(url.config.Window.Seconds()) // Vérifier la limite avec Redis allowed, remaining, resetTime, err := url.checkRedisLimit(c.Request.Context(), key, limit, windowSeconds) if err != nil { // Fail-secure: use in-memory fallback when Redis fails if url.config.Logger != nil { url.config.Logger.Warn("Redis rate limit check failed, using in-memory fallback", zap.Error(err), zap.String("user_id", userID.String())) } limiter := url.getFallbackLimiter(key, limit) allowed = limiter.Allow() remaining = int(limiter.Tokens()) if remaining < 0 { remaining = 0 } resetTime = time.Now().Add(url.config.Window).Unix() c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) if !allowed { appErr := apperrors.New(apperrors.ErrCodeRateLimitExceeded, "Rate limit exceeded") appErr.Context = map[string]interface{}{ "retry_after": func() int64 { r := resetTime - time.Now().Unix() if r < 0 { return 0 } return r }(), "limit": limit, "window": url.config.Window.String(), } response.RespondWithAppError(c, appErr) c.Abort() return } c.Next() return } // Ajouter les headers de rate limiting c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) if !allowed { retryAfter := resetTime - time.Now().Unix() if retryAfter < 0 { retryAfter = 0 } appErr := apperrors.New(apperrors.ErrCodeRateLimitExceeded, "Rate limit exceeded") appErr.Context = map[string]interface{}{"retry_after": retryAfter, "limit": limit, "window": url.config.Window.String()} response.RespondWithAppError(c, appErr) c.Abort() return } c.Next() } } // checkRedisLimit vérifie la limite dans Redis avec un script Lua atomique func (url *UserRateLimiter) checkRedisLimit(ctx context.Context, key string, limit, windowSeconds int) (bool, int, int64, error) { // Script Lua pour l'atomicité (sliding window) script := ` local key = KEYS[1] local limit = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local now = tonumber(ARGV[3]) -- Nettoyer les anciennes entrées (sliding window) redis.call('ZREMRANGEBYSCORE', key, 0, now - window) -- Compter les requêtes dans la fenêtre local count = redis.call('ZCARD', key) if count < limit then -- Ajouter la requête actuelle redis.call('ZADD', key, now, now) redis.call('EXPIRE', key, window) return {1, limit - count - 1, now + window} else -- Limite dépassée, retourner le temps de reset local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') local resetTime = now + window if #oldest > 0 then resetTime = tonumber(oldest[2]) + window end return {0, 0, resetTime} end ` now := time.Now().Unix() result, err := url.config.RedisClient.Eval( ctx, script, []string{key}, limit, windowSeconds, now, ).Result() if err != nil { return false, 0, 0, err } results := result.([]interface{}) allowed := results[0].(int64) == 1 remaining := int(results[1].(int64)) resetTime := results[2].(int64) return allowed, remaining, resetTime, nil } // getFallbackLimiter returns or creates an in-memory rate.Limiter for the given key (fail-secure fallback) func (url *UserRateLimiter) getFallbackLimiter(key string, limit int) *rate.Limiter { if v, ok := url.fallback.Load(key); ok { return v.(*rate.Limiter) } url.fallbackMu.Lock() defer url.fallbackMu.Unlock() if v, ok := url.fallback.Load(key); ok { return v.(*rate.Limiter) } window := url.config.Window if window == 0 { window = time.Minute } limiter := rate.NewLimiter(rate.Every(window/time.Duration(limit)), limit) url.fallback.Store(key, limiter) return limiter } // GetUserRateLimitInfo récupère les informations de rate limit pour un utilisateur func (url *UserRateLimiter) GetUserRateLimitInfo(ctx context.Context, userID uuid.UUID) (remaining int, resetTime int64, err error) { key := fmt.Sprintf("%s:user:%s", url.config.KeyPrefix, userID.String()) limit := url.config.RequestsPerMinute windowSeconds := int(url.config.Window.Seconds()) _, remaining, resetTime, err = url.checkRedisLimit(ctx, key, limit, windowSeconds) return remaining, resetTime, err }