diff --git a/veza-stream-server/src/routes/api.rs b/veza-stream-server/src/routes/api.rs index 3ce636dc3..a9faefea6 100644 --- a/veza-stream-server/src/routes/api.rs +++ b/veza-stream-server/src/routes/api.rs @@ -310,20 +310,17 @@ async fn stream_audio( }) } -// Handler WebSocket wrapper pour utiliser avec AppState +// Handler WebSocket wrapper — delegates to websocket_handler with full AppState for JWT auth async fn websocket_handler_wrapper( ws: axum::extract::ws::WebSocketUpgrade, query: Query, headers: HeaderMap, State(state): State, ) -> Response { - websocket_handler( - ws, - query, - headers, - State(state.websocket_manager.clone()), - ) - .await + match websocket_handler(ws, query, headers, State(state)).await { + Ok(response) => response, + Err((status, json)) => (status, json).into_response(), + } } // Handler pour la master playlist HLS diff --git a/veza-stream-server/src/streaming/websocket.rs b/veza-stream-server/src/streaming/websocket.rs index dbdee62b5..5fdb1fee3 100644 --- a/veza-stream-server/src/streaming/websocket.rs +++ b/veza-stream-server/src/streaming/websocket.rs @@ -3,9 +3,11 @@ use axum::{ ws::{Message, WebSocket, WebSocketUpgrade}, Query, State, }, - http::HeaderMap, - response::Response, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, + Json, }; +use crate::AppState; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, @@ -214,6 +216,7 @@ pub enum WebSocketCommand { pub struct WebSocketConnection { pub id: Uuid, pub user_id: Option, + pub authenticated: bool, pub ip_address: String, pub connected_at: SystemTime, pub last_activity: SystemTime, @@ -275,6 +278,7 @@ impl WebSocketManager { let connection = WebSocketConnection { id: connection_id, user_id: user_id.clone(), + authenticated: user_id.is_some(), ip_address: ip_address.clone(), connected_at: SystemTime::now(), last_activity: SystemTime::now(), @@ -788,13 +792,16 @@ pub struct WebSocketQuery { } /// Handler pour les connexions WebSocket avec authentification JWT +/// +/// Requires a valid JWT token either via `?token=` query param or +/// `Authorization: Bearer ` header. Rejects unauthenticated connections. pub async fn websocket_handler( ws: WebSocketUpgrade, Query(params): Query, headers: HeaderMap, - State(ws_manager): State>, -) -> Response { - // Extraire le token JWT depuis les query params ou headers + State(state): State, +) -> Result)> { + // Extract JWT token from query params or Authorization header let token = params.token.or_else(|| { headers .get("authorization") @@ -803,7 +810,7 @@ pub async fn websocket_handler( .map(|s| s.to_string()) }); - // Extraire l'adresse IP réelle + // Extract real IP address let ip_address = headers .get("x-forwarded-for") .or_else(|| headers.get("x-real-ip")) @@ -811,21 +818,47 @@ pub async fn websocket_handler( .unwrap_or("127.0.0.1") .to_string(); - // Si un token est fourni, on le valide (pour l'instant, on accepte la connexion) - // En production, on validerait le token avec AuthManager - let user_id = params.user_id.or_else(|| { - // Si un token est fourni, on pourrait extraire user_id du token - // Pour l'instant, on utilise le user_id fourni dans les params - None - }); + // Require a token — reject unauthenticated connections + let token = token.ok_or_else(|| { + tracing::warn!("WebSocket connection rejected: no token provided from {}", ip_address); + ( + StatusCode::UNAUTHORIZED, + Json(serde_json::json!({"error": "Authentication token required"})), + ) + })?; + + // Validate the JWT token via AuthManager + let validation_result = state.auth_manager.validate_token(&token).await; + + if !validation_result.valid { + let reason = validation_result.error.unwrap_or_else(|| "Invalid token".to_string()); + tracing::warn!("WebSocket auth failed from {}: {}", ip_address, reason); + return Err(( + StatusCode::UNAUTHORIZED, + Json(serde_json::json!({"error": reason})), + )); + } + + // Extract user_id from validated token claims (not from query params) + let claims = validation_result.claims.ok_or_else(|| { + tracing::error!("Token valid but claims missing — this should not happen"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({"error": "Internal authentication error"})), + ) + })?; + let user_id = Some(claims.sub.clone()); tracing::info!( - "Nouvelle connexion WebSocket demandée pour utilisateur: {:?} depuis {}", + "WebSocket connection authenticated for user: {:?} from {}", user_id, ip_address ); - ws_manager.handle_websocket(ws, user_id, ip_address).await + Ok(state + .websocket_manager + .handle_websocket(ws, user_id, ip_address) + .await) } #[cfg(test)]