veza/veza-backend-api/internal/handlers/oauth_handlers.go
2026-03-05 23:03:43 +01:00

212 lines
6.9 KiB
Go

package handlers
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"veza-backend-api/internal/config"
"veza-backend-api/internal/models"
"veza-backend-api/internal/services"
"github.com/gin-gonic/gin"
)
// OAuthServiceInterface defines the methods needed for OAuth handlers
type OAuthServiceInterface interface {
GetAuthURL(provider string) (string, error)
HandleCallback(ctx context.Context, provider, code, state, ipAddress, userAgent string) (*services.OAuthUser, *models.TokenPair, string, error)
GetAvailableProviders() []string
}
// OAuthHandlers handles OAuth authentication flows
type OAuthHandlers struct {
oauthService OAuthServiceInterface
logger interface{}
allowedRedirectOrigins []string // SECURITY: allowlist for OAuth redirect URLs
frontendURL string // URL du frontend pour redirect OAuth (depuis config)
cfg *config.Config
}
// OAuthHandlersInstance is the global instance
var OAuthHandlersInstance *OAuthHandlers
// InitOAuthHandlers initializes the OAuth handlers
func InitOAuthHandlers(oauthService *services.OAuthService) {
OAuthHandlersInstance = &OAuthHandlers{
oauthService: oauthService,
}
}
// NewOAuthHandler creates a new OAuth handler instance
// BE-API-042: Implement OAuth callback endpoint
// frontendURL: from config.FrontendURL (FRONTEND_URL or VITE_FRONTEND_URL env)
func NewOAuthHandler(oauthService *services.OAuthService, logger interface{}, allowedRedirectOrigins []string, frontendURL string, cfg *config.Config) *OAuthHandlers {
return &OAuthHandlers{
oauthService: oauthService,
logger: logger,
allowedRedirectOrigins: allowedRedirectOrigins,
frontendURL: frontendURL,
cfg: cfg,
}
}
// NewOAuthHandlerWithInterface creates a new OAuth handler instance with an interface (for testing)
func NewOAuthHandlerWithInterface(oauthService OAuthServiceInterface, logger interface{}, cfg *config.Config) *OAuthHandlers {
return &OAuthHandlers{
oauthService: oauthService,
logger: logger,
allowedRedirectOrigins: nil, // Tests use nil = dev fallback
frontendURL: "http://localhost:5173", // Tests use localhost
cfg: cfg,
}
}
// providerMeta defines display metadata for each OAuth provider
var providerMeta = map[string]map[string]interface{}{
"google": {
"name": "Google",
"id": "google",
"authorizeUrl": "/api/v1/auth/oauth/google",
"icon": "google",
},
"github": {
"name": "GitHub",
"id": "github",
"authorizeUrl": "/api/v1/auth/oauth/github",
"icon": "github",
},
"discord": {
"name": "Discord",
"id": "discord",
"authorizeUrl": "/api/v1/auth/oauth/discord",
"icon": "discord",
},
"spotify": {
"name": "Spotify",
"id": "spotify",
"authorizeUrl": "/api/v1/auth/oauth/spotify",
"icon": "spotify",
},
}
// GetOAuthProviders returns available OAuth providers (only configured ones)
func (oh *OAuthHandlers) GetOAuthProviders(c *gin.Context) {
available := oh.oauthService.GetAvailableProviders()
providers := make([]map[string]interface{}, 0, len(available))
for _, id := range available {
if meta, ok := providerMeta[id]; ok {
providers = append(providers, meta)
}
}
RespondSuccess(c, http.StatusOK, gin.H{
"providers": providers,
})
}
// InitiateOAuth initiates OAuth flow
func (oh *OAuthHandlers) InitiateOAuth(c *gin.Context) {
provider := c.Param("provider")
// Get authorization URL
authURL, err := oh.oauthService.GetAuthURL(provider)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Redirect to OAuth provider
c.Redirect(http.StatusTemporaryRedirect, authURL)
}
// OAuthCallback handles OAuth callback
func (oh *OAuthHandlers) OAuthCallback(c *gin.Context) {
provider := c.Param("provider")
code := c.Query("code")
state := c.Query("state")
if code == "" || state == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing code or state"})
return
}
// Handle callback (VEZA-SEC-001: returns TokenPair, creates session; v0.902: returns validated redirectURL)
user, tokens, redirectURL, err := oh.oauthService.HandleCallback(c.Request.Context(), provider, code, state, c.ClientIP(), c.Request.UserAgent())
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Use validated redirect URL from service as base, or fallback to frontendURL
baseURL := oh.frontendURL
if redirectURL != "" {
frontendURLParsed, _ := url.Parse(redirectURL)
if frontendURLParsed != nil && frontendURLParsed.Scheme != "" && frontendURLParsed.Host != "" {
baseURL = frontendURLParsed.Scheme + "://" + frontendURLParsed.Host
}
}
if !oh.isAllowedRedirectOrigin(baseURL) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid redirect configuration"})
return
}
// VEZA-SEC-001: Set httpOnly cookies (same as login flow)
if oh.cfg != nil {
refreshTokenExpires := 14 * 24 * time.Hour
refreshTokenCookie := &http.Cookie{
Name: "refresh_token",
Value: tokens.RefreshToken,
Path: oh.cfg.CookiePath,
Domain: oh.cfg.CookieDomain,
MaxAge: int(refreshTokenExpires.Seconds()),
HttpOnly: oh.cfg.CookieHttpOnly,
Secure: oh.cfg.ShouldUseSecureCookies(),
SameSite: oh.cfg.GetCookieSameSite(),
}
http.SetCookie(c.Writer, refreshTokenCookie)
accessTokenExpires := 5 * time.Minute // Match JWTService default
if oh.cfg.JWTService != nil {
accessTokenExpires = oh.cfg.JWTService.GetConfig().AccessTokenTTL
}
accessTokenCookie := &http.Cookie{
Name: "access_token",
Value: tokens.AccessToken,
Path: oh.cfg.CookiePath,
Domain: oh.cfg.CookieDomain,
MaxAge: int(accessTokenExpires.Seconds()),
HttpOnly: oh.cfg.CookieHttpOnly,
Secure: oh.cfg.ShouldUseSecureCookies(),
SameSite: oh.cfg.GetCookieSameSite(),
}
http.SetCookie(c.Writer, accessTokenCookie)
}
// Redirect to frontend (tokens in cookies, not URL)
finalRedirect := fmt.Sprintf("%s/auth/callback?user_id=%s", strings.TrimSuffix(baseURL, "/"), user.ID.String())
c.Redirect(http.StatusTemporaryRedirect, finalRedirect)
}
// isAllowedRedirectOrigin validates that the frontend URL is in the allowlist. Returns true if allowed.
func (oh *OAuthHandlers) isAllowedRedirectOrigin(frontendURL string) bool {
if len(oh.allowedRedirectOrigins) == 0 {
// Development: allow localhost
return strings.HasPrefix(frontendURL, "http://localhost") || strings.HasPrefix(frontendURL, "http://127.0.0.1")
}
parsed, err := url.Parse(frontendURL)
if err != nil {
return false
}
origin := fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host)
for _, allowed := range oh.allowedRedirectOrigins {
allowed = strings.TrimSpace(allowed)
if allowed == "*" || allowed == origin {
return true
}
}
return false
}