veza/veza-backend-api/internal/middleware/rate_limiter.go
2025-12-03 20:29:37 +01:00

240 lines
5.8 KiB
Go

package middleware
import (
"context"
"fmt"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"golang.org/x/time/rate"
)
// RateLimiterConfig configuration pour le rate limiter
type RateLimiterConfig struct {
// Limites par IP (non authentifié)
IPRequestsPerMinute int
IPBurst int
// Limites par utilisateur authentifié
UserRequestsPerMinute int
UserBurst int
// Configuration Redis
RedisClient *redis.Client
KeyPrefix string
}
// RateLimiter middleware pour limiter le taux de requêtes
type RateLimiter struct {
config *RateLimiterConfig
ipLimiter *rate.Limiter
userLimiter *rate.Limiter
}
// NewRateLimiter crée un nouveau rate limiter
func NewRateLimiter(config *RateLimiterConfig) *RateLimiter {
return &RateLimiter{
config: config,
ipLimiter: rate.NewLimiter(
rate.Every(time.Minute/time.Duration(config.IPRequestsPerMinute)),
config.IPBurst,
),
userLimiter: rate.NewLimiter(
rate.Every(time.Minute/time.Duration(config.UserRequestsPerMinute)),
config.UserBurst,
),
}
}
// RateLimitMiddleware middleware principal de rate limiting
func (rl *RateLimiter) RateLimitMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Déterminer si l'utilisateur est authentifié
userID, isAuthenticated := c.Get("user_id")
var limiter *rate.Limiter
var key string
var limit int
if isAuthenticated {
// Utilisateur authentifié - limite plus élevée
limiter = rl.userLimiter
key = fmt.Sprintf("%s:user:%v", rl.config.KeyPrefix, userID)
limit = rl.config.UserRequestsPerMinute
} else {
// IP non authentifiée - limite plus stricte
limiter = rl.ipLimiter
key = fmt.Sprintf("%s:ip:%s", rl.config.KeyPrefix, c.ClientIP())
limit = rl.config.IPRequestsPerMinute
}
// Vérifier la limite avec Redis pour persistance
allowed, remaining, err := rl.checkRedisLimit(c.Request.Context(), key, limit)
if err != nil {
// En cas d'erreur Redis, utiliser le limiter local
allowed = limiter.Allow()
remaining = int(limiter.Tokens())
}
// Ajouter les headers de rate limiting
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10))
if !allowed {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"retry_after": 60,
})
c.Abort()
return
}
c.Next()
}
}
// checkRedisLimit vérifie la limite dans Redis
func (rl *RateLimiter) checkRedisLimit(ctx context.Context, key string, limit int) (bool, int, error) {
// Utiliser un script Lua pour l'atomicité
script := `
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = redis.call('GET', key)
if current == false then
redis.call('SET', key, 1, 'EX', window)
return {1, limit - 1}
end
local count = tonumber(current)
if count < limit then
redis.call('INCR', key)
return {1, limit - count - 1}
else
return {0, 0}
end
`
result, err := rl.config.RedisClient.Eval(
ctx,
script,
[]string{key},
limit,
60, // 60 secondes
).Result()
if err != nil {
return false, 0, err
}
results := result.([]interface{})
allowed := results[0].(int64) == 1
remaining := int(results[1].(int64))
return allowed, remaining, nil
}
// RateLimitByIP middleware pour limiter par IP uniquement
func (rl *RateLimiter) RateLimitByIP() gin.HandlerFunc {
return func(c *gin.Context) {
key := fmt.Sprintf("%s:ip:%s", rl.config.KeyPrefix, c.ClientIP())
allowed, remaining, err := rl.checkRedisLimit(c.Request.Context(), key, rl.config.IPRequestsPerMinute)
if err != nil {
allowed = rl.ipLimiter.Allow()
remaining = int(rl.ipLimiter.Tokens())
}
c.Header("X-RateLimit-Limit", strconv.Itoa(rl.config.IPRequestsPerMinute))
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
if !allowed {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"retry_after": 60,
})
c.Abort()
return
}
c.Next()
}
}
// UploadRateLimit middleware pour limiter les uploads de tracks par utilisateur
// Limite: 10 uploads par heure par utilisateur
func UploadRateLimit(redisClient *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
userID := c.GetInt64("user_id")
if userID == 0 {
// Si pas d'utilisateur authentifié, passer au suivant
c.Next()
return
}
// Clé Redis pour cet utilisateur
key := fmt.Sprintf("upload_rate_limit:%d", userID)
limit := 10 // 10 uploads par heure
window := time.Hour
// Script Lua pour l'atomicité
script := `
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = redis.call('GET', key)
if current == false then
redis.call('SET', key, 1, 'EX', window)
return {1, limit - 1}
end
local count = tonumber(current)
if count < limit then
redis.call('INCR', key)
return {1, limit - count - 1}
else
return {0, 0}
end
`
result, err := redisClient.Eval(
c.Request.Context(),
script,
[]string{key},
limit,
int(window.Seconds()),
).Result()
if err != nil {
// En cas d'erreur Redis, autoriser la requête (fail-open)
c.Next()
return
}
results := result.([]interface{})
allowed := results[0].(int64) == 1
remaining := int(results[1].(int64))
// Ajouter les headers de rate limiting
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(window).Unix(), 10))
if !allowed {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "upload rate limit exceeded",
"retry_after": int(window.Seconds()),
})
c.Abort()
return
}
c.Next()
}
}