security(chat-server): implement auth middleware and permission checks for HTTP API

This commit is contained in:
okinrev 2025-12-06 13:18:12 +01:00
parent 76f2677c17
commit 4422e249a2

View file

@ -1,7 +1,8 @@
use axum::extract::ws::{Message as WsMessage, WebSocket};
use axum::{
extract::{Query, State, WebSocketUpgrade},
extract::{Extension, Query, State, WebSocketUpgrade},
http::StatusCode,
middleware,
response::Response,
routing::{get, post},
Json, Router,
@ -12,7 +13,7 @@ use chat_server::{
delivered_status::DeliveredStatusManager, // Add DeliveredStatusManager
error::ChatError,
event_bus::RabbitMQEventBus, // Add RabbitMQEventBus import
jwt_manager::JwtManager,
jwt_manager::{AccessTokenClaims, JwtManager},
models::message::Message, // Add Message model
read_receipts::ReadReceiptManager, // Add ReadReceiptManager
repository::MessageRepository, // Add MessageRepository
@ -51,10 +52,7 @@ struct AppState {
struct SendMessageRequest {
conversation_id: Uuid, // Add conversation_id
content: String,
sender_id: Uuid, // Use Uuid for sender_id
// author: String, // Remove author
// room: Option<String>, // Remove room
// is_direct: Option<bool>, // Remove is_direct
// sender_id is now taken from JWT token
}
/// Paramètres de récupération de messages
@ -276,9 +274,14 @@ async fn main() -> Result<(), ChatError> {
"/metrics",
get(move || std::future::ready(prometheus_handle.render())),
) // Prometheus metrics
.route("/api/messages/{conversation_id}", get(get_messages)) // Update route
.route("/api/messages/stats", get(get_stats));
let api_routes = Router::new()
.route("/api/messages/{conversation_id}", get(get_messages))
.route("/api/messages", post(send_message))
.route("/api/messages/stats", get(get_stats))
.route_layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
let app = app.merge(api_routes)
.route(
"/ws",
get({
@ -395,9 +398,19 @@ async fn health_check(State(state): State<AppState>) -> Json<ApiResponse<HashMap
#[tracing::instrument(skip(state, params))]
async fn get_messages(
State(state): State<AppState>,
Extension(claims): Extension<AccessTokenClaims>,
axum::extract::Path(conversation_id): axum::extract::Path<Uuid>, // Extract conversation_id from path
Query(params): Query<GetMessagesQuery>,
) -> Result<Json<ApiResponse<Vec<Message>>>, StatusCode> {
// Validate User ID from token
let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|_| StatusCode::UNAUTHORIZED)?;
// Check permission to read conversation
state.permission_service
.can_read_conversation(user_uuid, conversation_id)
.await
.map_err(|_| StatusCode::FORBIDDEN)?;
// Use Message model
let limit = params.limit.unwrap_or(50).min(100);
@ -417,12 +430,22 @@ async fn get_messages(
#[tracing::instrument(skip(state, payload))]
async fn send_message(
State(state): State<AppState>,
Extension(claims): Extension<AccessTokenClaims>,
Json(payload): Json<SendMessageRequest>,
) -> Result<Json<ApiResponse<Uuid>>, StatusCode> {
// Validate User ID from token
let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|_| StatusCode::UNAUTHORIZED)?;
// Check permission to send message
state.permission_service
.can_send_message(user_uuid, payload.conversation_id)
.await
.map_err(|_| StatusCode::FORBIDDEN)?;
// Return Uuid
let message = state
.message_repo
.create(payload.conversation_id, payload.sender_id, &payload.content) // Use message_repo
.create(payload.conversation_id, user_uuid, &payload.content) // Use user_uuid from token
.await
.map_err(|e| {
warn!("Erreur envoi message: {}", e);
@ -437,18 +460,40 @@ async fn send_message(
Ok(Json(ApiResponse::success(message.id)))
}
/// Statistiques basiques
#[tracing::instrument(skip(_state))]
async fn get_stats(State(_state): State<AppState>) -> Json<ApiResponse<HashMap<String, u32>>> {
let mut stats = HashMap::new();
stats.insert("total_messages".to_string(), 2);
stats.insert("active_users".to_string(), 1);
stats.insert("rooms".to_string(), 1);
stats.insert("websocket_enabled".to_string(), 1);
Json(ApiResponse::success(stats))
}
/// Middleware d'authentification
async fn auth_middleware(
State(state): State<AppState>,
mut req: axum::extract::Request,
next: axum::middleware::Next,
) -> Result<axum::response::Response, StatusCode> {
let auth_header = req.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let auth_header = if let Some(auth_header) = auth_header {
auth_header
} else {
return Err(StatusCode::UNAUTHORIZED);
};
if !auth_header.starts_with("Bearer ") {
return Err(StatusCode::UNAUTHORIZED);
}
let token = &auth_header[7..];
match state.jwt_manager.validate_access_token(token).await {
Ok(claims) => {
req.extensions_mut().insert(claims);
Ok(next.run(req).await)
}
Err(_) => Err(StatusCode::UNAUTHORIZED),
}
}
/// Gestionnaire de signal d'arrêt (Graceful Shutdown)
async fn shutdown_signal() {
let ctrl_c = async {