fix(security): validate OAuth redirect URL against allowlist, require auth for internal transcode endpoint
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
ceec16fbd5
commit
f52858f14b
4 changed files with 62 additions and 10 deletions
|
|
@ -472,7 +472,11 @@ func (r *APIRouter) setupAuthRoutes(router *gin.RouterGroup) error {
|
|||
// For MVP, we'll get from environment variables directly
|
||||
baseURL := os.Getenv("BASE_URL")
|
||||
if baseURL == "" {
|
||||
baseURL = "http://localhost:8080" // Default for development
|
||||
appDomain := os.Getenv("APP_DOMAIN")
|
||||
if appDomain == "" {
|
||||
appDomain = "veza.fr"
|
||||
}
|
||||
baseURL = "http://" + appDomain + ":8080"
|
||||
}
|
||||
// Get OAuth credentials from environment variables
|
||||
googleClientID := os.Getenv("OAUTH_GOOGLE_CLIENT_ID")
|
||||
|
|
@ -486,7 +490,7 @@ func (r *APIRouter) setupAuthRoutes(router *gin.RouterGroup) error {
|
|||
oauthService.InitializeConfigs(googleClientID, googleClientSecret, githubClientID, githubClientSecret, discordClientID, discordClientSecret, baseURL)
|
||||
}
|
||||
|
||||
oauthHandler := handlers.NewOAuthHandler(oauthService, r.logger)
|
||||
oauthHandler := handlers.NewOAuthHandler(oauthService, r.logger, r.config.CORSOrigins)
|
||||
oauthGroup := authGroup.Group("/oauth")
|
||||
{
|
||||
// Get available OAuth providers
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ package handlers
|
|||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"veza-backend-api/internal/services"
|
||||
|
||||
|
|
@ -18,8 +20,9 @@ type OAuthServiceInterface interface {
|
|||
|
||||
// OAuthHandlers handles OAuth authentication flows
|
||||
type OAuthHandlers struct {
|
||||
oauthService OAuthServiceInterface
|
||||
logger interface{}
|
||||
oauthService OAuthServiceInterface
|
||||
logger interface{}
|
||||
allowedRedirectOrigins []string // SECURITY: allowlist for OAuth redirect URLs
|
||||
}
|
||||
|
||||
// OAuthHandlersInstance is the global instance
|
||||
|
|
@ -34,18 +37,20 @@ func InitOAuthHandlers(oauthService *services.OAuthService) {
|
|||
|
||||
// NewOAuthHandler creates a new OAuth handler instance
|
||||
// BE-API-042: Implement OAuth callback endpoint
|
||||
func NewOAuthHandler(oauthService *services.OAuthService, logger interface{}) *OAuthHandlers {
|
||||
func NewOAuthHandler(oauthService *services.OAuthService, logger interface{}, allowedRedirectOrigins []string) *OAuthHandlers {
|
||||
return &OAuthHandlers{
|
||||
oauthService: oauthService,
|
||||
logger: logger,
|
||||
oauthService: oauthService,
|
||||
logger: logger,
|
||||
allowedRedirectOrigins: allowedRedirectOrigins,
|
||||
}
|
||||
}
|
||||
|
||||
// NewOAuthHandlerWithInterface creates a new OAuth handler instance with an interface (for testing)
|
||||
func NewOAuthHandlerWithInterface(oauthService OAuthServiceInterface, logger interface{}) *OAuthHandlers {
|
||||
return &OAuthHandlers{
|
||||
oauthService: oauthService,
|
||||
logger: logger,
|
||||
oauthService: oauthService,
|
||||
logger: logger,
|
||||
allowedRedirectOrigins: nil, // Tests use nil = dev fallback
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -115,7 +120,32 @@ func (oh *OAuthHandlers) OAuthCallback(c *gin.Context) {
|
|||
if frontendURL == "" {
|
||||
frontendURL = "http://localhost:5173" // Fallback for development
|
||||
}
|
||||
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s&user_id=%s", frontendURL, token, user.ID.String())
|
||||
// SECURITY: Validate redirect URL against allowlist to prevent open redirect
|
||||
if !oh.isAllowedRedirectOrigin(frontendURL) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid redirect configuration"})
|
||||
return
|
||||
}
|
||||
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s&user_id=%s", strings.TrimSuffix(frontendURL, "/"), token, user.ID.String())
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
}
|
||||
|
||||
// isAllowedRedirectOrigin validates that the frontend URL is in the allowlist. Returns true if allowed.
|
||||
func (oh *OAuthHandlers) isAllowedRedirectOrigin(frontendURL string) bool {
|
||||
if len(oh.allowedRedirectOrigins) == 0 {
|
||||
// Development: allow localhost
|
||||
return strings.HasPrefix(frontendURL, "http://localhost") || strings.HasPrefix(frontendURL, "http://127.0.0.1")
|
||||
}
|
||||
parsed, err := url.Parse(frontendURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
origin := fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host)
|
||||
for _, allowed := range oh.allowedRedirectOrigins {
|
||||
allowed = strings.TrimSpace(allowed)
|
||||
if allowed == "*" || allowed == origin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@
|
|||
SECRET_KEY=your-secret-key-minimum-32-characters-long
|
||||
JWT_SECRET=your-jwt-secret-minimum-32-characters-long
|
||||
|
||||
# Internal API key for /internal/jobs/transcode (backend must send X-Internal-API-Key header)
|
||||
INTERNAL_API_KEY=
|
||||
|
||||
# Database
|
||||
DATABASE_URL=postgres://user:password@host:5432/veza?sslmode=disable
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use axum::{
|
|||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use axum::extract::Request;
|
||||
use std::{collections::HashMap, time::Duration, sync::Arc};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::{
|
||||
|
|
@ -145,8 +146,22 @@ pub fn create_routes(
|
|||
|
||||
async fn internal_transcode_handler(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(payload): Json<serde_json::Value>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
||||
// SECURITY: Require X-Internal-API-Key when INTERNAL_API_KEY is set
|
||||
if let Ok(expected_key) = std::env::var("INTERNAL_API_KEY") {
|
||||
if !expected_key.is_empty() {
|
||||
let provided = headers
|
||||
.get("X-Internal-API-Key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
if provided != expected_key {
|
||||
return Err((StatusCode::UNAUTHORIZED, "Internal API key required".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract fields from payload
|
||||
let track_id = payload
|
||||
.get("track_id")
|
||||
|
|
|
|||
Loading…
Reference in a new issue