veza/veza-backend-api/internal/middleware/audit_test.go

113 lines
3.1 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
func TestAuditMiddleware_SkipsGET(t *testing.T) {
gin.SetMode(gin.TestMode)
logger := zap.NewNop()
router := gin.New()
router.Use(AuditMiddleware(nil, logger))
router.GET("/api/v1/tracks", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest("GET", "/api/v1/tracks", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestAuditMiddleware_LogsPOST(t *testing.T) {
gin.SetMode(gin.TestMode)
logger := zap.NewNop()
// Use nil audit service - middleware should no-op without panicking
router := gin.New()
router.Use(AuditMiddleware(nil, logger))
router.POST("/api/v1/tracks", func(c *gin.Context) {
c.Status(http.StatusCreated)
})
req := httptest.NewRequest("POST", "/api/v1/tracks", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusCreated, w.Code)
}
func TestAuditMiddleware_SkipsHealth(t *testing.T) {
gin.SetMode(gin.TestMode)
logger := zap.NewNop()
router := gin.New()
router.Use(AuditMiddleware(nil, logger))
router.POST("/health", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest("POST", "/health", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestAuditMiddleware_WithUserID(t *testing.T) {
gin.SetMode(gin.TestMode)
logger := zap.NewNop()
userID := uuid.New()
// Verifies the middleware doesn't panic with user_id in context
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("user_id", userID)
c.Next()
})
router.Use(AuditMiddleware(nil, logger)) // nil = no-op
router.POST("/api/v1/tracks", func(c *gin.Context) {
c.Status(http.StatusCreated)
})
req := httptest.NewRequest("POST", "/api/v1/tracks", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusCreated, w.Code)
}
func TestMapMethodToAction(t *testing.T) {
assert.Equal(t, "create", mapMethodToAction("POST"))
assert.Equal(t, "update", mapMethodToAction("PUT"))
assert.Equal(t, "update", mapMethodToAction("PATCH"))
assert.Equal(t, "delete", mapMethodToAction("DELETE"))
assert.Equal(t, "get", mapMethodToAction("GET"))
}
func TestExtractResourceFromPath(t *testing.T) {
assert.Equal(t, "track", extractResourceFromPath("/api/v1/tracks"))
assert.Equal(t, "track", extractResourceFromPath("/api/v1/tracks/123"))
assert.Equal(t, "user", extractResourceFromPath("/api/v1/users/me"))
assert.Equal(t, "playlist", extractResourceFromPath("/api/v1/playlists"))
assert.Equal(t, "conversation", extractResourceFromPath("/api/v1/conversations/abc"))
}
func TestShouldSkipAudit(t *testing.T) {
assert.True(t, shouldSkipAudit("/health"))
assert.True(t, shouldSkipAudit("/health/deep"))
assert.True(t, shouldSkipAudit("/metrics"))
assert.True(t, shouldSkipAudit("/swagger/index.html"))
assert.True(t, shouldSkipAudit("/api/v1/health"))
assert.False(t, shouldSkipAudit("/api/v1/tracks"))
assert.False(t, shouldSkipAudit("/api/v1/users"))
}