veza/veza-backend-api/internal/middleware/endpoint_limiter.go
2026-03-05 23:03:43 +01:00

352 lines
9.8 KiB
Go

package middleware
import (
"context"
"fmt"
"net/http"
"os"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// EndpointLimiterConfig configuration pour les limites par endpoint
type EndpointLimiterConfig struct {
RedisClient *redis.Client
KeyPrefix string
}
// EndpointLimits définit les limites pour chaque endpoint
type EndpointLimits struct {
// Login: 5 tentatives/15min par IP
LoginAttempts int
LoginWindow time.Duration
// Register: 3 comptes/heure par IP
RegisterAttempts int
RegisterWindow time.Duration
// Password reset: 3 tentatives/heure
PasswordResetAttempts int
PasswordResetWindow time.Duration
// Upload: 10 fichiers/heure par user
UploadAttempts int
UploadWindow time.Duration
}
// DefaultEndpointLimits retourne les limites par défaut
func DefaultEndpointLimits() *EndpointLimits {
return &EndpointLimits{
LoginAttempts: 5,
LoginWindow: 15 * time.Minute,
RegisterAttempts: 3,
RegisterWindow: time.Hour,
PasswordResetAttempts: 3,
PasswordResetWindow: time.Hour,
UploadAttempts: 10,
UploadWindow: time.Hour,
}
}
// endpointInMemoryEntry holds in-memory rate limit state for fallback
type endpointInMemoryEntry struct {
count int
windowStart time.Time
}
// EndpointLimiter gère les limites par endpoint
type EndpointLimiter struct {
config *EndpointLimiterConfig
limits *EndpointLimits
inMemoryStore map[string]*endpointInMemoryEntry
inMemoryMu sync.RWMutex
}
// NewEndpointLimiter crée un nouveau endpoint limiter
func NewEndpointLimiter(config *EndpointLimiterConfig, limits *EndpointLimits) *EndpointLimiter {
return &EndpointLimiter{
config: config,
limits: limits,
inMemoryStore: make(map[string]*endpointInMemoryEntry),
}
}
// LoginRateLimit middleware pour limiter les tentatives de login
func (el *EndpointLimiter) LoginRateLimit() gin.HandlerFunc {
return el.createEndpointLimit(
"login",
el.limits.LoginAttempts,
el.limits.LoginWindow,
"Too many login attempts",
)
}
// RegisterRateLimit middleware pour limiter les inscriptions
func (el *EndpointLimiter) RegisterRateLimit() gin.HandlerFunc {
return el.createEndpointLimit(
"register",
el.limits.RegisterAttempts,
el.limits.RegisterWindow,
"Too many registration attempts",
)
}
// PasswordResetRateLimit middleware pour limiter les reset de mot de passe
func (el *EndpointLimiter) PasswordResetRateLimit() gin.HandlerFunc {
return el.createEndpointLimit(
"password_reset",
el.limits.PasswordResetAttempts,
el.limits.PasswordResetWindow,
"Too many password reset attempts",
)
}
// VerifyEmailRateLimit middleware pour limiter les tentatives de vérification d'email
// BE-SEC-005: Implement rate limiting for authentication endpoints
func (el *EndpointLimiter) VerifyEmailRateLimit() gin.HandlerFunc {
return el.createEndpointLimit(
"verify_email",
5, // 5 tentatives par heure
time.Hour, // Fenêtre de 1 heure
"Too many email verification attempts",
)
}
// ResendVerificationRateLimit middleware pour limiter les renvois de vérification
// BE-SEC-005: Implement rate limiting for authentication endpoints
func (el *EndpointLimiter) ResendVerificationRateLimit() gin.HandlerFunc {
return el.createEndpointLimit(
"resend_verification",
3, // 3 tentatives par heure
time.Hour, // Fenêtre de 1 heure
"Too many verification resend attempts",
)
}
// CheckUsernameRateLimit middleware pour limiter l'énumération de noms d'utilisateur (SEC-009)
func (el *EndpointLimiter) CheckUsernameRateLimit() gin.HandlerFunc {
return el.createEndpointLimit(
"check_username",
30, // 30 requêtes par minute
time.Minute, // Fenêtre de 1 minute
"Too many username check attempts",
)
}
// RefreshRateLimit middleware pour limiter le token grinding sur /auth/refresh (SEC-010)
func (el *EndpointLimiter) RefreshRateLimit() gin.HandlerFunc {
return el.createEndpointLimit(
"refresh",
10, // 10 requêtes par minute
time.Minute, // Fenêtre de 1 minute
"Too many token refresh attempts",
)
}
// ValidateRateLimit middleware pour limiter les appels à POST /validate (A01)
func (el *EndpointLimiter) ValidateRateLimit() gin.HandlerFunc {
return el.createEndpointLimit(
"validate",
10, // 10 requêtes par minute par IP
time.Minute, // Fenêtre de 1 minute
"Too many validation requests",
)
}
// UploadRateLimit middleware pour limiter les uploads par utilisateur
func (el *EndpointLimiter) UploadRateLimit() gin.HandlerFunc {
return func(c *gin.Context) {
// Récupérer l'ID utilisateur depuis le contexte
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
c.Abort()
return
}
key := fmt.Sprintf("%s:upload:user:%v", el.config.KeyPrefix, userID)
allowed, remaining, err := el.checkLimit(c.Request.Context(), key, el.limits.UploadAttempts, el.limits.UploadWindow)
if err != nil {
// Fallback in-memory when Redis fails (fail-secure)
allowed, remaining = el.checkLimitInMemory(key, el.limits.UploadAttempts, el.limits.UploadWindow)
}
c.Header("X-UploadLimit-Limit", strconv.Itoa(el.limits.UploadAttempts))
c.Header("X-UploadLimit-Remaining", strconv.Itoa(remaining))
c.Header("X-UploadLimit-Reset", strconv.FormatInt(time.Now().Add(el.limits.UploadWindow).Unix(), 10))
if !allowed {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "Upload limit exceeded",
"retry_after": int(el.limits.UploadWindow.Seconds()),
})
c.Abort()
return
}
c.Next()
}
}
// createEndpointLimit crée un middleware de limitation pour un endpoint
func (el *EndpointLimiter) createEndpointLimit(
endpoint string,
attempts int,
window time.Duration,
errorMessage string,
) gin.HandlerFunc {
return func(c *gin.Context) {
// SEC-011: Never bypass rate limiting in production
if os.Getenv("APP_ENV") == "production" {
// Continue to rate limit
} else if os.Getenv("DISABLE_RATE_LIMIT_FOR_TESTS") == "true" {
c.Next()
return
}
key := fmt.Sprintf("%s:%s:ip:%s", el.config.KeyPrefix, endpoint, c.ClientIP())
allowed, remaining, err := el.checkLimit(c.Request.Context(), key, attempts, window)
if err != nil {
// Fallback in-memory when Redis fails (fail-secure)
allowed, remaining = el.checkLimitInMemory(key, attempts, window)
}
headerPrefix := fmt.Sprintf("X-%sLimit", capitalize(endpoint))
c.Header(headerPrefix+"-Limit", strconv.Itoa(attempts))
c.Header(headerPrefix+"-Remaining", strconv.Itoa(remaining))
c.Header(headerPrefix+"-Reset", strconv.FormatInt(time.Now().Add(window).Unix(), 10))
if !allowed {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": errorMessage,
"retry_after": int(window.Seconds()),
})
c.Abort()
return
}
c.Next()
}
}
// checkLimit vérifie si une limite est respectée
func (el *EndpointLimiter) checkLimit(ctx context.Context, key string, attempts int, window time.Duration) (bool, int, error) {
// Use in-memory fallback when Redis is not configured (e.g. integration tests)
if el.config == nil || el.config.RedisClient == nil {
allowed, remaining := el.checkLimitInMemory(key, attempts, window)
return allowed, remaining, nil
}
// Script Lua pour l'atomicité
script := `
local key = KEYS[1]
local attempts = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = redis.call('GET', key)
if current == false then
redis.call('SET', key, 1, 'EX', window)
return {1, attempts - 1}
end
local count = tonumber(current)
if count < attempts then
redis.call('INCR', key)
return {1, attempts - count - 1}
else
return {0, 0}
end
`
result, err := el.config.RedisClient.Eval(
ctx,
script,
[]string{key},
attempts,
int(window.Seconds()),
).Result()
if err != nil {
return false, 0, err
}
results := result.([]interface{})
allowed := results[0].(int64) == 1
remaining := int(results[1].(int64))
return allowed, remaining, nil
}
// checkLimitInMemory implements rate limiting in-memory when Redis fails (fail-secure)
func (el *EndpointLimiter) checkLimitInMemory(key string, attempts int, window time.Duration) (bool, int) {
el.inMemoryMu.Lock()
defer el.inMemoryMu.Unlock()
now := time.Now()
entry := el.inMemoryStore[key]
if entry == nil {
entry = &endpointInMemoryEntry{count: 0, windowStart: now}
el.inMemoryStore[key] = entry
}
if now.Sub(entry.windowStart) >= window {
entry.count = 0
entry.windowStart = now
}
entry.count++
remaining := attempts - entry.count
if remaining < 0 {
remaining = 0
}
return entry.count <= attempts, remaining
}
// capitalize met en majuscule la première lettre
func capitalize(s string) string {
if len(s) == 0 {
return s
}
return string(s[0]-32) + s[1:]
}
// RateLimitByUser middleware pour limiter par utilisateur (pour endpoints génériques)
func (el *EndpointLimiter) RateLimitByUser(attempts int, window time.Duration, errorMessage string) gin.HandlerFunc {
return func(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
c.Abort()
return
}
key := fmt.Sprintf("%s:user:%v", el.config.KeyPrefix, userID)
allowed, remaining, err := el.checkLimit(c.Request.Context(), key, attempts, window)
if err != nil {
// Fallback in-memory when Redis fails (fail-secure)
allowed, remaining = el.checkLimitInMemory(key, attempts, window)
}
c.Header("X-UserLimit-Limit", strconv.Itoa(attempts))
c.Header("X-UserLimit-Remaining", strconv.Itoa(remaining))
c.Header("X-UserLimit-Reset", strconv.FormatInt(time.Now().Add(window).Unix(), 10))
if !allowed {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": errorMessage,
"retry_after": int(window.Seconds()),
})
c.Abort()
return
}
c.Next()
}
}