242 lines
6.8 KiB
Go
242 lines
6.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"github.com/redis/go-redis/v9"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/time/rate"
|
|
|
|
apperrors "veza-backend-api/internal/errors"
|
|
"veza-backend-api/internal/response"
|
|
)
|
|
|
|
// UserRateLimiterConfig configuration pour le rate limiter par utilisateur
|
|
// BE-SVC-002: Implement rate limiting per user
|
|
type UserRateLimiterConfig struct {
|
|
// Limites par utilisateur
|
|
RequestsPerMinute int
|
|
Burst int
|
|
|
|
// Fenêtre de temps pour le rate limiting
|
|
Window time.Duration
|
|
|
|
// Configuration Redis
|
|
RedisClient *redis.Client
|
|
KeyPrefix string
|
|
|
|
// Logger
|
|
Logger *zap.Logger
|
|
}
|
|
|
|
// UserRateLimiter middleware pour limiter le taux de requêtes par utilisateur
|
|
type UserRateLimiter struct {
|
|
config *UserRateLimiterConfig
|
|
fallback sync.Map // map[string]*rate.Limiter for fail-secure when Redis fails
|
|
fallbackMu sync.Mutex
|
|
}
|
|
|
|
// NewUserRateLimiter crée un nouveau rate limiter par utilisateur
|
|
func NewUserRateLimiter(config *UserRateLimiterConfig) *UserRateLimiter {
|
|
if config.Window == 0 {
|
|
config.Window = time.Minute
|
|
}
|
|
if config.KeyPrefix == "" {
|
|
config.KeyPrefix = "rate_limit"
|
|
}
|
|
return &UserRateLimiter{
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
// Middleware retourne le middleware Gin pour le rate limiting par utilisateur
|
|
func (url *UserRateLimiter) Middleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Récupérer l'ID utilisateur depuis le contexte
|
|
userIDInterface, exists := c.Get("user_id")
|
|
if !exists {
|
|
response.RespondWithAppError(c, apperrors.NewUnauthorizedError("Authentication required for rate limiting"))
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Convertir l'ID utilisateur en UUID
|
|
var userID uuid.UUID
|
|
switch v := userIDInterface.(type) {
|
|
case uuid.UUID:
|
|
userID = v
|
|
case string:
|
|
var err error
|
|
userID, err = uuid.Parse(v)
|
|
if err != nil {
|
|
response.RespondWithAppError(c, apperrors.NewValidationError("Invalid user ID format"))
|
|
c.Abort()
|
|
return
|
|
}
|
|
default:
|
|
// Essayer de convertir en string puis en UUID
|
|
userIDStr := fmt.Sprintf("%v", v)
|
|
var err error
|
|
userID, err = uuid.Parse(userIDStr)
|
|
if err != nil {
|
|
response.RespondWithAppError(c, apperrors.NewValidationError("Invalid user ID format"))
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
|
|
// Construire la clé Redis pour cet utilisateur
|
|
key := fmt.Sprintf("%s:user:%s", url.config.KeyPrefix, userID.String())
|
|
limit := url.config.RequestsPerMinute
|
|
windowSeconds := int(url.config.Window.Seconds())
|
|
|
|
// Vérifier la limite avec Redis
|
|
allowed, remaining, resetTime, err := url.checkRedisLimit(c.Request.Context(), key, limit, windowSeconds)
|
|
if err != nil {
|
|
// Fail-secure: use in-memory fallback when Redis fails
|
|
if url.config.Logger != nil {
|
|
url.config.Logger.Warn("Redis rate limit check failed, using in-memory fallback",
|
|
zap.Error(err),
|
|
zap.String("user_id", userID.String()))
|
|
}
|
|
limiter := url.getFallbackLimiter(key, limit)
|
|
allowed = limiter.Allow()
|
|
remaining = int(limiter.Tokens())
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
resetTime = time.Now().Add(url.config.Window).Unix()
|
|
|
|
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
|
|
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
|
c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10))
|
|
|
|
if !allowed {
|
|
appErr := apperrors.New(apperrors.ErrCodeRateLimitExceeded, "Rate limit exceeded")
|
|
appErr.Context = map[string]interface{}{
|
|
"retry_after": func() int64 {
|
|
r := resetTime - time.Now().Unix()
|
|
if r < 0 {
|
|
return 0
|
|
}
|
|
return r
|
|
}(),
|
|
"limit": limit,
|
|
"window": url.config.Window.String(),
|
|
}
|
|
response.RespondWithAppError(c, appErr)
|
|
c.Abort()
|
|
return
|
|
}
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// Ajouter les headers de rate limiting
|
|
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
|
|
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
|
c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10))
|
|
|
|
if !allowed {
|
|
retryAfter := resetTime - time.Now().Unix()
|
|
if retryAfter < 0 {
|
|
retryAfter = 0
|
|
}
|
|
appErr := apperrors.New(apperrors.ErrCodeRateLimitExceeded, "Rate limit exceeded")
|
|
appErr.Context = map[string]interface{}{"retry_after": retryAfter, "limit": limit, "window": url.config.Window.String()}
|
|
response.RespondWithAppError(c, appErr)
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// checkRedisLimit vérifie la limite dans Redis avec un script Lua atomique
|
|
func (url *UserRateLimiter) checkRedisLimit(ctx context.Context, key string, limit, windowSeconds int) (bool, int, int64, error) {
|
|
// Script Lua pour l'atomicité (sliding window)
|
|
script := `
|
|
local key = KEYS[1]
|
|
local limit = tonumber(ARGV[1])
|
|
local window = tonumber(ARGV[2])
|
|
local now = tonumber(ARGV[3])
|
|
|
|
-- Nettoyer les anciennes entrées (sliding window)
|
|
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
|
|
|
|
-- Compter les requêtes dans la fenêtre
|
|
local count = redis.call('ZCARD', key)
|
|
|
|
if count < limit then
|
|
-- Ajouter la requête actuelle
|
|
redis.call('ZADD', key, now, now)
|
|
redis.call('EXPIRE', key, window)
|
|
return {1, limit - count - 1, now + window}
|
|
else
|
|
-- Limite dépassée, retourner le temps de reset
|
|
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
|
local resetTime = now + window
|
|
if #oldest > 0 then
|
|
resetTime = tonumber(oldest[2]) + window
|
|
end
|
|
return {0, 0, resetTime}
|
|
end
|
|
`
|
|
|
|
now := time.Now().Unix()
|
|
result, err := url.config.RedisClient.Eval(
|
|
ctx,
|
|
script,
|
|
[]string{key},
|
|
limit,
|
|
windowSeconds,
|
|
now,
|
|
).Result()
|
|
|
|
if err != nil {
|
|
return false, 0, 0, err
|
|
}
|
|
|
|
results := result.([]interface{})
|
|
allowed := results[0].(int64) == 1
|
|
remaining := int(results[1].(int64))
|
|
resetTime := results[2].(int64)
|
|
|
|
return allowed, remaining, resetTime, nil
|
|
}
|
|
|
|
// getFallbackLimiter returns or creates an in-memory rate.Limiter for the given key (fail-secure fallback)
|
|
func (url *UserRateLimiter) getFallbackLimiter(key string, limit int) *rate.Limiter {
|
|
if v, ok := url.fallback.Load(key); ok {
|
|
return v.(*rate.Limiter)
|
|
}
|
|
url.fallbackMu.Lock()
|
|
defer url.fallbackMu.Unlock()
|
|
if v, ok := url.fallback.Load(key); ok {
|
|
return v.(*rate.Limiter)
|
|
}
|
|
window := url.config.Window
|
|
if window == 0 {
|
|
window = time.Minute
|
|
}
|
|
limiter := rate.NewLimiter(rate.Every(window/time.Duration(limit)), limit)
|
|
url.fallback.Store(key, limiter)
|
|
return limiter
|
|
}
|
|
|
|
// GetUserRateLimitInfo récupère les informations de rate limit pour un utilisateur
|
|
func (url *UserRateLimiter) GetUserRateLimitInfo(ctx context.Context, userID uuid.UUID) (remaining int, resetTime int64, err error) {
|
|
key := fmt.Sprintf("%s:user:%s", url.config.KeyPrefix, userID.String())
|
|
limit := url.config.RequestsPerMinute
|
|
windowSeconds := int(url.config.Window.Seconds())
|
|
|
|
_, remaining, resetTime, err = url.checkRedisLimit(ctx, key, limit, windowSeconds)
|
|
return remaining, resetTime, err
|
|
}
|