veza/veza-backend-api/internal/middleware/user_rate_limiter.go
2026-03-06 19:13:16 +01:00

242 lines
6.8 KiB
Go

package middleware
import (
"context"
"fmt"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"golang.org/x/time/rate"
apperrors "veza-backend-api/internal/errors"
"veza-backend-api/internal/response"
)
// UserRateLimiterConfig configuration pour le rate limiter par utilisateur
// BE-SVC-002: Implement rate limiting per user
type UserRateLimiterConfig struct {
// Limites par utilisateur
RequestsPerMinute int
Burst int
// Fenêtre de temps pour le rate limiting
Window time.Duration
// Configuration Redis
RedisClient *redis.Client
KeyPrefix string
// Logger
Logger *zap.Logger
}
// UserRateLimiter middleware pour limiter le taux de requêtes par utilisateur
type UserRateLimiter struct {
config *UserRateLimiterConfig
fallback sync.Map // map[string]*rate.Limiter for fail-secure when Redis fails
fallbackMu sync.Mutex
}
// NewUserRateLimiter crée un nouveau rate limiter par utilisateur
func NewUserRateLimiter(config *UserRateLimiterConfig) *UserRateLimiter {
if config.Window == 0 {
config.Window = time.Minute
}
if config.KeyPrefix == "" {
config.KeyPrefix = "rate_limit"
}
return &UserRateLimiter{
config: config,
}
}
// Middleware retourne le middleware Gin pour le rate limiting par utilisateur
func (url *UserRateLimiter) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Récupérer l'ID utilisateur depuis le contexte
userIDInterface, exists := c.Get("user_id")
if !exists {
response.RespondWithAppError(c, apperrors.NewUnauthorizedError("Authentication required for rate limiting"))
c.Abort()
return
}
// Convertir l'ID utilisateur en UUID
var userID uuid.UUID
switch v := userIDInterface.(type) {
case uuid.UUID:
userID = v
case string:
var err error
userID, err = uuid.Parse(v)
if err != nil {
response.RespondWithAppError(c, apperrors.NewValidationError("Invalid user ID format"))
c.Abort()
return
}
default:
// Essayer de convertir en string puis en UUID
userIDStr := fmt.Sprintf("%v", v)
var err error
userID, err = uuid.Parse(userIDStr)
if err != nil {
response.RespondWithAppError(c, apperrors.NewValidationError("Invalid user ID format"))
c.Abort()
return
}
}
// Construire la clé Redis pour cet utilisateur
key := fmt.Sprintf("%s:user:%s", url.config.KeyPrefix, userID.String())
limit := url.config.RequestsPerMinute
windowSeconds := int(url.config.Window.Seconds())
// Vérifier la limite avec Redis
allowed, remaining, resetTime, err := url.checkRedisLimit(c.Request.Context(), key, limit, windowSeconds)
if err != nil {
// Fail-secure: use in-memory fallback when Redis fails
if url.config.Logger != nil {
url.config.Logger.Warn("Redis rate limit check failed, using in-memory fallback",
zap.Error(err),
zap.String("user_id", userID.String()))
}
limiter := url.getFallbackLimiter(key, limit)
allowed = limiter.Allow()
remaining = int(limiter.Tokens())
if remaining < 0 {
remaining = 0
}
resetTime = time.Now().Add(url.config.Window).Unix()
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10))
if !allowed {
appErr := apperrors.New(apperrors.ErrCodeRateLimitExceeded, "Rate limit exceeded")
appErr.Context = map[string]interface{}{
"retry_after": func() int64 {
r := resetTime - time.Now().Unix()
if r < 0 {
return 0
}
return r
}(),
"limit": limit,
"window": url.config.Window.String(),
}
response.RespondWithAppError(c, appErr)
c.Abort()
return
}
c.Next()
return
}
// Ajouter les headers de rate limiting
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10))
if !allowed {
retryAfter := resetTime - time.Now().Unix()
if retryAfter < 0 {
retryAfter = 0
}
appErr := apperrors.New(apperrors.ErrCodeRateLimitExceeded, "Rate limit exceeded")
appErr.Context = map[string]interface{}{"retry_after": retryAfter, "limit": limit, "window": url.config.Window.String()}
response.RespondWithAppError(c, appErr)
c.Abort()
return
}
c.Next()
}
}
// checkRedisLimit vérifie la limite dans Redis avec un script Lua atomique
func (url *UserRateLimiter) checkRedisLimit(ctx context.Context, key string, limit, windowSeconds int) (bool, int, int64, error) {
// Script Lua pour l'atomicité (sliding window)
script := `
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
-- Nettoyer les anciennes entrées (sliding window)
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
-- Compter les requêtes dans la fenêtre
local count = redis.call('ZCARD', key)
if count < limit then
-- Ajouter la requête actuelle
redis.call('ZADD', key, now, now)
redis.call('EXPIRE', key, window)
return {1, limit - count - 1, now + window}
else
-- Limite dépassée, retourner le temps de reset
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local resetTime = now + window
if #oldest > 0 then
resetTime = tonumber(oldest[2]) + window
end
return {0, 0, resetTime}
end
`
now := time.Now().Unix()
result, err := url.config.RedisClient.Eval(
ctx,
script,
[]string{key},
limit,
windowSeconds,
now,
).Result()
if err != nil {
return false, 0, 0, err
}
results := result.([]interface{})
allowed := results[0].(int64) == 1
remaining := int(results[1].(int64))
resetTime := results[2].(int64)
return allowed, remaining, resetTime, nil
}
// getFallbackLimiter returns or creates an in-memory rate.Limiter for the given key (fail-secure fallback)
func (url *UserRateLimiter) getFallbackLimiter(key string, limit int) *rate.Limiter {
if v, ok := url.fallback.Load(key); ok {
return v.(*rate.Limiter)
}
url.fallbackMu.Lock()
defer url.fallbackMu.Unlock()
if v, ok := url.fallback.Load(key); ok {
return v.(*rate.Limiter)
}
window := url.config.Window
if window == 0 {
window = time.Minute
}
limiter := rate.NewLimiter(rate.Every(window/time.Duration(limit)), limit)
url.fallback.Store(key, limiter)
return limiter
}
// GetUserRateLimitInfo récupère les informations de rate limit pour un utilisateur
func (url *UserRateLimiter) GetUserRateLimitInfo(ctx context.Context, userID uuid.UUID) (remaining int, resetTime int64, err error) {
key := fmt.Sprintf("%s:user:%s", url.config.KeyPrefix, userID.String())
limit := url.config.RequestsPerMinute
windowSeconds := int(url.config.Window.Seconds())
_, remaining, resetTime, err = url.checkRedisLimit(ctx, key, limit, windowSeconds)
return remaining, resetTime, err
}