package middleware import ( "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) func TestCORS_AllowedOrigin(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(CORS([]string{"http://localhost:3000", "https://example.com"})) router.GET("/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Origin", "http://localhost:3000") router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "http://localhost:3000", w.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "GET, POST, PUT, PATCH, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) assert.Equal(t, "Authorization, Content-Type, X-Requested-With, X-CSRF-Token", w.Header().Get("Access-Control-Allow-Headers")) assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials")) } func TestCORS_DisallowedOrigin(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(CORS([]string{"http://localhost:3000"})) router.GET("/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Origin", "http://evil.com") router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) // L'origine non autorisée ne doit pas être dans le header assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) } func TestCORS_Wildcard(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(CORS([]string{"*"})) router.GET("/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Origin", "http://any-origin.com") router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "http://any-origin.com", w.Header().Get("Access-Control-Allow-Origin")) } func TestCORS_NoOriginHeader(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(CORS([]string{"http://localhost:3000"})) router.GET("/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test", nil) // Pas de header Origin router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) // Sans header Origin, le header Access-Control-Allow-Origin ne doit pas être défini assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) } func TestCORS_OPTIONSRequest(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(CORS([]string{"http://localhost:3000"})) router.GET("/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) w := httptest.NewRecorder() req := httptest.NewRequest("OPTIONS", "/test", nil) req.Header.Set("Origin", "http://localhost:3000") router.ServeHTTP(w, req) assert.Equal(t, http.StatusNoContent, w.Code) assert.Equal(t, "http://localhost:3000", w.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "GET, POST, PUT, PATCH, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) } func TestCORS_MultipleAllowedOrigins(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() allowedOrigins := []string{"http://localhost:3000", "https://example.com", "https://app.example.com"} router.Use(CORS(allowedOrigins)) router.GET("/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) // Test avec la première origine w1 := httptest.NewRecorder() req1 := httptest.NewRequest("GET", "/test", nil) req1.Header.Set("Origin", "http://localhost:3000") router.ServeHTTP(w1, req1) assert.Equal(t, "http://localhost:3000", w1.Header().Get("Access-Control-Allow-Origin")) // Test avec la deuxième origine w2 := httptest.NewRecorder() req2 := httptest.NewRequest("GET", "/test", nil) req2.Header.Set("Origin", "https://example.com") router.ServeHTTP(w2, req2) assert.Equal(t, "https://example.com", w2.Header().Get("Access-Control-Allow-Origin")) // Test avec la troisième origine w3 := httptest.NewRecorder() req3 := httptest.NewRequest("GET", "/test", nil) req3.Header.Set("Origin", "https://app.example.com") router.ServeHTTP(w3, req3) assert.Equal(t, "https://app.example.com", w3.Header().Get("Access-Control-Allow-Origin")) } func TestIsAllowedOrigin(t *testing.T) { tests := []struct { name string origin string allowed []string expected bool }{ { name: "origin exact match", origin: "http://localhost:3000", allowed: []string{"http://localhost:3000"}, expected: true, }, { name: "origin not in list", origin: "http://evil.com", allowed: []string{"http://localhost:3000"}, expected: false, }, { name: "wildcard allows all", origin: "http://any-origin.com", allowed: []string{"*"}, expected: true, }, { name: "empty origin", origin: "", allowed: []string{"http://localhost:3000"}, expected: false, }, { name: "empty allowed list", origin: "http://localhost:3000", allowed: []string{}, expected: false, }, { name: "multiple allowed origins", origin: "https://example.com", allowed: []string{"http://localhost:3000", "https://example.com"}, expected: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := isAllowedOrigin(tt.origin, tt.allowed) assert.Equal(t, tt.expected, result) }) } } func TestCORSDefault(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(CORSDefault()) router.GET("/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Origin", "http://any-origin.com") router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "http://any-origin.com", w.Header().Get("Access-Control-Allow-Origin")) }