- Add access token blacklist on logout (VEZA-SEC-006) - Extend OAuthService for mock provider injection in tests - Add oauth_google_test.go: full OAuth Google flow with mocked provider - Add oauth_github_test.go: OAuth GitHub flow with PKCE verification - Add token_refresh_test.go: E2E refresh via httpOnly cookies - Add logout_blacklist_test.go: E2E logout + token blacklist - Fix testutils import path in resume_upload_test, track_quota_test - Fix CreatorID -> UserID in track_quota_test - Add test:integration script to package.json Release: v0.911 Keystone
723 lines
22 KiB
Go
723 lines
22 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"veza-backend-api/internal/database"
|
|
"veza-backend-api/internal/models"
|
|
"veza-backend-api/internal/utils"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
)
|
|
|
|
// OAuthService handles OAuth authentication
|
|
type OAuthService struct {
|
|
db *database.Database
|
|
logger *zap.Logger
|
|
googleConfig *oauth2.Config
|
|
githubConfig *oauth2.Config
|
|
discordConfig *oauth2.Config
|
|
spotifyConfig *oauth2.Config
|
|
jwtService *JWTService
|
|
sessionService *SessionService
|
|
userService *UserService
|
|
circuitBreaker *CircuitBreakerHTTPClient
|
|
cryptoService *CryptoService // v0.902: encrypt OAuth provider tokens at rest (nil = store plaintext)
|
|
frontendURL string // v0.902: default redirect URL
|
|
allowedDomains []string // v0.902: whitelist for OAuth redirect (OAUTH_ALLOWED_REDIRECT_DOMAINS)
|
|
testTokenURL string // v0.911: override for integration tests (mock provider)
|
|
testUserInfoURL string // v0.911: override for integration tests (mock userinfo)
|
|
}
|
|
|
|
// OAuthAccount represents an OAuth account linking
|
|
// Mapped to federated_identities table
|
|
type OAuthAccount struct {
|
|
ID uuid.UUID `json:"id" db:"id"`
|
|
UserID uuid.UUID `json:"user_id" db:"user_id"`
|
|
Provider string `json:"provider" db:"provider"`
|
|
ProviderID string `json:"provider_id" db:"provider_id"`
|
|
Email string `json:"email" db:"email"`
|
|
DisplayName string `json:"display_name" db:"display_name"`
|
|
AvatarURL string `json:"avatar_url" db:"avatar_url"`
|
|
AccessToken string `json:"-" db:"access_token"`
|
|
RefreshToken string `json:"-" db:"refresh_token"`
|
|
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
|
|
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
|
}
|
|
|
|
// OAuthState represents an OAuth state for CSRF protection and PKCE
|
|
type OAuthState struct {
|
|
ID int64 `db:"id"`
|
|
StateToken string `db:"state_token"`
|
|
Provider string `db:"provider"`
|
|
RedirectURL string `db:"redirect_url"`
|
|
CodeVerifier string `db:"code_verifier"`
|
|
ExpiresAt time.Time `db:"expires_at"`
|
|
CreatedAt time.Time `db:"created_at"`
|
|
}
|
|
|
|
// OAuthServiceConfig holds optional config for OAuth (v0.902)
|
|
// v0.911: TestTokenURL, TestUserInfoURL, TestHTTPClient for integration tests (mock provider)
|
|
type OAuthServiceConfig struct {
|
|
CryptoService *CryptoService
|
|
FrontendURL string
|
|
AllowedDomains []string
|
|
TestTokenURL string // Override token exchange URL (for tests)
|
|
TestUserInfoURL string // Override userinfo API URL (for tests)
|
|
TestHTTPClient *http.Client // Override HTTP client (for tests)
|
|
}
|
|
|
|
// NewOAuthService creates a new OAuth service
|
|
// cfg: optional config for crypto, redirect validation (v0.902), test overrides (v0.911)
|
|
func NewOAuthService(db *database.Database, logger *zap.Logger, jwtService *JWTService, sessionService *SessionService, userService *UserService, cfg *OAuthServiceConfig) *OAuthService {
|
|
httpClient := &http.Client{Timeout: 10 * time.Second}
|
|
if cfg != nil && cfg.TestHTTPClient != nil {
|
|
httpClient = cfg.TestHTTPClient
|
|
}
|
|
svc := &OAuthService{
|
|
db: db,
|
|
logger: logger,
|
|
jwtService: jwtService,
|
|
sessionService: sessionService,
|
|
userService: userService,
|
|
circuitBreaker: NewCircuitBreakerHTTPClient(httpClient, "oauth-service", logger),
|
|
}
|
|
if cfg != nil {
|
|
svc.cryptoService = cfg.CryptoService
|
|
svc.frontendURL = cfg.FrontendURL
|
|
svc.allowedDomains = cfg.AllowedDomains
|
|
svc.testTokenURL = cfg.TestTokenURL
|
|
svc.testUserInfoURL = cfg.TestUserInfoURL
|
|
}
|
|
return svc
|
|
}
|
|
|
|
// InitializeConfigs initializes OAuth configurations
|
|
func (os *OAuthService) InitializeConfigs(googleClientID, googleClientSecret, githubClientID, githubClientSecret, discordClientID, discordClientSecret, spotifyClientID, spotifyClientSecret, baseURL string) {
|
|
testEndpoint := oauth2.Endpoint{}
|
|
if os.testTokenURL != "" {
|
|
testEndpoint = oauth2.Endpoint{TokenURL: os.testTokenURL, AuthURL: os.testTokenURL}
|
|
}
|
|
|
|
// Google OAuth
|
|
if googleClientID != "" && googleClientSecret != "" {
|
|
endpoint := google.Endpoint
|
|
if os.testTokenURL != "" {
|
|
endpoint = testEndpoint
|
|
}
|
|
os.googleConfig = &oauth2.Config{
|
|
ClientID: googleClientID,
|
|
ClientSecret: googleClientSecret,
|
|
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/google/callback", baseURL),
|
|
Scopes: []string{
|
|
"https://www.googleapis.com/auth/userinfo.email",
|
|
"https://www.googleapis.com/auth/userinfo.profile",
|
|
},
|
|
Endpoint: endpoint,
|
|
}
|
|
}
|
|
|
|
// GitHub OAuth
|
|
if githubClientID != "" && githubClientSecret != "" {
|
|
endpoint := oauth2.Endpoint{
|
|
AuthURL: "https://github.com/login/oauth/authorize",
|
|
TokenURL: "https://github.com/login/oauth/access_token",
|
|
}
|
|
if os.testTokenURL != "" {
|
|
endpoint = testEndpoint
|
|
}
|
|
os.githubConfig = &oauth2.Config{
|
|
ClientID: githubClientID,
|
|
ClientSecret: githubClientSecret,
|
|
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/github/callback", baseURL),
|
|
Scopes: []string{"user:email", "read:user"},
|
|
Endpoint: endpoint,
|
|
}
|
|
}
|
|
|
|
// Discord OAuth
|
|
if discordClientID != "" && discordClientSecret != "" {
|
|
os.discordConfig = &oauth2.Config{
|
|
ClientID: discordClientID,
|
|
ClientSecret: discordClientSecret,
|
|
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/discord/callback", baseURL),
|
|
Scopes: []string{"identify", "email"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: "https://discord.com/api/oauth2/authorize",
|
|
TokenURL: "https://discord.com/api/oauth2/token",
|
|
},
|
|
}
|
|
}
|
|
|
|
// Spotify OAuth
|
|
if spotifyClientID != "" && spotifyClientSecret != "" {
|
|
os.spotifyConfig = &oauth2.Config{
|
|
ClientID: spotifyClientID,
|
|
ClientSecret: spotifyClientSecret,
|
|
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/spotify/callback", baseURL),
|
|
Scopes: []string{"user-read-email", "user-read-private"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: "https://accounts.spotify.com/authorize",
|
|
TokenURL: "https://accounts.spotify.com/api/token",
|
|
},
|
|
}
|
|
}
|
|
|
|
os.logger.Info("OAuth configs initialized")
|
|
}
|
|
|
|
// GetAvailableProviders returns the list of configured OAuth providers
|
|
func (os *OAuthService) GetAvailableProviders() []string {
|
|
var providers []string
|
|
if os.googleConfig != nil {
|
|
providers = append(providers, "google")
|
|
}
|
|
if os.githubConfig != nil {
|
|
providers = append(providers, "github")
|
|
}
|
|
if os.discordConfig != nil {
|
|
providers = append(providers, "discord")
|
|
}
|
|
if os.spotifyConfig != nil {
|
|
providers = append(providers, "spotify")
|
|
}
|
|
return providers
|
|
}
|
|
|
|
// GenerateStateToken generates a secure state token for CSRF protection and PKCE code_verifier
|
|
func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (stateToken, codeVerifier string, err error) {
|
|
// Generate PKCE code verifier (RFC 7636)
|
|
codeVerifier = oauth2.GenerateVerifier()
|
|
|
|
// Generate random state token
|
|
tokenBytes := make([]byte, 32)
|
|
_, err = rand.Read(tokenBytes)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
stateToken = base64.URLEncoding.EncodeToString(tokenBytes)
|
|
|
|
// Store in database with code_verifier for PKCE
|
|
ctx := context.Background()
|
|
expiresAt := time.Now().Add(10 * time.Minute)
|
|
_, err = os.db.ExecContext(ctx, `
|
|
INSERT INTO oauth_states (state_token, provider, redirect_url, code_verifier, expires_at)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
`, stateToken, provider, redirectURL, codeVerifier, expiresAt)
|
|
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
os.logger.Debug("State token generated", zap.String("provider", provider))
|
|
return stateToken, codeVerifier, nil
|
|
}
|
|
|
|
// ValidateStateToken validates and consumes a state token (returns OAuthState with CodeVerifier for PKCE)
|
|
func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, error) {
|
|
ctx := context.Background()
|
|
|
|
var state OAuthState
|
|
err := os.db.QueryRowContext(ctx, `
|
|
SELECT id, state_token, provider, redirect_url, code_verifier, expires_at, created_at
|
|
FROM oauth_states
|
|
WHERE state_token = $1
|
|
`, stateToken).Scan(
|
|
&state.ID,
|
|
&state.StateToken,
|
|
&state.Provider,
|
|
&state.RedirectURL,
|
|
&state.CodeVerifier,
|
|
&state.ExpiresAt,
|
|
&state.CreatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("invalid state token")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Check if expired
|
|
if time.Now().After(state.ExpiresAt) {
|
|
return nil, fmt.Errorf("state token expired")
|
|
}
|
|
|
|
// Delete used token
|
|
os.db.ExecContext(ctx, `DELETE FROM oauth_states WHERE id = $1`, state.ID)
|
|
|
|
return &state, nil
|
|
}
|
|
|
|
// GetAuthURL returns the OAuth provider authorization URL
|
|
func (os *OAuthService) GetAuthURL(provider string) (string, error) {
|
|
var config *oauth2.Config
|
|
var err error
|
|
|
|
switch provider {
|
|
case "google":
|
|
if os.googleConfig == nil {
|
|
return "", fmt.Errorf("google OAuth not configured")
|
|
}
|
|
config = os.googleConfig
|
|
case "github":
|
|
if os.githubConfig == nil {
|
|
return "", fmt.Errorf("GitHub OAuth not configured")
|
|
}
|
|
config = os.githubConfig
|
|
case "discord":
|
|
if os.discordConfig == nil {
|
|
return "", fmt.Errorf("discord OAuth not configured")
|
|
}
|
|
config = os.discordConfig
|
|
case "spotify":
|
|
if os.spotifyConfig == nil {
|
|
return "", fmt.Errorf("spotify OAuth not configured")
|
|
}
|
|
config = os.spotifyConfig
|
|
default:
|
|
return "", fmt.Errorf("unknown provider: %s", provider)
|
|
}
|
|
|
|
// Generate state token with PKCE code_verifier
|
|
stateToken, codeVerifier, err := os.GenerateStateToken(provider, "")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Return authorization URL with PKCE (code_challenge, code_challenge_method=S256)
|
|
authURL := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier))
|
|
return authURL, nil
|
|
}
|
|
|
|
// HandleCallback processes the OAuth callback.
|
|
// ipAddress and userAgent are used for session creation (optional, can be empty).
|
|
// Returns redirectURL (validated) for the handler to redirect to (v0.902).
|
|
func (os *OAuthService) HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*OAuthUser, *models.TokenPair, string, error) {
|
|
// Validate state and get code_verifier for PKCE
|
|
oauthState, err := os.ValidateStateToken(state)
|
|
if err != nil {
|
|
return nil, nil, "", err
|
|
}
|
|
|
|
// v0.902: Validate redirect URL against whitelist (VEZA-SEC-010)
|
|
redirectURL := oauthState.RedirectURL
|
|
if redirectURL == "" {
|
|
redirectURL = os.frontendURL
|
|
}
|
|
if redirectURL != "" {
|
|
if err := os.validateRedirectURL(redirectURL); err != nil {
|
|
return nil, nil, "", err
|
|
}
|
|
}
|
|
|
|
var config *oauth2.Config
|
|
switch provider {
|
|
case "google":
|
|
config = os.googleConfig
|
|
case "github":
|
|
config = os.githubConfig
|
|
case "discord":
|
|
config = os.discordConfig
|
|
case "spotify":
|
|
config = os.spotifyConfig
|
|
default:
|
|
return nil, nil, "", fmt.Errorf("unknown provider: %s", provider)
|
|
}
|
|
|
|
if config == nil {
|
|
return nil, nil, "", fmt.Errorf("%s OAuth not configured", provider)
|
|
}
|
|
|
|
// Exchange code for token with PKCE code_verifier
|
|
token, err := config.Exchange(ctx, code, oauth2.VerifierOption(oauthState.CodeVerifier))
|
|
if err != nil {
|
|
return nil, nil, "", err
|
|
}
|
|
|
|
// Get user info from provider
|
|
oauthUser, err := os.getUserInfo(provider, token.AccessToken)
|
|
if err != nil {
|
|
return nil, nil, "", err
|
|
}
|
|
|
|
// Check if user already exists (by provider account or email) — audit 1.8: OAuth ID lookup first
|
|
existingUser, err := os.getOrCreateUser(provider, oauthUser)
|
|
if err != nil {
|
|
return nil, nil, "", err
|
|
}
|
|
|
|
// Save/update OAuth account
|
|
err = os.saveOAuthAccount(provider, oauthUser, existingUser.ID, token)
|
|
if err != nil {
|
|
return nil, nil, "", err
|
|
}
|
|
|
|
// VEZA-SEC-001: Get full user for JWT (TokenVersion, Role, etc.)
|
|
user, err := os.userService.GetByID(existingUser.ID)
|
|
if err != nil {
|
|
return nil, nil, "", fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
|
|
// Generate tokens via JWTService (proper issuer, audience, token_version)
|
|
tokens, err := os.jwtService.GenerateTokenPair(user)
|
|
if err != nil {
|
|
return nil, nil, "", fmt.Errorf("failed to generate tokens: %w", err)
|
|
}
|
|
|
|
// Create session for refresh token validation
|
|
_, err = os.sessionService.CreateSession(ctx, &SessionCreateRequest{
|
|
UserID: user.ID,
|
|
Token: tokens.RefreshToken,
|
|
IPAddress: ipAddress,
|
|
UserAgent: userAgent,
|
|
ExpiresIn: os.jwtService.GetConfig().RefreshTokenTTL,
|
|
})
|
|
if err != nil {
|
|
os.logger.Warn("Failed to create session after OAuth callback",
|
|
zap.String("user_id", user.ID.String()),
|
|
zap.Error(err),
|
|
)
|
|
// Continue - tokens still valid, session is optional for some flows
|
|
}
|
|
|
|
return &OAuthUser{
|
|
ID: existingUser.ID,
|
|
Email: existingUser.Email,
|
|
}, tokens, redirectURL, nil
|
|
}
|
|
|
|
// validateRedirectURL checks that the redirect URL's origin is in the allowed domains whitelist
|
|
func (os *OAuthService) validateRedirectURL(redirectURL string) error {
|
|
parsed, err := url.Parse(redirectURL)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid redirect URL: %w", err)
|
|
}
|
|
if parsed.Scheme == "" || parsed.Host == "" {
|
|
return fmt.Errorf("invalid redirect URL: missing scheme or host")
|
|
}
|
|
redirectOrigin := parsed.Scheme + "://" + parsed.Host
|
|
|
|
if len(os.allowedDomains) == 0 {
|
|
// Dev fallback: allow localhost
|
|
if strings.HasPrefix(redirectOrigin, "http://localhost") || strings.HasPrefix(redirectOrigin, "http://127.0.0.1") {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("redirect URL not allowed: %s (no whitelist configured)", redirectURL)
|
|
}
|
|
|
|
for _, allowed := range os.allowedDomains {
|
|
allowed = strings.TrimSpace(allowed)
|
|
if allowed == "*" {
|
|
return nil
|
|
}
|
|
allowedParsed, err := url.Parse(allowed)
|
|
if err != nil || allowedParsed.Scheme == "" || allowedParsed.Host == "" {
|
|
continue
|
|
}
|
|
allowedOrigin := allowedParsed.Scheme + "://" + allowedParsed.Host
|
|
if redirectOrigin == allowedOrigin {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("redirect URL not allowed: %s", redirectURL)
|
|
}
|
|
|
|
// OAuthUser represents an OAuth authenticated user
|
|
type OAuthUser struct {
|
|
ID uuid.UUID `json:"id"`
|
|
Email string `json:"email"`
|
|
Username string `json:"username"`
|
|
Name string `json:"name"`
|
|
Avatar string `json:"avatar"`
|
|
ProviderID string `json:"-"` // Added to store provider ID
|
|
}
|
|
|
|
// OAuthUserInfo represents a user from the database
|
|
type OAuthUserInfo struct {
|
|
ID uuid.UUID `json:"id" db:"id"`
|
|
Email string `json:"email" db:"email"`
|
|
Username string `json:"username" db:"username"`
|
|
}
|
|
|
|
// getUserInfo fetches user information from the OAuth provider
|
|
func (os *OAuthService) getUserInfo(provider, accessToken string) (*OAuthUser, error) {
|
|
var apiURL string
|
|
if os.testUserInfoURL != "" {
|
|
apiURL = os.testUserInfoURL
|
|
} else {
|
|
switch provider {
|
|
case "google":
|
|
apiURL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
|
case "github":
|
|
apiURL = "https://api.github.com/user"
|
|
case "discord":
|
|
apiURL = "https://discord.com/api/users/@me"
|
|
case "spotify":
|
|
apiURL = "https://api.spotify.com/v1/me"
|
|
default:
|
|
return nil, fmt.Errorf("unknown provider: %s", provider)
|
|
}
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", apiURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Add auth header
|
|
switch provider {
|
|
case "github":
|
|
req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken))
|
|
case "discord":
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
|
default:
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
|
}
|
|
|
|
// MOD-P2-006: Retry avec backoff exponentiel pour requêtes HTTP externes
|
|
// MOD-P2-007: Circuit breaker pour protéger contre dépendances lentes
|
|
maxRetries := 3
|
|
backoff := time.Second
|
|
var resp *http.Response
|
|
|
|
for i := 0; i < maxRetries; i++ {
|
|
var err error
|
|
// MOD-P2-007: Utiliser circuit breaker pour protéger contre dépendances lentes
|
|
resp, err = os.circuitBreaker.Do(req)
|
|
if err == nil {
|
|
break // Succès
|
|
}
|
|
|
|
// Log retry
|
|
if i < maxRetries-1 {
|
|
time.Sleep(backoff)
|
|
backoff *= 2 // Exponential backoff: 1s, 2s, 4s
|
|
} else {
|
|
return nil, fmt.Errorf("OAuth API request failed after %d attempts: %w", maxRetries, err)
|
|
}
|
|
}
|
|
|
|
if resp == nil {
|
|
return nil, fmt.Errorf("OAuth API request failed: no response after %d attempts", maxRetries)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse response based on provider
|
|
var oauthUser OAuthUser
|
|
switch provider {
|
|
case "google":
|
|
var userInfo struct {
|
|
ID string `json:"id"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
}
|
|
if err := json.Unmarshal(body, &userInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
oauthUser.Username = userInfo.Email
|
|
oauthUser.Email = userInfo.Email
|
|
oauthUser.Name = userInfo.Name
|
|
oauthUser.ProviderID = userInfo.ID
|
|
case "github":
|
|
var userInfo struct {
|
|
ID int `json:"id"`
|
|
Login string `json:"login"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
}
|
|
if err := json.Unmarshal(body, &userInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
oauthUser.Username = userInfo.Login
|
|
oauthUser.Email = userInfo.Email
|
|
oauthUser.Name = userInfo.Name
|
|
oauthUser.ProviderID = fmt.Sprintf("%d", userInfo.ID)
|
|
case "discord":
|
|
var userInfo struct {
|
|
ID string `json:"id"`
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
Avatar string `json:"avatar"`
|
|
}
|
|
if err := json.Unmarshal(body, &userInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
oauthUser.Username = userInfo.Username
|
|
oauthUser.Email = userInfo.Email
|
|
oauthUser.Name = userInfo.Username
|
|
oauthUser.Avatar = userInfo.Avatar
|
|
oauthUser.ProviderID = userInfo.ID
|
|
case "spotify":
|
|
var userInfo struct {
|
|
ID string `json:"id"`
|
|
DisplayName string `json:"display_name"`
|
|
Email string `json:"email"`
|
|
Images []struct {
|
|
URL string `json:"url"`
|
|
} `json:"images"`
|
|
}
|
|
if err := json.Unmarshal(body, &userInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
oauthUser.Username = userInfo.DisplayName
|
|
if oauthUser.Username == "" {
|
|
oauthUser.Username = userInfo.ID
|
|
}
|
|
oauthUser.Email = userInfo.Email
|
|
if oauthUser.Email == "" {
|
|
oauthUser.Email = userInfo.ID + "@spotify.user"
|
|
}
|
|
oauthUser.Name = userInfo.DisplayName
|
|
oauthUser.ProviderID = userInfo.ID
|
|
if len(userInfo.Images) > 0 && userInfo.Images[0].URL != "" {
|
|
oauthUser.Avatar = userInfo.Images[0].URL
|
|
}
|
|
}
|
|
|
|
return &oauthUser, nil
|
|
}
|
|
|
|
// getOrCreateUser gets an existing user or creates a new one (audit 1.8: OAuth ID lookup first)
|
|
func (os *OAuthService) getOrCreateUser(provider string, oauthUser *OAuthUser) (*OAuthUserInfo, error) {
|
|
ctx := context.Background()
|
|
|
|
// Try OAuth ID lookup first to avoid duplicates when user changes email at provider
|
|
if oauthUser.ProviderID != "" {
|
|
dbUser, err := os.db.GetUserByOAuthID(oauthUser.ProviderID, provider)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if dbUser != nil {
|
|
return &OAuthUserInfo{ID: dbUser.ID, Email: dbUser.Email, Username: dbUser.Username}, nil
|
|
}
|
|
}
|
|
|
|
// Fallback: find existing user by email
|
|
var user OAuthUserInfo
|
|
err := os.db.QueryRowContext(ctx, `
|
|
SELECT id, email, username
|
|
FROM users
|
|
WHERE email = $1
|
|
`, oauthUser.Email).Scan(&user.ID, &user.Email, &user.Username)
|
|
|
|
if err == nil {
|
|
return &user, nil
|
|
}
|
|
|
|
if err != sql.ErrNoRows {
|
|
return nil, err
|
|
}
|
|
|
|
// T0219: Generate slug from username
|
|
slug := utils.Slugify(oauthUser.Username)
|
|
// Ensure slug is unique by appending a number if needed
|
|
baseSlug := slug
|
|
counter := 1
|
|
for {
|
|
var count int
|
|
err := os.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM users WHERE slug = $1", slug).Scan(&count)
|
|
if err == nil && count == 0 {
|
|
break
|
|
}
|
|
slug = fmt.Sprintf("%s%d", baseSlug, counter)
|
|
counter++
|
|
if counter > 1000 {
|
|
slug = fmt.Sprintf("user_%d", time.Now().Unix())
|
|
break
|
|
}
|
|
}
|
|
|
|
// Create new user
|
|
// ID est généré automatiquement par gen_random_uuid()
|
|
insertQuery := `
|
|
INSERT INTO users (email, username, slug, is_verified, is_active, created_at, updated_at)
|
|
VALUES ($1, $2, $3, TRUE, TRUE, NOW(), NOW())
|
|
RETURNING id, email, username
|
|
`
|
|
err = os.db.QueryRowContext(ctx, insertQuery, oauthUser.Email, oauthUser.Username, slug).Scan(
|
|
&user.ID,
|
|
&user.Email,
|
|
&user.Username,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
os.logger.Info("New user created via OAuth",
|
|
zap.String("email", oauthUser.Email),
|
|
zap.String("provider", "oauth"),
|
|
)
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// saveOAuthAccount saves or updates OAuth account information
|
|
// Uses federated_identities table. Tokens are encrypted at rest when CryptoService is set (v0.902)
|
|
func (os *OAuthService) saveOAuthAccount(provider string, oauthUser *OAuthUser, userID uuid.UUID, token *oauth2.Token) error {
|
|
ctx := context.Background()
|
|
|
|
accessToken := token.AccessToken
|
|
refreshToken := token.RefreshToken
|
|
if os.cryptoService != nil {
|
|
var errEnc error
|
|
accessToken, errEnc = os.cryptoService.EncryptString(token.AccessToken)
|
|
if errEnc != nil {
|
|
return fmt.Errorf("encrypt access token: %w", errEnc)
|
|
}
|
|
refreshToken, errEnc = os.cryptoService.EncryptString(token.RefreshToken)
|
|
if errEnc != nil {
|
|
return fmt.Errorf("encrypt refresh token: %w", errEnc)
|
|
}
|
|
}
|
|
|
|
// Check if OAuth account already exists
|
|
var existingID uuid.UUID
|
|
err := os.db.QueryRowContext(ctx, `
|
|
SELECT id FROM federated_identities
|
|
WHERE user_id = $1 AND provider_id = $2
|
|
`, userID, oauthUser.ProviderID).Scan(&existingID)
|
|
|
|
if err == nil {
|
|
// Update existing
|
|
_, err = os.db.ExecContext(ctx, `
|
|
UPDATE federated_identities
|
|
SET email = $1, display_name = $2, access_token = $3, refresh_token = $4, expires_at = $5, updated_at = NOW()
|
|
WHERE id = $6
|
|
`, oauthUser.Email, oauthUser.Name, accessToken, refreshToken, token.Expiry, existingID)
|
|
return err
|
|
}
|
|
|
|
if err != sql.ErrNoRows {
|
|
return err
|
|
}
|
|
|
|
// Insert new
|
|
_, err = os.db.ExecContext(ctx, `
|
|
INSERT INTO federated_identities (id, user_id, provider, provider_id, email, display_name, avatar_url, access_token, refresh_token, expires_at, created_at, updated_at)
|
|
VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW())
|
|
`, userID, provider, oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, accessToken, refreshToken, token.Expiry)
|
|
|
|
return err
|
|
}
|