505 lines
14 KiB
Go
505 lines
14 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
|
|
"veza-backend-api/internal/database"
|
|
"veza-backend-api/internal/utils"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
)
|
|
|
|
// OAuthService handles OAuth authentication
|
|
type OAuthService struct {
|
|
db *database.Database
|
|
logger *zap.Logger
|
|
googleConfig *oauth2.Config
|
|
githubConfig *oauth2.Config
|
|
discordConfig *oauth2.Config
|
|
jwtSecret []byte
|
|
circuitBreaker *CircuitBreakerHTTPClient
|
|
}
|
|
|
|
// OAuthAccount represents an OAuth account linking
|
|
// Mapped to federated_identities table
|
|
type OAuthAccount struct {
|
|
ID uuid.UUID `json:"id" db:"id"`
|
|
UserID uuid.UUID `json:"user_id" db:"user_id"`
|
|
Provider string `json:"provider" db:"provider"`
|
|
ProviderID string `json:"provider_id" db:"provider_id"`
|
|
Email string `json:"email" db:"email"`
|
|
DisplayName string `json:"display_name" db:"display_name"`
|
|
AvatarURL string `json:"avatar_url" db:"avatar_url"`
|
|
AccessToken string `json:"-" db:"access_token"`
|
|
RefreshToken string `json:"-" db:"refresh_token"`
|
|
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
|
|
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
|
}
|
|
|
|
// OAuthState represents an OAuth state for CSRF protection
|
|
type OAuthState struct {
|
|
ID int64 `db:"id"`
|
|
StateToken string `db:"state_token"`
|
|
Provider string `db:"provider"`
|
|
RedirectURL string `db:"redirect_url"`
|
|
ExpiresAt time.Time `db:"expires_at"`
|
|
CreatedAt time.Time `db:"created_at"`
|
|
}
|
|
|
|
// NewOAuthService creates a new OAuth service
|
|
func NewOAuthService(db *database.Database, logger *zap.Logger, jwtSecret []byte) *OAuthService {
|
|
httpClient := &http.Client{Timeout: 10 * time.Second}
|
|
return &OAuthService{
|
|
db: db,
|
|
logger: logger,
|
|
jwtSecret: jwtSecret,
|
|
circuitBreaker: NewCircuitBreakerHTTPClient(httpClient, "oauth-service", logger),
|
|
}
|
|
}
|
|
|
|
// InitializeConfigs initializes OAuth configurations
|
|
func (os *OAuthService) InitializeConfigs(googleClientID, googleClientSecret, githubClientID, githubClientSecret, discordClientID, discordClientSecret, baseURL string) {
|
|
// Google OAuth
|
|
os.googleConfig = &oauth2.Config{
|
|
ClientID: googleClientID,
|
|
ClientSecret: googleClientSecret,
|
|
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/google/callback", baseURL),
|
|
Scopes: []string{
|
|
"https://www.googleapis.com/auth/userinfo.email",
|
|
"https://www.googleapis.com/auth/userinfo.profile",
|
|
},
|
|
Endpoint: google.Endpoint,
|
|
}
|
|
|
|
// GitHub OAuth
|
|
os.githubConfig = &oauth2.Config{
|
|
ClientID: githubClientID,
|
|
ClientSecret: githubClientSecret,
|
|
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/github/callback", baseURL),
|
|
Scopes: []string{"user:email", "read:user"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: "https://github.com/login/oauth/authorize",
|
|
TokenURL: "https://github.com/login/oauth/access_token",
|
|
},
|
|
}
|
|
|
|
// Discord OAuth
|
|
os.discordConfig = &oauth2.Config{
|
|
ClientID: discordClientID,
|
|
ClientSecret: discordClientSecret,
|
|
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/discord/callback", baseURL),
|
|
Scopes: []string{"identify", "email"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: "https://discord.com/api/oauth2/authorize",
|
|
TokenURL: "https://discord.com/api/oauth2/token",
|
|
},
|
|
}
|
|
|
|
os.logger.Info("OAuth configs initialized")
|
|
}
|
|
|
|
// GenerateStateToken generates a secure state token for CSRF protection
|
|
func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (string, error) {
|
|
// Generate random token
|
|
tokenBytes := make([]byte, 32)
|
|
_, err := rand.Read(tokenBytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
stateToken := base64.URLEncoding.EncodeToString(tokenBytes)
|
|
|
|
// Store in database
|
|
ctx := context.Background()
|
|
expiresAt := time.Now().Add(10 * time.Minute)
|
|
_, err = os.db.ExecContext(ctx, `
|
|
INSERT INTO oauth_states (state_token, provider, redirect_url, expires_at)
|
|
VALUES ($1, $2, $3, $4)
|
|
`, stateToken, provider, redirectURL, expiresAt)
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
os.logger.Debug("State token generated", zap.String("provider", provider))
|
|
return stateToken, nil
|
|
}
|
|
|
|
// ValidateStateToken validates and consumes a state token
|
|
func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, error) {
|
|
ctx := context.Background()
|
|
|
|
var state OAuthState
|
|
err := os.db.QueryRowContext(ctx, `
|
|
SELECT id, state_token, provider, redirect_url, expires_at, created_at
|
|
FROM oauth_states
|
|
WHERE state_token = $1
|
|
`, stateToken).Scan(
|
|
&state.ID,
|
|
&state.StateToken,
|
|
&state.Provider,
|
|
&state.RedirectURL,
|
|
&state.ExpiresAt,
|
|
&state.CreatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("invalid state token")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Check if expired
|
|
if time.Now().After(state.ExpiresAt) {
|
|
return nil, fmt.Errorf("state token expired")
|
|
}
|
|
|
|
// Delete used token
|
|
os.db.ExecContext(ctx, `DELETE FROM oauth_states WHERE id = $1`, state.ID)
|
|
|
|
return &state, nil
|
|
}
|
|
|
|
// GetAuthURL returns the OAuth provider authorization URL
|
|
func (os *OAuthService) GetAuthURL(provider string) (string, error) {
|
|
var config *oauth2.Config
|
|
var err error
|
|
|
|
switch provider {
|
|
case "google":
|
|
if os.googleConfig == nil {
|
|
return "", fmt.Errorf("google OAuth not configured")
|
|
}
|
|
config = os.googleConfig
|
|
case "github":
|
|
if os.githubConfig == nil {
|
|
return "", fmt.Errorf("GitHub OAuth not configured")
|
|
}
|
|
config = os.githubConfig
|
|
case "discord":
|
|
if os.discordConfig == nil {
|
|
return "", fmt.Errorf("discord OAuth not configured")
|
|
}
|
|
config = os.discordConfig
|
|
default:
|
|
return "", fmt.Errorf("unknown provider: %s", provider)
|
|
}
|
|
|
|
// Generate state token
|
|
stateToken, err := os.GenerateStateToken(provider, "")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Return authorization URL
|
|
url := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline)
|
|
return url, nil
|
|
}
|
|
|
|
// HandleCallback processes the OAuth callback
|
|
func (os *OAuthService) HandleCallback(provider, code, state string) (*OAuthUser, string, error) {
|
|
// Validate state
|
|
_, err := os.ValidateStateToken(state)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
var config *oauth2.Config
|
|
switch provider {
|
|
case "google":
|
|
config = os.googleConfig
|
|
case "github":
|
|
config = os.githubConfig
|
|
case "discord":
|
|
config = os.discordConfig
|
|
default:
|
|
return nil, "", fmt.Errorf("unknown provider: %s", provider)
|
|
}
|
|
|
|
// Exchange code for token
|
|
token, err := config.Exchange(context.Background(), code)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Get user info from provider
|
|
oauthUser, err := os.getUserInfo(provider, token.AccessToken)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Check if user already exists (by provider account or email)
|
|
existingUser, err := os.getOrCreateUser(oauthUser)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Save/update OAuth account
|
|
err = os.saveOAuthAccount(oauthUser, existingUser.ID, token)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Generate JWT for the user
|
|
jwtToken, err := os.generateJWT(existingUser.ID)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
return &OAuthUser{
|
|
ID: existingUser.ID,
|
|
Email: existingUser.Email,
|
|
}, jwtToken, nil
|
|
}
|
|
|
|
// OAuthUser represents an OAuth authenticated user
|
|
type OAuthUser struct {
|
|
ID uuid.UUID `json:"id"`
|
|
Email string `json:"email"`
|
|
Username string `json:"username"`
|
|
Name string `json:"name"`
|
|
Avatar string `json:"avatar"`
|
|
ProviderID string `json:"-"` // Added to store provider ID
|
|
}
|
|
|
|
// OAuthUserInfo represents a user from the database
|
|
type OAuthUserInfo struct {
|
|
ID uuid.UUID `json:"id" db:"id"`
|
|
Email string `json:"email" db:"email"`
|
|
Username string `json:"username" db:"username"`
|
|
}
|
|
|
|
// getUserInfo fetches user information from the OAuth provider
|
|
func (os *OAuthService) getUserInfo(provider, accessToken string) (*OAuthUser, error) {
|
|
var apiURL string
|
|
switch provider {
|
|
case "google":
|
|
apiURL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
|
case "github":
|
|
apiURL = "https://api.github.com/user"
|
|
case "discord":
|
|
apiURL = "https://discord.com/api/users/@me"
|
|
default:
|
|
return nil, fmt.Errorf("unknown provider: %s", provider)
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", apiURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Add auth header
|
|
switch provider {
|
|
case "github":
|
|
req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken))
|
|
case "discord":
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
|
default:
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
|
}
|
|
|
|
// MOD-P2-006: Retry avec backoff exponentiel pour requêtes HTTP externes
|
|
// MOD-P2-007: Circuit breaker pour protéger contre dépendances lentes
|
|
maxRetries := 3
|
|
backoff := time.Second
|
|
var resp *http.Response
|
|
|
|
for i := 0; i < maxRetries; i++ {
|
|
var err error
|
|
// MOD-P2-007: Utiliser circuit breaker pour protéger contre dépendances lentes
|
|
resp, err = os.circuitBreaker.Do(req)
|
|
if err == nil {
|
|
break // Succès
|
|
}
|
|
|
|
// Log retry
|
|
if i < maxRetries-1 {
|
|
time.Sleep(backoff)
|
|
backoff *= 2 // Exponential backoff: 1s, 2s, 4s
|
|
} else {
|
|
return nil, fmt.Errorf("OAuth API request failed after %d attempts: %w", maxRetries, err)
|
|
}
|
|
}
|
|
|
|
if resp == nil {
|
|
return nil, fmt.Errorf("OAuth API request failed: no response after %d attempts", maxRetries)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse response based on provider
|
|
var oauthUser OAuthUser
|
|
switch provider {
|
|
case "google":
|
|
var userInfo struct {
|
|
ID string `json:"id"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
}
|
|
if err := json.Unmarshal(body, &userInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
oauthUser.Username = userInfo.Email
|
|
oauthUser.Email = userInfo.Email
|
|
oauthUser.Name = userInfo.Name
|
|
oauthUser.ProviderID = userInfo.ID
|
|
case "github":
|
|
var userInfo struct {
|
|
ID int `json:"id"`
|
|
Login string `json:"login"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
}
|
|
if err := json.Unmarshal(body, &userInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
oauthUser.Username = userInfo.Login
|
|
oauthUser.Email = userInfo.Email
|
|
oauthUser.Name = userInfo.Name
|
|
oauthUser.ProviderID = fmt.Sprintf("%d", userInfo.ID)
|
|
case "discord":
|
|
var userInfo struct {
|
|
ID string `json:"id"`
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
Avatar string `json:"avatar"`
|
|
}
|
|
if err := json.Unmarshal(body, &userInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
oauthUser.Username = userInfo.Username
|
|
oauthUser.Email = userInfo.Email
|
|
oauthUser.Name = userInfo.Username
|
|
oauthUser.Avatar = userInfo.Avatar
|
|
oauthUser.ProviderID = userInfo.ID
|
|
}
|
|
|
|
return &oauthUser, nil
|
|
}
|
|
|
|
// getOrCreateUser gets an existing user or creates a new one
|
|
func (os *OAuthService) getOrCreateUser(oauthUser *OAuthUser) (*OAuthUserInfo, error) {
|
|
ctx := context.Background()
|
|
|
|
// Try to find existing user by email
|
|
var user OAuthUserInfo
|
|
err := os.db.QueryRowContext(ctx, `
|
|
SELECT id, email, username
|
|
FROM users
|
|
WHERE email = $1
|
|
`, oauthUser.Email).Scan(&user.ID, &user.Email, &user.Username)
|
|
|
|
if err == nil {
|
|
return &user, nil
|
|
}
|
|
|
|
if err != sql.ErrNoRows {
|
|
return nil, err
|
|
}
|
|
|
|
// T0219: Generate slug from username
|
|
slug := utils.Slugify(oauthUser.Username)
|
|
// Ensure slug is unique by appending a number if needed
|
|
baseSlug := slug
|
|
counter := 1
|
|
for {
|
|
var count int
|
|
err := os.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM users WHERE slug = $1", slug).Scan(&count)
|
|
if err == nil && count == 0 {
|
|
break
|
|
}
|
|
slug = fmt.Sprintf("%s%d", baseSlug, counter)
|
|
counter++
|
|
if counter > 1000 {
|
|
slug = fmt.Sprintf("user_%d", time.Now().Unix())
|
|
break
|
|
}
|
|
}
|
|
|
|
// Create new user
|
|
// ID est généré automatiquement par gen_random_uuid()
|
|
insertQuery := `
|
|
INSERT INTO users (email, username, slug, is_verified, is_active, created_at, updated_at)
|
|
VALUES ($1, $2, $3, TRUE, TRUE, NOW(), NOW())
|
|
RETURNING id, email, username
|
|
`
|
|
err = os.db.QueryRowContext(ctx, insertQuery, oauthUser.Email, oauthUser.Username, slug).Scan(
|
|
&user.ID,
|
|
&user.Email,
|
|
&user.Username,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
os.logger.Info("New user created via OAuth",
|
|
zap.String("email", oauthUser.Email),
|
|
zap.String("provider", "oauth"),
|
|
)
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// saveOAuthAccount saves or updates OAuth account information
|
|
// Uses federated_identities table
|
|
func (os *OAuthService) saveOAuthAccount(oauthUser *OAuthUser, userID uuid.UUID, token *oauth2.Token) error {
|
|
ctx := context.Background()
|
|
|
|
// Check if OAuth account already exists
|
|
var existingID uuid.UUID
|
|
err := os.db.QueryRowContext(ctx, `
|
|
SELECT id FROM federated_identities
|
|
WHERE user_id = $1 AND provider_id = $2
|
|
`, userID, oauthUser.ProviderID).Scan(&existingID)
|
|
|
|
if err == nil {
|
|
// Update existing
|
|
_, err = os.db.ExecContext(ctx, `
|
|
UPDATE federated_identities
|
|
SET email = $1, display_name = $2, access_token = $3, refresh_token = $4, expires_at = $5, updated_at = NOW()
|
|
WHERE id = $6
|
|
`, oauthUser.Email, oauthUser.Name, token.AccessToken, token.RefreshToken, token.Expiry, existingID)
|
|
return err
|
|
}
|
|
|
|
if err != sql.ErrNoRows {
|
|
return err
|
|
}
|
|
|
|
// Insert new
|
|
_, err = os.db.ExecContext(ctx, `
|
|
INSERT INTO federated_identities (id, user_id, provider, provider_id, email, display_name, avatar_url, access_token, refresh_token, expires_at, created_at, updated_at)
|
|
VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW())
|
|
`, userID, "oauth", oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, token.AccessToken, token.RefreshToken, token.Expiry)
|
|
|
|
return err
|
|
}
|
|
|
|
// generateJWT generates a JWT token for the user
|
|
func (os *OAuthService) generateJWT(userID uuid.UUID) (string, error) {
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"user_id": userID.String(),
|
|
"sub": userID.String(),
|
|
"exp": time.Now().Add(time.Hour * 24).Unix(),
|
|
})
|
|
|
|
return token.SignedString(os.jwtSecret)
|
|
}
|