package services import ( "context" "crypto/rand" "database/sql" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "time" "veza-backend-api/internal/database" "veza-backend-api/internal/utils" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "go.uber.org/zap" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) // OAuthService handles OAuth authentication type OAuthService struct { db *database.Database logger *zap.Logger googleConfig *oauth2.Config githubConfig *oauth2.Config discordConfig *oauth2.Config jwtSecret []byte } // OAuthAccount represents an OAuth account linking // Mapped to federated_identities table type OAuthAccount struct { ID uuid.UUID `json:"id" db:"id"` UserID uuid.UUID `json:"user_id" db:"user_id"` Provider string `json:"provider" db:"provider"` ProviderID string `json:"provider_id" db:"provider_id"` Email string `json:"email" db:"email"` DisplayName string `json:"display_name" db:"display_name"` AvatarURL string `json:"avatar_url" db:"avatar_url"` AccessToken string `json:"-" db:"access_token"` RefreshToken string `json:"-" db:"refresh_token"` ExpiresAt time.Time `json:"expires_at" db:"expires_at"` CreatedAt time.Time `json:"created_at" db:"created_at"` UpdatedAt time.Time `json:"updated_at" db:"updated_at"` } // OAuthState represents an OAuth state for CSRF protection type OAuthState struct { ID int64 `db:"id"` StateToken string `db:"state_token"` Provider string `db:"provider"` RedirectURL string `db:"redirect_url"` ExpiresAt time.Time `db:"expires_at"` CreatedAt time.Time `db:"created_at"` } // NewOAuthService creates a new OAuth service func NewOAuthService(db *database.Database, logger *zap.Logger, jwtSecret []byte) *OAuthService { return &OAuthService{ db: db, logger: logger, jwtSecret: jwtSecret, } } // InitializeConfigs initializes OAuth configurations func (os *OAuthService) InitializeConfigs(googleClientID, googleClientSecret, githubClientID, githubClientSecret, discordClientID, discordClientSecret, baseURL string) { // Google OAuth os.googleConfig = &oauth2.Config{ ClientID: googleClientID, ClientSecret: googleClientSecret, RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/google/callback", baseURL), Scopes: []string{ "https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile", }, Endpoint: google.Endpoint, } // GitHub OAuth os.githubConfig = &oauth2.Config{ ClientID: githubClientID, ClientSecret: githubClientSecret, RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/github/callback", baseURL), Scopes: []string{"user:email", "read:user"}, Endpoint: oauth2.Endpoint{ AuthURL: "https://github.com/login/oauth/authorize", TokenURL: "https://github.com/login/oauth/access_token", }, } // Discord OAuth os.discordConfig = &oauth2.Config{ ClientID: discordClientID, ClientSecret: discordClientSecret, RedirectURL: fmt.Sprintf("%s/api/v1/auth/oauth/discord/callback", baseURL), Scopes: []string{"identify", "email"}, Endpoint: oauth2.Endpoint{ AuthURL: "https://discord.com/api/oauth2/authorize", TokenURL: "https://discord.com/api/oauth2/token", }, } os.logger.Info("OAuth configs initialized") } // GenerateStateToken generates a secure state token for CSRF protection func (os *OAuthService) GenerateStateToken(provider, redirectURL string) (string, error) { // Generate random token tokenBytes := make([]byte, 32) _, err := rand.Read(tokenBytes) if err != nil { return "", err } stateToken := base64.URLEncoding.EncodeToString(tokenBytes) // Store in database ctx := context.Background() expiresAt := time.Now().Add(10 * time.Minute) _, err = os.db.ExecContext(ctx, ` INSERT INTO oauth_states (state_token, provider, redirect_url, expires_at) VALUES ($1, $2, $3, $4) `, stateToken, provider, redirectURL, expiresAt) if err != nil { return "", err } os.logger.Debug("State token generated", zap.String("provider", provider)) return stateToken, nil } // ValidateStateToken validates and consumes a state token func (os *OAuthService) ValidateStateToken(stateToken string) (*OAuthState, error) { ctx := context.Background() var state OAuthState err := os.db.QueryRowContext(ctx, ` SELECT id, state_token, provider, redirect_url, expires_at, created_at FROM oauth_states WHERE state_token = $1 `, stateToken).Scan( &state.ID, &state.StateToken, &state.Provider, &state.RedirectURL, &state.ExpiresAt, &state.CreatedAt, ) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("invalid state token") } return nil, err } // Check if expired if time.Now().After(state.ExpiresAt) { return nil, fmt.Errorf("state token expired") } // Delete used token os.db.ExecContext(ctx, `DELETE FROM oauth_states WHERE id = $1`, state.ID) return &state, nil } // GetAuthURL returns the OAuth provider authorization URL func (os *OAuthService) GetAuthURL(provider string) (string, error) { var config *oauth2.Config var err error switch provider { case "google": if os.googleConfig == nil { return "", fmt.Errorf("Google OAuth not configured") } config = os.googleConfig case "github": if os.githubConfig == nil { return "", fmt.Errorf("GitHub OAuth not configured") } config = os.githubConfig case "discord": if os.discordConfig == nil { return "", fmt.Errorf("Discord OAuth not configured") } config = os.discordConfig default: return "", fmt.Errorf("unknown provider: %s", provider) } // Generate state token stateToken, err := os.GenerateStateToken(provider, "") if err != nil { return "", err } // Return authorization URL url := config.AuthCodeURL(stateToken, oauth2.AccessTypeOffline) return url, nil } // HandleCallback processes the OAuth callback func (os *OAuthService) HandleCallback(provider, code, state string) (*OAuthUser, string, error) { // Validate state _, err := os.ValidateStateToken(state) if err != nil { return nil, "", err } var config *oauth2.Config switch provider { case "google": config = os.googleConfig case "github": config = os.githubConfig case "discord": config = os.discordConfig default: return nil, "", fmt.Errorf("unknown provider: %s", provider) } // Exchange code for token token, err := config.Exchange(context.Background(), code) if err != nil { return nil, "", err } // Get user info from provider oauthUser, err := os.getUserInfo(provider, token.AccessToken) if err != nil { return nil, "", err } // Check if user already exists (by provider account or email) existingUser, err := os.getOrCreateUser(oauthUser) if err != nil { return nil, "", err } // Save/update OAuth account err = os.saveOAuthAccount(oauthUser, existingUser.ID, token) if err != nil { return nil, "", err } // Generate JWT for the user jwtToken, err := os.generateJWT(existingUser.ID) if err != nil { return nil, "", err } return &OAuthUser{ ID: existingUser.ID, Email: existingUser.Email, }, jwtToken, nil } // OAuthUser represents an OAuth authenticated user type OAuthUser struct { ID uuid.UUID `json:"id"` Email string `json:"email"` Username string `json:"username"` Name string `json:"name"` Avatar string `json:"avatar"` ProviderID string `json:"-"` // Added to store provider ID } // OAuthUserInfo represents a user from the database type OAuthUserInfo struct { ID uuid.UUID `json:"id" db:"id"` Email string `json:"email" db:"email"` Username string `json:"username" db:"username"` } // getUserInfo fetches user information from the OAuth provider func (os *OAuthService) getUserInfo(provider, accessToken string) (*OAuthUser, error) { var apiURL string switch provider { case "google": apiURL = "https://www.googleapis.com/oauth2/v2/userinfo" case "github": apiURL = "https://api.github.com/user" case "discord": apiURL = "https://discord.com/api/users/@me" default: return nil, fmt.Errorf("unknown provider: %s", provider) } req, err := http.NewRequest("GET", apiURL, nil) if err != nil { return nil, err } // Add auth header if provider == "github" { req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) } else if provider == "discord" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) } else { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) } client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } // Parse response based on provider var oauthUser OAuthUser switch provider { case "google": var userInfo struct { ID string `json:"id"` Email string `json:"email"` Name string `json:"name"` } if err := json.Unmarshal(body, &userInfo); err != nil { return nil, err } oauthUser.Username = userInfo.Email oauthUser.Email = userInfo.Email oauthUser.Name = userInfo.Name oauthUser.ProviderID = userInfo.ID case "github": var userInfo struct { ID int `json:"id"` Login string `json:"login"` Email string `json:"email"` Name string `json:"name"` } if err := json.Unmarshal(body, &userInfo); err != nil { return nil, err } oauthUser.Username = userInfo.Login oauthUser.Email = userInfo.Email oauthUser.Name = userInfo.Name oauthUser.ProviderID = fmt.Sprintf("%d", userInfo.ID) case "discord": var userInfo struct { ID string `json:"id"` Username string `json:"username"` Email string `json:"email"` Avatar string `json:"avatar"` } if err := json.Unmarshal(body, &userInfo); err != nil { return nil, err } oauthUser.Username = userInfo.Username oauthUser.Email = userInfo.Email oauthUser.Name = userInfo.Username oauthUser.Avatar = userInfo.Avatar oauthUser.ProviderID = userInfo.ID } return &oauthUser, nil } // getOrCreateUser gets an existing user or creates a new one func (os *OAuthService) getOrCreateUser(oauthUser *OAuthUser) (*OAuthUserInfo, error) { ctx := context.Background() // Try to find existing user by email var user OAuthUserInfo err := os.db.QueryRowContext(ctx, ` SELECT id, email, username FROM users WHERE email = $1 `, oauthUser.Email).Scan(&user.ID, &user.Email, &user.Username) if err == nil { return &user, nil } if err != sql.ErrNoRows { return nil, err } // T0219: Generate slug from username slug := utils.Slugify(oauthUser.Username) // Ensure slug is unique by appending a number if needed baseSlug := slug counter := 1 for { var count int err := os.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM users WHERE slug = $1", slug).Scan(&count) if err == nil && count == 0 { break } slug = fmt.Sprintf("%s%d", baseSlug, counter) counter++ if counter > 1000 { slug = fmt.Sprintf("user_%d", time.Now().Unix()) break } } // Create new user // ID est généré automatiquement par gen_random_uuid() insertQuery := ` INSERT INTO users (email, username, slug, is_verified, is_active, created_at, updated_at) VALUES ($1, $2, $3, TRUE, TRUE, NOW(), NOW()) RETURNING id, email, username ` err = os.db.QueryRowContext(ctx, insertQuery, oauthUser.Email, oauthUser.Username, slug).Scan( &user.ID, &user.Email, &user.Username, ) if err != nil { return nil, err } os.logger.Info("New user created via OAuth", zap.String("email", oauthUser.Email), zap.String("provider", "oauth"), ) return &user, nil } // saveOAuthAccount saves or updates OAuth account information // Uses federated_identities table func (os *OAuthService) saveOAuthAccount(oauthUser *OAuthUser, userID uuid.UUID, token *oauth2.Token) error { ctx := context.Background() // Check if OAuth account already exists var existingID uuid.UUID err := os.db.QueryRowContext(ctx, ` SELECT id FROM federated_identities WHERE user_id = $1 AND provider_id = $2 `, userID, oauthUser.ProviderID).Scan(&existingID) if err == nil { // Update existing _, err = os.db.ExecContext(ctx, ` UPDATE federated_identities SET email = $1, display_name = $2, access_token = $3, refresh_token = $4, expires_at = $5, updated_at = NOW() WHERE id = $6 `, oauthUser.Email, oauthUser.Name, token.AccessToken, token.RefreshToken, token.Expiry, existingID) return err } if err != sql.ErrNoRows { return err } // Insert new _, err = os.db.ExecContext(ctx, ` INSERT INTO federated_identities (id, user_id, provider, provider_id, email, display_name, avatar_url, access_token, refresh_token, expires_at, created_at, updated_at) VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8, $9, NOW(), NOW()) `, userID, "oauth", oauthUser.ProviderID, oauthUser.Email, oauthUser.Name, oauthUser.Avatar, token.AccessToken, token.RefreshToken, token.Expiry) return err } // generateJWT generates a JWT token for the user func (os *OAuthService) generateJWT(userID uuid.UUID) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "user_id": userID.String(), "sub": userID.String(), "exp": time.Now().Add(time.Hour * 24).Unix(), }) return token.SignedString(os.jwtSecret) }