release(v0.902): Sentinel - PKCE OAuth, token encryption, redirect validation, CHAT_JWT_SECRET
Some checks failed
Backend API CI / test-unit (push) Failing after 0s
Backend API CI / test-integration (push) Failing after 0s

- PKCE (S256) in OAuth flow: code_verifier in oauth_states, code_challenge in auth URL
- CryptoService: AES-256-GCM encryption for OAuth provider tokens at rest
- OAuth redirect URL validated against OAUTH_ALLOWED_REDIRECT_DOMAINS
- CHAT_JWT_SECRET must differ from JWT_SECRET in production
- Migration script: cmd/tools/encrypt_oauth_tokens for existing tokens
- Fixes: VEZA-SEC-003, VEZA-SEC-004, VEZA-SEC-009, VEZA-SEC-010
This commit is contained in:
senke 2026-02-26 19:49:15 +01:00
parent 51984e9a1f
commit 6823e5a30d
13 changed files with 734 additions and 81 deletions

View file

@ -1 +1 @@
0.901
0.902

View file

@ -0,0 +1,148 @@
// encrypt_oauth_tokens encrypts existing OAuth provider tokens in federated_identities (v0.902).
// Idempotent: skips tokens already prefixed with veza_enc_v1:
// Usage: DATABASE_URL=... OAUTH_ENCRYPTION_KEY=... go run ./cmd/tools/encrypt_oauth_tokens [-dry-run]
package main
import (
"context"
"database/sql"
"flag"
"log"
"os"
"strings"
"time"
_ "github.com/lib/pq"
"veza-backend-api/internal/services"
)
const encryptedPrefix = "veza_enc_v1:"
func main() {
dryRun := flag.Bool("dry-run", false, "Show what would be updated without making changes")
flag.Parse()
dbURL := os.Getenv("DATABASE_URL")
if dbURL == "" {
log.Fatal("DATABASE_URL is required")
}
encKey := os.Getenv("OAUTH_ENCRYPTION_KEY")
if encKey == "" {
log.Fatal("OAUTH_ENCRYPTION_KEY is required (32+ bytes, base64 or raw)")
}
cryptoService, err := services.NewCryptoServiceFromBase64(encKey)
if err != nil {
keyBytes := []byte(encKey)
if len(keyBytes) >= 32 {
cryptoService, err = services.NewCryptoService(keyBytes)
}
}
if err != nil {
log.Fatalf("CryptoService: %v", err)
}
db, err := sql.Open("postgres", dbURL)
if err != nil {
log.Fatalf("DB connect: %v", err)
}
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
if err := db.PingContext(ctx); err != nil {
log.Fatalf("DB ping: %v", err)
}
rows, err := db.QueryContext(ctx, `
SELECT id::text, access_token, refresh_token
FROM federated_identities
WHERE (access_token IS NOT NULL AND access_token != '')
OR (refresh_token IS NOT NULL AND refresh_token != '')
`)
if err != nil {
log.Fatalf("Query: %v", err)
}
defer rows.Close()
var id string
var accessToken, refreshToken sql.NullString
updated := 0
skipped := 0
errors := 0
for rows.Next() {
if err := rows.Scan(&id, &accessToken, &refreshToken); err != nil {
log.Printf("Scan error: %v", err)
errors++
continue
}
needUpdate := false
newAccess := accessToken.String
newRefresh := refreshToken.String
if accessToken.Valid && accessToken.String != "" && !strings.HasPrefix(accessToken.String, encryptedPrefix) {
newAccess, err = cryptoService.EncryptString(accessToken.String)
if err != nil {
log.Printf("Encrypt access_token for id %s: %v", id, err)
errors++
continue
}
needUpdate = true
}
if refreshToken.Valid && refreshToken.String != "" && !strings.HasPrefix(refreshToken.String, encryptedPrefix) {
newRefresh, err = cryptoService.EncryptString(refreshToken.String)
if err != nil {
log.Printf("Encrypt refresh_token for id %s: %v", id, err)
errors++
continue
}
needUpdate = true
}
if !needUpdate {
skipped++
continue
}
if *dryRun {
log.Printf("[dry-run] Would encrypt tokens for id %s", id)
updated++
continue
}
_, err = db.ExecContext(ctx, `
UPDATE federated_identities
SET access_token = $1, refresh_token = $2, updated_at = NOW()
WHERE id = $3
`, nullIfEmpty(newAccess), nullIfEmpty(newRefresh), id)
if err != nil {
log.Printf("UPDATE id %s: %v", id, err)
errors++
continue
}
updated++
}
if err := rows.Err(); err != nil {
log.Fatalf("Rows: %v", err)
}
mode := ""
if *dryRun {
mode = " [dry-run]"
}
log.Printf("Done%s: updated=%d skipped=%d errors=%d", mode, updated, skipped, errors)
if errors > 0 {
os.Exit(1)
}
}
func nullIfEmpty(s string) interface{} {
if s == "" {
return nil
}
return s
}

View file

@ -114,8 +114,30 @@ func (r *APIRouter) setupAuthRoutes(router *gin.RouterGroup) error {
}
checkUsernameGroup.GET("/check-username", handlers.CheckUsername(authService))
// BE-API-042: OAuth routes
oauthService := services.NewOAuthService(r.db, r.logger, jwtService, sessionService, userService)
// BE-API-042: OAuth routes (v0.902: CryptoService, redirect validation)
oauthCfg := &services.OAuthServiceConfig{
FrontendURL: r.config.FrontendURL,
AllowedDomains: r.config.OAuthAllowedRedirectDomains,
}
if r.config.OAuthEncryptionKey != "" {
var cryptoService *services.CryptoService
var err error
cryptoService, err = services.NewCryptoServiceFromBase64(r.config.OAuthEncryptionKey)
if err != nil {
// Fallback: use raw bytes if key is long enough
keyBytes := []byte(r.config.OAuthEncryptionKey)
if len(keyBytes) >= 32 {
cryptoService, err = services.NewCryptoService(keyBytes)
}
}
if err != nil {
return fmt.Errorf("OAuth CryptoService: %w", err)
}
if cryptoService != nil {
oauthCfg.CryptoService = cryptoService
}
}
oauthService := services.NewOAuthService(r.db, r.logger, jwtService, sessionService, userService, oauthCfg)
baseURL := os.Getenv("BASE_URL")
if baseURL == "" {
appDomain := os.Getenv("APP_DOMAIN")

View file

@ -81,6 +81,10 @@ type Config struct {
CORSOrigins []string // Liste des origines CORS autorisées
FrontendURL string // URL du frontend (OAuth redirect, password reset links). FRONTEND_URL ou VITE_FRONTEND_URL
// OAuth Security (v0.902 Sentinel)
OAuthEncryptionKey string // OAUTH_ENCRYPTION_KEY: 32 bytes for AES-256-GCM (required in production)
OAuthAllowedRedirectDomains []string // OAUTH_ALLOWED_REDIRECT_DOMAINS: whitelist for OAuth redirect URLs
// HLS Streaming Configuration (v0.503)
HLSEnabled bool // Enable HLS streaming routes
HLSStorageDir string // Directory for HLS segment storage
@ -304,6 +308,10 @@ func NewConfig() (*Config, error) {
CORSOrigins: corsOrigins,
FrontendURL: getFrontendURL(), // OAuth callback, password reset, email links
// OAuth Security (v0.902 Sentinel)
OAuthEncryptionKey: getEnv("OAUTH_ENCRYPTION_KEY", ""),
OAuthAllowedRedirectDomains: getOAuthAllowedRedirectDomains(env, getEnvStringSlice("OAUTH_ALLOWED_REDIRECT_DOMAINS", nil), corsOrigins, getFrontendURL()),
// HLS Streaming (v0.503)
HLSEnabled: getEnvBool("HLS_STREAMING", false),
HLSStorageDir: getEnv("HLS_STORAGE_DIR", "/tmp/veza-hls"),
@ -843,6 +851,16 @@ func (c *Config) ValidateForEnvironment() error {
return fmt.Errorf("CLAMAV_REQUIRED must be true in production. Virus scanning is mandatory for uploads")
}
// 6. v0.902: CHAT_JWT_SECRET must differ from JWT_SECRET in production (VEZA-SEC-009)
if c.ChatJWTSecret == c.JWTSecret {
return fmt.Errorf("CHAT_JWT_SECRET must be different from JWT_SECRET in production. Use a separate secret for the Chat Server")
}
// 7. v0.902: OAUTH_ENCRYPTION_KEY required in production for OAuth token encryption (VEZA-SEC-004)
if len(c.OAuthEncryptionKey) < 32 {
return fmt.Errorf("OAUTH_ENCRYPTION_KEY is required in production (min 32 bytes for AES-256). Set OAUTH_ENCRYPTION_KEY with a 32-byte hex or base64 key")
}
case EnvTest:
// TEST: Validation adaptée aux tests
// CORS peut être vide ou configuré explicitement

View file

@ -1,7 +1,9 @@
package config
import (
"net/url"
"net/http"
"strings"
)
// getCookieSecure détermine si les cookies doivent être Secure
@ -104,3 +106,27 @@ func getCORSOrigins(env string, appDomain string) []string {
return []string{"http://" + appDomain + ":3000", "http://" + appDomain + ":5173"}
}
}
// getOAuthAllowedRedirectDomains returns the whitelist of domains allowed for OAuth redirect (v0.902).
// If OAUTH_ALLOWED_REDIRECT_DOMAINS is set, use it. Otherwise derive from CORSOrigins or FrontendURL.
func getOAuthAllowedRedirectDomains(env string, explicit []string, corsOrigins []string, frontendURL string) []string {
if len(explicit) > 0 {
return explicit
}
if len(corsOrigins) > 0 {
return corsOrigins
}
if frontendURL != "" {
if u, err := url.Parse(frontendURL); err == nil {
origin := strings.TrimSuffix(u.String(), "/")
if u.Scheme != "" && u.Host != "" {
return []string{origin}
}
}
}
// Dev fallback
if env == EnvDevelopment || env == EnvStaging {
return []string{"http://localhost:5173", "http://localhost:3000", "http://127.0.0.1:5173", "http://127.0.0.1:3000"}
}
return []string{}
}

View file

@ -425,3 +425,28 @@ func TestValidateForEnvironment_ClamAVRequiredInProduction(t *testing.T) {
require.NoError(t, err)
})
}
// TestValidateForEnvironment_ChatJWTSecretInProduction verifies CHAT_JWT_SECRET must differ from JWT_SECRET in production (v0.902)
func TestValidateForEnvironment_ChatJWTSecretInProduction(t *testing.T) {
secret := strings.Repeat("a", 32)
cfg := &Config{
Env: EnvProduction,
AppPort: 8080,
JWTSecret: secret,
ChatJWTSecret: secret, // Same as JWT_SECRET - should fail
DatabaseURL: "postgresql://user:pass@localhost:5432/db",
RedisURL: "redis://localhost:6379",
RateLimitLimit: 100,
RateLimitWindow: 60,
CORSOrigins: []string{"https://example.com"},
LogLevel: "INFO",
OAuthEncryptionKey: strings.Repeat("b", 32),
}
logger, _ := zap.NewDevelopment()
cfg.Logger = logger
os.Setenv("CLAMAV_REQUIRED", "true")
err := cfg.ValidateForEnvironment()
require.Error(t, err)
assert.Contains(t, err.Error(), "CHAT_JWT_SECRET must be different from JWT_SECRET")
}

View file

@ -18,7 +18,7 @@ import (
// OAuthServiceInterface defines the methods needed for OAuth handlers
type OAuthServiceInterface interface {
GetAuthURL(provider string) (string, error)
HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, error)
HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, string, error)
GetAvailableProviders() []string
}
@ -134,16 +134,22 @@ func (oh *OAuthHandlers) OAuthCallback(c *gin.Context) {
return
}
// Handle callback (VEZA-SEC-001: returns TokenPair, creates session)
user, tokens, err := oh.oauthService.HandleCallback(c.Request.Context(), provider, code, state, c.ClientIP(), c.Request.UserAgent())
// Handle callback (VEZA-SEC-001: returns TokenPair, creates session; v0.902: returns validated redirectURL)
user, tokens, redirectURL, err := oh.oauthService.HandleCallback(c.Request.Context(), provider, code, state, c.ClientIP(), c.Request.UserAgent())
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// SECURITY: Validate redirect URL against allowlist to prevent open redirect
frontendURL := oh.frontendURL
if !oh.isAllowedRedirectOrigin(frontendURL) {
// Use validated redirect URL from service as base, or fallback to frontendURL
baseURL := oh.frontendURL
if redirectURL != "" {
frontendURLParsed, _ := url.Parse(redirectURL)
if frontendURLParsed != nil && frontendURLParsed.Scheme != "" && frontendURLParsed.Host != "" {
baseURL = frontendURLParsed.Scheme + "://" + frontendURLParsed.Host
}
}
if !oh.isAllowedRedirectOrigin(baseURL) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid redirect configuration"})
return
}
@ -181,8 +187,8 @@ func (oh *OAuthHandlers) OAuthCallback(c *gin.Context) {
}
// Redirect to frontend (tokens in cookies, not URL)
redirectURL := fmt.Sprintf("%s/auth/callback?user_id=%s", strings.TrimSuffix(frontendURL, "/"), user.ID.String())
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
finalRedirect := fmt.Sprintf("%s/auth/callback?user_id=%s", strings.TrimSuffix(baseURL, "/"), user.ID.String())
c.Redirect(http.StatusTemporaryRedirect, finalRedirect)
}
// isAllowedRedirectOrigin validates that the frontend URL is in the allowlist. Returns true if allowed.

View file

@ -28,15 +28,15 @@ func (m *MockOAuthService) GetAuthURL(provider string) (string, error) {
return args.String(0), args.Error(1)
}
func (m *MockOAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, error) {
func (m *MockOAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, string, error) {
args := m.Called(ctx, provider, code, state, ipAddress, userAgent)
if args.Get(0) == nil {
return nil, nil, args.Error(2)
return nil, nil, "", args.Error(3)
}
if args.Get(1) == nil {
return args.Get(0).(*services.OAuthUser), nil, args.Error(2)
return args.Get(0).(*services.OAuthUser), nil, args.String(2), args.Error(3)
}
return args.Get(0).(*services.OAuthUser), args.Get(1).(*models.TokenPair), args.Error(2)
return args.Get(0).(*services.OAuthUser), args.Get(1).(*models.TokenPair), args.String(2), args.Error(3)
}
func (m *MockOAuthService) GetAvailableProviders() []string {
@ -145,7 +145,7 @@ func TestOAuthHandlers_OAuthCallback_Success(t *testing.T) {
ExpiresIn: int(5 * time.Minute.Seconds()),
}
mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(mockUser, tokens, nil)
mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(mockUser, tokens, "", nil)
// Execute
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google/callback?code=test-code&state=test-state", nil)
@ -197,7 +197,7 @@ func TestOAuthHandlers_OAuthCallback_ServiceError(t *testing.T) {
mockService := new(MockOAuthService)
router := setupTestOAuthRouter(mockService)
mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(nil, nil, assert.AnError)
mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(nil, nil, "", assert.AnError)
// Execute
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google/callback?code=test-code&state=test-state", nil)

View file

@ -0,0 +1,128 @@
package services
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
)
const (
// NonceSize for AES-GCM (12 bytes recommended by NIST)
gcmNonceSize = 12
// Prefix for encrypted tokens stored in DB (detect already-encrypted)
encryptedTokenPrefix = "veza_enc_v1:"
)
// CryptoService provides AES-256-GCM encryption for sensitive data (e.g. OAuth tokens at rest)
type CryptoService struct {
aead cipher.AEAD
}
// NewCryptoService creates a CryptoService with the given key (32 bytes for AES-256)
func NewCryptoService(key []byte) (*CryptoService, error) {
if len(key) < 32 {
return nil, errors.New("encryption key must be at least 32 bytes for AES-256")
}
// Use first 32 bytes
key32 := key
if len(key) > 32 {
key32 = key[:32]
}
block, err := aes.NewCipher(key32)
if err != nil {
return nil, fmt.Errorf("aes new cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("gcm: %w", err)
}
return &CryptoService{aead: aead}, nil
}
// NewCryptoServiceFromBase64 creates a CryptoService from a base64-encoded key
func NewCryptoServiceFromBase64(keyBase64 string) (*CryptoService, error) {
if keyBase64 == "" {
return nil, errors.New("encryption key must not be empty")
}
key, err := base64.RawStdEncoding.DecodeString(keyBase64)
if err != nil {
// Try standard base64
key, err = base64.StdEncoding.DecodeString(keyBase64)
if err != nil {
return nil, fmt.Errorf("decode key: %w", err)
}
}
return NewCryptoService(key)
}
// NewCryptoServiceFromHex creates a CryptoService from a hex-encoded key (optional, for future)
// For now we use base64. Key can also be raw bytes if passed as string - we'll decode.
// Encrypt encrypts plaintext with AES-256-GCM. Returns base64-encoded result.
func (c *CryptoService) Encrypt(plaintext []byte) ([]byte, error) {
nonce := make([]byte, gcmNonceSize)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("rand nonce: %w", err)
}
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
// Prepend nonce: nonce || ciphertext (ciphertext includes tag)
out := make([]byte, 0, len(nonce)+len(ciphertext))
out = append(out, nonce...)
out = append(out, ciphertext...)
return out, nil
}
// Decrypt decrypts ciphertext (format: nonce || sealed). Returns plaintext.
func (c *CryptoService) Decrypt(ciphertext []byte) ([]byte, error) {
if len(ciphertext) < gcmNonceSize {
return nil, errors.New("ciphertext too short")
}
nonce := ciphertext[:gcmNonceSize]
sealed := ciphertext[gcmNonceSize:]
return c.aead.Open(nil, nonce, sealed, nil)
}
// EncryptString encrypts a string and returns the prefixed base64 result for DB storage
func (c *CryptoService) EncryptString(plaintext string) (string, error) {
if plaintext == "" {
return "", nil
}
enc, err := c.Encrypt([]byte(plaintext))
if err != nil {
return "", err
}
return encryptedTokenPrefix + base64.RawStdEncoding.EncodeToString(enc), nil
}
// DecryptString decrypts a string stored with EncryptString (checks prefix)
func (c *CryptoService) DecryptString(stored string) (string, error) {
if stored == "" {
return "", nil
}
if len(stored) < len(encryptedTokenPrefix) || stored[:len(encryptedTokenPrefix)] != encryptedTokenPrefix {
// Not encrypted (legacy plaintext)
return stored, nil
}
b64 := stored[len(encryptedTokenPrefix):]
enc, err := base64.RawStdEncoding.DecodeString(b64)
if err != nil {
enc, err = base64.StdEncoding.DecodeString(b64)
if err != nil {
return "", fmt.Errorf("decode stored: %w", err)
}
}
dec, err := c.Decrypt(enc)
if err != nil {
return "", err
}
return string(dec), nil
}
// EncryptedTokenPrefix returns the prefix used for encrypted tokens
func EncryptedTokenPrefix() string {
return encryptedTokenPrefix
}

View file

@ -0,0 +1,161 @@
package services
import (
"encoding/base64"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewCryptoService_InvalidKey(t *testing.T) {
_, err := NewCryptoService([]byte("short"))
require.Error(t, err)
assert.Contains(t, err.Error(), "32 bytes")
}
func TestNewCryptoService_ValidKey(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
svc, err := NewCryptoService(key)
require.NoError(t, err)
assert.NotNil(t, svc)
}
func TestCryptoService_EncryptDecrypt_Roundtrip(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
svc, err := NewCryptoService(key)
require.NoError(t, err)
plaintext := []byte("sensitive-oauth-token-value")
enc, err := svc.Encrypt(plaintext)
require.NoError(t, err)
assert.NotEqual(t, plaintext, enc)
assert.NotEmpty(t, enc)
dec, err := svc.Decrypt(enc)
require.NoError(t, err)
assert.Equal(t, plaintext, dec)
}
func TestCryptoService_EncryptDecrypt_DifferentEachTime(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
svc, err := NewCryptoService(key)
require.NoError(t, err)
plaintext := []byte("token")
enc1, err := svc.Encrypt(plaintext)
require.NoError(t, err)
enc2, err := svc.Encrypt(plaintext)
require.NoError(t, err)
assert.NotEqual(t, enc1, enc2, "encryption should be non-deterministic (random nonce)")
dec1, err := svc.Decrypt(enc1)
require.NoError(t, err)
dec2, err := svc.Decrypt(enc2)
require.NoError(t, err)
assert.Equal(t, plaintext, dec1)
assert.Equal(t, plaintext, dec2)
}
func TestCryptoService_EncryptString_DecryptString_Roundtrip(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
svc, err := NewCryptoService(key)
require.NoError(t, err)
plaintext := "my-access-token-123"
enc, err := svc.EncryptString(plaintext)
require.NoError(t, err)
assert.True(t, strings.HasPrefix(enc, EncryptedTokenPrefix()))
assert.NotEqual(t, plaintext, enc)
dec, err := svc.DecryptString(enc)
require.NoError(t, err)
assert.Equal(t, plaintext, dec)
}
func TestCryptoService_DecryptString_Plaintext(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
svc, err := NewCryptoService(key)
require.NoError(t, err)
// Legacy plaintext (no prefix) -> returned as-is
plain := "legacy-token"
dec, err := svc.DecryptString(plain)
require.NoError(t, err)
assert.Equal(t, plain, dec)
}
func TestCryptoService_DecryptString_Empty(t *testing.T) {
key := make([]byte, 32)
svc, err := NewCryptoService(key)
require.NoError(t, err)
dec, err := svc.DecryptString("")
require.NoError(t, err)
assert.Empty(t, dec)
}
func TestCryptoService_EncryptString_Empty(t *testing.T) {
key := make([]byte, 32)
svc, err := NewCryptoService(key)
require.NoError(t, err)
enc, err := svc.EncryptString("")
require.NoError(t, err)
assert.Empty(t, enc)
}
func TestCryptoService_Decrypt_ModifiedCiphertext(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
svc, err := NewCryptoService(key)
require.NoError(t, err)
enc, err := svc.Encrypt([]byte("token"))
require.NoError(t, err)
enc[20] ^= 0xff // flip a byte
_, err = svc.Decrypt(enc)
require.Error(t, err)
}
func TestNewCryptoServiceFromBase64(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
keyB64 := base64.RawStdEncoding.EncodeToString(key)
svc, err := NewCryptoServiceFromBase64(keyB64)
require.NoError(t, err)
assert.NotNil(t, svc)
enc, err := svc.EncryptString("test")
require.NoError(t, err)
dec, err := svc.DecryptString(enc)
require.NoError(t, err)
assert.Equal(t, "test", dec)
}
func TestNewCryptoServiceFromBase64_Empty(t *testing.T) {
_, err := NewCryptoServiceFromBase64("")
require.Error(t, err)
}

View file

@ -9,6 +9,8 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"veza-backend-api/internal/database"
@ -33,6 +35,9 @@ type OAuthService struct {
sessionService *SessionService
userService *UserService
circuitBreaker *CircuitBreakerHTTPClient
cryptoService *CryptoService // v0.902: encrypt OAuth provider tokens at rest (nil = store plaintext)
frontendURL string // v0.902: default redirect URL
allowedDomains []string // v0.902: whitelist for OAuth redirect (OAUTH_ALLOWED_REDIRECT_DOMAINS)
}
// OAuthAccount represents an OAuth account linking
@ -52,20 +57,29 @@ type OAuthAccount struct {
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// OAuthState represents an OAuth state for CSRF protection
// OAuthState represents an OAuth state for CSRF protection and PKCE
type OAuthState struct {
ID int64 `db:"id"`
StateToken string `db:"state_token"`
Provider string `db:"provider"`
RedirectURL string `db:"redirect_url"`
ExpiresAt time.Time `db:"expires_at"`
CreatedAt time.Time `db:"created_at"`
ID int64 `db:"id"`
StateToken string `db:"state_token"`
Provider string `db:"provider"`
RedirectURL string `db:"redirect_url"`
CodeVerifier string `db:"code_verifier"`
ExpiresAt time.Time `db:"expires_at"`
CreatedAt time.Time `db:"created_at"`
}
// OAuthServiceConfig holds optional config for OAuth (v0.902)
type OAuthServiceConfig struct {
CryptoService *CryptoService
FrontendURL string
AllowedDomains []string
}
// NewOAuthService creates a new OAuth service
func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTService, sessionService *SessionService, userService *UserService) *OAuthService {
// cfg: optional config for crypto, redirect validation (v0.902)
func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTService, sessionService *SessionService, userService *UserService, cfg *OAuthServiceConfig) *OAuthService {
httpClient := &http.Client{Timeout: 10 * time.Second}
return &OAuthService{
svc := &OAuthService{
db: db,
logger: logger,
jwtService: jwtService,
@ -73,6 +87,12 @@ func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTS
userService: userService,
circuitBreaker: NewCircuitBreakerHTTPClient(httpClient, "oauth-service", logger),
}
if cfg != nil {
svc.cryptoService = cfg.CryptoService
svc.frontendURL = cfg.FrontendURL
svc.allowedDomains = cfg.AllowedDomains
}
return svc
}
// InitializeConfigs initializes OAuth configurations
@ -154,39 +174,42 @@ func (os *OAuthService) GetAvailableProviders() []string {
return providers
}
// GenerateStateToken generates a secure state token for CSRF protection
func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (string, error) {
// Generate random token
tokenBytes := make([]byte, 32)
_, err := rand.Read(tokenBytes)
if err != nil {
return "", err
}
stateToken := base64.URLEncoding.EncodeToString(tokenBytes)
// GenerateStateToken generates a secure state token for CSRF protection and PKCE code_verifier
func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (stateToken, codeVerifier string, err error) {
// Generate PKCE code verifier (RFC 7636)
codeVerifier = oauth2.GenerateVerifier()
// Store in database
// Generate random state token
tokenBytes := make([]byte, 32)
_, err = rand.Read(tokenBytes)
if err != nil {
return "", "", err
}
stateToken = base64.URLEncoding.EncodeToString(tokenBytes)
// Store in database with code_verifier for PKCE
ctx := context.Background()
expiresAt := time.Now().Add(10 * time.Minute)
_, err = os.db.ExecContext(ctx, `
INSERT INTO oauth_states (state_token, provider, redirect_url, expires_at)
VALUES ($1, $2, $3, $4)
`, stateToken, provider, redirectURL, expiresAt)
INSERT INTO oauth_states (state_token, provider, redirect_url, code_verifier, expires_at)
VALUES ($1, $2, $3, $4, $5)
`, stateToken, provider, redirectURL, codeVerifier, expiresAt)
if err != nil {
return "", err
return "", "", err
}
os.logger.Debug("State token generated", zap.String("provider", provider))
return stateToken, nil
return stateToken, codeVerifier, nil
}
// ValidateStateToken validates and consumes a state token
// ValidateStateToken validates and consumes a state token (returns OAuthState with CodeVerifier for PKCE)
func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, error) {
ctx := context.Background()
var state OAuthState
err := os.db.QueryRowContext(ctx, `
SELECT id, state_token, provider, redirect_url, expires_at, created_at
SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at
FROM oauth_states
WHERE state_token = $1
`, stateToken).Scan(
@ -194,6 +217,7 @@ func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, erro
&state.StateToken,
&state.Provider,
&state.RedirectURL,
&state.CodeVerifier,
&state.ExpiresAt,
&state.CreatedAt,
)
@ -246,24 +270,36 @@ func (os *OAuthService) GetAuthURL(provider string) (string, error) {
return "", fmt.Errorf("unknown provider: %s", provider)
}
// Generate state token
stateToken, err := os.GenerateStateToken(provider, "")
// Generate state token with PKCE code_verifier
stateToken, codeVerifier, err := os.GenerateStateToken(provider, "")
if err != nil {
return "", err
}
// Return authorization URL
url := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline)
return url, nil
// Return authorization URL with PKCE (code_challenge, code_challenge_method=S256)
authURL := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier))
return authURL, nil
}
// HandleCallback processes the OAuth callback.
// ipAddress and userAgent are used for session creation (optional, can be empty).
func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*OAuthUser, *models.TokenPair, error) {
// Validate state
_, err := os.ValidateStateToken(state)
// Returns redirectURL (validated) for the handler to redirect to (v0.902).
func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*OAuthUser, *models.TokenPair, string, error) {
// Validate state and get code_verifier for PKCE
oauthState, err := os.ValidateStateToken(state)
if err != nil {
return nil, nil, err
return nil, nil, "", err
}
// v0.902: Validate redirect URL against whitelist (VEZA-SEC-010)
redirectURL := oauthState.RedirectURL
if redirectURL == "" {
redirectURL = os.frontendURL
}
if redirectURL != "" {
if err := os.validateRedirectURL(redirectURL); err != nil {
return nil, nil, "", err
}
}
var config *oauth2.Config
@ -277,47 +313,47 @@ func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, stat
case "spotify":
config = os.spotifyConfig
default:
return nil, nil, fmt.Errorf("unknown provider: %s", provider)
return nil, nil, "", fmt.Errorf("unknown provider: %s", provider)
}
if config == nil {
return nil, nil, fmt.Errorf("%s OAuth not configured", provider)
return nil, nil, "", fmt.Errorf("%s OAuth not configured", provider)
}
// Exchange code for token
token, err := config.Exchange(ctx, code)
// Exchange code for token with PKCE code_verifier
token, err := config.Exchange(ctx, code, oauth2.VerifierOption(oauthState.CodeVerifier))
if err != nil {
return nil, nil, err
return nil, nil, "", err
}
// Get user info from provider
oauthUser, err := os.getUserInfo(provider, token.AccessToken)
if err != nil {
return nil, nil, err
return nil, nil, "", err
}
// Check if user already exists (by provider account or email) — audit 1.8: OAuth ID lookup first
existingUser, err := os.getOrCreateUser(provider, oauthUser)
if err != nil {
return nil, nil, err
return nil, nil, "", err
}
// Save/update OAuth account
err = os.saveOAuthAccount(provider, oauthUser, existingUser.ID, token)
if err != nil {
return nil, nil, err
return nil, nil, "", err
}
// VEZA-SEC-001: Get full user for JWT (TokenVersion, Role, etc.)
user, err := os.userService.GetByID(existingUser.ID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get user: %w", err)
return nil, nil, "", fmt.Errorf("failed to get user: %w", err)
}
// Generate tokens via JWTService (proper issuer, audience, token_version)
tokens, err := os.jwtService.GenerateTokenPair(user)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate tokens: %w", err)
return nil, nil, "", fmt.Errorf("failed to generate tokens: %w", err)
}
// Create session for refresh token validation
@ -339,7 +375,43 @@ func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, stat
return &OAuthUser{
ID: existingUser.ID,
Email: existingUser.Email,
}, tokens, nil
}, tokens, redirectURL, nil
}
// validateRedirectURL checks that the redirect URL's origin is in the allowed domains whitelist
func (os *OAuthService) validateRedirectURL(redirectURL string) error {
parsed, err := url.Parse(redirectURL)
if err != nil {
return fmt.Errorf("invalid redirect URL: %w", err)
}
if parsed.Scheme == "" || parsed.Host == "" {
return fmt.Errorf("invalid redirect URL: missing scheme or host")
}
redirectOrigin := parsed.Scheme + "://" + parsed.Host
if len(os.allowedDomains) == 0 {
// Dev fallback: allow localhost
if strings.HasPrefix(redirectOrigin, "http://localhost") || strings.HasPrefix(redirectOrigin, "http://127.0.0.1") {
return nil
}
return fmt.Errorf("redirect URL not allowed: %s (no whitelist configured)", redirectURL)
}
for _, allowed := range os.allowedDomains {
allowed = strings.TrimSpace(allowed)
if allowed == "*" {
return nil
}
allowedParsed, err := url.Parse(allowed)
if err != nil || allowedParsed.Scheme == "" || allowedParsed.Host == "" {
continue
}
allowedOrigin := allowedParsed.Scheme + "://" + allowedParsed.Host
if redirectOrigin == allowedOrigin {
return nil
}
}
return fmt.Errorf("redirect URL not allowed: %s", redirectURL)
}
// OAuthUser represents an OAuth authenticated user
@ -574,10 +646,24 @@ func (os *OAuthService) getOrCreateUser(provider string, oauthUser *OAuthUser) (
}
// saveOAuthAccount saves or updates OAuth account information
// Uses federated_identities table
// Uses federated_identities table. Tokens are encrypted at rest when CryptoService is set (v0.902)
func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser, userID uuid.UUID, token *oauth2.Token) error {
ctx := context.Background()
accessToken := token.AccessToken
refreshToken := token.RefreshToken
if os.cryptoService != nil {
var errEnc error
accessToken, errEnc = os.cryptoService.EncryptString(token.AccessToken)
if errEnc != nil {
return fmt.Errorf("encrypt access token: %w", errEnc)
}
refreshToken, errEnc = os.cryptoService.EncryptString(token.RefreshToken)
if errEnc != nil {
return fmt.Errorf("encrypt refresh token: %w", errEnc)
}
}
// Check if OAuth account already exists
var existingID uuid.UUID
err := os.db.QueryRowContext(ctx, `
@ -591,7 +677,7 @@ func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser,
UPDATE federated_identities
SET email = $1, display_name = $2, access_token = $3, refresh_token = $4, expires_at = $5, updated_at = NOW()
WHERE id = $6
`, oauthUser.Email, oauthUser.Name, token.AccessToken, token.RefreshToken, token.Expiry, existingID)
`, oauthUser.Email, oauthUser.Name, accessToken, refreshToken, token.Expiry, existingID)
return err
}
@ -603,7 +689,7 @@ func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser,
_, err = os.db.ExecContext(ctx, `
INSERT INTO federated_identities (id, user_id, provider, provider_id, email, display_name, avatar_url, access_token, refresh_token, expires_at, created_at, updated_at)
VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW())
`, userID, provider, oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, token.AccessToken, token.RefreshToken, token.Expiry)
`, userID, provider, oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, accessToken, refreshToken, token.Expiry)
return err
}

View file

@ -32,7 +32,7 @@ func setupOAuthServiceForTests(t *testing.T, db *database.Database) *OAuthServic
userRepo := repositories.NewGormUserRepository(db.GormDB)
userService = NewUserServiceWithDB(userRepo, db.GormDB)
}
return NewOAuthService(db, zap.NewNop(), jwtService, sessionService, userService)
return NewOAuthService(db, zap.NewNop(), jwtService, sessionService, userService, nil)
}
// Helper to setup mock DB
@ -62,17 +62,18 @@ func TestOAuthService_GenerateStateToken_Success(t *testing.T) {
provider := "google"
redirectURL := "http://example.com"
// Expectation
// Expectation: state_token, provider, redirect_url, code_verifier, expires_at
mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO oauth_states`)).
WithArgs(sqlmock.AnyArg(), provider, redirectURL, sqlmock.AnyArg()).
WithArgs(sqlmock.AnyArg(), provider, redirectURL, sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
// Execute
token, err := service.GenerateStateToken(provider, redirectURL)
stateToken, codeVerifier, err := service.GenerateStateToken(provider, redirectURL)
// Assert
assert.NoError(t, err)
assert.NotEmpty(t, token)
assert.NotEmpty(t, stateToken)
assert.NotEmpty(t, codeVerifier)
assert.NoError(t, mock.ExpectationsWereMet())
}
@ -90,11 +91,11 @@ func TestOAuthService_ValidateStateToken_Success(t *testing.T) {
token := "valid_token"
now := time.Now()
// Expectation
rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "expires_at", "created_at"}).
AddRow(1, token, "google", "http://example.com", now.Add(time.Hour), now)
// Expectation: include code_verifier for PKCE
rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "code_verifier", "expires_at", "created_at"}).
AddRow(1, token, "google", "http://example.com", "pkce_verifier_123", now.Add(time.Hour), now)
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
WithArgs(token).
WillReturnRows(rows)
@ -109,6 +110,7 @@ func TestOAuthService_ValidateStateToken_Success(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, state)
assert.Equal(t, token, state.StateToken)
assert.Equal(t, "pkce_verifier_123", state.CodeVerifier)
assert.NoError(t, mock.ExpectationsWereMet())
}
@ -126,7 +128,7 @@ func TestOAuthService_ValidateStateToken_NotFound(t *testing.T) {
token := "invalid_token"
// Expectation
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
WithArgs(token).
WillReturnError(sql.ErrNoRows)
@ -155,10 +157,10 @@ func TestOAuthService_ValidateStateToken_Expired(t *testing.T) {
now := time.Now()
// Expectation
rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "expires_at", "created_at"}).
AddRow(1, token, "google", "http://example.com", now.Add(-time.Hour), now.Add(-2*time.Hour))
rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "code_verifier", "expires_at", "created_at"}).
AddRow(1, token, "google", "http://example.com", "verifier", now.Add(-time.Hour), now.Add(-2*time.Hour))
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
WithArgs(token).
WillReturnRows(rows)
@ -180,7 +182,7 @@ func TestOAuthService_GetAuthURL_Discord(t *testing.T) {
svc.InitializeConfigs("", "", "", "", "discord-client", "discord-secret", "", "", "http://localhost:8080")
mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO oauth_states`)).
WithArgs(sqlmock.AnyArg(), "discord", "", sqlmock.AnyArg()).
WithArgs(sqlmock.AnyArg(), "discord", "", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
url, err := svc.GetAuthURL("discord")
@ -191,6 +193,7 @@ func TestOAuthService_GetAuthURL_Discord(t *testing.T) {
assert.Contains(t, url, "identify")
assert.Contains(t, url, "email")
assert.Contains(t, url, "discord-client")
assert.Contains(t, url, "code_challenge")
assert.NoError(t, mock.ExpectationsWereMet())
}
@ -202,7 +205,7 @@ func TestOAuthService_GetAuthURL_Spotify(t *testing.T) {
svc.InitializeConfigs("", "", "", "", "", "", "spotify-client", "spotify-secret", "http://localhost:8080")
mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO oauth_states`)).
WithArgs(sqlmock.AnyArg(), "spotify", "", sqlmock.AnyArg()).
WithArgs(sqlmock.AnyArg(), "spotify", "", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
url, err := svc.GetAuthURL("spotify")
@ -213,6 +216,8 @@ func TestOAuthService_GetAuthURL_Spotify(t *testing.T) {
assert.Contains(t, url, "user-read-email")
assert.Contains(t, url, "user-read-private")
assert.Contains(t, url, "spotify-client")
assert.Contains(t, url, "code_challenge")
assert.Contains(t, url, "code_challenge_method=S256")
assert.NoError(t, mock.ExpectationsWereMet())
}
@ -295,6 +300,16 @@ func TestOAuthService_GetUserInfo_Spotify(t *testing.T) {
assert.Equal(t, "https://avatar.url", user.Avatar)
}
func TestOAuthService_ValidateRedirectURL_EvilDomain(t *testing.T) {
// v0.902: redirect to evil.com should be rejected
svc := NewOAuthService(nil, zap.NewNop(), nil, nil, nil, &OAuthServiceConfig{
AllowedDomains: []string{"https://app.veza.com", "https://veza.fr:5173"},
})
err := svc.validateRedirectURL("https://evil.com/auth/callback")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not allowed")
}
func TestOAuthService_GetUserInfo_Spotify_FallbackEmail(t *testing.T) {
// Spotify without email - should fallback to id@spotify.user
spotifyJSON := `{"id":"spot456","display_name":"","images":[]}`

View file

@ -0,0 +1,18 @@
-- 936_oauth_states_pkce.sql
-- OAuth states table with PKCE code_verifier support (v0.902 Sentinel)
CREATE TABLE IF NOT EXISTS public.oauth_states (
id BIGSERIAL PRIMARY KEY,
state_token VARCHAR(255) NOT NULL UNIQUE,
provider VARCHAR(50) NOT NULL,
redirect_url TEXT,
code_verifier VARCHAR(255),
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_oauth_states_state_token ON public.oauth_states(state_token);
CREATE INDEX IF NOT EXISTS idx_oauth_states_expires_at ON public.oauth_states(expires_at);
-- If table already exists (without code_verifier), add the column
ALTER TABLE public.oauth_states ADD COLUMN IF NOT EXISTS code_verifier VARCHAR(255);