use axum::extract::ws::{Message as WsMessage, WebSocket}; use axum::{ extract::{Extension, Query, State, WebSocketUpgrade}, http::StatusCode, middleware, response::Response, routing::{get, post}, Json, Router, }; use chat_server::{ config::SecurityConfig, database::pool::create_pool_from_env, delivered_status::DeliveredStatusManager, error::ChatError, event_bus::RabbitMQEventBus, jwt_manager::{AccessTokenClaims, JwtManager}, models::message::Message, read_receipts::ReadReceiptManager, repository::MessageRepository, security::permission::PermissionService, services::MessageEditService, typing_indicator::TypingIndicatorManager, monitoring::ChatMetrics, websocket::{ handler::{websocket_handler, WebSocketState}, IncomingMessage, OutgoingMessage, WebSocketManager, }, }; use futures_util::{FutureExt, SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tracing::{error, info, warn}; use uuid::Uuid; /// État global de l'application #[derive(Clone)] struct AppState { message_repo: Arc, _ws_manager: Arc, database_pool: Option, event_bus: Option>, config: chat_server::config::Config, jwt_manager: Arc, metrics: Arc, permission_service: Arc, } /// Requête d'envoi de message #[derive(Deserialize)] struct SendMessageRequest { conversation_id: Uuid, content: String, } /// Paramètres de récupération de messages #[derive(Deserialize)] struct GetMessagesQuery { conversation_id: Uuid, limit: Option, } /// Réponse API standard #[derive(Serialize)] struct ApiResponse { success: bool, data: T, message: Option, } impl ApiResponse { fn success(data: T) -> Self { Self { success: true, data, message: None, } } } use metrics_exporter_prometheus::PrometheusBuilder; #[tokio::main] async fn main() -> Result<(), ChatError> { // Configuration du logging avec tracing let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")); let is_prod = std::env::var("APP_ENV").unwrap_or_default() == "production"; if is_prod { tracing_subscriber::fmt() .with_env_filter(env_filter) .json() .init(); } else { tracing_subscriber::fmt() .with_env_filter(env_filter) .with_target(true) .with_file(true) .with_line_number(true) .init(); } // Initialisation des métriques Prometheus let builder = PrometheusBuilder::new(); let prometheus_handle = builder .install_recorder() .map_err(|e| ChatError::configuration_error(&format!("Failed to install Prometheus recorder: {}", e)))?; info!("🚀 Démarrage du serveur de chat Veza..."); let app_config = chat_server::config::Config::from_env().map_err(|e| ChatError::Configuration { message: e.to_string(), })?; // Initialisation du pool de connexions à la base de données let database_pool = match create_pool_from_env(Some(&app_config.database_url)).await { Ok(pool) => { info!("✅ Pool de connexions PostgreSQL initialisé avec succès"); Some(pool) } Err(e) => { warn!("⚠️ Échec d'initialisation du pool de connexions: {}. Le serveur continuera sans base de données.", e); None } }; // Database pool est requis pour les managers let pool_ref = database_pool.as_ref().ok_or_else(|| { ChatError::configuration_error("Database pool is required but not initialized") })?; let message_repo = Arc::new(MessageRepository::new(pool_ref.clone())); let read_receipt_manager = Arc::new(ReadReceiptManager::new(pool_ref.clone())); let delivered_status_manager = Arc::new(DeliveredStatusManager::new(pool_ref.clone())); let typing_indicator_manager = Arc::new(TypingIndicatorManager::new()); let permission_service = Arc::new(PermissionService::new(pool_ref.clone())); let message_edit_service = Arc::new(MessageEditService::new(pool_ref.clone())); // Metrics let metrics = Arc::new(ChatMetrics::new()); // Initialisation de l'Event Bus RabbitMQ let event_bus = match RabbitMQEventBus::new_with_retry(app_config.rabbit_mq.clone()).await { Ok(eb) => { info!("✅ Event Bus RabbitMQ initialisé avec succès"); Some(eb) } Err(e) => { warn!("⚠️ Échec d'initialisation de l'Event Bus RabbitMQ: {}. Le serveur démarrera en mode dégradé (sans Event Bus).", e); None } }; // Initialisation du gestionnaire WebSocket let ws_manager = Arc::new(WebSocketManager::new()); // Initialisation du gestionnaire JWT let jwt_secret = chat_server::env::require_env_min_length("JWT_SECRET", 32); let security_config = SecurityConfig { jwt_secret, jwt_access_duration: Duration::from_secs(900), // 15 min jwt_refresh_duration: Duration::from_secs(86400 * 30), // 30 days jwt_algorithm: "HS256".to_string(), jwt_audience: "veza-chat".to_string(), jwt_issuer: "veza-backend".to_string(), enable_2fa: false, totp_window: 1, content_filtering: false, password_min_length: 8, bcrypt_cost: 12, }; // Créer JwtManager avec pool DB si disponible let jwt_manager = Arc::new( if let Some(ref pool) = database_pool { JwtManager::with_pool(security_config, pool.clone()) .map_err(|e| ChatError::configuration_error(&format!("JWT Manager error: {}", e)))? } else { JwtManager::new(security_config) .map_err(|e| ChatError::configuration_error(&format!("JWT Manager error: {}", e)))? } ); // Définir l'adresse d'écoute let bind_addr = format!("{}:{}", app_config.host, app_config.port); // État pour les routes HTTP (AppState reste pour compatibilité) let state = AppState { message_repo: message_repo.clone(), _ws_manager: ws_manager.clone(), database_pool: database_pool.clone(), event_bus: event_bus.map(Arc::new), config: app_config.clone(), jwt_manager: jwt_manager.clone(), metrics: metrics.clone(), permission_service: permission_service.clone(), }; // État pour le handler WebSocket let ws_state = WebSocketState { message_repo: message_repo.clone(), read_receipt_manager: read_receipt_manager.clone(), delivered_status_manager: delivered_status_manager.clone(), typing_indicator_manager: typing_indicator_manager.clone(), message_edit_service: message_edit_service.clone(), ws_manager: ws_manager.clone(), jwt_manager: jwt_manager.clone(), permission_service: permission_service.clone(), metrics: metrics.clone(), }; // Démarrer le task de monitoring des typing indicators let typing_manager_monitor = typing_indicator_manager.clone(); let ws_manager_monitor = ws_manager.clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500)); loop { interval.tick().await; let expired_changes = typing_manager_monitor.monitor_timeouts().await; for change in expired_changes { let typing_message = OutgoingMessage::UserTyping { conversation_id: change.conversation_id, user_id: change.user_id, is_typing: false, }; if let Err(e) = ws_manager_monitor .broadcast_to_conversation(change.conversation_id, typing_message) .await { warn!( conversation_id = %change.conversation_id, user_id = %change.user_id, error = %e, "Erreur lors du broadcast de typing timeout" ); } } } }); info!("✅ Task de monitoring des typing indicators démarré"); // Configuration des routes let app = Router::new() .route("/health", get(health_check)) .route("/healthz", get(health_check)) .route("/readyz", get(readiness_check)) .route( "/metrics", get(move || std::future::ready(prometheus_handle.render())), ) .route("/api/messages/stats", get(get_stats)); let api_routes = Router::new() .route("/api/messages/{conversation_id}", get(get_messages)) .route("/api/messages", post(send_message)) .route_layer(middleware::from_fn_with_state(state.clone(), auth_middleware)); let app = app.merge(api_routes) .route( "/ws", get({ let ws_state_clone = ws_state.clone(); move |ws: WebSocketUpgrade, query: Query>| async move { websocket_handler(ws, query, State(ws_state_clone)).await } }), ) .with_state(state); // Démarrage du serveur let listener = TcpListener::bind(&bind_addr) .await .map_err(|e| ChatError::configuration_error(&format!("Bind error on {bind_addr}: {e}")))?; info!("✅ Serveur démarré sur http://{}", bind_addr); info!("📊 Endpoints disponibles:"); info!(" - GET /health - Vérification de santé"); info!(" - GET /api/messages/:conversation_id - Récupération des messages"); info!(" - POST /api/messages - Envoi de message"); info!(" - GET /api/messages/stats - Statistiques"); info!(" - GET /ws - WebSocket Chat (🆕)"); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await .map_err(|e| ChatError::configuration_error(&format!("Server error: {e}")))?; Ok(()) } /// Endpoint de readiness (DB check) async fn readiness_check( State(state): State, ) -> Result>>, StatusCode> { let mut info = HashMap::new(); // Check Database if let Some(pool) = &state.database_pool { if let Err(e) = sqlx::query("SELECT 1").execute(pool).await { warn!("Readiness check failed (DB): {}", e); return Err(StatusCode::SERVICE_UNAVAILABLE); } } else { warn!("Readiness check failed (No DB pool)"); return Err(StatusCode::SERVICE_UNAVAILABLE); } // Check RabbitMQ Event Bus if state.config.rabbit_mq.enable { if let Some(ref event_bus) = state.event_bus { if !event_bus.is_enabled { warn!("Readiness check failed (RabbitMQ EventBus not enabled)"); return Err(StatusCode::SERVICE_UNAVAILABLE); } } else { warn!("Readiness check failed (RabbitMQ EventBus not initialized but enabled in config)"); return Err(StatusCode::SERVICE_UNAVAILABLE); } } info.insert("status".to_string(), "ready".to_string()); Ok(Json(ApiResponse::success(info))) } /// Endpoint de vérification de santé #[tracing::instrument(skip(state))] async fn health_check(State(state): State) -> Json>> { let mut info = HashMap::new(); info.insert("status".to_string(), "healthy".to_string()); info.insert("service".to_string(), "veza-chat-server".to_string()); info.insert("version".to_string(), "0.3.0".to_string()); info.insert("websocket".to_string(), "enabled".to_string()); if let Some(pool) = &state.database_pool { match sqlx::query("SELECT 1").execute(pool).await { Ok(_) => { info.insert("database".to_string(), "connected".to_string()); } Err(e) => { info.insert("database".to_string(), format!("error: {}", e)); } } } else { info.insert("database".to_string(), "not_configured".to_string()); } if let Some(event_bus) = &state.event_bus { if event_bus.is_enabled { info.insert("rabbitmq".to_string(), "connected".to_string()); } else { info.insert("rabbitmq".to_string(), "disabled".to_string()); } } else { if state.config.rabbit_mq.enable { info.insert("rabbitmq".to_string(), "disconnected".to_string()); } else { info.insert("rabbitmq".to_string(), "not_configured".to_string()); } } Json(ApiResponse::success(info)) } /// Récupération des messages #[tracing::instrument(skip(state, params))] async fn get_messages( State(state): State, Extension(claims): Extension, axum::extract::Path(conversation_id): axum::extract::Path, Query(params): Query, ) -> Result>>, StatusCode> { let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|_| StatusCode::UNAUTHORIZED)?; state.permission_service .can_read_conversation(user_uuid, conversation_id) .await .map_err(|_| StatusCode::FORBIDDEN)?; let limit = params.limit.unwrap_or(50).min(100); let messages = state .message_repo .get_conversation_messages(conversation_id, limit) .await .map_err(|e| { warn!("Erreur récupération messages conversation: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; Ok(Json(ApiResponse::success(messages))) } /// Envoi de message #[tracing::instrument(skip(state, payload))] async fn send_message( State(state): State, Extension(claims): Extension, Json(payload): Json, ) -> Result>, StatusCode> { let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|_| StatusCode::UNAUTHORIZED)?; state.permission_service .can_send_message(user_uuid, payload.conversation_id) .await .map_err(|_| StatusCode::FORBIDDEN)?; let message = state .message_repo .create(payload.conversation_id, user_uuid, &payload.content) .await .map_err(|e| { warn!("Erreur envoi message: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; info!("✅ Message envoyé - ID: {:?}, sender: {:?}", message.id, message.sender_id); Ok(Json(ApiResponse::success(message.id))) } /// Statistiques avec métriques réelles (Memory/CPU) #[tracing::instrument(skip(state))] async fn get_stats(State(state): State) -> Json>> { let mut stats = HashMap::new(); // Récupérer les métriques système via metrics let (memory_mb, cpu) = state.metrics.get_system_metrics().await; stats.insert("active_users".to_string(), serde_json::json!(0)); // Placeholder for active users stats.insert("server_memory_mb".to_string(), serde_json::json!(memory_mb)); stats.insert("server_cpu_percent".to_string(), serde_json::json!(cpu)); stats.insert("websocket_enabled".to_string(), serde_json::json!(true)); Json(ApiResponse::success(stats)) } /// Middleware d'authentification async fn auth_middleware( State(state): State, mut req: axum::extract::Request, next: axum::middleware::Next, ) -> Result { let auth_header = req.headers() .get(axum::http::header::AUTHORIZATION) .and_then(|header| header.to_str().ok()); let auth_header = if let Some(auth_header) = auth_header { auth_header } else { return Err(StatusCode::UNAUTHORIZED); }; if !auth_header.starts_with("Bearer ") { return Err(StatusCode::UNAUTHORIZED); } let token = &auth_header[7..]; match state.jwt_manager.validate_access_token(token).await { Ok(claims) => { req.extensions_mut().insert(claims); Ok(next.run(req).await) } Err(_) => Err(StatusCode::UNAUTHORIZED), } } async fn shutdown_signal() { let ctrl_c = async { tokio::signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } info!("🛑 Signal d'arrêt reçu, fermeture gracieuse..."); }