package middleware import ( "net/http" "strconv" "sync" "time" "veza-backend-api/internal/models" "github.com/gin-gonic/gin" ) // APIKeyRateLimiterConfig defines rate limits for API key authenticated requests. // Limits are per API key (not per IP) to give each developer their own quota. type APIKeyRateLimiterConfig struct { // ReadLimit is the max number of read (GET) requests per window per key ReadLimit int // WriteLimit is the max number of write (POST/PUT/DELETE) requests per window per key WriteLimit int // Window is the sliding window duration Window time.Duration } // DefaultAPIKeyRateLimiterConfig returns default limits: 1000 reads/hour, 200 writes/hour func DefaultAPIKeyRateLimiterConfig() *APIKeyRateLimiterConfig { return &APIKeyRateLimiterConfig{ ReadLimit: 1000, WriteLimit: 200, Window: time.Hour, } } type apiKeyEntry struct { timestamps []time.Time } // APIKeyRateLimiter rate-limits requests authenticated via API key. // It tracks read and write operations separately, keyed by API key ID. type APIKeyRateLimiter struct { config *APIKeyRateLimiterConfig readStore map[string]*apiKeyEntry writeStore map[string]*apiKeyEntry mu sync.Mutex stop chan struct{} } // NewAPIKeyRateLimiter creates a new API key rate limiter func NewAPIKeyRateLimiter(config *APIKeyRateLimiterConfig) *APIKeyRateLimiter { if config == nil { config = DefaultAPIKeyRateLimiterConfig() } rl := &APIKeyRateLimiter{ config: config, readStore: make(map[string]*apiKeyEntry), writeStore: make(map[string]*apiKeyEntry), stop: make(chan struct{}), } go rl.cleanup() return rl } // Middleware returns a Gin middleware that enforces API key rate limits. // It only applies to requests authenticated via API key (context has "api_key" set). // JWT-authenticated requests pass through without additional limiting. func (rl *APIKeyRateLimiter) Middleware() gin.HandlerFunc { return func(c *gin.Context) { // Only rate-limit API key requests apiKeyVal, exists := c.Get("api_key") if !exists { c.Next() return } // Extract key ID for per-key tracking key, ok := apiKeyVal.(*models.APIKey) if !ok { c.Next() return } keyID := key.ID.String() isWrite := c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead && c.Request.Method != http.MethodOptions var limit int var store map[string]*apiKeyEntry if isWrite { limit = rl.config.WriteLimit store = rl.writeStore } else { limit = rl.config.ReadLimit store = rl.readStore } rl.mu.Lock() now := time.Now() cutoff := now.Add(-rl.config.Window) entry, ok := store[keyID] if !ok { entry = &apiKeyEntry{} store[keyID] = entry } // Prune expired timestamps valid := make([]time.Time, 0, len(entry.timestamps)) for _, t := range entry.timestamps { if t.After(cutoff) { valid = append(valid, t) } } if len(valid) >= limit { entry.timestamps = valid rl.mu.Unlock() resetTime := now.Add(rl.config.Window).Unix() retryAfter := int(rl.config.Window.Seconds()) c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", "0") c.Header("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) c.Header("Retry-After", strconv.Itoa(retryAfter)) c.JSON(http.StatusTooManyRequests, gin.H{ "success": false, "error": gin.H{ "code": "RATE_LIMIT_EXCEEDED", "message": "API key rate limit exceeded. Please try again later.", "retry_after": retryAfter, "limit": limit, "remaining": 0, "reset": resetTime, }, }) c.Abort() return } valid = append(valid, now) entry.timestamps = valid remaining := limit - len(valid) rl.mu.Unlock() c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) c.Header("X-RateLimit-Reset", strconv.FormatInt(now.Add(rl.config.Window).Unix(), 10)) c.Next() } } // Stop signals the cleanup goroutine to exit func (rl *APIKeyRateLimiter) Stop() { close(rl.stop) } func (rl *APIKeyRateLimiter) cleanup() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: rl.mu.Lock() cutoff := time.Now().Add(-rl.config.Window) pruneStore(rl.readStore, cutoff) pruneStore(rl.writeStore, cutoff) rl.mu.Unlock() case <-rl.stop: return } } } func pruneStore(store map[string]*apiKeyEntry, cutoff time.Time) { for key, entry := range store { valid := make([]time.Time, 0, len(entry.timestamps)) for _, t := range entry.timestamps { if t.After(cutoff) { valid = append(valid, t) } } if len(valid) == 0 { delete(store, key) } else { entry.timestamps = valid } } }