151 lines
4.4 KiB
Go
151 lines
4.4 KiB
Go
package handlers
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
|
|
"veza-backend-api/internal/services"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// OAuthServiceInterface defines the methods needed for OAuth handlers
|
|
type OAuthServiceInterface interface {
|
|
GetAuthURL(provider string) (string, error)
|
|
HandleCallback(provider, code, state string) (*services.OAuthUser, string, error)
|
|
}
|
|
|
|
// OAuthHandlers handles OAuth authentication flows
|
|
type OAuthHandlers struct {
|
|
oauthService OAuthServiceInterface
|
|
logger interface{}
|
|
allowedRedirectOrigins []string // SECURITY: allowlist for OAuth redirect URLs
|
|
}
|
|
|
|
// OAuthHandlersInstance is the global instance
|
|
var OAuthHandlersInstance *OAuthHandlers
|
|
|
|
// InitOAuthHandlers initializes the OAuth handlers
|
|
func InitOAuthHandlers(oauthService *services.OAuthService) {
|
|
OAuthHandlersInstance = &OAuthHandlers{
|
|
oauthService: oauthService,
|
|
}
|
|
}
|
|
|
|
// NewOAuthHandler creates a new OAuth handler instance
|
|
// BE-API-042: Implement OAuth callback endpoint
|
|
func NewOAuthHandler(oauthService *services.OAuthService, logger interface{}, allowedRedirectOrigins []string) *OAuthHandlers {
|
|
return &OAuthHandlers{
|
|
oauthService: oauthService,
|
|
logger: logger,
|
|
allowedRedirectOrigins: allowedRedirectOrigins,
|
|
}
|
|
}
|
|
|
|
// NewOAuthHandlerWithInterface creates a new OAuth handler instance with an interface (for testing)
|
|
func NewOAuthHandlerWithInterface(oauthService OAuthServiceInterface, logger interface{}) *OAuthHandlers {
|
|
return &OAuthHandlers{
|
|
oauthService: oauthService,
|
|
logger: logger,
|
|
allowedRedirectOrigins: nil, // Tests use nil = dev fallback
|
|
}
|
|
}
|
|
|
|
// GetOAuthProviders returns available OAuth providers
|
|
func (oh *OAuthHandlers) GetOAuthProviders(c *gin.Context) {
|
|
providers := []map[string]interface{}{
|
|
{
|
|
"name": "Google",
|
|
"id": "google",
|
|
"authorizeUrl": "/api/v1/auth/oauth/google",
|
|
"icon": "google",
|
|
},
|
|
{
|
|
"name": "GitHub",
|
|
"id": "github",
|
|
"authorizeUrl": "/api/v1/auth/oauth/github",
|
|
"icon": "github",
|
|
},
|
|
{
|
|
"name": "Discord",
|
|
"id": "discord",
|
|
"authorizeUrl": "/api/v1/auth/oauth/discord",
|
|
"icon": "discord",
|
|
},
|
|
}
|
|
|
|
RespondSuccess(c, http.StatusOK, gin.H{
|
|
"providers": providers,
|
|
})
|
|
}
|
|
|
|
// InitiateOAuth initiates OAuth flow
|
|
func (oh *OAuthHandlers) InitiateOAuth(c *gin.Context) {
|
|
provider := c.Param("provider")
|
|
|
|
// Get authorization URL
|
|
authURL, err := oh.oauthService.GetAuthURL(provider)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Redirect to OAuth provider
|
|
c.Redirect(http.StatusTemporaryRedirect, authURL)
|
|
}
|
|
|
|
// OAuthCallback handles OAuth callback
|
|
func (oh *OAuthHandlers) OAuthCallback(c *gin.Context) {
|
|
provider := c.Param("provider")
|
|
code := c.Query("code")
|
|
state := c.Query("state")
|
|
|
|
if code == "" || state == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "missing code or state"})
|
|
return
|
|
}
|
|
|
|
// Handle callback
|
|
user, token, err := oh.oauthService.HandleCallback(provider, code, state)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Redirect to frontend with token
|
|
frontendURL := os.Getenv("FRONTEND_URL")
|
|
if frontendURL == "" {
|
|
frontendURL = "http://localhost:5173" // Fallback for development
|
|
}
|
|
// SECURITY: Validate redirect URL against allowlist to prevent open redirect
|
|
if !oh.isAllowedRedirectOrigin(frontendURL) {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid redirect configuration"})
|
|
return
|
|
}
|
|
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s&user_id=%s", strings.TrimSuffix(frontendURL, "/"), token, user.ID.String())
|
|
|
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
|
}
|
|
|
|
// isAllowedRedirectOrigin validates that the frontend URL is in the allowlist. Returns true if allowed.
|
|
func (oh *OAuthHandlers) isAllowedRedirectOrigin(frontendURL string) bool {
|
|
if len(oh.allowedRedirectOrigins) == 0 {
|
|
// Development: allow localhost
|
|
return strings.HasPrefix(frontendURL, "http://localhost") || strings.HasPrefix(frontendURL, "http://127.0.0.1")
|
|
}
|
|
parsed, err := url.Parse(frontendURL)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
origin := fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host)
|
|
for _, allowed := range oh.allowedRedirectOrigins {
|
|
allowed = strings.TrimSpace(allowed)
|
|
if allowed == "*" || allowed == origin {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|