veza/veza-backend-api/internal/services/oauth_service.go

480 lines
13 KiB
Go

package services
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"veza-backend-api/internal/database"
"veza-backend-api/internal/utils"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"go.uber.org/zap"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
// OAuthService handles OAuth authentication
type OAuthService struct {
db *database.Database
logger *zap.Logger
googleConfig *oauth2.Config
githubConfig *oauth2.Config
discordConfig *oauth2.Config
jwtSecret []byte
}
// OAuthAccount represents an OAuth account linking
// Mapped to federated_identities table
type OAuthAccount struct {
ID uuid.UUID `json:"id" db:"id"`
UserID uuid.UUID `json:"user_id" db:"user_id"`
Provider string `json:"provider" db:"provider"`
ProviderID string `json:"provider_id" db:"provider_id"`
Email string `json:"email" db:"email"`
DisplayName string `json:"display_name" db:"display_name"`
AvatarURL string `json:"avatar_url" db:"avatar_url"`
AccessToken string `json:"-" db:"access_token"`
RefreshToken string `json:"-" db:"refresh_token"`
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// OAuthState represents an OAuth state for CSRF protection
type OAuthState struct {
ID int64 `db:"id"`
StateToken string `db:"state_token"`
Provider string `db:"provider"`
RedirectURL string `db:"redirect_url"`
ExpiresAt time.Time `db:"expires_at"`
CreatedAt time.Time `db:"created_at"`
}
// NewOAuthService creates a new OAuth service
func NewOAuthService(db *database.Database, logger *zap.Logger, jwtSecret []byte) *OAuthService {
return &OAuthService{
db: db,
logger: logger,
jwtSecret: jwtSecret,
}
}
// InitializeConfigs initializes OAuth configurations
func (os *OAuthService) InitializeConfigs(googleClientID, googleClientSecret, githubClientID, githubClientSecret, discordClientID, discordClientSecret, baseURL string) {
// Google OAuth
os.googleConfig = &oauth2.Config{
ClientID: googleClientID,
ClientSecret: googleClientSecret,
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/google/callback", baseURL),
Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
}
// GitHub OAuth
os.githubConfig = &oauth2.Config{
ClientID: githubClientID,
ClientSecret: githubClientSecret,
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/github/callback", baseURL),
Scopes: []string{"user:email", "read:user"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
},
}
// Discord OAuth
os.discordConfig = &oauth2.Config{
ClientID: discordClientID,
ClientSecret: discordClientSecret,
RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/discord/callback", baseURL),
Scopes: []string{"identify", "email"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://discord.com/api/oauth2/authorize",
TokenURL: "https://discord.com/api/oauth2/token",
},
}
os.logger.Info("OAuth configs initialized")
}
// GenerateStateToken generates a secure state token for CSRF protection
func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (string, error) {
// Generate random token
tokenBytes := make([]byte, 32)
_, err := rand.Read(tokenBytes)
if err != nil {
return "", err
}
stateToken := base64.URLEncoding.EncodeToString(tokenBytes)
// Store in database
ctx := context.Background()
expiresAt := time.Now().Add(10 * time.Minute)
_, err = os.db.ExecContext(ctx, `
INSERT INTO oauth_states (state_token, provider, redirect_url, expires_at)
VALUES ($1, $2, $3, $4)
`, stateToken, provider, redirectURL, expiresAt)
if err != nil {
return "", err
}
os.logger.Debug("State token generated", zap.String("provider", provider))
return stateToken, nil
}
// ValidateStateToken validates and consumes a state token
func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, error) {
ctx := context.Background()
var state OAuthState
err := os.db.QueryRowContext(ctx, `
SELECT id, state_token, provider, redirect_url, expires_at, created_at
FROM oauth_states
WHERE state_token = $1
`, stateToken).Scan(
&state.ID,
&state.StateToken,
&state.Provider,
&state.RedirectURL,
&state.ExpiresAt,
&state.CreatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("invalid state token")
}
return nil, err
}
// Check if expired
if time.Now().After(state.ExpiresAt) {
return nil, fmt.Errorf("state token expired")
}
// Delete used token
os.db.ExecContext(ctx, `DELETE FROM oauth_states WHERE id = $1`, state.ID)
return &state, nil
}
// GetAuthURL returns the OAuth provider authorization URL
func (os *OAuthService) GetAuthURL(provider string) (string, error) {
var config *oauth2.Config
var err error
switch provider {
case "google":
if os.googleConfig == nil {
return "", fmt.Errorf("Google OAuth not configured")
}
config = os.googleConfig
case "github":
if os.githubConfig == nil {
return "", fmt.Errorf("GitHub OAuth not configured")
}
config = os.githubConfig
case "discord":
if os.discordConfig == nil {
return "", fmt.Errorf("Discord OAuth not configured")
}
config = os.discordConfig
default:
return "", fmt.Errorf("unknown provider: %s", provider)
}
// Generate state token
stateToken, err := os.GenerateStateToken(provider, "")
if err != nil {
return "", err
}
// Return authorization URL
url := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline)
return url, nil
}
// HandleCallback processes the OAuth callback
func (os *OAuthService) HandleCallback(provider, code, state string) (*OAuthUser, string, error) {
// Validate state
_, err := os.ValidateStateToken(state)
if err != nil {
return nil, "", err
}
var config *oauth2.Config
switch provider {
case "google":
config = os.googleConfig
case "github":
config = os.githubConfig
case "discord":
config = os.discordConfig
default:
return nil, "", fmt.Errorf("unknown provider: %s", provider)
}
// Exchange code for token
token, err := config.Exchange(context.Background(), code)
if err != nil {
return nil, "", err
}
// Get user info from provider
oauthUser, err := os.getUserInfo(provider, token.AccessToken)
if err != nil {
return nil, "", err
}
// Check if user already exists (by provider account or email)
existingUser, err := os.getOrCreateUser(oauthUser)
if err != nil {
return nil, "", err
}
// Save/update OAuth account
err = os.saveOAuthAccount(oauthUser, existingUser.ID, token)
if err != nil {
return nil, "", err
}
// Generate JWT for the user
jwtToken, err := os.generateJWT(existingUser.ID)
if err != nil {
return nil, "", err
}
return &OAuthUser{
ID: existingUser.ID,
Email: existingUser.Email,
}, jwtToken, nil
}
// OAuthUser represents an OAuth authenticated user
type OAuthUser struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Name string `json:"name"`
Avatar string `json:"avatar"`
ProviderID string `json:"-"` // Added to store provider ID
}
// OAuthUserInfo represents a user from the database
type OAuthUserInfo struct {
ID uuid.UUID `json:"id" db:"id"`
Email string `json:"email" db:"email"`
Username string `json:"username" db:"username"`
}
// getUserInfo fetches user information from the OAuth provider
func (os *OAuthService) getUserInfo(provider, accessToken string) (*OAuthUser, error) {
var apiURL string
switch provider {
case "google":
apiURL = "https://www.googleapis.com/oauth2/v2/userinfo"
case "github":
apiURL = "https://api.github.com/user"
case "discord":
apiURL = "https://discord.com/api/users/@me"
default:
return nil, fmt.Errorf("unknown provider: %s", provider)
}
req, err := http.NewRequest("GET", apiURL, nil)
if err != nil {
return nil, err
}
// Add auth header
if provider == "github" {
req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken))
} else if provider == "discord" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
} else {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// Parse response based on provider
var oauthUser OAuthUser
switch provider {
case "google":
var userInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
}
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, err
}
oauthUser.Username = userInfo.Email
oauthUser.Email = userInfo.Email
oauthUser.Name = userInfo.Name
oauthUser.ProviderID = userInfo.ID
case "github":
var userInfo struct {
ID int `json:"id"`
Login string `json:"login"`
Email string `json:"email"`
Name string `json:"name"`
}
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, err
}
oauthUser.Username = userInfo.Login
oauthUser.Email = userInfo.Email
oauthUser.Name = userInfo.Name
oauthUser.ProviderID = fmt.Sprintf("%d", userInfo.ID)
case "discord":
var userInfo struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Avatar string `json:"avatar"`
}
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, err
}
oauthUser.Username = userInfo.Username
oauthUser.Email = userInfo.Email
oauthUser.Name = userInfo.Username
oauthUser.Avatar = userInfo.Avatar
oauthUser.ProviderID = userInfo.ID
}
return &oauthUser, nil
}
// getOrCreateUser gets an existing user or creates a new one
func (os *OAuthService) getOrCreateUser(oauthUser *OAuthUser) (*OAuthUserInfo, error) {
ctx := context.Background()
// Try to find existing user by email
var user OAuthUserInfo
err := os.db.QueryRowContext(ctx, `
SELECT id, email, username
FROM users
WHERE email = $1
`, oauthUser.Email).Scan(&user.ID, &user.Email, &user.Username)
if err == nil {
return &user, nil
}
if err != sql.ErrNoRows {
return nil, err
}
// T0219: Generate slug from username
slug := utils.Slugify(oauthUser.Username)
// Ensure slug is unique by appending a number if needed
baseSlug := slug
counter := 1
for {
var count int
err := os.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM users WHERE slug = $1", slug).Scan(&count)
if err == nil && count == 0 {
break
}
slug = fmt.Sprintf("%s%d", baseSlug, counter)
counter++
if counter > 1000 {
slug = fmt.Sprintf("user_%d", time.Now().Unix())
break
}
}
// Create new user
// ID est généré automatiquement par gen_random_uuid()
insertQuery := `
INSERT INTO users (email, username, slug, is_verified, is_active, created_at, updated_at)
VALUES ($1, $2, $3, TRUE, TRUE, NOW(), NOW())
RETURNING id, email, username
`
err = os.db.QueryRowContext(ctx, insertQuery, oauthUser.Email, oauthUser.Username, slug).Scan(
&user.ID,
&user.Email,
&user.Username,
)
if err != nil {
return nil, err
}
os.logger.Info("New user created via OAuth",
zap.String("email", oauthUser.Email),
zap.String("provider", "oauth"),
)
return &user, nil
}
// saveOAuthAccount saves or updates OAuth account information
// Uses federated_identities table
func (os *OAuthService) saveOAuthAccount(oauthUser *OAuthUser, userID uuid.UUID, token *oauth2.Token) error {
ctx := context.Background()
// Check if OAuth account already exists
var existingID uuid.UUID
err := os.db.QueryRowContext(ctx, `
SELECT id FROM federated_identities
WHERE user_id = $1 AND provider_id = $2
`, userID, oauthUser.ProviderID).Scan(&existingID)
if err == nil {
// Update existing
_, err = os.db.ExecContext(ctx, `
UPDATE federated_identities
SET email = $1, display_name = $2, access_token = $3, refresh_token = $4, expires_at = $5, updated_at = NOW()
WHERE id = $6
`, oauthUser.Email, oauthUser.Name, token.AccessToken, token.RefreshToken, token.Expiry, existingID)
return err
}
if err != sql.ErrNoRows {
return err
}
// Insert new
_, err = os.db.ExecContext(ctx, `
INSERT INTO federated_identities (id, user_id, provider, provider_id, email, display_name, avatar_url, access_token, refresh_token, expires_at, created_at, updated_at)
VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW())
`, userID, "oauth", oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, token.AccessToken, token.RefreshToken, token.Expiry)
return err
}
// generateJWT generates a JWT token for the user
func (os *OAuthService) generateJWT(userID uuid.UUID) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": userID.String(),
"sub": userID.String(),
"exp": time.Now().Add(time.Hour * 24).Unix(),
})
return token.SignedString(os.jwtSecret)
}