package middleware import ( "fmt" "os" "strconv" "sync" "time" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" "go.uber.org/zap" apperrors "veza-backend-api/internal/errors" "veza-backend-api/internal/response" ) // RedisRateLimiter is a Redis-backed rate limiter with the same interface as SimpleRateLimiter. // Uses INCR + EXPIRE atomically via Lua script for counting. type RedisRateLimiter struct { client *redis.Client limit int window time.Duration mu sync.Mutex } // NewRedisRateLimiter creates a new Redis-backed rate limiter. // limit: maximum number of requests per window // window: time window (e.g. 1*time.Minute for 100 req/min) func NewRedisRateLimiter(client *redis.Client, limit int, window time.Duration) *RedisRateLimiter { return &RedisRateLimiter{ client: client, limit: limit, window: window, } } // Lua script for atomic INCR + EXPIRE. Returns {allowed, remaining}. // - allowed: 1 if request allowed, 0 if rate limit exceeded // - remaining: remaining requests in window var redisRateLimitScript = ` local key = KEYS[1] local limit = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = redis.call('GET', key) if current == false then redis.call('SET', key, 1, 'EX', window) return {1, limit - 1} end local count = tonumber(current) if count < limit then redis.call('INCR', key) return {1, limit - count - 1} else return {0, 0} end ` // Middleware returns the Gin middleware for rate limiting. // Same interface as SimpleRateLimiter. func (rl *RedisRateLimiter) Middleware() gin.HandlerFunc { return func(c *gin.Context) { if isExcludedPath(c.Request.URL.Path) { c.Next() return } if os.Getenv("APP_ENV") == "production" { // Continue to rate limit } else if os.Getenv("DISABLE_RATE_LIMIT_FOR_TESTS") == "true" { c.Next() return } ip := c.ClientIP() key := fmt.Sprintf("ratelimit:%s", ip) rl.mu.Lock() limit := rl.limit window := rl.window rl.mu.Unlock() windowSec := int(window.Seconds()) result, err := rl.client.Eval( c.Request.Context(), redisRateLimitScript, []string{key}, limit, windowSec, ).Result() if err != nil { zap.L().Error("Redis rate limit check failed", zap.Error(err), zap.String("ip", ip), zap.String("key", key)) resetTime := time.Now().Add(window).Unix() retryAfter := windowSec c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", "0") c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) c.Header("Retry-After", strconv.Itoa(retryAfter)) appErr := apperrors.New(apperrors.ErrCodeRateLimitExceeded, "Rate limit exceeded. Please try again later.") appErr.Details = []apperrors.ErrorDetail{{Field: "rate_limit", Message: fmt.Sprintf("You have exceeded the rate limit of %d requests per %v", limit, window)}} appErr.Context = map[string]interface{}{"retry_after": retryAfter, "limit": limit, "remaining": 0, "reset": resetTime} response.RespondWithAppError(c, appErr) c.Abort() return } results := result.([]interface{}) allowed := results[0].(int64) == 1 remaining := int(results[1].(int64)) now := time.Now() resetTime := now.Add(window).Unix() if !allowed { retryAfter := windowSec c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", "0") c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) c.Header("Retry-After", strconv.Itoa(retryAfter)) appErr := apperrors.New(apperrors.ErrCodeRateLimitExceeded, "Rate limit exceeded. Please try again later.") appErr.Details = []apperrors.ErrorDetail{{Field: "rate_limit", Message: fmt.Sprintf("You have exceeded the rate limit of %d requests per %v", limit, window)}} appErr.Context = map[string]interface{}{"retry_after": retryAfter, "limit": limit, "remaining": 0, "reset": resetTime} response.RespondWithAppError(c, appErr) c.Abort() return } c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) c.Next() } } // UpdateLimits updates the rate limit configuration (hot reload without restart). func (rl *RedisRateLimiter) UpdateLimits(limit int, window time.Duration) { rl.mu.Lock() defer rl.mu.Unlock() rl.limit = limit rl.window = window }