package chat import ( "context" "fmt" "sync" "time" "github.com/google/uuid" "github.com/redis/go-redis/v9" "go.uber.org/zap" ) type rateConfig struct { maxRequests int window time.Duration } // inMemoryRateLimiter is the original per-process rate limiter kept as fallback. type inMemoryRateLimiter struct { limits map[string]rateConfig entries map[string]*rateBucket mu sync.Mutex } type rateBucket struct { count int windowAt time.Time } func newInMemoryRateLimiter(limits map[string]rateConfig) *inMemoryRateLimiter { return &inMemoryRateLimiter{ limits: limits, entries: make(map[string]*rateBucket), } } func (rl *inMemoryRateLimiter) allow(userID uuid.UUID, action string) bool { cfg, ok := rl.limits[action] if !ok { return true } key := userID.String() + ":" + action now := time.Now() rl.mu.Lock() defer rl.mu.Unlock() bucket, exists := rl.entries[key] if !exists || now.Sub(bucket.windowAt) > cfg.window { rl.entries[key] = &rateBucket{count: 1, windowAt: now} return true } if bucket.count >= cfg.maxRequests { return false } bucket.count++ return true } // RateLimiter is a Redis-backed sliding window rate limiter with in-memory fallback. // When Redis is unavailable (nil client or connection error), it transparently // falls back to the in-memory implementation. type RateLimiter struct { redis *redis.Client limits map[string]rateConfig logger *zap.Logger fallback *inMemoryRateLimiter } // Lua script: sliding window using sorted sets. // KEYS[1] = rate limit key // ARGV[1] = window in milliseconds // ARGV[2] = max requests // ARGV[3] = current timestamp in milliseconds // ARGV[4] = unique member (timestamp:nonce) var slidingWindowScript = redis.NewScript(` local key = KEYS[1] local window_ms = tonumber(ARGV[1]) local max_requests = tonumber(ARGV[2]) local now_ms = tonumber(ARGV[3]) local member = ARGV[4] local window_start = now_ms - window_ms redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start) local count = redis.call('ZCARD', key) if count < max_requests then redis.call('ZADD', key, now_ms, member) redis.call('PEXPIRE', key, window_ms) return 1 end return 0 `) var defaultLimits = map[string]rateConfig{ "send_message": {maxRequests: 10, window: time.Second}, "send_live_message": {maxRequests: 1, window: 3 * time.Second}, // F474: live chat rate limit "typing": {maxRequests: 5, window: time.Second}, "search": {maxRequests: 2, window: time.Second}, "fetch_history": {maxRequests: 5, window: time.Second}, } func NewRateLimiter(redisClient *redis.Client, logger *zap.Logger) *RateLimiter { if logger == nil { logger = zap.NewNop() } return &RateLimiter{ redis: redisClient, limits: defaultLimits, logger: logger, fallback: newInMemoryRateLimiter(defaultLimits), } } func (rl *RateLimiter) Allow(userID uuid.UUID, action string) bool { cfg, ok := rl.limits[action] if !ok { return true } if rl.redis == nil { return rl.fallback.allow(userID, action) } key := fmt.Sprintf("chat:ratelimit:%s:%s", userID.String(), action) nowMs := time.Now().UnixMilli() member := fmt.Sprintf("%d:%s", nowMs, uuid.New().String()[:8]) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() result, err := slidingWindowScript.Run(ctx, rl.redis, []string{key}, cfg.window.Milliseconds(), cfg.maxRequests, nowMs, member, ).Int64() if err != nil { rl.logger.Warn("Redis rate limit failed, using in-memory fallback", zap.Error(err), zap.String("user_id", userID.String()), zap.String("action", action)) return rl.fallback.allow(userID, action) } return result == 1 }