release(v0.902): Sentinel - PKCE OAuth, token encryption, redirect validation, CHAT_JWT_SECRET
- PKCE (S256) in OAuth flow: code_verifier in oauth_states, code_challenge in auth URL - CryptoService: AES-256-GCM encryption for OAuth provider tokens at rest - OAuth redirect URL validated against OAUTH_ALLOWED_REDIRECT_DOMAINS - CHAT_JWT_SECRET must differ from JWT_SECRET in production - Migration script: cmd/tools/encrypt_oauth_tokens for existing tokens - Fixes: VEZA-SEC-003, VEZA-SEC-004, VEZA-SEC-009, VEZA-SEC-010
This commit is contained in:
parent
51984e9a1f
commit
6823e5a30d
13 changed files with 734 additions and 81 deletions
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
|||
0.901
|
||||
0.902
|
||||
|
|
|
|||
148
veza-backend-api/cmd/tools/encrypt_oauth_tokens/main.go
Normal file
148
veza-backend-api/cmd/tools/encrypt_oauth_tokens/main.go
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
// encrypt_oauth_tokens encrypts existing OAuth provider tokens in federated_identities (v0.902).
|
||||
// Idempotent: skips tokens already prefixed with veza_enc_v1:
|
||||
// Usage: DATABASE_URL=... OAUTH_ENCRYPTION_KEY=... go run ./cmd/tools/encrypt_oauth_tokens [-dry-run]
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"veza-backend-api/internal/services"
|
||||
)
|
||||
|
||||
const encryptedPrefix = "veza_enc_v1:"
|
||||
|
||||
func main() {
|
||||
dryRun := flag.Bool("dry-run", false, "Show what would be updated without making changes")
|
||||
flag.Parse()
|
||||
|
||||
dbURL := os.Getenv("DATABASE_URL")
|
||||
if dbURL == "" {
|
||||
log.Fatal("DATABASE_URL is required")
|
||||
}
|
||||
encKey := os.Getenv("OAUTH_ENCRYPTION_KEY")
|
||||
if encKey == "" {
|
||||
log.Fatal("OAUTH_ENCRYPTION_KEY is required (32+ bytes, base64 or raw)")
|
||||
}
|
||||
|
||||
cryptoService, err := services.NewCryptoServiceFromBase64(encKey)
|
||||
if err != nil {
|
||||
keyBytes := []byte(encKey)
|
||||
if len(keyBytes) >= 32 {
|
||||
cryptoService, err = services.NewCryptoService(keyBytes)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("CryptoService: %v", err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dbURL)
|
||||
if err != nil {
|
||||
log.Fatalf("DB connect: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
log.Fatalf("DB ping: %v", err)
|
||||
}
|
||||
|
||||
rows, err := db.QueryContext(ctx, `
|
||||
SELECT id::text, access_token, refresh_token
|
||||
FROM federated_identities
|
||||
WHERE (access_token IS NOT NULL AND access_token != '')
|
||||
OR (refresh_token IS NOT NULL AND refresh_token != '')
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatalf("Query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var id string
|
||||
var accessToken, refreshToken sql.NullString
|
||||
updated := 0
|
||||
skipped := 0
|
||||
errors := 0
|
||||
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&id, &accessToken, &refreshToken); err != nil {
|
||||
log.Printf("Scan error: %v", err)
|
||||
errors++
|
||||
continue
|
||||
}
|
||||
|
||||
needUpdate := false
|
||||
newAccess := accessToken.String
|
||||
newRefresh := refreshToken.String
|
||||
|
||||
if accessToken.Valid && accessToken.String != "" && !strings.HasPrefix(accessToken.String, encryptedPrefix) {
|
||||
newAccess, err = cryptoService.EncryptString(accessToken.String)
|
||||
if err != nil {
|
||||
log.Printf("Encrypt access_token for id %s: %v", id, err)
|
||||
errors++
|
||||
continue
|
||||
}
|
||||
needUpdate = true
|
||||
}
|
||||
if refreshToken.Valid && refreshToken.String != "" && !strings.HasPrefix(refreshToken.String, encryptedPrefix) {
|
||||
newRefresh, err = cryptoService.EncryptString(refreshToken.String)
|
||||
if err != nil {
|
||||
log.Printf("Encrypt refresh_token for id %s: %v", id, err)
|
||||
errors++
|
||||
continue
|
||||
}
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
if !needUpdate {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
if *dryRun {
|
||||
log.Printf("[dry-run] Would encrypt tokens for id %s", id)
|
||||
updated++
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = db.ExecContext(ctx, `
|
||||
UPDATE federated_identities
|
||||
SET access_token = $1, refresh_token = $2, updated_at = NOW()
|
||||
WHERE id = $3
|
||||
`, nullIfEmpty(newAccess), nullIfEmpty(newRefresh), id)
|
||||
if err != nil {
|
||||
log.Printf("UPDATE id %s: %v", id, err)
|
||||
errors++
|
||||
continue
|
||||
}
|
||||
updated++
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Fatalf("Rows: %v", err)
|
||||
}
|
||||
|
||||
mode := ""
|
||||
if *dryRun {
|
||||
mode = " [dry-run]"
|
||||
}
|
||||
log.Printf("Done%s: updated=%d skipped=%d errors=%d", mode, updated, skipped, errors)
|
||||
if errors > 0 {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func nullIfEmpty(s string) interface{} {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
|
@ -114,8 +114,30 @@ func (r *APIRouter) setupAuthRoutes(router *gin.RouterGroup) error {
|
|||
}
|
||||
checkUsernameGroup.GET("/check-username", handlers.CheckUsername(authService))
|
||||
|
||||
// BE-API-042: OAuth routes
|
||||
oauthService := services.NewOAuthService(r.db, r.logger, jwtService, sessionService, userService)
|
||||
// BE-API-042: OAuth routes (v0.902: CryptoService, redirect validation)
|
||||
oauthCfg := &services.OAuthServiceConfig{
|
||||
FrontendURL: r.config.FrontendURL,
|
||||
AllowedDomains: r.config.OAuthAllowedRedirectDomains,
|
||||
}
|
||||
if r.config.OAuthEncryptionKey != "" {
|
||||
var cryptoService *services.CryptoService
|
||||
var err error
|
||||
cryptoService, err = services.NewCryptoServiceFromBase64(r.config.OAuthEncryptionKey)
|
||||
if err != nil {
|
||||
// Fallback: use raw bytes if key is long enough
|
||||
keyBytes := []byte(r.config.OAuthEncryptionKey)
|
||||
if len(keyBytes) >= 32 {
|
||||
cryptoService, err = services.NewCryptoService(keyBytes)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("OAuth CryptoService: %w", err)
|
||||
}
|
||||
if cryptoService != nil {
|
||||
oauthCfg.CryptoService = cryptoService
|
||||
}
|
||||
}
|
||||
oauthService := services.NewOAuthService(r.db, r.logger, jwtService, sessionService, userService, oauthCfg)
|
||||
baseURL := os.Getenv("BASE_URL")
|
||||
if baseURL == "" {
|
||||
appDomain := os.Getenv("APP_DOMAIN")
|
||||
|
|
|
|||
|
|
@ -81,6 +81,10 @@ type Config struct {
|
|||
CORSOrigins []string // Liste des origines CORS autorisées
|
||||
FrontendURL string // URL du frontend (OAuth redirect, password reset links). FRONTEND_URL ou VITE_FRONTEND_URL
|
||||
|
||||
// OAuth Security (v0.902 Sentinel)
|
||||
OAuthEncryptionKey string // OAUTH_ENCRYPTION_KEY: 32 bytes for AES-256-GCM (required in production)
|
||||
OAuthAllowedRedirectDomains []string // OAUTH_ALLOWED_REDIRECT_DOMAINS: whitelist for OAuth redirect URLs
|
||||
|
||||
// HLS Streaming Configuration (v0.503)
|
||||
HLSEnabled bool // Enable HLS streaming routes
|
||||
HLSStorageDir string // Directory for HLS segment storage
|
||||
|
|
@ -304,6 +308,10 @@ func NewConfig() (*Config, error) {
|
|||
CORSOrigins: corsOrigins,
|
||||
FrontendURL: getFrontendURL(), // OAuth callback, password reset, email links
|
||||
|
||||
// OAuth Security (v0.902 Sentinel)
|
||||
OAuthEncryptionKey: getEnv("OAUTH_ENCRYPTION_KEY", ""),
|
||||
OAuthAllowedRedirectDomains: getOAuthAllowedRedirectDomains(env, getEnvStringSlice("OAUTH_ALLOWED_REDIRECT_DOMAINS", nil), corsOrigins, getFrontendURL()),
|
||||
|
||||
// HLS Streaming (v0.503)
|
||||
HLSEnabled: getEnvBool("HLS_STREAMING", false),
|
||||
HLSStorageDir: getEnv("HLS_STORAGE_DIR", "/tmp/veza-hls"),
|
||||
|
|
@ -843,6 +851,16 @@ func (c *Config) ValidateForEnvironment() error {
|
|||
return fmt.Errorf("CLAMAV_REQUIRED must be true in production. Virus scanning is mandatory for uploads")
|
||||
}
|
||||
|
||||
// 6. v0.902: CHAT_JWT_SECRET must differ from JWT_SECRET in production (VEZA-SEC-009)
|
||||
if c.ChatJWTSecret == c.JWTSecret {
|
||||
return fmt.Errorf("CHAT_JWT_SECRET must be different from JWT_SECRET in production. Use a separate secret for the Chat Server")
|
||||
}
|
||||
|
||||
// 7. v0.902: OAUTH_ENCRYPTION_KEY required in production for OAuth token encryption (VEZA-SEC-004)
|
||||
if len(c.OAuthEncryptionKey) < 32 {
|
||||
return fmt.Errorf("OAUTH_ENCRYPTION_KEY is required in production (min 32 bytes for AES-256). Set OAUTH_ENCRYPTION_KEY with a 32-byte hex or base64 key")
|
||||
}
|
||||
|
||||
case EnvTest:
|
||||
// TEST: Validation adaptée aux tests
|
||||
// CORS peut être vide ou configuré explicitement
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// getCookieSecure détermine si les cookies doivent être Secure
|
||||
|
|
@ -104,3 +106,27 @@ func getCORSOrigins(env string, appDomain string) []string {
|
|||
return []string{"http://" + appDomain + ":3000", "http://" + appDomain + ":5173"}
|
||||
}
|
||||
}
|
||||
|
||||
// getOAuthAllowedRedirectDomains returns the whitelist of domains allowed for OAuth redirect (v0.902).
|
||||
// If OAUTH_ALLOWED_REDIRECT_DOMAINS is set, use it. Otherwise derive from CORSOrigins or FrontendURL.
|
||||
func getOAuthAllowedRedirectDomains(env string, explicit []string, corsOrigins []string, frontendURL string) []string {
|
||||
if len(explicit) > 0 {
|
||||
return explicit
|
||||
}
|
||||
if len(corsOrigins) > 0 {
|
||||
return corsOrigins
|
||||
}
|
||||
if frontendURL != "" {
|
||||
if u, err := url.Parse(frontendURL); err == nil {
|
||||
origin := strings.TrimSuffix(u.String(), "/")
|
||||
if u.Scheme != "" && u.Host != "" {
|
||||
return []string{origin}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Dev fallback
|
||||
if env == EnvDevelopment || env == EnvStaging {
|
||||
return []string{"http://localhost:5173", "http://localhost:3000", "http://127.0.0.1:5173", "http://127.0.0.1:3000"}
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -425,3 +425,28 @@ func TestValidateForEnvironment_ClamAVRequiredInProduction(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestValidateForEnvironment_ChatJWTSecretInProduction verifies CHAT_JWT_SECRET must differ from JWT_SECRET in production (v0.902)
|
||||
func TestValidateForEnvironment_ChatJWTSecretInProduction(t *testing.T) {
|
||||
secret := strings.Repeat("a", 32)
|
||||
cfg := &Config{
|
||||
Env: EnvProduction,
|
||||
AppPort: 8080,
|
||||
JWTSecret: secret,
|
||||
ChatJWTSecret: secret, // Same as JWT_SECRET - should fail
|
||||
DatabaseURL: "postgresql://user:pass@localhost:5432/db",
|
||||
RedisURL: "redis://localhost:6379",
|
||||
RateLimitLimit: 100,
|
||||
RateLimitWindow: 60,
|
||||
CORSOrigins: []string{"https://example.com"},
|
||||
LogLevel: "INFO",
|
||||
OAuthEncryptionKey: strings.Repeat("b", 32),
|
||||
}
|
||||
logger, _ := zap.NewDevelopment()
|
||||
cfg.Logger = logger
|
||||
|
||||
os.Setenv("CLAMAV_REQUIRED", "true")
|
||||
err := cfg.ValidateForEnvironment()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "CHAT_JWT_SECRET must be different from JWT_SECRET")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import (
|
|||
// OAuthServiceInterface defines the methods needed for OAuth handlers
|
||||
type OAuthServiceInterface interface {
|
||||
GetAuthURL(provider string) (string, error)
|
||||
HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, error)
|
||||
HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, string, error)
|
||||
GetAvailableProviders() []string
|
||||
}
|
||||
|
||||
|
|
@ -134,16 +134,22 @@ func (oh *OAuthHandlers) OAuthCallback(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
// Handle callback (VEZA-SEC-001: returns TokenPair, creates session)
|
||||
user, tokens, err := oh.oauthService.HandleCallback(c.Request.Context(), provider, code, state, c.ClientIP(), c.Request.UserAgent())
|
||||
// Handle callback (VEZA-SEC-001: returns TokenPair, creates session; v0.902: returns validated redirectURL)
|
||||
user, tokens, redirectURL, err := oh.oauthService.HandleCallback(c.Request.Context(), provider, code, state, c.ClientIP(), c.Request.UserAgent())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// SECURITY: Validate redirect URL against allowlist to prevent open redirect
|
||||
frontendURL := oh.frontendURL
|
||||
if !oh.isAllowedRedirectOrigin(frontendURL) {
|
||||
// Use validated redirect URL from service as base, or fallback to frontendURL
|
||||
baseURL := oh.frontendURL
|
||||
if redirectURL != "" {
|
||||
frontendURLParsed, _ := url.Parse(redirectURL)
|
||||
if frontendURLParsed != nil && frontendURLParsed.Scheme != "" && frontendURLParsed.Host != "" {
|
||||
baseURL = frontendURLParsed.Scheme + "://" + frontendURLParsed.Host
|
||||
}
|
||||
}
|
||||
if !oh.isAllowedRedirectOrigin(baseURL) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid redirect configuration"})
|
||||
return
|
||||
}
|
||||
|
|
@ -181,8 +187,8 @@ func (oh *OAuthHandlers) OAuthCallback(c *gin.Context) {
|
|||
}
|
||||
|
||||
// Redirect to frontend (tokens in cookies, not URL)
|
||||
redirectURL := fmt.Sprintf("%s/auth/callback?user_id=%s", strings.TrimSuffix(frontendURL, "/"), user.ID.String())
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
finalRedirect := fmt.Sprintf("%s/auth/callback?user_id=%s", strings.TrimSuffix(baseURL, "/"), user.ID.String())
|
||||
c.Redirect(http.StatusTemporaryRedirect, finalRedirect)
|
||||
}
|
||||
|
||||
// isAllowedRedirectOrigin validates that the frontend URL is in the allowlist. Returns true if allowed.
|
||||
|
|
|
|||
|
|
@ -28,15 +28,15 @@ func (m *MockOAuthService) GetAuthURL(provider string) (string, error) {
|
|||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockOAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, error) {
|
||||
func (m *MockOAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, string, error) {
|
||||
args := m.Called(ctx, provider, code, state, ipAddress, userAgent)
|
||||
if args.Get(0) == nil {
|
||||
return nil, nil, args.Error(2)
|
||||
return nil, nil, "", args.Error(3)
|
||||
}
|
||||
if args.Get(1) == nil {
|
||||
return args.Get(0).(*services.OAuthUser), nil, args.Error(2)
|
||||
return args.Get(0).(*services.OAuthUser), nil, args.String(2), args.Error(3)
|
||||
}
|
||||
return args.Get(0).(*services.OAuthUser), args.Get(1).(*models.TokenPair), args.Error(2)
|
||||
return args.Get(0).(*services.OAuthUser), args.Get(1).(*models.TokenPair), args.String(2), args.Error(3)
|
||||
}
|
||||
|
||||
func (m *MockOAuthService) GetAvailableProviders() []string {
|
||||
|
|
@ -145,7 +145,7 @@ func TestOAuthHandlers_OAuthCallback_Success(t *testing.T) {
|
|||
ExpiresIn: int(5 * time.Minute.Seconds()),
|
||||
}
|
||||
|
||||
mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(mockUser, tokens, nil)
|
||||
mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(mockUser, tokens, "", nil)
|
||||
|
||||
// Execute
|
||||
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google/callback?code=test-code&state=test-state", nil)
|
||||
|
|
@ -197,7 +197,7 @@ func TestOAuthHandlers_OAuthCallback_ServiceError(t *testing.T) {
|
|||
mockService := new(MockOAuthService)
|
||||
router := setupTestOAuthRouter(mockService)
|
||||
|
||||
mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(nil, nil, assert.AnError)
|
||||
mockService.On("HandleCallback", mock.Anything, "google", "test-code", "test-state", mock.Anything, mock.Anything).Return(nil, nil, "", assert.AnError)
|
||||
|
||||
// Execute
|
||||
req, _ := http.NewRequest("GET", "/api/v1/auth/oauth/google/callback?code=test-code&state=test-state", nil)
|
||||
|
|
|
|||
128
veza-backend-api/internal/services/crypto_service.go
Normal file
128
veza-backend-api/internal/services/crypto_service.go
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
// NonceSize for AES-GCM (12 bytes recommended by NIST)
|
||||
gcmNonceSize = 12
|
||||
// Prefix for encrypted tokens stored in DB (detect already-encrypted)
|
||||
encryptedTokenPrefix = "veza_enc_v1:"
|
||||
)
|
||||
|
||||
// CryptoService provides AES-256-GCM encryption for sensitive data (e.g. OAuth tokens at rest)
|
||||
type CryptoService struct {
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
// NewCryptoService creates a CryptoService with the given key (32 bytes for AES-256)
|
||||
func NewCryptoService(key []byte) (*CryptoService, error) {
|
||||
if len(key) < 32 {
|
||||
return nil, errors.New("encryption key must be at least 32 bytes for AES-256")
|
||||
}
|
||||
// Use first 32 bytes
|
||||
key32 := key
|
||||
if len(key) > 32 {
|
||||
key32 = key[:32]
|
||||
}
|
||||
block, err := aes.NewCipher(key32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aes new cipher: %w", err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gcm: %w", err)
|
||||
}
|
||||
return &CryptoService{aead: aead}, nil
|
||||
}
|
||||
|
||||
// NewCryptoServiceFromBase64 creates a CryptoService from a base64-encoded key
|
||||
func NewCryptoServiceFromBase64(keyBase64 string) (*CryptoService, error) {
|
||||
if keyBase64 == "" {
|
||||
return nil, errors.New("encryption key must not be empty")
|
||||
}
|
||||
key, err := base64.RawStdEncoding.DecodeString(keyBase64)
|
||||
if err != nil {
|
||||
// Try standard base64
|
||||
key, err = base64.StdEncoding.DecodeString(keyBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode key: %w", err)
|
||||
}
|
||||
}
|
||||
return NewCryptoService(key)
|
||||
}
|
||||
|
||||
// NewCryptoServiceFromHex creates a CryptoService from a hex-encoded key (optional, for future)
|
||||
// For now we use base64. Key can also be raw bytes if passed as string - we'll decode.
|
||||
|
||||
// Encrypt encrypts plaintext with AES-256-GCM. Returns base64-encoded result.
|
||||
func (c *CryptoService) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
nonce := make([]byte, gcmNonceSize)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, fmt.Errorf("rand nonce: %w", err)
|
||||
}
|
||||
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
|
||||
// Prepend nonce: nonce || ciphertext (ciphertext includes tag)
|
||||
out := make([]byte, 0, len(nonce)+len(ciphertext))
|
||||
out = append(out, nonce...)
|
||||
out = append(out, ciphertext...)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext (format: nonce || sealed). Returns plaintext.
|
||||
func (c *CryptoService) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) < gcmNonceSize {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
nonce := ciphertext[:gcmNonceSize]
|
||||
sealed := ciphertext[gcmNonceSize:]
|
||||
return c.aead.Open(nil, nonce, sealed, nil)
|
||||
}
|
||||
|
||||
// EncryptString encrypts a string and returns the prefixed base64 result for DB storage
|
||||
func (c *CryptoService) EncryptString(plaintext string) (string, error) {
|
||||
if plaintext == "" {
|
||||
return "", nil
|
||||
}
|
||||
enc, err := c.Encrypt([]byte(plaintext))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return encryptedTokenPrefix + base64.RawStdEncoding.EncodeToString(enc), nil
|
||||
}
|
||||
|
||||
// DecryptString decrypts a string stored with EncryptString (checks prefix)
|
||||
func (c *CryptoService) DecryptString(stored string) (string, error) {
|
||||
if stored == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(stored) < len(encryptedTokenPrefix) || stored[:len(encryptedTokenPrefix)] != encryptedTokenPrefix {
|
||||
// Not encrypted (legacy plaintext)
|
||||
return stored, nil
|
||||
}
|
||||
b64 := stored[len(encryptedTokenPrefix):]
|
||||
enc, err := base64.RawStdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
enc, err = base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode stored: %w", err)
|
||||
}
|
||||
}
|
||||
dec, err := c.Decrypt(enc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(dec), nil
|
||||
}
|
||||
|
||||
// EncryptedTokenPrefix returns the prefix used for encrypted tokens
|
||||
func EncryptedTokenPrefix() string {
|
||||
return encryptedTokenPrefix
|
||||
}
|
||||
161
veza-backend-api/internal/services/crypto_service_test.go
Normal file
161
veza-backend-api/internal/services/crypto_service_test.go
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewCryptoService_InvalidKey(t *testing.T) {
|
||||
_, err := NewCryptoService([]byte("short"))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "32 bytes")
|
||||
}
|
||||
|
||||
func TestNewCryptoService_ValidKey(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
svc, err := NewCryptoService(key)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, svc)
|
||||
}
|
||||
|
||||
func TestCryptoService_EncryptDecrypt_Roundtrip(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
svc, err := NewCryptoService(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("sensitive-oauth-token-value")
|
||||
enc, err := svc.Encrypt(plaintext)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, plaintext, enc)
|
||||
assert.NotEmpty(t, enc)
|
||||
|
||||
dec, err := svc.Decrypt(enc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, dec)
|
||||
}
|
||||
|
||||
func TestCryptoService_EncryptDecrypt_DifferentEachTime(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
svc, err := NewCryptoService(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("token")
|
||||
enc1, err := svc.Encrypt(plaintext)
|
||||
require.NoError(t, err)
|
||||
enc2, err := svc.Encrypt(plaintext)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, enc1, enc2, "encryption should be non-deterministic (random nonce)")
|
||||
|
||||
dec1, err := svc.Decrypt(enc1)
|
||||
require.NoError(t, err)
|
||||
dec2, err := svc.Decrypt(enc2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, dec1)
|
||||
assert.Equal(t, plaintext, dec2)
|
||||
}
|
||||
|
||||
func TestCryptoService_EncryptString_DecryptString_Roundtrip(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
svc, err := NewCryptoService(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := "my-access-token-123"
|
||||
enc, err := svc.EncryptString(plaintext)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, strings.HasPrefix(enc, EncryptedTokenPrefix()))
|
||||
assert.NotEqual(t, plaintext, enc)
|
||||
|
||||
dec, err := svc.DecryptString(enc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, dec)
|
||||
}
|
||||
|
||||
func TestCryptoService_DecryptString_Plaintext(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
svc, err := NewCryptoService(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Legacy plaintext (no prefix) -> returned as-is
|
||||
plain := "legacy-token"
|
||||
dec, err := svc.DecryptString(plain)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plain, dec)
|
||||
}
|
||||
|
||||
func TestCryptoService_DecryptString_Empty(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
svc, err := NewCryptoService(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
dec, err := svc.DecryptString("")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, dec)
|
||||
}
|
||||
|
||||
func TestCryptoService_EncryptString_Empty(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
svc, err := NewCryptoService(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
enc, err := svc.EncryptString("")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, enc)
|
||||
}
|
||||
|
||||
func TestCryptoService_Decrypt_ModifiedCiphertext(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
svc, err := NewCryptoService(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
enc, err := svc.Encrypt([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
enc[20] ^= 0xff // flip a byte
|
||||
|
||||
_, err = svc.Decrypt(enc)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewCryptoServiceFromBase64(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
keyB64 := base64.RawStdEncoding.EncodeToString(key)
|
||||
|
||||
svc, err := NewCryptoServiceFromBase64(keyB64)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, svc)
|
||||
|
||||
enc, err := svc.EncryptString("test")
|
||||
require.NoError(t, err)
|
||||
dec, err := svc.DecryptString(enc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test", dec)
|
||||
}
|
||||
|
||||
func TestNewCryptoServiceFromBase64_Empty(t *testing.T) {
|
||||
_, err := NewCryptoServiceFromBase64("")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
|
@ -9,6 +9,8 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"veza-backend-api/internal/database"
|
||||
|
|
@ -33,6 +35,9 @@ type OAuthService struct {
|
|||
sessionService *SessionService
|
||||
userService *UserService
|
||||
circuitBreaker *CircuitBreakerHTTPClient
|
||||
cryptoService *CryptoService // v0.902: encrypt OAuth provider tokens at rest (nil = store plaintext)
|
||||
frontendURL string // v0.902: default redirect URL
|
||||
allowedDomains []string // v0.902: whitelist for OAuth redirect (OAUTH_ALLOWED_REDIRECT_DOMAINS)
|
||||
}
|
||||
|
||||
// OAuthAccount represents an OAuth account linking
|
||||
|
|
@ -52,20 +57,29 @@ type OAuthAccount struct {
|
|||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
}
|
||||
|
||||
// OAuthState represents an OAuth state for CSRF protection
|
||||
// OAuthState represents an OAuth state for CSRF protection and PKCE
|
||||
type OAuthState struct {
|
||||
ID int64 `db:"id"`
|
||||
StateToken string `db:"state_token"`
|
||||
Provider string `db:"provider"`
|
||||
RedirectURL string `db:"redirect_url"`
|
||||
ExpiresAt time.Time `db:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
ID int64 `db:"id"`
|
||||
StateToken string `db:"state_token"`
|
||||
Provider string `db:"provider"`
|
||||
RedirectURL string `db:"redirect_url"`
|
||||
CodeVerifier string `db:"code_verifier"`
|
||||
ExpiresAt time.Time `db:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
}
|
||||
|
||||
// OAuthServiceConfig holds optional config for OAuth (v0.902)
|
||||
type OAuthServiceConfig struct {
|
||||
CryptoService *CryptoService
|
||||
FrontendURL string
|
||||
AllowedDomains []string
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuth service
|
||||
func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTService, sessionService *SessionService, userService *UserService) *OAuthService {
|
||||
// cfg: optional config for crypto, redirect validation (v0.902)
|
||||
func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTService, sessionService *SessionService, userService *UserService, cfg *OAuthServiceConfig) *OAuthService {
|
||||
httpClient := &http.Client{Timeout: 10 * time.Second}
|
||||
return &OAuthService{
|
||||
svc := &OAuthService{
|
||||
db: db,
|
||||
logger: logger,
|
||||
jwtService: jwtService,
|
||||
|
|
@ -73,6 +87,12 @@ func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTS
|
|||
userService: userService,
|
||||
circuitBreaker: NewCircuitBreakerHTTPClient(httpClient, "oauth-service", logger),
|
||||
}
|
||||
if cfg != nil {
|
||||
svc.cryptoService = cfg.CryptoService
|
||||
svc.frontendURL = cfg.FrontendURL
|
||||
svc.allowedDomains = cfg.AllowedDomains
|
||||
}
|
||||
return svc
|
||||
}
|
||||
|
||||
// InitializeConfigs initializes OAuth configurations
|
||||
|
|
@ -154,39 +174,42 @@ func (os *OAuthService) GetAvailableProviders() []string {
|
|||
return providers
|
||||
}
|
||||
|
||||
// GenerateStateToken generates a secure state token for CSRF protection
|
||||
func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (string, error) {
|
||||
// Generate random token
|
||||
tokenBytes := make([]byte, 32)
|
||||
_, err := rand.Read(tokenBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
stateToken := base64.URLEncoding.EncodeToString(tokenBytes)
|
||||
// GenerateStateToken generates a secure state token for CSRF protection and PKCE code_verifier
|
||||
func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (stateToken, codeVerifier string, err error) {
|
||||
// Generate PKCE code verifier (RFC 7636)
|
||||
codeVerifier = oauth2.GenerateVerifier()
|
||||
|
||||
// Store in database
|
||||
// Generate random state token
|
||||
tokenBytes := make([]byte, 32)
|
||||
_, err = rand.Read(tokenBytes)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
stateToken = base64.URLEncoding.EncodeToString(tokenBytes)
|
||||
|
||||
// Store in database with code_verifier for PKCE
|
||||
ctx := context.Background()
|
||||
expiresAt := time.Now().Add(10 * time.Minute)
|
||||
_, err = os.db.ExecContext(ctx, `
|
||||
INSERT INTO oauth_states (state_token, provider, redirect_url, expires_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`, stateToken, provider, redirectURL, expiresAt)
|
||||
INSERT INTO oauth_states (state_token, provider, redirect_url, code_verifier, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`, stateToken, provider, redirectURL, codeVerifier, expiresAt)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
os.logger.Debug("State token generated", zap.String("provider", provider))
|
||||
return stateToken, nil
|
||||
return stateToken, codeVerifier, nil
|
||||
}
|
||||
|
||||
// ValidateStateToken validates and consumes a state token
|
||||
// ValidateStateToken validates and consumes a state token (returns OAuthState with CodeVerifier for PKCE)
|
||||
func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
var state OAuthState
|
||||
err := os.db.QueryRowContext(ctx, `
|
||||
SELECT id, state_token, provider, redirect_url, expires_at, created_at
|
||||
SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at
|
||||
FROM oauth_states
|
||||
WHERE state_token = $1
|
||||
`, stateToken).Scan(
|
||||
|
|
@ -194,6 +217,7 @@ func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, erro
|
|||
&state.StateToken,
|
||||
&state.Provider,
|
||||
&state.RedirectURL,
|
||||
&state.CodeVerifier,
|
||||
&state.ExpiresAt,
|
||||
&state.CreatedAt,
|
||||
)
|
||||
|
|
@ -246,24 +270,36 @@ func (os *OAuthService) GetAuthURL(provider string) (string, error) {
|
|||
return "", fmt.Errorf("unknown provider: %s", provider)
|
||||
}
|
||||
|
||||
// Generate state token
|
||||
stateToken, err := os.GenerateStateToken(provider, "")
|
||||
// Generate state token with PKCE code_verifier
|
||||
stateToken, codeVerifier, err := os.GenerateStateToken(provider, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Return authorization URL
|
||||
url := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline)
|
||||
return url, nil
|
||||
// Return authorization URL with PKCE (code_challenge, code_challenge_method=S256)
|
||||
authURL := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier))
|
||||
return authURL, nil
|
||||
}
|
||||
|
||||
// HandleCallback processes the OAuth callback.
|
||||
// ipAddress and userAgent are used for session creation (optional, can be empty).
|
||||
func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*OAuthUser, *models.TokenPair, error) {
|
||||
// Validate state
|
||||
_, err := os.ValidateStateToken(state)
|
||||
// Returns redirectURL (validated) for the handler to redirect to (v0.902).
|
||||
func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*OAuthUser, *models.TokenPair, string, error) {
|
||||
// Validate state and get code_verifier for PKCE
|
||||
oauthState, err := os.ValidateStateToken(state)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
// v0.902: Validate redirect URL against whitelist (VEZA-SEC-010)
|
||||
redirectURL := oauthState.RedirectURL
|
||||
if redirectURL == "" {
|
||||
redirectURL = os.frontendURL
|
||||
}
|
||||
if redirectURL != "" {
|
||||
if err := os.validateRedirectURL(redirectURL); err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
}
|
||||
|
||||
var config *oauth2.Config
|
||||
|
|
@ -277,47 +313,47 @@ func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, stat
|
|||
case "spotify":
|
||||
config = os.spotifyConfig
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unknown provider: %s", provider)
|
||||
return nil, nil, "", fmt.Errorf("unknown provider: %s", provider)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
return nil, nil, fmt.Errorf("%s OAuth not configured", provider)
|
||||
return nil, nil, "", fmt.Errorf("%s OAuth not configured", provider)
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := config.Exchange(ctx, code)
|
||||
// Exchange code for token with PKCE code_verifier
|
||||
token, err := config.Exchange(ctx, code, oauth2.VerifierOption(oauthState.CodeVerifier))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
// Get user info from provider
|
||||
oauthUser, err := os.getUserInfo(provider, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
// Check if user already exists (by provider account or email) — audit 1.8: OAuth ID lookup first
|
||||
existingUser, err := os.getOrCreateUser(provider, oauthUser)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
// Save/update OAuth account
|
||||
err = os.saveOAuthAccount(provider, oauthUser, existingUser.ID, token)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
// VEZA-SEC-001: Get full user for JWT (TokenVersion, Role, etc.)
|
||||
user, err := os.userService.GetByID(existingUser.ID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get user: %w", err)
|
||||
return nil, nil, "", fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Generate tokens via JWTService (proper issuer, audience, token_version)
|
||||
tokens, err := os.jwtService.GenerateTokenPair(user)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate tokens: %w", err)
|
||||
return nil, nil, "", fmt.Errorf("failed to generate tokens: %w", err)
|
||||
}
|
||||
|
||||
// Create session for refresh token validation
|
||||
|
|
@ -339,7 +375,43 @@ func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, stat
|
|||
return &OAuthUser{
|
||||
ID: existingUser.ID,
|
||||
Email: existingUser.Email,
|
||||
}, tokens, nil
|
||||
}, tokens, redirectURL, nil
|
||||
}
|
||||
|
||||
// validateRedirectURL checks that the redirect URL's origin is in the allowed domains whitelist
|
||||
func (os *OAuthService) validateRedirectURL(redirectURL string) error {
|
||||
parsed, err := url.Parse(redirectURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid redirect URL: %w", err)
|
||||
}
|
||||
if parsed.Scheme == "" || parsed.Host == "" {
|
||||
return fmt.Errorf("invalid redirect URL: missing scheme or host")
|
||||
}
|
||||
redirectOrigin := parsed.Scheme + "://" + parsed.Host
|
||||
|
||||
if len(os.allowedDomains) == 0 {
|
||||
// Dev fallback: allow localhost
|
||||
if strings.HasPrefix(redirectOrigin, "http://localhost") || strings.HasPrefix(redirectOrigin, "http://127.0.0.1") {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("redirect URL not allowed: %s (no whitelist configured)", redirectURL)
|
||||
}
|
||||
|
||||
for _, allowed := range os.allowedDomains {
|
||||
allowed = strings.TrimSpace(allowed)
|
||||
if allowed == "*" {
|
||||
return nil
|
||||
}
|
||||
allowedParsed, err := url.Parse(allowed)
|
||||
if err != nil || allowedParsed.Scheme == "" || allowedParsed.Host == "" {
|
||||
continue
|
||||
}
|
||||
allowedOrigin := allowedParsed.Scheme + "://" + allowedParsed.Host
|
||||
if redirectOrigin == allowedOrigin {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("redirect URL not allowed: %s", redirectURL)
|
||||
}
|
||||
|
||||
// OAuthUser represents an OAuth authenticated user
|
||||
|
|
@ -574,10 +646,24 @@ func (os *OAuthService) getOrCreateUser(provider string, oauthUser *OAuthUser) (
|
|||
}
|
||||
|
||||
// saveOAuthAccount saves or updates OAuth account information
|
||||
// Uses federated_identities table
|
||||
// Uses federated_identities table. Tokens are encrypted at rest when CryptoService is set (v0.902)
|
||||
func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser, userID uuid.UUID, token *oauth2.Token) error {
|
||||
ctx := context.Background()
|
||||
|
||||
accessToken := token.AccessToken
|
||||
refreshToken := token.RefreshToken
|
||||
if os.cryptoService != nil {
|
||||
var errEnc error
|
||||
accessToken, errEnc = os.cryptoService.EncryptString(token.AccessToken)
|
||||
if errEnc != nil {
|
||||
return fmt.Errorf("encrypt access token: %w", errEnc)
|
||||
}
|
||||
refreshToken, errEnc = os.cryptoService.EncryptString(token.RefreshToken)
|
||||
if errEnc != nil {
|
||||
return fmt.Errorf("encrypt refresh token: %w", errEnc)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if OAuth account already exists
|
||||
var existingID uuid.UUID
|
||||
err := os.db.QueryRowContext(ctx, `
|
||||
|
|
@ -591,7 +677,7 @@ func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser,
|
|||
UPDATE federated_identities
|
||||
SET email = $1, display_name = $2, access_token = $3, refresh_token = $4, expires_at = $5, updated_at = NOW()
|
||||
WHERE id = $6
|
||||
`, oauthUser.Email, oauthUser.Name, token.AccessToken, token.RefreshToken, token.Expiry, existingID)
|
||||
`, oauthUser.Email, oauthUser.Name, accessToken, refreshToken, token.Expiry, existingID)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -603,7 +689,7 @@ func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser,
|
|||
_, err = os.db.ExecContext(ctx, `
|
||||
INSERT INTO federated_identities (id, user_id, provider, provider_id, email, display_name, avatar_url, access_token, refresh_token, expires_at, created_at, updated_at)
|
||||
VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW())
|
||||
`, userID, provider, oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, token.AccessToken, token.RefreshToken, token.Expiry)
|
||||
`, userID, provider, oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, accessToken, refreshToken, token.Expiry)
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ func setupOAuthServiceForTests(t *testing.T, db *database.Database) *OAuthServic
|
|||
userRepo := repositories.NewGormUserRepository(db.GormDB)
|
||||
userService = NewUserServiceWithDB(userRepo, db.GormDB)
|
||||
}
|
||||
return NewOAuthService(db, zap.NewNop(), jwtService, sessionService, userService)
|
||||
return NewOAuthService(db, zap.NewNop(), jwtService, sessionService, userService, nil)
|
||||
}
|
||||
|
||||
// Helper to setup mock DB
|
||||
|
|
@ -62,17 +62,18 @@ func TestOAuthService_GenerateStateToken_Success(t *testing.T) {
|
|||
provider := "google"
|
||||
redirectURL := "http://example.com"
|
||||
|
||||
// Expectation
|
||||
// Expectation: state_token, provider, redirect_url, code_verifier, expires_at
|
||||
mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO oauth_states`)).
|
||||
WithArgs(sqlmock.AnyArg(), provider, redirectURL, sqlmock.AnyArg()).
|
||||
WithArgs(sqlmock.AnyArg(), provider, redirectURL, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
// Execute
|
||||
token, err := service.GenerateStateToken(provider, redirectURL)
|
||||
stateToken, codeVerifier, err := service.GenerateStateToken(provider, redirectURL)
|
||||
|
||||
// Assert
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
assert.NotEmpty(t, stateToken)
|
||||
assert.NotEmpty(t, codeVerifier)
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
|
|
@ -90,11 +91,11 @@ func TestOAuthService_ValidateStateToken_Success(t *testing.T) {
|
|||
token := "valid_token"
|
||||
now := time.Now()
|
||||
|
||||
// Expectation
|
||||
rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "expires_at", "created_at"}).
|
||||
AddRow(1, token, "google", "http://example.com", now.Add(time.Hour), now)
|
||||
// Expectation: include code_verifier for PKCE
|
||||
rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "code_verifier", "expires_at", "created_at"}).
|
||||
AddRow(1, token, "google", "http://example.com", "pkce_verifier_123", now.Add(time.Hour), now)
|
||||
|
||||
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
|
||||
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
|
||||
WithArgs(token).
|
||||
WillReturnRows(rows)
|
||||
|
||||
|
|
@ -109,6 +110,7 @@ func TestOAuthService_ValidateStateToken_Success(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.NotNil(t, state)
|
||||
assert.Equal(t, token, state.StateToken)
|
||||
assert.Equal(t, "pkce_verifier_123", state.CodeVerifier)
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
|
|
@ -126,7 +128,7 @@ func TestOAuthService_ValidateStateToken_NotFound(t *testing.T) {
|
|||
token := "invalid_token"
|
||||
|
||||
// Expectation
|
||||
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
|
||||
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
|
||||
WithArgs(token).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
|
|
@ -155,10 +157,10 @@ func TestOAuthService_ValidateStateToken_Expired(t *testing.T) {
|
|||
now := time.Now()
|
||||
|
||||
// Expectation
|
||||
rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "expires_at", "created_at"}).
|
||||
AddRow(1, token, "google", "http://example.com", now.Add(-time.Hour), now.Add(-2*time.Hour))
|
||||
rows := sqlmock.NewRows([]string{"id", "state_token", "provider", "redirect_url", "code_verifier", "expires_at", "created_at"}).
|
||||
AddRow(1, token, "google", "http://example.com", "verifier", now.Add(-time.Hour), now.Add(-2*time.Hour))
|
||||
|
||||
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
|
||||
mock.ExpectQuery(regexp.QuoteMeta(`SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1`)).
|
||||
WithArgs(token).
|
||||
WillReturnRows(rows)
|
||||
|
||||
|
|
@ -180,7 +182,7 @@ func TestOAuthService_GetAuthURL_Discord(t *testing.T) {
|
|||
svc.InitializeConfigs("", "", "", "", "discord-client", "discord-secret", "", "", "http://localhost:8080")
|
||||
|
||||
mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO oauth_states`)).
|
||||
WithArgs(sqlmock.AnyArg(), "discord", "", sqlmock.AnyArg()).
|
||||
WithArgs(sqlmock.AnyArg(), "discord", "", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
url, err := svc.GetAuthURL("discord")
|
||||
|
|
@ -191,6 +193,7 @@ func TestOAuthService_GetAuthURL_Discord(t *testing.T) {
|
|||
assert.Contains(t, url, "identify")
|
||||
assert.Contains(t, url, "email")
|
||||
assert.Contains(t, url, "discord-client")
|
||||
assert.Contains(t, url, "code_challenge")
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
|
|
@ -202,7 +205,7 @@ func TestOAuthService_GetAuthURL_Spotify(t *testing.T) {
|
|||
svc.InitializeConfigs("", "", "", "", "", "", "spotify-client", "spotify-secret", "http://localhost:8080")
|
||||
|
||||
mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO oauth_states`)).
|
||||
WithArgs(sqlmock.AnyArg(), "spotify", "", sqlmock.AnyArg()).
|
||||
WithArgs(sqlmock.AnyArg(), "spotify", "", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
url, err := svc.GetAuthURL("spotify")
|
||||
|
|
@ -213,6 +216,8 @@ func TestOAuthService_GetAuthURL_Spotify(t *testing.T) {
|
|||
assert.Contains(t, url, "user-read-email")
|
||||
assert.Contains(t, url, "user-read-private")
|
||||
assert.Contains(t, url, "spotify-client")
|
||||
assert.Contains(t, url, "code_challenge")
|
||||
assert.Contains(t, url, "code_challenge_method=S256")
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
|
|
@ -295,6 +300,16 @@ func TestOAuthService_GetUserInfo_Spotify(t *testing.T) {
|
|||
assert.Equal(t, "https://avatar.url", user.Avatar)
|
||||
}
|
||||
|
||||
func TestOAuthService_ValidateRedirectURL_EvilDomain(t *testing.T) {
|
||||
// v0.902: redirect to evil.com should be rejected
|
||||
svc := NewOAuthService(nil, zap.NewNop(), nil, nil, nil, &OAuthServiceConfig{
|
||||
AllowedDomains: []string{"https://app.veza.com", "https://veza.fr:5173"},
|
||||
})
|
||||
err := svc.validateRedirectURL("https://evil.com/auth/callback")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not allowed")
|
||||
}
|
||||
|
||||
func TestOAuthService_GetUserInfo_Spotify_FallbackEmail(t *testing.T) {
|
||||
// Spotify without email - should fallback to id@spotify.user
|
||||
spotifyJSON := `{"id":"spot456","display_name":"","images":[]}`
|
||||
|
|
|
|||
18
veza-backend-api/migrations/936_oauth_states_pkce.sql
Normal file
18
veza-backend-api/migrations/936_oauth_states_pkce.sql
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
-- 936_oauth_states_pkce.sql
|
||||
-- OAuth states table with PKCE code_verifier support (v0.902 Sentinel)
|
||||
|
||||
CREATE TABLE IF NOT EXISTS public.oauth_states (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
state_token VARCHAR(255) NOT NULL UNIQUE,
|
||||
provider VARCHAR(50) NOT NULL,
|
||||
redirect_url TEXT,
|
||||
code_verifier VARCHAR(255),
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_oauth_states_state_token ON public.oauth_states(state_token);
|
||||
CREATE INDEX IF NOT EXISTS idx_oauth_states_expires_at ON public.oauth_states(expires_at);
|
||||
|
||||
-- If table already exists (without code_verifier), add the column
|
||||
ALTER TABLE public.oauth_states ADD COLUMN IF NOT EXISTS code_verifier VARCHAR(255);
|
||||
Loading…
Reference in a new issue