diff --git a/veza-backend-api/internal/middleware/ratelimit_redis.go b/veza-backend-api/internal/middleware/ratelimit_redis.go new file mode 100644 index 000000000..5bffc6d96 --- /dev/null +++ b/veza-backend-api/internal/middleware/ratelimit_redis.go @@ -0,0 +1,172 @@ +package middleware + +import ( + "fmt" + "net/http" + "os" + "strconv" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "go.uber.org/zap" +) + +// RedisRateLimiter is a Redis-backed rate limiter with the same interface as SimpleRateLimiter. +// Uses INCR + EXPIRE atomically via Lua script for counting. +type RedisRateLimiter struct { + client *redis.Client + limit int + window time.Duration + mu sync.Mutex +} + +// NewRedisRateLimiter creates a new Redis-backed rate limiter. +// limit: maximum number of requests per window +// window: time window (e.g. 1*time.Minute for 100 req/min) +func NewRedisRateLimiter(client *redis.Client, limit int, window time.Duration) *RedisRateLimiter { + return &RedisRateLimiter{ + client: client, + limit: limit, + window: window, + } +} + +// Lua script for atomic INCR + EXPIRE. Returns {allowed, remaining}. +// - allowed: 1 if request allowed, 0 if rate limit exceeded +// - remaining: remaining requests in window +var redisRateLimitScript = ` + local key = KEYS[1] + local limit = tonumber(ARGV[1]) + local window = tonumber(ARGV[2]) + + local current = redis.call('GET', key) + if current == false then + redis.call('SET', key, 1, 'EX', window) + return {1, limit - 1} + end + + local count = tonumber(current) + if count < limit then + redis.call('INCR', key) + return {1, limit - count - 1} + else + return {0, 0} + end +` + +// Middleware returns the Gin middleware for rate limiting. +// Same interface as SimpleRateLimiter. +func (rl *RedisRateLimiter) Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + if isExcludedPath(c.Request.URL.Path) { + c.Next() + return + } + + if os.Getenv("APP_ENV") == "production" { + // Continue to rate limit + } else if os.Getenv("DISABLE_RATE_LIMIT_FOR_TESTS") == "true" { + c.Next() + return + } + + ip := c.ClientIP() + key := fmt.Sprintf("ratelimit:%s", ip) + + rl.mu.Lock() + limit := rl.limit + window := rl.window + rl.mu.Unlock() + + windowSec := int(window.Seconds()) + result, err := rl.client.Eval( + c.Request.Context(), + redisRateLimitScript, + []string{key}, + limit, + windowSec, + ).Result() + + if err != nil { + zap.L().Error("Redis rate limit check failed", + zap.Error(err), + zap.String("ip", ip), + zap.String("key", key)) + resetTime := time.Now().Add(window).Unix() + retryAfter := windowSec + c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) + c.Header("X-RateLimit-Remaining", "0") + c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) + c.Header("Retry-After", strconv.Itoa(retryAfter)) + c.JSON(http.StatusTooManyRequests, gin.H{ + "success": false, + "error": gin.H{ + "code": 429, + "message": "Rate limit exceeded. Please try again later.", + "details": []gin.H{ + { + "field": "rate_limit", + "message": fmt.Sprintf("You have exceeded the rate limit of %d requests per %v", limit, window), + }, + }, + "retry_after": retryAfter, + "limit": limit, + "remaining": 0, + "reset": resetTime, + }, + }) + c.Abort() + return + } + + results := result.([]interface{}) + allowed := results[0].(int64) == 1 + remaining := int(results[1].(int64)) + + now := time.Now() + resetTime := now.Add(window).Unix() + + if !allowed { + retryAfter := windowSec + c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) + c.Header("X-RateLimit-Remaining", "0") + c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) + c.Header("Retry-After", strconv.Itoa(retryAfter)) + c.JSON(http.StatusTooManyRequests, gin.H{ + "success": false, + "error": gin.H{ + "code": 429, + "message": "Rate limit exceeded. Please try again later.", + "details": []gin.H{ + { + "field": "rate_limit", + "message": fmt.Sprintf("You have exceeded the rate limit of %d requests per %v", limit, window), + }, + }, + "retry_after": retryAfter, + "limit": limit, + "remaining": 0, + "reset": resetTime, + }, + }) + c.Abort() + return + } + + c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) + c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) + c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) + + c.Next() + } +} + +// UpdateLimits updates the rate limit configuration (hot reload without restart). +func (rl *RedisRateLimiter) UpdateLimits(limit int, window time.Duration) { + rl.mu.Lock() + defer rl.mu.Unlock() + rl.limit = limit + rl.window = window +}