diff --git a/veza-chat-server/src/main.rs b/veza-chat-server/src/main.rs index 88faaddd5..10805897e 100644 --- a/veza-chat-server/src/main.rs +++ b/veza-chat-server/src/main.rs @@ -1,7 +1,8 @@ use axum::extract::ws::{Message as WsMessage, WebSocket}; use axum::{ - extract::{Query, State, WebSocketUpgrade}, + extract::{Extension, Query, State, WebSocketUpgrade}, http::StatusCode, + middleware, response::Response, routing::{get, post}, Json, Router, @@ -12,7 +13,7 @@ use chat_server::{ delivered_status::DeliveredStatusManager, // Add DeliveredStatusManager error::ChatError, event_bus::RabbitMQEventBus, // Add RabbitMQEventBus import - jwt_manager::JwtManager, + jwt_manager::{AccessTokenClaims, JwtManager}, models::message::Message, // Add Message model read_receipts::ReadReceiptManager, // Add ReadReceiptManager repository::MessageRepository, // Add MessageRepository @@ -51,10 +52,7 @@ struct AppState { struct SendMessageRequest { conversation_id: Uuid, // Add conversation_id content: String, - sender_id: Uuid, // Use Uuid for sender_id - // author: String, // Remove author - // room: Option, // Remove room - // is_direct: Option, // Remove is_direct + // sender_id is now taken from JWT token } /// Paramètres de récupération de messages @@ -276,9 +274,14 @@ async fn main() -> Result<(), ChatError> { "/metrics", get(move || std::future::ready(prometheus_handle.render())), ) // Prometheus metrics - .route("/api/messages/{conversation_id}", get(get_messages)) // Update route + .route("/api/messages/stats", get(get_stats)); + + let api_routes = Router::new() + .route("/api/messages/{conversation_id}", get(get_messages)) .route("/api/messages", post(send_message)) - .route("/api/messages/stats", get(get_stats)) + .route_layer(middleware::from_fn_with_state(state.clone(), auth_middleware)); + + let app = app.merge(api_routes) .route( "/ws", get({ @@ -395,9 +398,19 @@ async fn health_check(State(state): State) -> Json, + Extension(claims): Extension, axum::extract::Path(conversation_id): axum::extract::Path, // Extract conversation_id from path Query(params): Query, ) -> Result>>, StatusCode> { + // Validate User ID from token + let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|_| StatusCode::UNAUTHORIZED)?; + + // Check permission to read conversation + state.permission_service + .can_read_conversation(user_uuid, conversation_id) + .await + .map_err(|_| StatusCode::FORBIDDEN)?; + // Use Message model let limit = params.limit.unwrap_or(50).min(100); @@ -417,12 +430,22 @@ async fn get_messages( #[tracing::instrument(skip(state, payload))] async fn send_message( State(state): State, + Extension(claims): Extension, Json(payload): Json, ) -> Result>, StatusCode> { + // Validate User ID from token + let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|_| StatusCode::UNAUTHORIZED)?; + + // Check permission to send message + state.permission_service + .can_send_message(user_uuid, payload.conversation_id) + .await + .map_err(|_| StatusCode::FORBIDDEN)?; + // Return Uuid let message = state .message_repo - .create(payload.conversation_id, payload.sender_id, &payload.content) // Use message_repo + .create(payload.conversation_id, user_uuid, &payload.content) // Use user_uuid from token .await .map_err(|e| { warn!("Erreur envoi message: {}", e); @@ -437,18 +460,40 @@ async fn send_message( Ok(Json(ApiResponse::success(message.id))) } -/// Statistiques basiques -#[tracing::instrument(skip(_state))] -async fn get_stats(State(_state): State) -> Json>> { - let mut stats = HashMap::new(); - stats.insert("total_messages".to_string(), 2); - stats.insert("active_users".to_string(), 1); - stats.insert("rooms".to_string(), 1); - stats.insert("websocket_enabled".to_string(), 1); - Json(ApiResponse::success(stats)) } +/// Middleware d'authentification +async fn auth_middleware( + State(state): State, + mut req: axum::extract::Request, + next: axum::middleware::Next, +) -> Result { + let auth_header = req.headers() + .get(axum::http::header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()); + + let auth_header = if let Some(auth_header) = auth_header { + auth_header + } else { + return Err(StatusCode::UNAUTHORIZED); + }; + + if !auth_header.starts_with("Bearer ") { + return Err(StatusCode::UNAUTHORIZED); + } + + let token = &auth_header[7..]; + + match state.jwt_manager.validate_access_token(token).await { + Ok(claims) => { + req.extensions_mut().insert(claims); + Ok(next.run(req).await) + } + Err(_) => Err(StatusCode::UNAUTHORIZED), + } +} + /// Gestionnaire de signal d'arrêt (Graceful Shutdown) async fn shutdown_signal() { let ctrl_c = async {