veza/veza-backend-api/internal/middleware/api_key_rate_limiter_test.go

217 lines
5.5 KiB
Go
Raw Normal View History

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"veza-backend-api/internal/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestAPIKeyRateLimiter_PassthroughWithoutAPIKey(t *testing.T) {
rl := NewAPIKeyRateLimiter(DefaultAPIKeyRateLimiterConfig())
defer rl.Stop()
router := gin.New()
router.Use(rl.Middleware())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
// Request without api_key in context — should pass through
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
}
func TestAPIKeyRateLimiter_EnforcesReadLimit(t *testing.T) {
config := &APIKeyRateLimiterConfig{
ReadLimit: 3,
WriteLimit: 2,
Window: time.Minute,
}
rl := NewAPIKeyRateLimiter(config)
defer rl.Stop()
keyID := uuid.New()
apiKey := &models.APIKey{ID: keyID, UserID: uuid.New(), Name: "test"}
router := gin.New()
// Simulate auth middleware setting api_key
router.Use(func(c *gin.Context) {
c.Set("api_key", apiKey)
c.Next()
})
router.Use(rl.Middleware())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
// First 3 requests should pass
for i := 0; i < 3; i++ {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("request %d: expected 200, got %d", i+1, w.Code)
}
}
// 4th request should be rate limited
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("4th request: expected 429, got %d", w.Code)
}
// Check rate limit headers
if w.Header().Get("X-RateLimit-Limit") != "3" {
t.Errorf("expected X-RateLimit-Limit=3, got %s", w.Header().Get("X-RateLimit-Limit"))
}
if w.Header().Get("X-RateLimit-Remaining") != "0" {
t.Errorf("expected X-RateLimit-Remaining=0, got %s", w.Header().Get("X-RateLimit-Remaining"))
}
}
func TestAPIKeyRateLimiter_EnforcesWriteLimit(t *testing.T) {
config := &APIKeyRateLimiterConfig{
ReadLimit: 100,
WriteLimit: 2,
Window: time.Minute,
}
rl := NewAPIKeyRateLimiter(config)
defer rl.Stop()
keyID := uuid.New()
apiKey := &models.APIKey{ID: keyID, UserID: uuid.New(), Name: "test"}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("api_key", apiKey)
c.Next()
})
router.Use(rl.Middleware())
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
// First 2 POST requests should pass
for i := 0; i < 2; i++ {
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/test", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("request %d: expected 200, got %d", i+1, w.Code)
}
}
// 3rd POST should be rate limited
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/test", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("3rd POST: expected 429, got %d", w.Code)
}
}
func TestAPIKeyRateLimiter_SeparateKeysHaveSeparateLimits(t *testing.T) {
config := &APIKeyRateLimiterConfig{
ReadLimit: 2,
WriteLimit: 2,
Window: time.Minute,
}
rl := NewAPIKeyRateLimiter(config)
defer rl.Stop()
key1 := &models.APIKey{ID: uuid.New(), UserID: uuid.New(), Name: "key1"}
key2 := &models.APIKey{ID: uuid.New(), UserID: uuid.New(), Name: "key2"}
makeRouter := func(apiKey *models.APIKey) *gin.Engine {
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("api_key", apiKey)
c.Next()
})
r.Use(rl.Middleware())
r.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
return r
}
router1 := makeRouter(key1)
router2 := makeRouter(key2)
// Exhaust key1's limit
for i := 0; i < 2; i++ {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
router1.ServeHTTP(w, req)
}
// key1 should be rate limited
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
router1.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("key1 3rd request: expected 429, got %d", w.Code)
}
// key2 should still work
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/test", nil)
router2.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("key2 1st request: expected 200, got %d", w.Code)
}
}
func TestAPIKeyRateLimiter_ReturnsRateLimitHeaders(t *testing.T) {
config := &APIKeyRateLimiterConfig{
ReadLimit: 10,
WriteLimit: 5,
Window: time.Hour,
}
rl := NewAPIKeyRateLimiter(config)
defer rl.Stop()
apiKey := &models.APIKey{ID: uuid.New(), UserID: uuid.New(), Name: "test"}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("api_key", apiKey)
c.Next()
})
router.Use(rl.Middleware())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
router.ServeHTTP(w, req)
if w.Header().Get("X-RateLimit-Limit") != "10" {
t.Errorf("expected X-RateLimit-Limit=10, got %s", w.Header().Get("X-RateLimit-Limit"))
}
if w.Header().Get("X-RateLimit-Remaining") != "9" {
t.Errorf("expected X-RateLimit-Remaining=9, got %s", w.Header().Get("X-RateLimit-Remaining"))
}
if w.Header().Get("X-RateLimit-Reset") == "" {
t.Error("expected X-RateLimit-Reset header to be set")
}
}