package middleware import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "veza-backend-api/internal/models" "veza-backend-api/internal/services" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap" ) // MockTwoFactorChecker for testing MFA enforcement type MockTwoFactorChecker struct { mock.Mock } func (m *MockTwoFactorChecker) GetTwoFactorStatus(ctx context.Context, userID uuid.UUID) (bool, error) { args := m.Called(ctx, userID) return args.Bool(0), args.Error(1) } func setupMFATestMiddleware(t *testing.T, role string, mfaEnabled bool) (*AuthMiddleware, uuid.UUID, string) { t.Helper() gin.SetMode(gin.TestMode) userID := uuid.New() mockPermissionChecker := new(MockPermissionChecker) mockPermissionChecker.On("HasRole", mock.Anything, userID, "admin").Return(role == "admin", nil) mockSessionService := new(MockSessionService) mockSessionService.On("ValidateSession", mock.Anything, mock.Anything).Return(&services.Session{ ID: uuid.New(), UserID: userID, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(7 * 24 * time.Hour), }, nil) mockSessionService.On("RefreshSession", mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() mockAuditService := new(MockAuditService) mockAuditService.On("LogAction", mock.Anything, mock.Anything).Return(nil).Maybe() mockUserRepository := new(MockUserRepository) mockUserRepository.On("GetByID", mock.Anything).Return(&models.User{ ID: userID, TokenVersion: 0, Role: role, }, nil) mockTwoFactorChecker := new(MockTwoFactorChecker) mockTwoFactorChecker.On("GetTwoFactorStatus", mock.Anything, userID).Return(mfaEnabled, nil) jwtService := setupTestJWTService(t) userService := services.NewUserService(mockUserRepository) am := NewAuthMiddleware(mockSessionService, mockAuditService, mockPermissionChecker, jwtService, userService, nil, nil, zap.NewNop()) am.SetTwoFactorChecker(mockTwoFactorChecker) // Generate a valid token claims := jwt.MapClaims{ "sub": userID.String(), "exp": time.Now().Add(5 * time.Minute).Unix(), "iat": time.Now().Unix(), "iss": "veza-api", "aud": "veza-app", "token_version": 0, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(testJWTSecret)) require.NoError(t, err) return am, userID, tokenString } // TestRequireMFA_AdminWithoutMFA tests that admin without MFA is denied. // SFIX-001: ORIGIN_SECURITY_FRAMEWORK.md Rule 5 func TestRequireMFA_AdminWithoutMFA(t *testing.T) { am, _, tokenString := setupMFATestMiddleware(t, "admin", false) router := gin.New() router.Use(am.RequireAuth()) router.Use(am.RequireAdmin()) router.Use(am.RequireMFA()) router.GET("/admin/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest("GET", "/admin/test", nil) req.AddCookie(&http.Cookie{Name: "access_token", Value: tokenString}) w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusForbidden, w.Code) var resp map[string]interface{} err := json.Unmarshal(w.Body.Bytes(), &resp) require.NoError(t, err) errObj := resp["error"].(map[string]interface{}) assert.Equal(t, "mfa_setup_required", errObj["code"]) } // TestRequireMFA_AdminWithMFA tests that admin with MFA is allowed. func TestRequireMFA_AdminWithMFA(t *testing.T) { am, _, tokenString := setupMFATestMiddleware(t, "admin", true) router := gin.New() router.Use(am.RequireAuth()) router.Use(am.RequireAdmin()) router.Use(am.RequireMFA()) router.GET("/admin/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest("GET", "/admin/test", nil) req.AddCookie(&http.Cookie{Name: "access_token", Value: tokenString}) w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) } // TestRequireMFA_RegularUserNotAffected tests that non-privileged users bypass MFA check. func TestRequireMFA_RegularUserNotAffected(t *testing.T) { am, _, tokenString := setupMFATestMiddleware(t, "user", false) router := gin.New() router.Use(am.RequireAuth()) router.Use(am.RequireMFA()) router.GET("/protected/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest("GET", "/protected/test", nil) req.AddCookie(&http.Cookie{Name: "access_token", Value: tokenString}) w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) } // TestRefreshTokenTTL_Is7Days validates that the JWT config has 7-day refresh TTL. // SFIX-002: ORIGIN_SECURITY_FRAMEWORK.md Rule 4 func TestRefreshTokenTTL_Is7Days(t *testing.T) { jwtService := setupTestJWTService(t) config := jwtService.GetConfig() expected := 7 * 24 * time.Hour assert.Equal(t, expected, config.RefreshTokenTTL, "RefreshTokenTTL should be 7 days") assert.Equal(t, expected, config.RememberMeRefreshTokenTTL, "RememberMeRefreshTokenTTL should be 7 days") } // TestAccessTokenTTL_Is5Minutes validates access token TTL. func TestAccessTokenTTL_Is5Minutes(t *testing.T) { jwtService := setupTestJWTService(t) config := jwtService.GetConfig() assert.Equal(t, 5*time.Minute, config.AccessTokenTTL, "AccessTokenTTL should be 5 minutes") }