package services import ( "context" "crypto/rand" "database/sql" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" "veza-backend-api/internal/database" "veza-backend-api/internal/models" "veza-backend-api/internal/utils" "github.com/google/uuid" "go.uber.org/zap" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) // OAuthService handles OAuth authentication type OAuthService struct { db *database.Database logger *zap.Logger googleConfig *oauth2.Config githubConfig *oauth2.Config discordConfig *oauth2.Config spotifyConfig *oauth2.Config jwtService *JWTService sessionService *SessionService userService *UserService circuitBreaker *CircuitBreakerHTTPClient cryptoService *CryptoService // v0.902: encrypt OAuth provider tokens at rest (nil = store plaintext) frontendURL string // v0.902: default redirect URL allowedDomains []string // v0.902: whitelist for OAuth redirect (OAUTH_ALLOWED_REDIRECT_DOMAINS) testTokenURL string // v0.911: override for integration tests (mock provider) testUserInfoURL string // v0.911: override for integration tests (mock userinfo) } // OAuthAccount represents an OAuth account linking // Mapped to federated_identities table type OAuthAccount struct { ID uuid.UUID `json:"id" db:"id"` UserID uuid.UUID `json:"user_id" db:"user_id"` Provider string `json:"provider" db:"provider"` ProviderID string `json:"provider_id" db:"provider_id"` Email string `json:"email" db:"email"` DisplayName string `json:"display_name" db:"display_name"` AvatarURL string `json:"avatar_url" db:"avatar_url"` AccessToken string `json:"-" db:"access_token"` RefreshToken string `json:"-" db:"refresh_token"` ExpiresAt time.Time `json:"expires_at" db:"expires_at"` CreatedAt time.Time `json:"created_at" db:"created_at"` UpdatedAt time.Time `json:"updated_at" db:"updated_at"` } // OAuthState represents an OAuth state for CSRF protection and PKCE type OAuthState struct { ID int64 `db:"id"` StateToken string `db:"state_token"` Provider string `db:"provider"` RedirectURL string `db:"redirect_url"` CodeVerifier string `db:"code_verifier"` ExpiresAt time.Time `db:"expires_at"` CreatedAt time.Time `db:"created_at"` } // OAuthServiceConfig holds optional config for OAuth (v0.902) // v0.911: TestTokenURL, TestUserInfoURL, TestHTTPClient for integration tests (mock provider) type OAuthServiceConfig struct { CryptoService *CryptoService FrontendURL string AllowedDomains []string TestTokenURL string // Override token exchange URL (for tests) TestUserInfoURL string // Override userinfo API URL (for tests) TestHTTPClient *http.Client // Override HTTP client (for tests) } // NewOAuthService creates a new OAuth service // cfg: optional config for crypto, redirect validation (v0.902), test overrides (v0.911) func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTService, sessionService *SessionService, userService *UserService, cfg *OAuthServiceConfig) *OAuthService { httpClient := &http.Client{Timeout: 10 * time.Second} if cfg != nil && cfg.TestHTTPClient != nil { httpClient = cfg.TestHTTPClient } svc := &OAuthService{ db: db, logger: logger, jwtService: jwtService, sessionService: sessionService, userService: userService, circuitBreaker: NewCircuitBreakerHTTPClient(httpClient, "oauth-service", logger), } if cfg != nil { svc.cryptoService = cfg.CryptoService svc.frontendURL = cfg.FrontendURL svc.allowedDomains = cfg.AllowedDomains svc.testTokenURL = cfg.TestTokenURL svc.testUserInfoURL = cfg.TestUserInfoURL } return svc } // InitializeConfigs initializes OAuth configurations func (os *OAuthService) InitializeConfigs(googleClientID, googleClientSecret, githubClientID, githubClientSecret, discordClientID, discordClientSecret, spotifyClientID, spotifyClientSecret, baseURL string) { testEndpoint := oauth2.Endpoint{} if os.testTokenURL != "" { testEndpoint = oauth2.Endpoint{TokenURL: os.testTokenURL, AuthURL: os.testTokenURL} } // Google OAuth if googleClientID != "" && googleClientSecret != "" { endpoint := google.Endpoint if os.testTokenURL != "" { endpoint = testEndpoint } os.googleConfig = &oauth2.Config{ ClientID: googleClientID, ClientSecret: googleClientSecret, RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/google/callback", baseURL), Scopes: []string{ "https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile", }, Endpoint: endpoint, } } // GitHub OAuth if githubClientID != "" && githubClientSecret != "" { endpoint := oauth2.Endpoint{ AuthURL: "https://github.com/login/oauth/authorize", TokenURL: "https://github.com/login/oauth/access_token", } if os.testTokenURL != "" { endpoint = testEndpoint } os.githubConfig = &oauth2.Config{ ClientID: githubClientID, ClientSecret: githubClientSecret, RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/github/callback", baseURL), Scopes: []string{"user:email", "read:user"}, Endpoint: endpoint, } } // Discord OAuth if discordClientID != "" && discordClientSecret != "" { os.discordConfig = &oauth2.Config{ ClientID: discordClientID, ClientSecret: discordClientSecret, RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/discord/callback", baseURL), Scopes: []string{"identify", "email"}, Endpoint: oauth2.Endpoint{ AuthURL: "https://discord.com/api/oauth2/authorize", TokenURL: "https://discord.com/api/oauth2/token", }, } } // Spotify OAuth if spotifyClientID != "" && spotifyClientSecret != "" { os.spotifyConfig = &oauth2.Config{ ClientID: spotifyClientID, ClientSecret: spotifyClientSecret, RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/spotify/callback", baseURL), Scopes: []string{"user-read-email", "user-read-private"}, Endpoint: oauth2.Endpoint{ AuthURL: "https://accounts.spotify.com/authorize", TokenURL: "https://accounts.spotify.com/api/token", }, } } os.logger.Info("OAuth configs initialized") } // GetAvailableProviders returns the list of configured OAuth providers func (os *OAuthService) GetAvailableProviders() []string { var providers []string if os.googleConfig != nil { providers = append(providers, "google") } if os.githubConfig != nil { providers = append(providers, "github") } if os.discordConfig != nil { providers = append(providers, "discord") } if os.spotifyConfig != nil { providers = append(providers, "spotify") } return providers } // GenerateStateToken generates a secure state token for CSRF protection and PKCE code_verifier func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (stateToken, codeVerifier string, err error) { // Generate PKCE code verifier (RFC 7636) codeVerifier = oauth2.GenerateVerifier() // Generate random state token tokenBytes := make([]byte, 32) _, err = rand.Read(tokenBytes) if err != nil { return "", "", err } stateToken = base64.URLEncoding.EncodeToString(tokenBytes) // Store in database with code_verifier for PKCE ctx := context.Background() expiresAt := time.Now().Add(10 * time.Minute) _, err = os.db.ExecContext(ctx, ` INSERT INTO oauth_states (state_token, provider, redirect_url, code_verifier, expires_at) VALUES ($1, $2, $3, $4, $5) `, stateToken, provider, redirectURL, codeVerifier, expiresAt) if err != nil { return "", "", err } os.logger.Debug("State token generated", zap.String("provider", provider)) return stateToken, codeVerifier, nil } // ValidateStateToken validates and consumes a state token (returns OAuthState with CodeVerifier for PKCE) func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, error) { ctx := context.Background() var state OAuthState err := os.db.QueryRowContext(ctx, ` SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at FROM oauth_states WHERE state_token = $1 `, stateToken).Scan( &state.ID, &state.StateToken, &state.Provider, &state.RedirectURL, &state.CodeVerifier, &state.ExpiresAt, &state.CreatedAt, ) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("invalid state token") } return nil, err } // Check if expired if time.Now().After(state.ExpiresAt) { return nil, fmt.Errorf("state token expired") } // Delete used token os.db.ExecContext(ctx, `DELETE FROM oauth_states WHERE id = $1`, state.ID) return &state, nil } // GetAuthURL returns the OAuth provider authorization URL func (os *OAuthService) GetAuthURL(provider string) (string, error) { var config *oauth2.Config var err error switch provider { case "google": if os.googleConfig == nil { return "", fmt.Errorf("google OAuth not configured") } config = os.googleConfig case "github": if os.githubConfig == nil { return "", fmt.Errorf("GitHub OAuth not configured") } config = os.githubConfig case "discord": if os.discordConfig == nil { return "", fmt.Errorf("discord OAuth not configured") } config = os.discordConfig case "spotify": if os.spotifyConfig == nil { return "", fmt.Errorf("spotify OAuth not configured") } config = os.spotifyConfig default: return "", fmt.Errorf("unknown provider: %s", provider) } // Generate state token with PKCE code_verifier stateToken, codeVerifier, err := os.GenerateStateToken(provider, "") if err != nil { return "", err } // Return authorization URL with PKCE (code_challenge, code_challenge_method=S256) authURL := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier)) return authURL, nil } // HandleCallback processes the OAuth callback. // ipAddress and userAgent are used for session creation (optional, can be empty). // Returns redirectURL (validated) for the handler to redirect to (v0.902). func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*OAuthUser, *models.TokenPair, string, error) { // Validate state and get code_verifier for PKCE oauthState, err := os.ValidateStateToken(state) if err != nil { return nil, nil, "", err } // v0.902: Validate redirect URL against whitelist (VEZA-SEC-010) redirectURL := oauthState.RedirectURL if redirectURL == "" { redirectURL = os.frontendURL } if redirectURL != "" { if err := os.validateRedirectURL(redirectURL); err != nil { return nil, nil, "", err } } var config *oauth2.Config switch provider { case "google": config = os.googleConfig case "github": config = os.githubConfig case "discord": config = os.discordConfig case "spotify": config = os.spotifyConfig default: return nil, nil, "", fmt.Errorf("unknown provider: %s", provider) } if config == nil { return nil, nil, "", fmt.Errorf("%s OAuth not configured", provider) } // Exchange code for token with PKCE code_verifier token, err := config.Exchange(ctx, code, oauth2.VerifierOption(oauthState.CodeVerifier)) if err != nil { return nil, nil, "", err } // Get user info from provider oauthUser, err := os.getUserInfo(provider, token.AccessToken) if err != nil { return nil, nil, "", err } // Check if user already exists (by provider account or email) — audit 1.8: OAuth ID lookup first existingUser, err := os.getOrCreateUser(provider, oauthUser) if err != nil { return nil, nil, "", err } // Save/update OAuth account err = os.saveOAuthAccount(provider, oauthUser, existingUser.ID, token) if err != nil { return nil, nil, "", err } // VEZA-SEC-001: Get full user for JWT (TokenVersion, Role, etc.) user, err := os.userService.GetByID(ctx, existingUser.ID) if err != nil { return nil, nil, "", fmt.Errorf("failed to get user: %w", err) } // Generate tokens via JWTService (proper issuer, audience, token_version) tokens, err := os.jwtService.GenerateTokenPair(user) if err != nil { return nil, nil, "", fmt.Errorf("failed to generate tokens: %w", err) } // Create session for refresh token validation _, err = os.sessionService.CreateSession(ctx, &SessionCreateRequest{ UserID: user.ID, Token: tokens.RefreshToken, IPAddress: ipAddress, UserAgent: userAgent, ExpiresIn: os.jwtService.GetConfig().RefreshTokenTTL, }) if err != nil { os.logger.Warn("Failed to create session after OAuth callback", zap.String("user_id", user.ID.String()), zap.Error(err), ) // Continue - tokens still valid, session is optional for some flows } return &OAuthUser{ ID: existingUser.ID, Email: existingUser.Email, }, tokens, redirectURL, nil } // validateRedirectURL checks that the redirect URL's origin is in the allowed domains whitelist func (os *OAuthService) validateRedirectURL(redirectURL string) error { parsed, err := url.Parse(redirectURL) if err != nil { return fmt.Errorf("invalid redirect URL: %w", err) } if parsed.Scheme == "" || parsed.Host == "" { return fmt.Errorf("invalid redirect URL: missing scheme or host") } redirectOrigin := parsed.Scheme + "://" + parsed.Host if len(os.allowedDomains) == 0 { // Dev fallback: allow localhost if strings.HasPrefix(redirectOrigin, "http://localhost") || strings.HasPrefix(redirectOrigin, "http://127.0.0.1") { return nil } return fmt.Errorf("redirect URL not allowed: %s (no whitelist configured)", redirectURL) } for _, allowed := range os.allowedDomains { allowed = strings.TrimSpace(allowed) if allowed == "*" { return nil } allowedParsed, err := url.Parse(allowed) if err != nil || allowedParsed.Scheme == "" || allowedParsed.Host == "" { continue } allowedOrigin := allowedParsed.Scheme + "://" + allowedParsed.Host if redirectOrigin == allowedOrigin { return nil } } return fmt.Errorf("redirect URL not allowed: %s", redirectURL) } // OAuthUser represents an OAuth authenticated user type OAuthUser struct { ID uuid.UUID `json:"id"` Email string `json:"email"` Username string `json:"username"` Name string `json:"name"` Avatar string `json:"avatar"` ProviderID string `json:"-"` // Added to store provider ID } // OAuthUserInfo represents a user from the database type OAuthUserInfo struct { ID uuid.UUID `json:"id" db:"id"` Email string `json:"email" db:"email"` Username string `json:"username" db:"username"` } // getUserInfo fetches user information from the OAuth provider func (os *OAuthService) getUserInfo(provider, accessToken string) (*OAuthUser, error) { var apiURL string if os.testUserInfoURL != "" { apiURL = os.testUserInfoURL } else { switch provider { case "google": apiURL = "https://www.googleapis.com/oauth2/v2/userinfo" case "github": apiURL = "https://api.github.com/user" case "discord": apiURL = "https://discord.com/api/users/@me" case "spotify": apiURL = "https://api.spotify.com/v1/me" default: return nil, fmt.Errorf("unknown provider: %s", provider) } } req, err := http.NewRequest("GET", apiURL, nil) if err != nil { return nil, err } // Add auth header switch provider { case "github": req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) case "discord": req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) default: req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) } // MOD-P2-006: Retry avec backoff exponentiel pour requêtes HTTP externes // MOD-P2-007: Circuit breaker pour protéger contre dépendances lentes maxRetries := 3 backoff := time.Second var resp *http.Response for i := 0; i < maxRetries; i++ { var err error // MOD-P2-007: Utiliser circuit breaker pour protéger contre dépendances lentes resp, err = os.circuitBreaker.Do(req) if err == nil { break // Succès } // Log retry if i < maxRetries-1 { time.Sleep(backoff) backoff *= 2 // Exponential backoff: 1s, 2s, 4s } else { return nil, fmt.Errorf("OAuth API request failed after %d attempts: %w", maxRetries, err) } } if resp == nil { return nil, fmt.Errorf("OAuth API request failed: no response after %d attempts", maxRetries) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } // Parse response based on provider var oauthUser OAuthUser switch provider { case "google": var userInfo struct { ID string `json:"id"` Email string `json:"email"` Name string `json:"name"` } if err := json.Unmarshal(body, &userInfo); err != nil { return nil, err } oauthUser.Username = userInfo.Email oauthUser.Email = userInfo.Email oauthUser.Name = userInfo.Name oauthUser.ProviderID = userInfo.ID case "github": var userInfo struct { ID int `json:"id"` Login string `json:"login"` Email string `json:"email"` Name string `json:"name"` } if err := json.Unmarshal(body, &userInfo); err != nil { return nil, err } oauthUser.Username = userInfo.Login oauthUser.Email = userInfo.Email oauthUser.Name = userInfo.Name oauthUser.ProviderID = fmt.Sprintf("%d", userInfo.ID) case "discord": var userInfo struct { ID string `json:"id"` Username string `json:"username"` Email string `json:"email"` Avatar string `json:"avatar"` } if err := json.Unmarshal(body, &userInfo); err != nil { return nil, err } oauthUser.Username = userInfo.Username oauthUser.Email = userInfo.Email oauthUser.Name = userInfo.Username oauthUser.Avatar = userInfo.Avatar oauthUser.ProviderID = userInfo.ID case "spotify": var userInfo struct { ID string `json:"id"` DisplayName string `json:"display_name"` Email string `json:"email"` Images []struct { URL string `json:"url"` } `json:"images"` } if err := json.Unmarshal(body, &userInfo); err != nil { return nil, err } oauthUser.Username = userInfo.DisplayName if oauthUser.Username == "" { oauthUser.Username = userInfo.ID } oauthUser.Email = userInfo.Email if oauthUser.Email == "" { oauthUser.Email = userInfo.ID + "@spotify.user" } oauthUser.Name = userInfo.DisplayName oauthUser.ProviderID = userInfo.ID if len(userInfo.Images) > 0 && userInfo.Images[0].URL != "" { oauthUser.Avatar = userInfo.Images[0].URL } } return &oauthUser, nil } // getOrCreateUser gets an existing user or creates a new one (audit 1.8: OAuth ID lookup first) func (os *OAuthService) getOrCreateUser(provider string, oauthUser *OAuthUser) (*OAuthUserInfo, error) { ctx := context.Background() // Try OAuth ID lookup first to avoid duplicates when user changes email at provider if oauthUser.ProviderID != "" { dbUser, err := os.db.GetUserByOAuthID(oauthUser.ProviderID, provider) if err != nil { return nil, err } if dbUser != nil { return &OAuthUserInfo{ID: dbUser.ID, Email: dbUser.Email, Username: dbUser.Username}, nil } } // Fallback: find existing user by email var user OAuthUserInfo err := os.db.QueryRowContext(ctx, ` SELECT id, email, username FROM users WHERE email = $1 `, oauthUser.Email).Scan(&user.ID, &user.Email, &user.Username) if err == nil { return &user, nil } if err != sql.ErrNoRows { return nil, err } // T0219: Generate slug from username slug := utils.Slugify(oauthUser.Username) // Ensure slug is unique by appending a number if needed baseSlug := slug counter := 1 for { var count int err := os.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM users WHERE slug = $1", slug).Scan(&count) if err == nil && count == 0 { break } slug = fmt.Sprintf("%s%d", baseSlug, counter) counter++ if counter > 1000 { slug = fmt.Sprintf("user_%d", time.Now().Unix()) break } } // Create new user // ID est généré automatiquement par gen_random_uuid() insertQuery := ` INSERT INTO users (email, username, slug, is_verified, is_active, created_at, updated_at) VALUES ($1, $2, $3, TRUE, TRUE, NOW(), NOW()) RETURNING id, email, username ` err = os.db.QueryRowContext(ctx, insertQuery, oauthUser.Email, oauthUser.Username, slug).Scan( &user.ID, &user.Email, &user.Username, ) if err != nil { return nil, err } os.logger.Info("New user created via OAuth", zap.String("email", oauthUser.Email), zap.String("provider", "oauth"), ) return &user, nil } // saveOAuthAccount saves or updates OAuth account information // Uses federated_identities table. Tokens are encrypted at rest when CryptoService is set (v0.902) func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser, userID uuid.UUID, token *oauth2.Token) error { ctx := context.Background() accessToken := token.AccessToken refreshToken := token.RefreshToken if os.cryptoService != nil { var errEnc error accessToken, errEnc = os.cryptoService.EncryptString(token.AccessToken) if errEnc != nil { return fmt.Errorf("encrypt access token: %w", errEnc) } refreshToken, errEnc = os.cryptoService.EncryptString(token.RefreshToken) if errEnc != nil { return fmt.Errorf("encrypt refresh token: %w", errEnc) } } // Check if OAuth account already exists var existingID uuid.UUID err := os.db.QueryRowContext(ctx, ` SELECT id FROM federated_identities WHERE user_id = $1 AND provider_id = $2 `, userID, oauthUser.ProviderID).Scan(&existingID) if err == nil { // Update existing _, err = os.db.ExecContext(ctx, ` UPDATE federated_identities SET email = $1, display_name = $2, access_token = $3, refresh_token = $4, expires_at = $5, updated_at = NOW() WHERE id = $6 `, oauthUser.Email, oauthUser.Name, accessToken, refreshToken, token.Expiry, existingID) return err } if err != sql.ErrNoRows { return err } // Insert new _, err = os.db.ExecContext(ctx, ` INSERT INTO federated_identities (id, user_id, provider, provider_id, email, display_name, avatar_url, access_token, refresh_token, expires_at, created_at, updated_at) VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW()) `, userID, provider, oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, accessToken, refreshToken, token.Expiry) return err }