veza/veza-backend-api/internal/services/oauth_service.go
senke 4720bb20b2
Some checks failed
Backend API CI / test-unit (push) Failing after 0s
Backend API CI / test-integration (push) Failing after 0s
feat(auth): v0.911 Keystone - OAuth and auth integration tests
- Add access token blacklist on logout (VEZA-SEC-006)
- Extend OAuthService for mock provider injection in tests
- Add oauth_google_test.go: full OAuth Google flow with mocked provider
- Add oauth_github_test.go: OAuth GitHub flow with PKCE verification
- Add token_refresh_test.go: E2E refresh via httpOnly cookies
- Add logout_blacklist_test.go: E2E logout + token blacklist
- Fix testutils import path in resume_upload_test, track_quota_test
- Fix CreatorID -> UserID in track_quota_test
- Add test:integration script to package.json

Release: v0.911 Keystone
2026-02-27 09:58:53 +01:00

723 lines
22 KiB
Go

package services
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"veza-backend-api/internal/database"
"veza-backend-api/internal/models"
"veza-backend-api/internal/utils"
"github.com/google/uuid"
"go.uber.org/zap"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
// OAuthService handles OAuth authentication
type OAuthService struct {
db *database.Database
logger *zap.Logger
googleConfig *oauth2.Config
githubConfig *oauth2.Config
discordConfig *oauth2.Config
spotifyConfig *oauth2.Config
jwtService *JWTService
sessionService *SessionService
userService *UserService
circuitBreaker *CircuitBreakerHTTPClient
cryptoService *CryptoService // v0.902: encrypt OAuth provider tokens at rest (nil = store plaintext)
frontendURL string // v0.902: default redirect URL
allowedDomains []string // v0.902: whitelist for OAuth redirect (OAUTH_ALLOWED_REDIRECT_DOMAINS)
testTokenURL string // v0.911: override for integration tests (mock provider)
testUserInfoURL string // v0.911: override for integration tests (mock userinfo)
}
// OAuthAccount represents an OAuth account linking
// Mapped to federated_identities table
type OAuthAccount struct {
ID uuid.UUID `json:"id" db:"id"`
UserID uuid.UUID `json:"user_id" db:"user_id"`
Provider string `json:"provider" db:"provider"`
ProviderID string `json:"provider_id" db:"provider_id"`
Email string `json:"email" db:"email"`
DisplayName string `json:"display_name" db:"display_name"`
AvatarURL string `json:"avatar_url" db:"avatar_url"`
AccessToken string `json:"-" db:"access_token"`
RefreshToken string `json:"-" db:"refresh_token"`
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// OAuthState represents an OAuth state for CSRF protection and PKCE
type OAuthState struct {
ID int64 `db:"id"`
StateToken string `db:"state_token"`
Provider string `db:"provider"`
RedirectURL string `db:"redirect_url"`
CodeVerifier string `db:"code_verifier"`
ExpiresAt time.Time `db:"expires_at"`
CreatedAt time.Time `db:"created_at"`
}
// OAuthServiceConfig holds optional config for OAuth (v0.902)
// v0.911: TestTokenURL, TestUserInfoURL, TestHTTPClient for integration tests (mock provider)
type OAuthServiceConfig struct {
CryptoService *CryptoService
FrontendURL string
AllowedDomains []string
TestTokenURL string // Override token exchange URL (for tests)
TestUserInfoURL string // Override userinfo API URL (for tests)
TestHTTPClient *http.Client // Override HTTP client (for tests)
}
// NewOAuthService creates a new OAuth service
// cfg: optional config for crypto, redirect validation (v0.902), test overrides (v0.911)
func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTService, sessionService *SessionService, userService *UserService, cfg *OAuthServiceConfig) *OAuthService {
httpClient := &http.Client{Timeout: 10 * time.Second}
if cfg != nil && cfg.TestHTTPClient != nil {
httpClient = cfg.TestHTTPClient
}
svc := &OAuthService{
db: db,
logger: logger,
jwtService: jwtService,
sessionService: sessionService,
userService: userService,
circuitBreaker: NewCircuitBreakerHTTPClient(httpClient, "oauth-service", logger),
}
if cfg != nil {
svc.cryptoService = cfg.CryptoService
svc.frontendURL = cfg.FrontendURL
svc.allowedDomains = cfg.AllowedDomains
svc.testTokenURL = cfg.TestTokenURL
svc.testUserInfoURL = cfg.TestUserInfoURL
}
return svc
}
// InitializeConfigs initializes OAuth configurations
func (os *OAuthService) InitializeConfigs(googleClientID, googleClientSecret, githubClientID, githubClientSecret, discordClientID, discordClientSecret, spotifyClientID, spotifyClientSecret, baseURL string) {
testEndpoint := oauth2.Endpoint{}
if os.testTokenURL != "" {
testEndpoint = oauth2.Endpoint{TokenURL: os.testTokenURL, AuthURL: os.testTokenURL}
}
// Google OAuth
if googleClientID != "" && googleClientSecret != "" {
endpoint := google.Endpoint
if os.testTokenURL != "" {
endpoint = testEndpoint
}
os.googleConfig = &oauth2.Config{
ClientID: googleClientID,
ClientSecret: googleClientSecret,
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/google/callback", baseURL),
Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: endpoint,
}
}
// GitHub OAuth
if githubClientID != "" && githubClientSecret != "" {
endpoint := oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
}
if os.testTokenURL != "" {
endpoint = testEndpoint
}
os.githubConfig = &oauth2.Config{
ClientID: githubClientID,
ClientSecret: githubClientSecret,
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/github/callback", baseURL),
Scopes: []string{"user:email", "read:user"},
Endpoint: endpoint,
}
}
// Discord OAuth
if discordClientID != "" && discordClientSecret != "" {
os.discordConfig = &oauth2.Config{
ClientID: discordClientID,
ClientSecret: discordClientSecret,
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/discord/callback", baseURL),
Scopes: []string{"identify", "email"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://discord.com/api/oauth2/authorize",
TokenURL: "https://discord.com/api/oauth2/token",
},
}
}
// Spotify OAuth
if spotifyClientID != "" && spotifyClientSecret != "" {
os.spotifyConfig = &oauth2.Config{
ClientID: spotifyClientID,
ClientSecret: spotifyClientSecret,
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/spotify/callback", baseURL),
Scopes: []string{"user-read-email", "user-read-private"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://accounts.spotify.com/authorize",
TokenURL: "https://accounts.spotify.com/api/token",
},
}
}
os.logger.Info("OAuth configs initialized")
}
// GetAvailableProviders returns the list of configured OAuth providers
func (os *OAuthService) GetAvailableProviders() []string {
var providers []string
if os.googleConfig != nil {
providers = append(providers, "google")
}
if os.githubConfig != nil {
providers = append(providers, "github")
}
if os.discordConfig != nil {
providers = append(providers, "discord")
}
if os.spotifyConfig != nil {
providers = append(providers, "spotify")
}
return providers
}
// GenerateStateToken generates a secure state token for CSRF protection and PKCE code_verifier
func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (stateToken, codeVerifier string, err error) {
// Generate PKCE code verifier (RFC 7636)
codeVerifier = oauth2.GenerateVerifier()
// Generate random state token
tokenBytes := make([]byte, 32)
_, err = rand.Read(tokenBytes)
if err != nil {
return "", "", err
}
stateToken = base64.URLEncoding.EncodeToString(tokenBytes)
// Store in database with code_verifier for PKCE
ctx := context.Background()
expiresAt := time.Now().Add(10 * time.Minute)
_, err = os.db.ExecContext(ctx, `
INSERT INTO oauth_states (state_token, provider, redirect_url, code_verifier, expires_at)
VALUES ($1, $2, $3, $4, $5)
`, stateToken, provider, redirectURL, codeVerifier, expiresAt)
if err != nil {
return "", "", err
}
os.logger.Debug("State token generated", zap.String("provider", provider))
return stateToken, codeVerifier, nil
}
// ValidateStateToken validates and consumes a state token (returns OAuthState with CodeVerifier for PKCE)
func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, error) {
ctx := context.Background()
var state OAuthState
err := os.db.QueryRowContext(ctx, `
SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at
FROM oauth_states
WHERE state_token = $1
`, stateToken).Scan(
&state.ID,
&state.StateToken,
&state.Provider,
&state.RedirectURL,
&state.CodeVerifier,
&state.ExpiresAt,
&state.CreatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("invalid state token")
}
return nil, err
}
// Check if expired
if time.Now().After(state.ExpiresAt) {
return nil, fmt.Errorf("state token expired")
}
// Delete used token
os.db.ExecContext(ctx, `DELETE FROM oauth_states WHERE id = $1`, state.ID)
return &state, nil
}
// GetAuthURL returns the OAuth provider authorization URL
func (os *OAuthService) GetAuthURL(provider string) (string, error) {
var config *oauth2.Config
var err error
switch provider {
case "google":
if os.googleConfig == nil {
return "", fmt.Errorf("google OAuth not configured")
}
config = os.googleConfig
case "github":
if os.githubConfig == nil {
return "", fmt.Errorf("GitHub OAuth not configured")
}
config = os.githubConfig
case "discord":
if os.discordConfig == nil {
return "", fmt.Errorf("discord OAuth not configured")
}
config = os.discordConfig
case "spotify":
if os.spotifyConfig == nil {
return "", fmt.Errorf("spotify OAuth not configured")
}
config = os.spotifyConfig
default:
return "", fmt.Errorf("unknown provider: %s", provider)
}
// Generate state token with PKCE code_verifier
stateToken, codeVerifier, err := os.GenerateStateToken(provider, "")
if err != nil {
return "", err
}
// Return authorization URL with PKCE (code_challenge, code_challenge_method=S256)
authURL := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier))
return authURL, nil
}
// HandleCallback processes the OAuth callback.
// ipAddress and userAgent are used for session creation (optional, can be empty).
// Returns redirectURL (validated) for the handler to redirect to (v0.902).
func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*OAuthUser, *models.TokenPair, string, error) {
// Validate state and get code_verifier for PKCE
oauthState, err := os.ValidateStateToken(state)
if err != nil {
return nil, nil, "", err
}
// v0.902: Validate redirect URL against whitelist (VEZA-SEC-010)
redirectURL := oauthState.RedirectURL
if redirectURL == "" {
redirectURL = os.frontendURL
}
if redirectURL != "" {
if err := os.validateRedirectURL(redirectURL); err != nil {
return nil, nil, "", err
}
}
var config *oauth2.Config
switch provider {
case "google":
config = os.googleConfig
case "github":
config = os.githubConfig
case "discord":
config = os.discordConfig
case "spotify":
config = os.spotifyConfig
default:
return nil, nil, "", fmt.Errorf("unknown provider: %s", provider)
}
if config == nil {
return nil, nil, "", fmt.Errorf("%s OAuth not configured", provider)
}
// Exchange code for token with PKCE code_verifier
token, err := config.Exchange(ctx, code, oauth2.VerifierOption(oauthState.CodeVerifier))
if err != nil {
return nil, nil, "", err
}
// Get user info from provider
oauthUser, err := os.getUserInfo(provider, token.AccessToken)
if err != nil {
return nil, nil, "", err
}
// Check if user already exists (by provider account or email) — audit 1.8: OAuth ID lookup first
existingUser, err := os.getOrCreateUser(provider, oauthUser)
if err != nil {
return nil, nil, "", err
}
// Save/update OAuth account
err = os.saveOAuthAccount(provider, oauthUser, existingUser.ID, token)
if err != nil {
return nil, nil, "", err
}
// VEZA-SEC-001: Get full user for JWT (TokenVersion, Role, etc.)
user, err := os.userService.GetByID(existingUser.ID)
if err != nil {
return nil, nil, "", fmt.Errorf("failed to get user: %w", err)
}
// Generate tokens via JWTService (proper issuer, audience, token_version)
tokens, err := os.jwtService.GenerateTokenPair(user)
if err != nil {
return nil, nil, "", fmt.Errorf("failed to generate tokens: %w", err)
}
// Create session for refresh token validation
_, err = os.sessionService.CreateSession(ctx, &SessionCreateRequest{
UserID: user.ID,
Token: tokens.RefreshToken,
IPAddress: ipAddress,
UserAgent: userAgent,
ExpiresIn: os.jwtService.GetConfig().RefreshTokenTTL,
})
if err != nil {
os.logger.Warn("Failed to create session after OAuth callback",
zap.String("user_id", user.ID.String()),
zap.Error(err),
)
// Continue - tokens still valid, session is optional for some flows
}
return &OAuthUser{
ID: existingUser.ID,
Email: existingUser.Email,
}, tokens, redirectURL, nil
}
// validateRedirectURL checks that the redirect URL's origin is in the allowed domains whitelist
func (os *OAuthService) validateRedirectURL(redirectURL string) error {
parsed, err := url.Parse(redirectURL)
if err != nil {
return fmt.Errorf("invalid redirect URL: %w", err)
}
if parsed.Scheme == "" || parsed.Host == "" {
return fmt.Errorf("invalid redirect URL: missing scheme or host")
}
redirectOrigin := parsed.Scheme + "://" + parsed.Host
if len(os.allowedDomains) == 0 {
// Dev fallback: allow localhost
if strings.HasPrefix(redirectOrigin, "http://localhost") || strings.HasPrefix(redirectOrigin, "http://127.0.0.1") {
return nil
}
return fmt.Errorf("redirect URL not allowed: %s (no whitelist configured)", redirectURL)
}
for _, allowed := range os.allowedDomains {
allowed = strings.TrimSpace(allowed)
if allowed == "*" {
return nil
}
allowedParsed, err := url.Parse(allowed)
if err != nil || allowedParsed.Scheme == "" || allowedParsed.Host == "" {
continue
}
allowedOrigin := allowedParsed.Scheme + "://" + allowedParsed.Host
if redirectOrigin == allowedOrigin {
return nil
}
}
return fmt.Errorf("redirect URL not allowed: %s", redirectURL)
}
// OAuthUser represents an OAuth authenticated user
type OAuthUser struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Name string `json:"name"`
Avatar string `json:"avatar"`
ProviderID string `json:"-"` // Added to store provider ID
}
// OAuthUserInfo represents a user from the database
type OAuthUserInfo struct {
ID uuid.UUID `json:"id" db:"id"`
Email string `json:"email" db:"email"`
Username string `json:"username" db:"username"`
}
// getUserInfo fetches user information from the OAuth provider
func (os *OAuthService) getUserInfo(provider, accessToken string) (*OAuthUser, error) {
var apiURL string
if os.testUserInfoURL != "" {
apiURL = os.testUserInfoURL
} else {
switch provider {
case "google":
apiURL = "https://www.googleapis.com/oauth2/v2/userinfo"
case "github":
apiURL = "https://api.github.com/user"
case "discord":
apiURL = "https://discord.com/api/users/@me"
case "spotify":
apiURL = "https://api.spotify.com/v1/me"
default:
return nil, fmt.Errorf("unknown provider: %s", provider)
}
}
req, err := http.NewRequest("GET", apiURL, nil)
if err != nil {
return nil, err
}
// Add auth header
switch provider {
case "github":
req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken))
case "discord":
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
default:
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
}
// MOD-P2-006: Retry avec backoff exponentiel pour requêtes HTTP externes
// MOD-P2-007: Circuit breaker pour protéger contre dépendances lentes
maxRetries := 3
backoff := time.Second
var resp *http.Response
for i := 0; i < maxRetries; i++ {
var err error
// MOD-P2-007: Utiliser circuit breaker pour protéger contre dépendances lentes
resp, err = os.circuitBreaker.Do(req)
if err == nil {
break // Succès
}
// Log retry
if i < maxRetries-1 {
time.Sleep(backoff)
backoff *= 2 // Exponential backoff: 1s, 2s, 4s
} else {
return nil, fmt.Errorf("OAuth API request failed after %d attempts: %w", maxRetries, err)
}
}
if resp == nil {
return nil, fmt.Errorf("OAuth API request failed: no response after %d attempts", maxRetries)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// Parse response based on provider
var oauthUser OAuthUser
switch provider {
case "google":
var userInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
}
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, err
}
oauthUser.Username = userInfo.Email
oauthUser.Email = userInfo.Email
oauthUser.Name = userInfo.Name
oauthUser.ProviderID = userInfo.ID
case "github":
var userInfo struct {
ID int `json:"id"`
Login string `json:"login"`
Email string `json:"email"`
Name string `json:"name"`
}
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, err
}
oauthUser.Username = userInfo.Login
oauthUser.Email = userInfo.Email
oauthUser.Name = userInfo.Name
oauthUser.ProviderID = fmt.Sprintf("%d", userInfo.ID)
case "discord":
var userInfo struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Avatar string `json:"avatar"`
}
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, err
}
oauthUser.Username = userInfo.Username
oauthUser.Email = userInfo.Email
oauthUser.Name = userInfo.Username
oauthUser.Avatar = userInfo.Avatar
oauthUser.ProviderID = userInfo.ID
case "spotify":
var userInfo struct {
ID string `json:"id"`
DisplayName string `json:"display_name"`
Email string `json:"email"`
Images []struct {
URL string `json:"url"`
} `json:"images"`
}
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, err
}
oauthUser.Username = userInfo.DisplayName
if oauthUser.Username == "" {
oauthUser.Username = userInfo.ID
}
oauthUser.Email = userInfo.Email
if oauthUser.Email == "" {
oauthUser.Email = userInfo.ID + "@spotify.user"
}
oauthUser.Name = userInfo.DisplayName
oauthUser.ProviderID = userInfo.ID
if len(userInfo.Images) > 0 && userInfo.Images[0].URL != "" {
oauthUser.Avatar = userInfo.Images[0].URL
}
}
return &oauthUser, nil
}
// getOrCreateUser gets an existing user or creates a new one (audit 1.8: OAuth ID lookup first)
func (os *OAuthService) getOrCreateUser(provider string, oauthUser *OAuthUser) (*OAuthUserInfo, error) {
ctx := context.Background()
// Try OAuth ID lookup first to avoid duplicates when user changes email at provider
if oauthUser.ProviderID != "" {
dbUser, err := os.db.GetUserByOAuthID(oauthUser.ProviderID, provider)
if err != nil {
return nil, err
}
if dbUser != nil {
return &OAuthUserInfo{ID: dbUser.ID, Email: dbUser.Email, Username: dbUser.Username}, nil
}
}
// Fallback: find existing user by email
var user OAuthUserInfo
err := os.db.QueryRowContext(ctx, `
SELECT id, email, username
FROM users
WHERE email = $1
`, oauthUser.Email).Scan(&user.ID, &user.Email, &user.Username)
if err == nil {
return &user, nil
}
if err != sql.ErrNoRows {
return nil, err
}
// T0219: Generate slug from username
slug := utils.Slugify(oauthUser.Username)
// Ensure slug is unique by appending a number if needed
baseSlug := slug
counter := 1
for {
var count int
err := os.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM users WHERE slug = $1", slug).Scan(&count)
if err == nil && count == 0 {
break
}
slug = fmt.Sprintf("%s%d", baseSlug, counter)
counter++
if counter > 1000 {
slug = fmt.Sprintf("user_%d", time.Now().Unix())
break
}
}
// Create new user
// ID est généré automatiquement par gen_random_uuid()
insertQuery := `
INSERT INTO users (email, username, slug, is_verified, is_active, created_at, updated_at)
VALUES ($1, $2, $3, TRUE, TRUE, NOW(), NOW())
RETURNING id, email, username
`
err = os.db.QueryRowContext(ctx, insertQuery, oauthUser.Email, oauthUser.Username, slug).Scan(
&user.ID,
&user.Email,
&user.Username,
)
if err != nil {
return nil, err
}
os.logger.Info("New user created via OAuth",
zap.String("email", oauthUser.Email),
zap.String("provider", "oauth"),
)
return &user, nil
}
// saveOAuthAccount saves or updates OAuth account information
// Uses federated_identities table. Tokens are encrypted at rest when CryptoService is set (v0.902)
func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser, userID uuid.UUID, token *oauth2.Token) error {
ctx := context.Background()
accessToken := token.AccessToken
refreshToken := token.RefreshToken
if os.cryptoService != nil {
var errEnc error
accessToken, errEnc = os.cryptoService.EncryptString(token.AccessToken)
if errEnc != nil {
return fmt.Errorf("encrypt access token: %w", errEnc)
}
refreshToken, errEnc = os.cryptoService.EncryptString(token.RefreshToken)
if errEnc != nil {
return fmt.Errorf("encrypt refresh token: %w", errEnc)
}
}
// Check if OAuth account already exists
var existingID uuid.UUID
err := os.db.QueryRowContext(ctx, `
SELECT id FROM federated_identities
WHERE user_id = $1 AND provider_id = $2
`, userID, oauthUser.ProviderID).Scan(&existingID)
if err == nil {
// Update existing
_, err = os.db.ExecContext(ctx, `
UPDATE federated_identities
SET email = $1, display_name = $2, access_token = $3, refresh_token = $4, expires_at = $5, updated_at = NOW()
WHERE id = $6
`, oauthUser.Email, oauthUser.Name, accessToken, refreshToken, token.Expiry, existingID)
return err
}
if err != sql.ErrNoRows {
return err
}
// Insert new
_, err = os.db.ExecContext(ctx, `
INSERT INTO federated_identities (id, user_id, provider, provider_id, email, display_name, avatar_url, access_token, refresh_token, expires_at, created_at, updated_at)
VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW())
`, userID, provider, oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, accessToken, refreshToken, token.Expiry)
return err
}