veza/veza-chat-server/src/main.rs
2025-12-06 14:45:07 +01:00

498 lines
17 KiB
Rust

use axum::extract::ws::{Message as WsMessage, WebSocket};
use axum::{
extract::{Extension, Query, State, WebSocketUpgrade},
http::StatusCode,
middleware,
response::Response,
routing::{get, post},
Json, Router,
};
use chat_server::{
config::SecurityConfig,
database::pool::create_pool_from_env,
delivered_status::DeliveredStatusManager,
error::ChatError,
event_bus::RabbitMQEventBus,
jwt_manager::{AccessTokenClaims, JwtManager},
models::message::Message,
read_receipts::ReadReceiptManager,
repository::MessageRepository,
security::permission::PermissionService,
services::MessageEditService,
typing_indicator::TypingIndicatorManager,
monitoring::ChatMetrics,
websocket::{
handler::{websocket_handler, WebSocketState},
IncomingMessage, OutgoingMessage, WebSocketManager,
},
};
use futures_util::{FutureExt, SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tracing::{error, info, warn};
use uuid::Uuid;
/// État global de l'application
#[derive(Clone)]
struct AppState {
message_repo: Arc<MessageRepository>,
_ws_manager: Arc<WebSocketManager>,
database_pool: Option<sqlx::PgPool>,
event_bus: Option<Arc<RabbitMQEventBus>>,
config: chat_server::config::Config,
jwt_manager: Arc<JwtManager>,
metrics: Arc<ChatMetrics>,
permission_service: Arc<PermissionService>,
}
/// Requête d'envoi de message
#[derive(Deserialize)]
struct SendMessageRequest {
conversation_id: Uuid,
content: String,
}
/// Paramètres de récupération de messages
#[derive(Deserialize)]
struct GetMessagesQuery {
conversation_id: Uuid,
limit: Option<i64>,
}
/// Réponse API standard
#[derive(Serialize)]
struct ApiResponse<T> {
success: bool,
data: T,
message: Option<String>,
}
impl<T> ApiResponse<T> {
fn success(data: T) -> Self {
Self {
success: true,
data,
message: None,
}
}
}
use metrics_exporter_prometheus::PrometheusBuilder;
#[tokio::main]
async fn main() -> Result<(), ChatError> {
// Configuration du logging avec tracing
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"));
let is_prod = std::env::var("APP_ENV").unwrap_or_default() == "production";
if is_prod {
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.json()
.init();
} else {
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.with_target(true)
.with_file(true)
.with_line_number(true)
.init();
}
// Initialisation des métriques Prometheus
let builder = PrometheusBuilder::new();
let prometheus_handle = builder
.install_recorder()
.map_err(|e| ChatError::configuration_error(&format!("Failed to install Prometheus recorder: {}", e)))?;
info!("🚀 Démarrage du serveur de chat Veza...");
let app_config =
chat_server::config::Config::from_env().map_err(|e| ChatError::Configuration {
message: e.to_string(),
})?;
// Initialisation du pool de connexions à la base de données
let database_pool = match create_pool_from_env(Some(&app_config.database_url)).await {
Ok(pool) => {
info!("✅ Pool de connexions PostgreSQL initialisé avec succès");
Some(pool)
}
Err(e) => {
warn!("⚠️ Échec d'initialisation du pool de connexions: {}. Le serveur continuera sans base de données.", e);
None
}
};
// Database pool est requis pour les managers
let pool_ref = database_pool.as_ref().ok_or_else(|| {
ChatError::configuration_error("Database pool is required but not initialized")
})?;
let message_repo = Arc::new(MessageRepository::new(pool_ref.clone()));
let read_receipt_manager = Arc::new(ReadReceiptManager::new(pool_ref.clone()));
let delivered_status_manager = Arc::new(DeliveredStatusManager::new(pool_ref.clone()));
let typing_indicator_manager = Arc::new(TypingIndicatorManager::new());
let permission_service = Arc::new(PermissionService::new(pool_ref.clone()));
let message_edit_service = Arc::new(MessageEditService::new(pool_ref.clone()));
// Metrics
let metrics = Arc::new(ChatMetrics::new());
// Initialisation de l'Event Bus RabbitMQ
let event_bus = match RabbitMQEventBus::new_with_retry(app_config.rabbit_mq.clone()).await {
Ok(eb) => {
info!("✅ Event Bus RabbitMQ initialisé avec succès");
Some(eb)
}
Err(e) => {
warn!("⚠️ Échec d'initialisation de l'Event Bus RabbitMQ: {}. Le serveur démarrera en mode dégradé (sans Event Bus).", e);
None
}
};
// Initialisation du gestionnaire WebSocket
let ws_manager = Arc::new(WebSocketManager::new());
// Initialisation du gestionnaire JWT
let jwt_secret = chat_server::env::require_env_min_length("JWT_SECRET", 32);
let security_config = SecurityConfig {
jwt_secret,
jwt_access_duration: Duration::from_secs(900), // 15 min
jwt_refresh_duration: Duration::from_secs(86400 * 30), // 30 days
jwt_algorithm: "HS256".to_string(),
jwt_audience: "veza-chat".to_string(),
jwt_issuer: "veza-backend".to_string(),
enable_2fa: false,
totp_window: 1,
content_filtering: false,
password_min_length: 8,
bcrypt_cost: 12,
};
// Créer JwtManager avec pool DB si disponible
let jwt_manager = Arc::new(
if let Some(ref pool) = database_pool {
JwtManager::with_pool(security_config, pool.clone())
.map_err(|e| ChatError::configuration_error(&format!("JWT Manager error: {}", e)))?
} else {
JwtManager::new(security_config)
.map_err(|e| ChatError::configuration_error(&format!("JWT Manager error: {}", e)))?
}
);
// Définir l'adresse d'écoute
let bind_addr = format!("{}:{}", app_config.host, app_config.port);
// État pour les routes HTTP (AppState reste pour compatibilité)
let state = AppState {
message_repo: message_repo.clone(),
_ws_manager: ws_manager.clone(),
database_pool: database_pool.clone(),
event_bus: event_bus.map(Arc::new),
config: app_config.clone(),
jwt_manager: jwt_manager.clone(),
metrics: metrics.clone(),
permission_service: permission_service.clone(),
};
// État pour le handler WebSocket
let ws_state = WebSocketState {
message_repo: message_repo.clone(),
read_receipt_manager: read_receipt_manager.clone(),
delivered_status_manager: delivered_status_manager.clone(),
typing_indicator_manager: typing_indicator_manager.clone(),
message_edit_service: message_edit_service.clone(),
ws_manager: ws_manager.clone(),
jwt_manager: jwt_manager.clone(),
permission_service: permission_service.clone(),
metrics: metrics.clone(),
};
// Démarrer le task de monitoring des typing indicators
let typing_manager_monitor = typing_indicator_manager.clone();
let ws_manager_monitor = ws_manager.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500));
loop {
interval.tick().await;
let expired_changes = typing_manager_monitor.monitor_timeouts().await;
for change in expired_changes {
let typing_message = OutgoingMessage::UserTyping {
conversation_id: change.conversation_id,
user_id: change.user_id,
is_typing: false,
};
if let Err(e) = ws_manager_monitor
.broadcast_to_conversation(change.conversation_id, typing_message)
.await
{
warn!(
conversation_id = %change.conversation_id,
user_id = %change.user_id,
error = %e,
"Erreur lors du broadcast de typing timeout"
);
}
}
}
});
info!("✅ Task de monitoring des typing indicators démarré");
// Configuration des routes
let app = Router::new()
.route("/health", get(health_check))
.route("/healthz", get(health_check))
.route("/readyz", get(readiness_check))
.route(
"/metrics",
get(move || std::future::ready(prometheus_handle.render())),
)
.route("/api/messages/stats", get(get_stats));
let api_routes = Router::new()
.route("/api/messages/{conversation_id}", get(get_messages))
.route("/api/messages", post(send_message))
.route_layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
let app = app.merge(api_routes)
.route(
"/ws",
get({
let ws_state_clone = ws_state.clone();
move |ws: WebSocketUpgrade, query: Query<HashMap<String, String>>| async move {
websocket_handler(ws, query, State(ws_state_clone)).await
}
}),
)
.with_state(state);
// Démarrage du serveur
let listener = TcpListener::bind(&bind_addr)
.await
.map_err(|e| ChatError::configuration_error(&format!("Bind error on {bind_addr}: {e}")))?;
info!("✅ Serveur démarré sur http://{}", bind_addr);
info!("📊 Endpoints disponibles:");
info!(" - GET /health - Vérification de santé");
info!(" - GET /api/messages/:conversation_id - Récupération des messages");
info!(" - POST /api/messages - Envoi de message");
info!(" - GET /api/messages/stats - Statistiques");
info!(" - GET /ws - WebSocket Chat (🆕)");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.map_err(|e| ChatError::configuration_error(&format!("Server error: {e}")))?;
Ok(())
}
/// Endpoint de readiness (DB check)
async fn readiness_check(
State(state): State<AppState>,
) -> Result<Json<ApiResponse<HashMap<String, String>>>, StatusCode> {
let mut info = HashMap::new();
// Check Database
if let Some(pool) = &state.database_pool {
if let Err(e) = sqlx::query("SELECT 1").execute(pool).await {
warn!("Readiness check failed (DB): {}", e);
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
} else {
warn!("Readiness check failed (No DB pool)");
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
// Check RabbitMQ Event Bus
if state.config.rabbit_mq.enable {
if let Some(ref event_bus) = state.event_bus {
if !event_bus.is_enabled {
warn!("Readiness check failed (RabbitMQ EventBus not enabled)");
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
} else {
warn!("Readiness check failed (RabbitMQ EventBus not initialized but enabled in config)");
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
}
info.insert("status".to_string(), "ready".to_string());
Ok(Json(ApiResponse::success(info)))
}
/// Endpoint de vérification de santé
#[tracing::instrument(skip(state))]
async fn health_check(State(state): State<AppState>) -> Json<ApiResponse<HashMap<String, String>>> {
let mut info = HashMap::new();
info.insert("status".to_string(), "healthy".to_string());
info.insert("service".to_string(), "veza-chat-server".to_string());
info.insert("version".to_string(), "0.3.0".to_string());
info.insert("websocket".to_string(), "enabled".to_string());
if let Some(pool) = &state.database_pool {
match sqlx::query("SELECT 1").execute(pool).await {
Ok(_) => { info.insert("database".to_string(), "connected".to_string()); }
Err(e) => { info.insert("database".to_string(), format!("error: {}", e)); }
}
} else {
info.insert("database".to_string(), "not_configured".to_string());
}
if let Some(event_bus) = &state.event_bus {
if event_bus.is_enabled {
info.insert("rabbitmq".to_string(), "connected".to_string());
} else {
info.insert("rabbitmq".to_string(), "disabled".to_string());
}
} else {
if state.config.rabbit_mq.enable {
info.insert("rabbitmq".to_string(), "disconnected".to_string());
} else {
info.insert("rabbitmq".to_string(), "not_configured".to_string());
}
}
Json(ApiResponse::success(info))
}
/// Récupération des messages
#[tracing::instrument(skip(state, params))]
async fn get_messages(
State(state): State<AppState>,
Extension(claims): Extension<AccessTokenClaims>,
axum::extract::Path(conversation_id): axum::extract::Path<Uuid>,
Query(params): Query<GetMessagesQuery>,
) -> Result<Json<ApiResponse<Vec<Message>>>, StatusCode> {
let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|_| StatusCode::UNAUTHORIZED)?;
state.permission_service
.can_read_conversation(user_uuid, conversation_id)
.await
.map_err(|_| StatusCode::FORBIDDEN)?;
let limit = params.limit.unwrap_or(50).min(100);
let messages = state
.message_repo
.get_conversation_messages(conversation_id, limit)
.await
.map_err(|e| {
warn!("Erreur récupération messages conversation: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(ApiResponse::success(messages)))
}
/// Envoi de message
#[tracing::instrument(skip(state, payload))]
async fn send_message(
State(state): State<AppState>,
Extension(claims): Extension<AccessTokenClaims>,
Json(payload): Json<SendMessageRequest>,
) -> Result<Json<ApiResponse<Uuid>>, StatusCode> {
let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|_| StatusCode::UNAUTHORIZED)?;
state.permission_service
.can_send_message(user_uuid, payload.conversation_id)
.await
.map_err(|_| StatusCode::FORBIDDEN)?;
let message = state
.message_repo
.create(payload.conversation_id, user_uuid, &payload.content)
.await
.map_err(|e| {
warn!("Erreur envoi message: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
info!("✅ Message envoyé - ID: {:?}, sender: {:?}", message.id, message.sender_id);
Ok(Json(ApiResponse::success(message.id)))
}
/// Statistiques avec métriques réelles (Memory/CPU)
#[tracing::instrument(skip(state))]
async fn get_stats(State(state): State<AppState>) -> Json<ApiResponse<HashMap<String, serde_json::Value>>> {
let mut stats = HashMap::new();
// Récupérer les métriques système via metrics
let (memory_mb, cpu) = state.metrics.get_system_metrics().await;
stats.insert("active_users".to_string(), serde_json::json!(0)); // Placeholder for active users
stats.insert("server_memory_mb".to_string(), serde_json::json!(memory_mb));
stats.insert("server_cpu_percent".to_string(), serde_json::json!(cpu));
stats.insert("websocket_enabled".to_string(), serde_json::json!(true));
Json(ApiResponse::success(stats))
}
/// Middleware d'authentification
async fn auth_middleware(
State(state): State<AppState>,
mut req: axum::extract::Request,
next: axum::middleware::Next,
) -> Result<axum::response::Response, StatusCode> {
let auth_header = req.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let auth_header = if let Some(auth_header) = auth_header {
auth_header
} else {
return Err(StatusCode::UNAUTHORIZED);
};
if !auth_header.starts_with("Bearer ") {
return Err(StatusCode::UNAUTHORIZED);
}
let token = &auth_header[7..];
match state.jwt_manager.validate_access_token(token).await {
Ok(claims) => {
req.extensions_mut().insert(claims);
Ok(next.run(req).await)
}
Err(_) => Err(StatusCode::UNAUTHORIZED),
}
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
info!("🛑 Signal d'arrêt reçu, fermeture gracieuse...");
}