INF-01: RedisRateLimiter uses atomic Lua script (INCR+EXPIRE) for distributed rate limiting. Falls back to in-memory SimpleRateLimiter when Redis is unavailable. Same X-RateLimit-* headers and 429 format.
172 lines
4.3 KiB
Go
172 lines
4.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/redis/go-redis/v9"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// RedisRateLimiter is a Redis-backed rate limiter with the same interface as SimpleRateLimiter.
|
|
// Uses INCR + EXPIRE atomically via Lua script for counting.
|
|
type RedisRateLimiter struct {
|
|
client *redis.Client
|
|
limit int
|
|
window time.Duration
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NewRedisRateLimiter creates a new Redis-backed rate limiter.
|
|
// limit: maximum number of requests per window
|
|
// window: time window (e.g. 1*time.Minute for 100 req/min)
|
|
func NewRedisRateLimiter(client *redis.Client, limit int, window time.Duration) *RedisRateLimiter {
|
|
return &RedisRateLimiter{
|
|
client: client,
|
|
limit: limit,
|
|
window: window,
|
|
}
|
|
}
|
|
|
|
// Lua script for atomic INCR + EXPIRE. Returns {allowed, remaining}.
|
|
// - allowed: 1 if request allowed, 0 if rate limit exceeded
|
|
// - remaining: remaining requests in window
|
|
var redisRateLimitScript = `
|
|
local key = KEYS[1]
|
|
local limit = tonumber(ARGV[1])
|
|
local window = tonumber(ARGV[2])
|
|
|
|
local current = redis.call('GET', key)
|
|
if current == false then
|
|
redis.call('SET', key, 1, 'EX', window)
|
|
return {1, limit - 1}
|
|
end
|
|
|
|
local count = tonumber(current)
|
|
if count < limit then
|
|
redis.call('INCR', key)
|
|
return {1, limit - count - 1}
|
|
else
|
|
return {0, 0}
|
|
end
|
|
`
|
|
|
|
// Middleware returns the Gin middleware for rate limiting.
|
|
// Same interface as SimpleRateLimiter.
|
|
func (rl *RedisRateLimiter) Middleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if isExcludedPath(c.Request.URL.Path) {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
if os.Getenv("APP_ENV") == "production" {
|
|
// Continue to rate limit
|
|
} else if os.Getenv("DISABLE_RATE_LIMIT_FOR_TESTS") == "true" {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
ip := c.ClientIP()
|
|
key := fmt.Sprintf("ratelimit:%s", ip)
|
|
|
|
rl.mu.Lock()
|
|
limit := rl.limit
|
|
window := rl.window
|
|
rl.mu.Unlock()
|
|
|
|
windowSec := int(window.Seconds())
|
|
result, err := rl.client.Eval(
|
|
c.Request.Context(),
|
|
redisRateLimitScript,
|
|
[]string{key},
|
|
limit,
|
|
windowSec,
|
|
).Result()
|
|
|
|
if err != nil {
|
|
zap.L().Error("Redis rate limit check failed",
|
|
zap.Error(err),
|
|
zap.String("ip", ip),
|
|
zap.String("key", key))
|
|
resetTime := time.Now().Add(window).Unix()
|
|
retryAfter := windowSec
|
|
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
|
|
c.Header("X-RateLimit-Remaining", "0")
|
|
c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10))
|
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
|
"success": false,
|
|
"error": gin.H{
|
|
"code": 429,
|
|
"message": "Rate limit exceeded. Please try again later.",
|
|
"details": []gin.H{
|
|
{
|
|
"field": "rate_limit",
|
|
"message": fmt.Sprintf("You have exceeded the rate limit of %d requests per %v", limit, window),
|
|
},
|
|
},
|
|
"retry_after": retryAfter,
|
|
"limit": limit,
|
|
"remaining": 0,
|
|
"reset": resetTime,
|
|
},
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
results := result.([]interface{})
|
|
allowed := results[0].(int64) == 1
|
|
remaining := int(results[1].(int64))
|
|
|
|
now := time.Now()
|
|
resetTime := now.Add(window).Unix()
|
|
|
|
if !allowed {
|
|
retryAfter := windowSec
|
|
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
|
|
c.Header("X-RateLimit-Remaining", "0")
|
|
c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10))
|
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
|
"success": false,
|
|
"error": gin.H{
|
|
"code": 429,
|
|
"message": "Rate limit exceeded. Please try again later.",
|
|
"details": []gin.H{
|
|
{
|
|
"field": "rate_limit",
|
|
"message": fmt.Sprintf("You have exceeded the rate limit of %d requests per %v", limit, window),
|
|
},
|
|
},
|
|
"retry_after": retryAfter,
|
|
"limit": limit,
|
|
"remaining": 0,
|
|
"reset": resetTime,
|
|
},
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
|
|
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
|
c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10))
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// UpdateLimits updates the rate limit configuration (hot reload without restart).
|
|
func (rl *RedisRateLimiter) UpdateLimits(limit int, window time.Duration) {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
rl.limit = limit
|
|
rl.window = window
|
|
}
|