- PKCE (S256) in OAuth flow: code_verifier in oauth_states, code_challenge in auth URL - CryptoService: AES-256-GCM encryption for OAuth provider tokens at rest - OAuth redirect URL validated against OAUTH_ALLOWED_REDIRECT_DOMAINS - CHAT_JWT_SECRET must differ from JWT_SECRET in production - Migration script: cmd/tools/encrypt_oauth_tokens for existing tokens - Fixes: VEZA-SEC-003, VEZA-SEC-004, VEZA-SEC-009, VEZA-SEC-010
223 lines
6.4 KiB
Go
223 lines
6.4 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"veza-backend-api/internal/models"
|
|
"veza-backend-api/internal/services"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// MockOAuthService mocks the OAuthService interface
|
|
type MockOAuthService struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockOAuthService) GetAuthURL(provider string) (string, error) {
|
|
args := m.Called(provider)
|
|
return args.String(0), args.Error(1)
|
|
}
|
|
|
|
func (m *MockOAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, string, error) {
|
|
args := m.Called(ctx, provider, code, state, ipAddress, userAgent)
|
|
if args.Get(0) == nil {
|
|
return nil, nil, "", args.Error(3)
|
|
}
|
|
if args.Get(1) == nil {
|
|
return args.Get(0).(*services.OAuthUser), nil, args.String(2), args.Error(3)
|
|
}
|
|
return args.Get(0).(*services.OAuthUser), args.Get(1).(*models.TokenPair), args.String(2), args.Error(3)
|
|
}
|
|
|
|
func (m *MockOAuthService) GetAvailableProviders() []string {
|
|
args := m.Called()
|
|
if args.Get(0) == nil {
|
|
return nil
|
|
}
|
|
return args.Get(0).([]string)
|
|
}
|
|
|
|
func setupTestOAuthRouter(mockService *MockOAuthService) *gin.Engine {
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
|
|
logger := zap.NewNop()
|
|
handler := NewOAuthHandlerWithInterface(mockService, logger, nil)
|
|
|
|
api := router.Group("/api/v1/auth/oauth")
|
|
{
|
|
api.GET("/providers", handler.GetOAuthProviders)
|
|
api.GET("/:provider", handler.InitiateOAuth)
|
|
api.GET("/:provider/callback", handler.OAuthCallback)
|
|
}
|
|
|
|
return router
|
|
}
|
|
|
|
func TestOAuthHandlers_GetOAuthProviders_Success(t *testing.T) {
|
|
// Setup
|
|
mockService := new(MockOAuthService)
|
|
mockService.On("GetAvailableProviders").Return([]string{"google", "github", "discord"})
|
|
router := setupTestOAuthRouter(mockService)
|
|
|
|
// Execute
|
|
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/providers", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
// Assert
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
|
|
var response map[string]interface{}
|
|
err := json.Unmarshal(w.Body.Bytes(), &response)
|
|
assert.NoError(t, err)
|
|
assert.True(t, response["success"].(bool))
|
|
|
|
data := response["data"].(map[string]interface{})
|
|
providers := data["providers"].([]interface{})
|
|
assert.Len(t, providers, 3)
|
|
|
|
// Verify provider structure
|
|
provider1 := providers[0].(map[string]interface{})
|
|
assert.Equal(t, "Google", provider1["name"])
|
|
assert.Equal(t, "google", provider1["id"])
|
|
}
|
|
|
|
func TestOAuthHandlers_InitiateOAuth_Success(t *testing.T) {
|
|
// Setup
|
|
mockService := new(MockOAuthService)
|
|
router := setupTestOAuthRouter(mockService)
|
|
|
|
expectedAuthURL := "https://accounts.google.com/o/oauth2/auth?client_id=test&redirect_uri=test&response_type=code&scope=email+profile&state=test"
|
|
mockService.On("GetAuthURL", "google").Return(expectedAuthURL, nil)
|
|
|
|
// Execute
|
|
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
// Assert
|
|
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
|
|
assert.Equal(t, expectedAuthURL, w.Header().Get("Location"))
|
|
mockService.AssertExpectations(t)
|
|
}
|
|
|
|
func TestOAuthHandlers_InitiateOAuth_InvalidProvider(t *testing.T) {
|
|
// Setup
|
|
mockService := new(MockOAuthService)
|
|
router := setupTestOAuthRouter(mockService)
|
|
|
|
mockService.On("GetAuthURL", "invalid").Return("", assert.AnError)
|
|
|
|
// Execute
|
|
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/invalid", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
// Assert
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
mockService.AssertExpectations(t)
|
|
}
|
|
|
|
func TestOAuthHandlers_OAuthCallback_Success(t *testing.T) {
|
|
// Setup
|
|
mockService := new(MockOAuthService)
|
|
router := setupTestOAuthRouter(mockService)
|
|
|
|
userID := uuid.New()
|
|
mockUser := &services.OAuthUser{
|
|
ID: userID,
|
|
Email: "test@example.com",
|
|
}
|
|
tokens := &models.TokenPair{
|
|
AccessToken: "access-token",
|
|
RefreshToken: "refresh-token",
|
|
ExpiresIn: int(5 * time.Minute.Seconds()),
|
|
}
|
|
|
|
mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(mockUser, tokens, "", nil)
|
|
|
|
// Execute
|
|
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google/callback?code=test-code&state=test-state", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
// Assert
|
|
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
|
|
location := w.Header().Get("Location")
|
|
assert.Contains(t, location, "user_id="+userID.String())
|
|
assert.Contains(t, location, "/auth/callback")
|
|
// Tokens now in cookies, not URL
|
|
assert.NotContains(t, location, "token=")
|
|
mockService.AssertExpectations(t)
|
|
}
|
|
|
|
func TestOAuthHandlers_OAuthCallback_MissingCode(t *testing.T) {
|
|
// Setup
|
|
mockService := new(MockOAuthService)
|
|
router := setupTestOAuthRouter(mockService)
|
|
|
|
// Execute - Missing code parameter
|
|
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google/callback?state=test-state", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
// Assert
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
mockService.AssertNotCalled(t, "HandleCallback")
|
|
}
|
|
|
|
func TestOAuthHandlers_OAuthCallback_MissingState(t *testing.T) {
|
|
// Setup
|
|
mockService := new(MockOAuthService)
|
|
router := setupTestOAuthRouter(mockService)
|
|
|
|
// Execute - Missing state parameter
|
|
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google/callback?code=test-code", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
// Assert
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
mockService.AssertNotCalled(t, "HandleCallback")
|
|
}
|
|
|
|
func TestOAuthHandlers_OAuthCallback_ServiceError(t *testing.T) {
|
|
// Setup
|
|
mockService := new(MockOAuthService)
|
|
router := setupTestOAuthRouter(mockService)
|
|
|
|
mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(nil, nil, "", assert.AnError)
|
|
|
|
// Execute
|
|
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google/callback?code=test-code&state=test-state", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
// Assert
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
mockService.AssertExpectations(t)
|
|
}
|
|
|
|
func TestNewOAuthHandlerWithInterface(t *testing.T) {
|
|
// Setup
|
|
mockService := new(MockOAuthService)
|
|
logger := zap.NewNop()
|
|
|
|
// Execute
|
|
handler := NewOAuthHandlerWithInterface(mockService, logger, nil)
|
|
|
|
// Assert
|
|
assert.NotNil(t, handler)
|
|
assert.Equal(t, mockService, handler.oauthService)
|
|
}
|