diff --git a/VERSION b/VERSION index 6026868be..c1ba3fa76 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.901 +0.902 diff --git a/veza-backend-api/cmd/tools/encrypt_oauth_tokens/main.go b/veza-backend-api/cmd/tools/encrypt_oauth_tokens/main.go new file mode 100644 index 000000000..b5e678446 --- /dev/null +++ b/veza-backend-api/cmd/tools/encrypt_oauth_tokens/main.go @@ -0,0 +1,148 @@ +// encrypt_oauth_tokens encrypts existing OAuth provider tokens in federated_identities (v0.902). +// Idempotent: skips tokens already prefixed with veza_enc_v1: +// Usage: DATABASE_URL=... OAUTH_ENCRYPTION_KEY=... go run ./cmd/tools/encrypt_oauth_tokens [-dry-run] +package main + +import ( + "context" + "database/sql" + "flag" + "log" + "os" + "strings" + "time" + + _ "github.com/lib/pq" + "veza-backend-api/internal/services" +) + +const encryptedPrefix = "veza_enc_v1:" + +func main() { + dryRun := flag.Bool("dry-run", false, "Show what would be updated without making changes") + flag.Parse() + + dbURL := os.Getenv("DATABASE_URL") + if dbURL == "" { + log.Fatal("DATABASE_URL is required") + } + encKey := os.Getenv("OAUTH_ENCRYPTION_KEY") + if encKey == "" { + log.Fatal("OAUTH_ENCRYPTION_KEY is required (32+ bytes, base64 or raw)") + } + + cryptoService, err := services.NewCryptoServiceFromBase64(encKey) + if err != nil { + keyBytes := []byte(encKey) + if len(keyBytes) >= 32 { + cryptoService, err = services.NewCryptoService(keyBytes) + } + } + if err != nil { + log.Fatalf("CryptoService: %v", err) + } + + db, err := sql.Open("postgres", dbURL) + if err != nil { + log.Fatalf("DB connect: %v", err) + } + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + if err := db.PingContext(ctx); err != nil { + log.Fatalf("DB ping: %v", err) + } + + rows, err := db.QueryContext(ctx, ` + SELECT id::text, access_token, refresh_token + FROM federated_identities + WHERE (access_token IS NOT NULL AND access_token != '') + OR (refresh_token IS NOT NULL AND refresh_token != '') + `) + if err != nil { + log.Fatalf("Query: %v", err) + } + defer rows.Close() + + var id string + var accessToken, refreshToken sql.NullString + updated := 0 + skipped := 0 + errors := 0 + + for rows.Next() { + if err := rows.Scan(&id, &accessToken, &refreshToken); err != nil { + log.Printf("Scan error: %v", err) + errors++ + continue + } + + needUpdate := false + newAccess := accessToken.String + newRefresh := refreshToken.String + + if accessToken.Valid && accessToken.String != "" && !strings.HasPrefix(accessToken.String, encryptedPrefix) { + newAccess, err = cryptoService.EncryptString(accessToken.String) + if err != nil { + log.Printf("Encrypt access_token for id %s: %v", id, err) + errors++ + continue + } + needUpdate = true + } + if refreshToken.Valid && refreshToken.String != "" && !strings.HasPrefix(refreshToken.String, encryptedPrefix) { + newRefresh, err = cryptoService.EncryptString(refreshToken.String) + if err != nil { + log.Printf("Encrypt refresh_token for id %s: %v", id, err) + errors++ + continue + } + needUpdate = true + } + + if !needUpdate { + skipped++ + continue + } + + if *dryRun { + log.Printf("[dry-run] Would encrypt tokens for id %s", id) + updated++ + continue + } + + _, err = db.ExecContext(ctx, ` + UPDATE federated_identities + SET access_token = $1, refresh_token = $2, updated_at = NOW() + WHERE id = $3 + `, nullIfEmpty(newAccess), nullIfEmpty(newRefresh), id) + if err != nil { + log.Printf("UPDATE id %s: %v", id, err) + errors++ + continue + } + updated++ + } + + if err := rows.Err(); err != nil { + log.Fatalf("Rows: %v", err) + } + + mode := "" + if *dryRun { + mode = " [dry-run]" + } + log.Printf("Done%s: updated=%d skipped=%d errors=%d", mode, updated, skipped, errors) + if errors > 0 { + os.Exit(1) + } +} + +func nullIfEmpty(s string) interface{} { + if s == "" { + return nil + } + return s +} diff --git a/veza-backend-api/internal/api/routes_auth.go b/veza-backend-api/internal/api/routes_auth.go index 107a5a758..77c71f384 100644 --- a/veza-backend-api/internal/api/routes_auth.go +++ b/veza-backend-api/internal/api/routes_auth.go @@ -114,8 +114,30 @@ func (r *APIRouter) setupAuthRoutes(router *gin.RouterGroup) error { } checkUsernameGroup.GET("/check-username", handlers.CheckUsername(authService)) - // BE-API-042: OAuth routes - oauthService := services.NewOAuthService(r.db, r.logger, jwtService, sessionService, userService) + // BE-API-042: OAuth routes (v0.902: CryptoService, redirect validation) + oauthCfg := &services.OAuthServiceConfig{ + FrontendURL: r.config.FrontendURL, + AllowedDomains: r.config.OAuthAllowedRedirectDomains, + } + if r.config.OAuthEncryptionKey != "" { + var cryptoService *services.CryptoService + var err error + cryptoService, err = services.NewCryptoServiceFromBase64(r.config.OAuthEncryptionKey) + if err != nil { + // Fallback: use raw bytes if key is long enough + keyBytes := []byte(r.config.OAuthEncryptionKey) + if len(keyBytes) >= 32 { + cryptoService, err = services.NewCryptoService(keyBytes) + } + } + if err != nil { + return fmt.Errorf("OAuth CryptoService: %w", err) + } + if cryptoService != nil { + oauthCfg.CryptoService = cryptoService + } + } + oauthService := services.NewOAuthService(r.db, r.logger, jwtService, sessionService, userService, oauthCfg) baseURL := os.Getenv("BASE_URL") if baseURL == "" { appDomain := os.Getenv("APP_DOMAIN") diff --git a/veza-backend-api/internal/config/config.go b/veza-backend-api/internal/config/config.go index 96408f19c..8c66b3848 100644 --- a/veza-backend-api/internal/config/config.go +++ b/veza-backend-api/internal/config/config.go @@ -81,6 +81,10 @@ type Config struct { CORSOrigins []string // Liste des origines CORS autorisées FrontendURL string // URL du frontend (OAuth redirect, password reset links). FRONTEND_URL ou VITE_FRONTEND_URL + // OAuth Security (v0.902 Sentinel) + OAuthEncryptionKey string // OAUTH_ENCRYPTION_KEY: 32 bytes for AES-256-GCM (required in production) + OAuthAllowedRedirectDomains []string // OAUTH_ALLOWED_REDIRECT_DOMAINS: whitelist for OAuth redirect URLs + // HLS Streaming Configuration (v0.503) HLSEnabled bool // Enable HLS streaming routes HLSStorageDir string // Directory for HLS segment storage @@ -304,6 +308,10 @@ func NewConfig() (*Config, error) { CORSOrigins: corsOrigins, FrontendURL: getFrontendURL(), // OAuth callback, password reset, email links + // OAuth Security (v0.902 Sentinel) + OAuthEncryptionKey: getEnv("OAUTH_ENCRYPTION_KEY", ""), + OAuthAllowedRedirectDomains: getOAuthAllowedRedirectDomains(env, getEnvStringSlice("OAUTH_ALLOWED_REDIRECT_DOMAINS", nil), corsOrigins, getFrontendURL()), + // HLS Streaming (v0.503) HLSEnabled: getEnvBool("HLS_STREAMING", false), HLSStorageDir: getEnv("HLS_STORAGE_DIR", "/tmp/veza-hls"), @@ -843,6 +851,16 @@ func (c *Config) ValidateForEnvironment() error { return fmt.Errorf("CLAMAV_REQUIRED must be true in production. Virus scanning is mandatory for uploads") } + // 6. v0.902: CHAT_JWT_SECRET must differ from JWT_SECRET in production (VEZA-SEC-009) + if c.ChatJWTSecret == c.JWTSecret { + return fmt.Errorf("CHAT_JWT_SECRET must be different from JWT_SECRET in production. Use a separate secret for the Chat Server") + } + + // 7. v0.902: OAUTH_ENCRYPTION_KEY required in production for OAuth token encryption (VEZA-SEC-004) + if len(c.OAuthEncryptionKey) < 32 { + return fmt.Errorf("OAUTH_ENCRYPTION_KEY is required in production (min 32 bytes for AES-256). Set OAUTH_ENCRYPTION_KEY with a 32-byte hex or base64 key") + } + case EnvTest: // TEST: Validation adaptée aux tests // CORS peut être vide ou configuré explicitement diff --git a/veza-backend-api/internal/config/cors.go b/veza-backend-api/internal/config/cors.go index d1eaec9a9..53f9e7b63 100644 --- a/veza-backend-api/internal/config/cors.go +++ b/veza-backend-api/internal/config/cors.go @@ -1,7 +1,9 @@ package config import ( + "net/url" "net/http" + "strings" ) // getCookieSecure détermine si les cookies doivent être Secure @@ -104,3 +106,27 @@ func getCORSOrigins(env string, appDomain string) []string { return []string{"http://" + appDomain + ":3000", "http://" + appDomain + ":5173"} } } + +// getOAuthAllowedRedirectDomains returns the whitelist of domains allowed for OAuth redirect (v0.902). +// If OAUTH_ALLOWED_REDIRECT_DOMAINS is set, use it. Otherwise derive from CORSOrigins or FrontendURL. +func getOAuthAllowedRedirectDomains(env string, explicit []string, corsOrigins []string, frontendURL string) []string { + if len(explicit) > 0 { + return explicit + } + if len(corsOrigins) > 0 { + return corsOrigins + } + if frontendURL != "" { + if u, err := url.Parse(frontendURL); err == nil { + origin := strings.TrimSuffix(u.String(), "/") + if u.Scheme != "" && u.Host != "" { + return []string{origin} + } + } + } + // Dev fallback + if env == EnvDevelopment || env == EnvStaging { + return []string{"http://localhost:5173", "http://localhost:3000", "http://127.0.0.1:5173", "http://127.0.0.1:3000"} + } + return []string{} +} diff --git a/veza-backend-api/internal/config/validation_test.go b/veza-backend-api/internal/config/validation_test.go index aca39ff77..2d27b0e84 100644 --- a/veza-backend-api/internal/config/validation_test.go +++ b/veza-backend-api/internal/config/validation_test.go @@ -425,3 +425,28 @@ func TestValidateForEnvironment_ClamAVRequiredInProduction(t *testing.T) { require.NoError(t, err) }) } + +// TestValidateForEnvironment_ChatJWTSecretInProduction verifies CHAT_JWT_SECRET must differ from JWT_SECRET in production (v0.902) +func TestValidateForEnvironment_ChatJWTSecretInProduction(t *testing.T) { + secret := strings.Repeat("a", 32) + cfg := &Config{ + Env: EnvProduction, + AppPort: 8080, + JWTSecret: secret, + ChatJWTSecret: secret, // Same as JWT_SECRET - should fail + DatabaseURL: "postgresql://user:pass@localhost:5432/db", + RedisURL: "redis://localhost:6379", + RateLimitLimit: 100, + RateLimitWindow: 60, + CORSOrigins: []string{"https://example.com"}, + LogLevel: "INFO", + OAuthEncryptionKey: strings.Repeat("b", 32), + } + logger, _ := zap.NewDevelopment() + cfg.Logger = logger + + os.Setenv("CLAMAV_REQUIRED", "true") + err := cfg.ValidateForEnvironment() + require.Error(t, err) + assert.Contains(t, err.Error(), "CHAT_JWT_SECRET must be different from JWT_SECRET") +} diff --git a/veza-backend-api/internal/handlers/oauth_handlers.go b/veza-backend-api/internal/handlers/oauth_handlers.go index b3060798c..495b9440d 100644 --- a/veza-backend-api/internal/handlers/oauth_handlers.go +++ b/veza-backend-api/internal/handlers/oauth_handlers.go @@ -18,7 +18,7 @@ import ( // OAuthServiceInterface defines the methods needed for OAuth handlers type OAuthServiceInterface interface { GetAuthURL(provider string) (string, error) - HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, error) + HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, string, error) GetAvailableProviders() []string } @@ -134,16 +134,22 @@ func (oh *OAuthHandlers) OAuthCallback(c *gin.Context) { return } - // Handle callback (VEZA-SEC-001: returns TokenPair, creates session) - user, tokens, err := oh.oauthService.HandleCallback(c.Request.Context(), provider, code, state, c.ClientIP(), c.Request.UserAgent()) + // Handle callback (VEZA-SEC-001: returns TokenPair, creates session; v0.902: returns validated redirectURL) + user, tokens, redirectURL, err := oh.oauthService.HandleCallback(c.Request.Context(), provider, code, state, c.ClientIP(), c.Request.UserAgent()) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // SECURITY: Validate redirect URL against allowlist to prevent open redirect - frontendURL := oh.frontendURL - if !oh.isAllowedRedirectOrigin(frontendURL) { + // Use validated redirect URL from service as base, or fallback to frontendURL + baseURL := oh.frontendURL + if redirectURL != "" { + frontendURLParsed, _ := url.Parse(redirectURL) + if frontendURLParsed != nil && frontendURLParsed.Scheme != "" && frontendURLParsed.Host != "" { + baseURL = frontendURLParsed.Scheme + "://" + frontendURLParsed.Host + } + } + if !oh.isAllowedRedirectOrigin(baseURL) { c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid redirect configuration"}) return } @@ -181,8 +187,8 @@ func (oh *OAuthHandlers) OAuthCallback(c *gin.Context) { } // Redirect to frontend (tokens in cookies, not URL) - redirectURL := fmt.Sprintf("%s/auth/callback?user_id=%s", strings.TrimSuffix(frontendURL, "/"), user.ID.String()) - c.Redirect(http.StatusTemporaryRedirect, redirectURL) + finalRedirect := fmt.Sprintf("%s/auth/callback?user_id=%s", strings.TrimSuffix(baseURL, "/"), user.ID.String()) + c.Redirect(http.StatusTemporaryRedirect, finalRedirect) } // isAllowedRedirectOrigin validates that the frontend URL is in the allowlist. Returns true if allowed. diff --git a/veza-backend-api/internal/handlers/oauth_handlers_test.go b/veza-backend-api/internal/handlers/oauth_handlers_test.go index 26c817f7b..cc7e226b8 100644 --- a/veza-backend-api/internal/handlers/oauth_handlers_test.go +++ b/veza-backend-api/internal/handlers/oauth_handlers_test.go @@ -28,15 +28,15 @@ func (m *MockOAuthService) GetAuthURL(provider string) (string, error) { return args.String(0), args.Error(1) } -func (m *MockOAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, error) { +func (m *MockOAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, string, error) { args := m.Called(ctx, provider, code, state, ipAddress, userAgent) if args.Get(0) == nil { - return nil, nil, args.Error(2) + return nil, nil, "", args.Error(3) } if args.Get(1) == nil { - return args.Get(0).(*services.OAuthUser), nil, args.Error(2) + return args.Get(0).(*services.OAuthUser), nil, args.String(2), args.Error(3) } - return args.Get(0).(*services.OAuthUser), args.Get(1).(*models.TokenPair), args.Error(2) + return args.Get(0).(*services.OAuthUser), args.Get(1).(*models.TokenPair), args.String(2), args.Error(3) } func (m *MockOAuthService) GetAvailableProviders() []string { @@ -145,7 +145,7 @@ func TestOAuthHandlers_OAuthCallback_Success(t *testing.T) { ExpiresIn: int(5 * time.Minute.Seconds()), } - mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(mockUser, tokens, nil) + mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(mockUser, tokens, "", nil) // Execute req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google/callback?code=test-code&state=test-state", nil) @@ -197,7 +197,7 @@ func TestOAuthHandlers_OAuthCallback_ServiceError(t *testing.T) { mockService := new(MockOAuthService) router := setupTestOAuthRouter(mockService) - mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(nil, nil, assert.AnError) + mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(nil, nil, "", assert.AnError) // Execute req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google/callback?code=test-code&state=test-state", nil) diff --git a/veza-backend-api/internal/services/crypto_service.go b/veza-backend-api/internal/services/crypto_service.go new file mode 100644 index 000000000..4f0cc465d --- /dev/null +++ b/veza-backend-api/internal/services/crypto_service.go @@ -0,0 +1,128 @@ +package services + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" +) + +const ( + // NonceSize for AES-GCM (12 bytes recommended by NIST) + gcmNonceSize = 12 + // Prefix for encrypted tokens stored in DB (detect already-encrypted) + encryptedTokenPrefix = "veza_enc_v1:" +) + +// CryptoService provides AES-256-GCM encryption for sensitive data (e.g. OAuth tokens at rest) +type CryptoService struct { + aead cipher.AEAD +} + +// NewCryptoService creates a CryptoService with the given key (32 bytes for AES-256) +func NewCryptoService(key []byte) (*CryptoService, error) { + if len(key) < 32 { + return nil, errors.New("encryption key must be at least 32 bytes for AES-256") + } + // Use first 32 bytes + key32 := key + if len(key) > 32 { + key32 = key[:32] + } + block, err := aes.NewCipher(key32) + if err != nil { + return nil, fmt.Errorf("aes new cipher: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("gcm: %w", err) + } + return &CryptoService{aead: aead}, nil +} + +// NewCryptoServiceFromBase64 creates a CryptoService from a base64-encoded key +func NewCryptoServiceFromBase64(keyBase64 string) (*CryptoService, error) { + if keyBase64 == "" { + return nil, errors.New("encryption key must not be empty") + } + key, err := base64.RawStdEncoding.DecodeString(keyBase64) + if err != nil { + // Try standard base64 + key, err = base64.StdEncoding.DecodeString(keyBase64) + if err != nil { + return nil, fmt.Errorf("decode key: %w", err) + } + } + return NewCryptoService(key) +} + +// NewCryptoServiceFromHex creates a CryptoService from a hex-encoded key (optional, for future) +// For now we use base64. Key can also be raw bytes if passed as string - we'll decode. + +// Encrypt encrypts plaintext with AES-256-GCM. Returns base64-encoded result. +func (c *CryptoService) Encrypt(plaintext []byte) ([]byte, error) { + nonce := make([]byte, gcmNonceSize) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, fmt.Errorf("rand nonce: %w", err) + } + ciphertext := c.aead.Seal(nil, nonce, plaintext, nil) + // Prepend nonce: nonce || ciphertext (ciphertext includes tag) + out := make([]byte, 0, len(nonce)+len(ciphertext)) + out = append(out, nonce...) + out = append(out, ciphertext...) + return out, nil +} + +// Decrypt decrypts ciphertext (format: nonce || sealed). Returns plaintext. +func (c *CryptoService) Decrypt(ciphertext []byte) ([]byte, error) { + if len(ciphertext) < gcmNonceSize { + return nil, errors.New("ciphertext too short") + } + nonce := ciphertext[:gcmNonceSize] + sealed := ciphertext[gcmNonceSize:] + return c.aead.Open(nil, nonce, sealed, nil) +} + +// EncryptString encrypts a string and returns the prefixed base64 result for DB storage +func (c *CryptoService) EncryptString(plaintext string) (string, error) { + if plaintext == "" { + return "", nil + } + enc, err := c.Encrypt([]byte(plaintext)) + if err != nil { + return "", err + } + return encryptedTokenPrefix + base64.RawStdEncoding.EncodeToString(enc), nil +} + +// DecryptString decrypts a string stored with EncryptString (checks prefix) +func (c *CryptoService) DecryptString(stored string) (string, error) { + if stored == "" { + return "", nil + } + if len(stored) < len(encryptedTokenPrefix) || stored[:len(encryptedTokenPrefix)] != encryptedTokenPrefix { + // Not encrypted (legacy plaintext) + return stored, nil + } + b64 := stored[len(encryptedTokenPrefix):] + enc, err := base64.RawStdEncoding.DecodeString(b64) + if err != nil { + enc, err = base64.StdEncoding.DecodeString(b64) + if err != nil { + return "", fmt.Errorf("decode stored: %w", err) + } + } + dec, err := c.Decrypt(enc) + if err != nil { + return "", err + } + return string(dec), nil +} + +// EncryptedTokenPrefix returns the prefix used for encrypted tokens +func EncryptedTokenPrefix() string { + return encryptedTokenPrefix +} diff --git a/veza-backend-api/internal/services/crypto_service_test.go b/veza-backend-api/internal/services/crypto_service_test.go new file mode 100644 index 000000000..a8a47be1f --- /dev/null +++ b/veza-backend-api/internal/services/crypto_service_test.go @@ -0,0 +1,161 @@ +package services + +import ( + "encoding/base64" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCryptoService_InvalidKey(t *testing.T) { + _, err := NewCryptoService([]byte("short")) + require.Error(t, err) + assert.Contains(t, err.Error(), "32 bytes") +} + +func TestNewCryptoService_ValidKey(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + svc, err := NewCryptoService(key) + require.NoError(t, err) + assert.NotNil(t, svc) +} + +func TestCryptoService_EncryptDecrypt_Roundtrip(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + svc, err := NewCryptoService(key) + require.NoError(t, err) + + plaintext := []byte("sensitive-oauth-token-value") + enc, err := svc.Encrypt(plaintext) + require.NoError(t, err) + assert.NotEqual(t, plaintext, enc) + assert.NotEmpty(t, enc) + + dec, err := svc.Decrypt(enc) + require.NoError(t, err) + assert.Equal(t, plaintext, dec) +} + +func TestCryptoService_EncryptDecrypt_DifferentEachTime(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + svc, err := NewCryptoService(key) + require.NoError(t, err) + + plaintext := []byte("token") + enc1, err := svc.Encrypt(plaintext) + require.NoError(t, err) + enc2, err := svc.Encrypt(plaintext) + require.NoError(t, err) + assert.NotEqual(t, enc1, enc2, "encryption should be non-deterministic (random nonce)") + + dec1, err := svc.Decrypt(enc1) + require.NoError(t, err) + dec2, err := svc.Decrypt(enc2) + require.NoError(t, err) + assert.Equal(t, plaintext, dec1) + assert.Equal(t, plaintext, dec2) +} + +func TestCryptoService_EncryptString_DecryptString_Roundtrip(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + svc, err := NewCryptoService(key) + require.NoError(t, err) + + plaintext := "my-access-token-123" + enc, err := svc.EncryptString(plaintext) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(enc, EncryptedTokenPrefix())) + assert.NotEqual(t, plaintext, enc) + + dec, err := svc.DecryptString(enc) + require.NoError(t, err) + assert.Equal(t, plaintext, dec) +} + +func TestCryptoService_DecryptString_Plaintext(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + svc, err := NewCryptoService(key) + require.NoError(t, err) + + // Legacy plaintext (no prefix) -> returned as-is + plain := "legacy-token" + dec, err := svc.DecryptString(plain) + require.NoError(t, err) + assert.Equal(t, plain, dec) +} + +func TestCryptoService_DecryptString_Empty(t *testing.T) { + key := make([]byte, 32) + svc, err := NewCryptoService(key) + require.NoError(t, err) + + dec, err := svc.DecryptString("") + require.NoError(t, err) + assert.Empty(t, dec) +} + +func TestCryptoService_EncryptString_Empty(t *testing.T) { + key := make([]byte, 32) + svc, err := NewCryptoService(key) + require.NoError(t, err) + + enc, err := svc.EncryptString("") + require.NoError(t, err) + assert.Empty(t, enc) +} + +func TestCryptoService_Decrypt_ModifiedCiphertext(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + svc, err := NewCryptoService(key) + require.NoError(t, err) + + enc, err := svc.Encrypt([]byte("token")) + require.NoError(t, err) + enc[20] ^= 0xff // flip a byte + + _, err = svc.Decrypt(enc) + require.Error(t, err) +} + +func TestNewCryptoServiceFromBase64(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + keyB64 := base64.RawStdEncoding.EncodeToString(key) + + svc, err := NewCryptoServiceFromBase64(keyB64) + require.NoError(t, err) + assert.NotNil(t, svc) + + enc, err := svc.EncryptString("test") + require.NoError(t, err) + dec, err := svc.DecryptString(enc) + require.NoError(t, err) + assert.Equal(t, "test", dec) +} + +func TestNewCryptoServiceFromBase64_Empty(t *testing.T) { + _, err := NewCryptoServiceFromBase64("") + require.Error(t, err) +} diff --git a/veza-backend-api/internal/services/oauth_service.go b/veza-backend-api/internal/services/oauth_service.go index ae26f17e5..15539fe5c 100644 --- a/veza-backend-api/internal/services/oauth_service.go +++ b/veza-backend-api/internal/services/oauth_service.go @@ -9,6 +9,8 @@ import ( "fmt" "io" "net/http" + "net/url" + "strings" "time" "veza-backend-api/internal/database" @@ -33,6 +35,9 @@ type OAuthService struct { sessionService *SessionService userService *UserService circuitBreaker *CircuitBreakerHTTPClient + cryptoService *CryptoService // v0.902: encrypt OAuth provider tokens at rest (nil = store plaintext) + frontendURL string // v0.902: default redirect URL + allowedDomains []string // v0.902: whitelist for OAuth redirect (OAUTH_ALLOWED_REDIRECT_DOMAINS) } // OAuthAccount represents an OAuth account linking @@ -52,20 +57,29 @@ type OAuthAccount struct { UpdatedAt time.Time `json:"updated_at" db:"updated_at"` } -// OAuthState represents an OAuth state for CSRF protection +// OAuthState represents an OAuth state for CSRF protection and PKCE type OAuthState struct { - ID int64 `db:"id"` - StateToken string `db:"state_token"` - Provider string `db:"provider"` - RedirectURL string `db:"redirect_url"` - ExpiresAt time.Time `db:"expires_at"` - CreatedAt time.Time `db:"created_at"` + ID int64 `db:"id"` + StateToken string `db:"state_token"` + Provider string `db:"provider"` + RedirectURL string `db:"redirect_url"` + CodeVerifier string `db:"code_verifier"` + ExpiresAt time.Time `db:"expires_at"` + CreatedAt time.Time `db:"created_at"` +} + +// OAuthServiceConfig holds optional config for OAuth (v0.902) +type OAuthServiceConfig struct { + CryptoService *CryptoService + FrontendURL string + AllowedDomains []string } // NewOAuthService creates a new OAuth service -func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTService, sessionService *SessionService, userService *UserService) *OAuthService { +// cfg: optional config for crypto, redirect validation (v0.902) +func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTService, sessionService *SessionService, userService *UserService, cfg *OAuthServiceConfig) *OAuthService { httpClient := &http.Client{Timeout: 10 * time.Second} - return &OAuthService{ + svc := &OAuthService{ db: db, logger: logger, jwtService: jwtService, @@ -73,6 +87,12 @@ func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTS userService: userService, circuitBreaker: NewCircuitBreakerHTTPClient(httpClient, "oauth-service", logger), } + if cfg != nil { + svc.cryptoService = cfg.CryptoService + svc.frontendURL = cfg.FrontendURL + svc.allowedDomains = cfg.AllowedDomains + } + return svc } // InitializeConfigs initializes OAuth configurations @@ -154,39 +174,42 @@ func (os *OAuthService) GetAvailableProviders() []string { return providers } -// GenerateStateToken generates a secure state token for CSRF protection -func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (string, error) { - // Generate random token - tokenBytes := make([]byte, 32) - _, err := rand.Read(tokenBytes) - if err != nil { - return "", err - } - stateToken := base64.URLEncoding.EncodeToString(tokenBytes) +// GenerateStateToken generates a secure state token for CSRF protection and PKCE code_verifier +func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (stateToken, codeVerifier string, err error) { + // Generate PKCE code verifier (RFC 7636) + codeVerifier = oauth2.GenerateVerifier() - // Store in database + // Generate random state token + tokenBytes := make([]byte, 32) + _, err = rand.Read(tokenBytes) + if err != nil { + return "", "", err + } + stateToken = base64.URLEncoding.EncodeToString(tokenBytes) + + // Store in database with code_verifier for PKCE ctx := context.Background() expiresAt := time.Now().Add(10 * time.Minute) _, err = os.db.ExecContext(ctx, ` - INSERT INTO oauth_states (state_token, provider, redirect_url, expires_at) - VALUES ($1, $2, $3, $4) - `, stateToken, provider, redirectURL, expiresAt) + INSERT INTO oauth_states (state_token, provider, redirect_url, code_verifier, expires_at) + VALUES ($1, $2, $3, $4, $5) + `, stateToken, provider, redirectURL, codeVerifier, expiresAt) if err != nil { - return "", err + return "", "", err } os.logger.Debug("State token generated", zap.String("provider", provider)) - return stateToken, nil + return stateToken, codeVerifier, nil } -// ValidateStateToken validates and consumes a state token +// ValidateStateToken validates and consumes a state token (returns OAuthState with CodeVerifier for PKCE) func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, error) { ctx := context.Background() var state OAuthState err := os.db.QueryRowContext(ctx, ` - SELECT id, state_token, provider, redirect_url, expires_at, created_at + SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1 `, stateToken).Scan( @@ -194,6 +217,7 @@ func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, erro &state.StateToken, &state.Provider, &state.RedirectURL, + &state.CodeVerifier, &state.ExpiresAt, &state.CreatedAt, ) @@ -246,24 +270,36 @@ func (os *OAuthService) GetAuthURL(provider string) (string, error) { return "", fmt.Errorf("unknown provider: %s", provider) } - // Generate state token - stateToken, err := os.GenerateStateToken(provider, "") + // Generate state token with PKCE code_verifier + stateToken, codeVerifier, err := os.GenerateStateToken(provider, "") if err != nil { return "", err } - // Return authorization URL - url := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline) - return url, nil + // Return authorization URL with PKCE (code_challenge, code_challenge_method=S256) + authURL := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier)) + return authURL, nil } // HandleCallback processes the OAuth callback. // ipAddress and userAgent are used for session creation (optional, can be empty). -func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*OAuthUser, *models.TokenPair, error) { - // Validate state - _, err := os.ValidateStateToken(state) +// Returns redirectURL (validated) for the handler to redirect to (v0.902). +func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*OAuthUser, *models.TokenPair, string, error) { + // Validate state and get code_verifier for PKCE + oauthState, err := os.ValidateStateToken(state) if err != nil { - return nil, nil, err + return nil, nil, "", err + } + + // v0.902: Validate redirect URL against whitelist (VEZA-SEC-010) + redirectURL := oauthState.RedirectURL + if redirectURL == "" { + redirectURL = os.frontendURL + } + if redirectURL != "" { + if err := os.validateRedirectURL(redirectURL); err != nil { + return nil, nil, "", err + } } var config *oauth2.Config @@ -277,47 +313,47 @@ func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, stat case "spotify": config = os.spotifyConfig default: - return nil, nil, fmt.Errorf("unknown provider: %s", provider) + return nil, nil, "", fmt.Errorf("unknown provider: %s", provider) } if config == nil { - return nil, nil, fmt.Errorf("%s OAuth not configured", provider) + return nil, nil, "", fmt.Errorf("%s OAuth not configured", provider) } - // Exchange code for token - token, err := config.Exchange(ctx, code) + // Exchange code for token with PKCE code_verifier + token, err := config.Exchange(ctx, code, oauth2.VerifierOption(oauthState.CodeVerifier)) if err != nil { - return nil, nil, err + return nil, nil, "", err } // Get user info from provider oauthUser, err := os.getUserInfo(provider, token.AccessToken) if err != nil { - return nil, nil, err + return nil, nil, "", err } // Check if user already exists (by provider account or email) — audit 1.8: OAuth ID lookup first existingUser, err := os.getOrCreateUser(provider, oauthUser) if err != nil { - return nil, nil, err + return nil, nil, "", err } // Save/update OAuth account err = os.saveOAuthAccount(provider, oauthUser, existingUser.ID, token) if err != nil { - return nil, nil, err + return nil, nil, "", err } // VEZA-SEC-001: Get full user for JWT (TokenVersion, Role, etc.) user, err := os.userService.GetByID(existingUser.ID) if err != nil { - return nil, nil, fmt.Errorf("failed to get user: %w", err) + return nil, nil, "", fmt.Errorf("failed to get user: %w", err) } // Generate tokens via JWTService (proper issuer, audience, token_version) tokens, err := os.jwtService.GenerateTokenPair(user) if err != nil { - return nil, nil, fmt.Errorf("failed to generate tokens: %w", err) + return nil, nil, "", fmt.Errorf("failed to generate tokens: %w", err) } // Create session for refresh token validation @@ -339,7 +375,43 @@ func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, stat return &OAuthUser{ ID: existingUser.ID, Email: existingUser.Email, - }, tokens, nil + }, tokens, redirectURL, nil +} + +// validateRedirectURL checks that the redirect URL's origin is in the allowed domains whitelist +func (os *OAuthService) validateRedirectURL(redirectURL string) error { + parsed, err := url.Parse(redirectURL) + if err != nil { + return fmt.Errorf("invalid redirect URL: %w", err) + } + if parsed.Scheme == "" || parsed.Host == "" { + return fmt.Errorf("invalid redirect URL: missing scheme or host") + } + redirectOrigin := parsed.Scheme + "://" + parsed.Host + + if len(os.allowedDomains) == 0 { + // Dev fallback: allow localhost + if strings.HasPrefix(redirectOrigin, "http://localhost") || strings.HasPrefix(redirectOrigin, "http://127.0.0.1") { + return nil + } + return fmt.Errorf("redirect URL not allowed: %s (no whitelist configured)", redirectURL) + } + + for _, allowed := range os.allowedDomains { + allowed = strings.TrimSpace(allowed) + if allowed == "*" { + return nil + } + allowedParsed, err := url.Parse(allowed) + if err != nil || allowedParsed.Scheme == "" || allowedParsed.Host == "" { + continue + } + allowedOrigin := allowedParsed.Scheme + "://" + allowedParsed.Host + if redirectOrigin == allowedOrigin { + return nil + } + } + return fmt.Errorf("redirect URL not allowed: %s", redirectURL) } // OAuthUser represents an OAuth authenticated user @@ -574,10 +646,24 @@ func (os *OAuthService) getOrCreateUser(provider string, oauthUser *OAuthUser) ( } // saveOAuthAccount saves or updates OAuth account information -// Uses federated_identities table +// Uses federated_identities table. Tokens are encrypted at rest when CryptoService is set (v0.902) func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser, userID uuid.UUID, token *oauth2.Token) error { ctx := context.Background() + accessToken := token.AccessToken + refreshToken := token.RefreshToken + if os.cryptoService != nil { + var errEnc error + accessToken, errEnc = os.cryptoService.EncryptString(token.AccessToken) + if errEnc != nil { + return fmt.Errorf("encrypt access token: %w", errEnc) + } + refreshToken, errEnc = os.cryptoService.EncryptString(token.RefreshToken) + if errEnc != nil { + return fmt.Errorf("encrypt refresh token: %w", errEnc) + } + } + // Check if OAuth account already exists var existingID uuid.UUID err := os.db.QueryRowContext(ctx, ` @@ -591,7 +677,7 @@ func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser, UPDATE federated_identities SET email = $1, display_name = $2, access_token = $3, refresh_token = $4, expires_at = $5, updated_at = NOW() WHERE id = $6 - `, oauthUser.Email, oauthUser.Name, token.AccessToken, token.RefreshToken, token.Expiry, existingID) + `, oauthUser.Email, oauthUser.Name, accessToken, refreshToken, token.Expiry, existingID) return err } @@ -603,7 +689,7 @@ func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser, _, err = os.db.ExecContext(ctx, ` INSERT INTO federated_identities (id, user_id, provider, provider_id, email, display_name, avatar_url, access_token, refresh_token, expires_at, created_at, updated_at) VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW()) - `, userID, provider, oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, token.AccessToken, token.RefreshToken, token.Expiry) + `, userID, provider, oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, accessToken, refreshToken, token.Expiry) return err } diff --git a/veza-backend-api/internal/services/oauth_service_test.go b/veza-backend-api/internal/services/oauth_service_test.go index 89bfbd208..ee127e0c0 100644 --- a/veza-backend-api/internal/services/oauth_service_test.go +++ b/veza-backend-api/internal/services/oauth_service_test.go @@ -32,7 +32,7 @@ func setupOAuthServiceForTests(t *testing.T, db *database.Database) *OAuthServic userRepo := repositories.NewGormUserRepository(db.GormDB) userService = NewUserServiceWithDB(userRepo, db.GormDB) } - return NewOAuthService(db, zap.NewNop(), jwtService, sessionService, userService) + return NewOAuthService(db, zap.NewNop(), jwtService, sessionService, userService, nil) } // Helper to setup mock DB @@ -62,17 +62,18 @@ func TestOAuthService_GenerateStateToken_Success(t *testing.T) { provider := "google" redirectURL := "http://example.com" - // Expectation + // Expectation: state_token, provider, redirect_url, code_verifier, expires_at mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO oauth_states`)). - WithArgs(sqlmock.AnyArg(), provider, redirectURL, sqlmock.AnyArg()). + WithArgs(sqlmock.AnyArg(), provider, redirectURL, sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(1, 1)) // Execute - token, err := service.GenerateStateToken(provider, redirectURL) + stateToken, codeVerifier, err := service.GenerateStateToken(provider, redirectURL) // Assert assert.NoError(t, err) - assert.NotEmpty(t, token) + assert.NotEmpty(t, stateToken) + assert.NotEmpty(t, codeVerifier) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -90,11 +91,11 @@ func TestOAuthService_ValidateStateToken_Success(t *testing.T) { token := "valid_token" now := time.Now() - // Expectation - rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "expires_at", "created_at"}). - AddRow(1, token, "google", "http://example.com", now.Add(time.Hour), now) + // Expectation: include code_verifier for PKCE + rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "code_verifier", "expires_at", "created_at"}). + AddRow(1, token, "google", "http://example.com", "pkce_verifier_123", now.Add(time.Hour), now) - mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, expires_at, created_at FROM oauth_states WHERE state_token = $1`)). + mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1`)). WithArgs(token). WillReturnRows(rows) @@ -109,6 +110,7 @@ func TestOAuthService_ValidateStateToken_Success(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, state) assert.Equal(t, token, state.StateToken) + assert.Equal(t, "pkce_verifier_123", state.CodeVerifier) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -126,7 +128,7 @@ func TestOAuthService_ValidateStateToken_NotFound(t *testing.T) { token := "invalid_token" // Expectation - mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, expires_at, created_at FROM oauth_states WHERE state_token = $1`)). + mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1`)). WithArgs(token). WillReturnError(sql.ErrNoRows) @@ -155,10 +157,10 @@ func TestOAuthService_ValidateStateToken_Expired(t *testing.T) { now := time.Now() // Expectation - rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "expires_at", "created_at"}). - AddRow(1, token, "google", "http://example.com", now.Add(-time.Hour), now.Add(-2*time.Hour)) + rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "code_verifier", "expires_at", "created_at"}). + AddRow(1, token, "google", "http://example.com", "verifier", now.Add(-time.Hour), now.Add(-2*time.Hour)) - mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, expires_at, created_at FROM oauth_states WHERE state_token = $1`)). + mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1`)). WithArgs(token). WillReturnRows(rows) @@ -180,7 +182,7 @@ func TestOAuthService_GetAuthURL_Discord(t *testing.T) { svc.InitializeConfigs("", "", "", "", "discord-client", "discord-secret", "", "", "http://localhost:8080") mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO oauth_states`)). - WithArgs(sqlmock.AnyArg(), "discord", "", sqlmock.AnyArg()). + WithArgs(sqlmock.AnyArg(), "discord", "", sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(1, 1)) url, err := svc.GetAuthURL("discord") @@ -191,6 +193,7 @@ func TestOAuthService_GetAuthURL_Discord(t *testing.T) { assert.Contains(t, url, "identify") assert.Contains(t, url, "email") assert.Contains(t, url, "discord-client") + assert.Contains(t, url, "code_challenge") assert.NoError(t, mock.ExpectationsWereMet()) } @@ -202,7 +205,7 @@ func TestOAuthService_GetAuthURL_Spotify(t *testing.T) { svc.InitializeConfigs("", "", "", "", "", "", "spotify-client", "spotify-secret", "http://localhost:8080") mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO oauth_states`)). - WithArgs(sqlmock.AnyArg(), "spotify", "", sqlmock.AnyArg()). + WithArgs(sqlmock.AnyArg(), "spotify", "", sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(1, 1)) url, err := svc.GetAuthURL("spotify") @@ -213,6 +216,8 @@ func TestOAuthService_GetAuthURL_Spotify(t *testing.T) { assert.Contains(t, url, "user-read-email") assert.Contains(t, url, "user-read-private") assert.Contains(t, url, "spotify-client") + assert.Contains(t, url, "code_challenge") + assert.Contains(t, url, "code_challenge_method=S256") assert.NoError(t, mock.ExpectationsWereMet()) } @@ -295,6 +300,16 @@ func TestOAuthService_GetUserInfo_Spotify(t *testing.T) { assert.Equal(t, "https://avatar.url", user.Avatar) } +func TestOAuthService_ValidateRedirectURL_EvilDomain(t *testing.T) { + // v0.902: redirect to evil.com should be rejected + svc := NewOAuthService(nil, zap.NewNop(), nil, nil, nil, &OAuthServiceConfig{ + AllowedDomains: []string{"https://app.veza.com", "https://veza.fr:5173"}, + }) + err := svc.validateRedirectURL("https://evil.com/auth/callback") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not allowed") +} + func TestOAuthService_GetUserInfo_Spotify_FallbackEmail(t *testing.T) { // Spotify without email - should fallback to id@spotify.user spotifyJSON := `{"id":"spot456","display_name":"","images":[]}` diff --git a/veza-backend-api/migrations/936_oauth_states_pkce.sql b/veza-backend-api/migrations/936_oauth_states_pkce.sql new file mode 100644 index 000000000..e966be007 --- /dev/null +++ b/veza-backend-api/migrations/936_oauth_states_pkce.sql @@ -0,0 +1,18 @@ +-- 936_oauth_states_pkce.sql +-- OAuth states table with PKCE code_verifier support (v0.902 Sentinel) + +CREATE TABLE IF NOT EXISTS public.oauth_states ( + id BIGSERIAL PRIMARY KEY, + state_token VARCHAR(255) NOT NULL UNIQUE, + provider VARCHAR(50) NOT NULL, + redirect_url TEXT, + code_verifier VARCHAR(255), + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_oauth_states_state_token ON public.oauth_states(state_token); +CREATE INDEX IF NOT EXISTS idx_oauth_states_expires_at ON public.oauth_states(expires_at); + +-- If table already exists (without code_verifier), add the column +ALTER TABLE public.oauth_states ADD COLUMN IF NOT EXISTS code_verifier VARCHAR(255);