package middleware import ( "net/http" "net/http/httptest" "testing" "time" "veza-backend-api/internal/models" "github.com/gin-gonic/gin" "github.com/google/uuid" ) func init() { gin.SetMode(gin.TestMode) } func TestAPIKeyRateLimiter_PassthroughWithoutAPIKey(t *testing.T) { rl := NewAPIKeyRateLimiter(DefaultAPIKeyRateLimiterConfig()) defer rl.Stop() router := gin.New() router.Use(rl.Middleware()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) // Request without api_key in context — should pass through w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("expected 200, got %d", w.Code) } } func TestAPIKeyRateLimiter_EnforcesReadLimit(t *testing.T) { config := &APIKeyRateLimiterConfig{ ReadLimit: 3, WriteLimit: 2, Window: time.Minute, } rl := NewAPIKeyRateLimiter(config) defer rl.Stop() keyID := uuid.New() apiKey := &models.APIKey{ID: keyID, UserID: uuid.New(), Name: "test"} router := gin.New() // Simulate auth middleware setting api_key router.Use(func(c *gin.Context) { c.Set("api_key", apiKey) c.Next() }) router.Use(rl.Middleware()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) // First 3 requests should pass for i := 0; i < 3; i++ { w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("request %d: expected 200, got %d", i+1, w.Code) } } // 4th request should be rate limited w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) router.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Errorf("4th request: expected 429, got %d", w.Code) } // Check rate limit headers if w.Header().Get("X-RateLimit-Limit") != "3" { t.Errorf("expected X-RateLimit-Limit=3, got %s", w.Header().Get("X-RateLimit-Limit")) } if w.Header().Get("X-RateLimit-Remaining") != "0" { t.Errorf("expected X-RateLimit-Remaining=0, got %s", w.Header().Get("X-RateLimit-Remaining")) } } func TestAPIKeyRateLimiter_EnforcesWriteLimit(t *testing.T) { config := &APIKeyRateLimiterConfig{ ReadLimit: 100, WriteLimit: 2, Window: time.Minute, } rl := NewAPIKeyRateLimiter(config) defer rl.Stop() keyID := uuid.New() apiKey := &models.APIKey{ID: keyID, UserID: uuid.New(), Name: "test"} router := gin.New() router.Use(func(c *gin.Context) { c.Set("api_key", apiKey) c.Next() }) router.Use(rl.Middleware()) router.POST("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) // First 2 POST requests should pass for i := 0; i < 2; i++ { w := httptest.NewRecorder() req, _ := http.NewRequest("POST", "/test", nil) router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("request %d: expected 200, got %d", i+1, w.Code) } } // 3rd POST should be rate limited w := httptest.NewRecorder() req, _ := http.NewRequest("POST", "/test", nil) router.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Errorf("3rd POST: expected 429, got %d", w.Code) } } func TestAPIKeyRateLimiter_SeparateKeysHaveSeparateLimits(t *testing.T) { config := &APIKeyRateLimiterConfig{ ReadLimit: 2, WriteLimit: 2, Window: time.Minute, } rl := NewAPIKeyRateLimiter(config) defer rl.Stop() key1 := &models.APIKey{ID: uuid.New(), UserID: uuid.New(), Name: "key1"} key2 := &models.APIKey{ID: uuid.New(), UserID: uuid.New(), Name: "key2"} makeRouter := func(apiKey *models.APIKey) *gin.Engine { r := gin.New() r.Use(func(c *gin.Context) { c.Set("api_key", apiKey) c.Next() }) r.Use(rl.Middleware()) r.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) return r } router1 := makeRouter(key1) router2 := makeRouter(key2) // Exhaust key1's limit for i := 0; i < 2; i++ { w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) router1.ServeHTTP(w, req) } // key1 should be rate limited w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) router1.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Errorf("key1 3rd request: expected 429, got %d", w.Code) } // key2 should still work w = httptest.NewRecorder() req, _ = http.NewRequest("GET", "/test", nil) router2.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("key2 1st request: expected 200, got %d", w.Code) } } func TestAPIKeyRateLimiter_ReturnsRateLimitHeaders(t *testing.T) { config := &APIKeyRateLimiterConfig{ ReadLimit: 10, WriteLimit: 5, Window: time.Hour, } rl := NewAPIKeyRateLimiter(config) defer rl.Stop() apiKey := &models.APIKey{ID: uuid.New(), UserID: uuid.New(), Name: "test"} router := gin.New() router.Use(func(c *gin.Context) { c.Set("api_key", apiKey) c.Next() }) router.Use(rl.Middleware()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) router.ServeHTTP(w, req) if w.Header().Get("X-RateLimit-Limit") != "10" { t.Errorf("expected X-RateLimit-Limit=10, got %s", w.Header().Get("X-RateLimit-Limit")) } if w.Header().Get("X-RateLimit-Remaining") != "9" { t.Errorf("expected X-RateLimit-Remaining=9, got %s", w.Header().Get("X-RateLimit-Remaining")) } if w.Header().Get("X-RateLimit-Reset") == "" { t.Error("expected X-RateLimit-Reset header to be set") } }