security(chat-server): implement auth middleware and permission checks for HTTP API
This commit is contained in:
parent
a47464509a
commit
99f960140a
1 changed files with 63 additions and 18 deletions
|
|
@ -1,7 +1,8 @@
|
|||
use axum::extract::ws::{Message as WsMessage, WebSocket};
|
||||
use axum::{
|
||||
extract::{Query, State, WebSocketUpgrade},
|
||||
extract::{Extension, Query, State, WebSocketUpgrade},
|
||||
http::StatusCode,
|
||||
middleware,
|
||||
response::Response,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
|
|
@ -12,7 +13,7 @@ use chat_server::{
|
|||
delivered_status::DeliveredStatusManager, // Add DeliveredStatusManager
|
||||
error::ChatError,
|
||||
event_bus::RabbitMQEventBus, // Add RabbitMQEventBus import
|
||||
jwt_manager::JwtManager,
|
||||
jwt_manager::{AccessTokenClaims, JwtManager},
|
||||
models::message::Message, // Add Message model
|
||||
read_receipts::ReadReceiptManager, // Add ReadReceiptManager
|
||||
repository::MessageRepository, // Add MessageRepository
|
||||
|
|
@ -51,10 +52,7 @@ struct AppState {
|
|||
struct SendMessageRequest {
|
||||
conversation_id: Uuid, // Add conversation_id
|
||||
content: String,
|
||||
sender_id: Uuid, // Use Uuid for sender_id
|
||||
// author: String, // Remove author
|
||||
// room: Option<String>, // Remove room
|
||||
// is_direct: Option<bool>, // Remove is_direct
|
||||
// sender_id is now taken from JWT token
|
||||
}
|
||||
|
||||
/// Paramètres de récupération de messages
|
||||
|
|
@ -276,9 +274,14 @@ async fn main() -> Result<(), ChatError> {
|
|||
"/metrics",
|
||||
get(move || std::future::ready(prometheus_handle.render())),
|
||||
) // Prometheus metrics
|
||||
.route("/api/messages/{conversation_id}", get(get_messages)) // Update route
|
||||
.route("/api/messages/stats", get(get_stats));
|
||||
|
||||
let api_routes = Router::new()
|
||||
.route("/api/messages/{conversation_id}", get(get_messages))
|
||||
.route("/api/messages", post(send_message))
|
||||
.route("/api/messages/stats", get(get_stats))
|
||||
.route_layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
|
||||
|
||||
let app = app.merge(api_routes)
|
||||
.route(
|
||||
"/ws",
|
||||
get({
|
||||
|
|
@ -395,9 +398,19 @@ async fn health_check(State(state): State<AppState>) -> Json<ApiResponse<HashMap
|
|||
#[tracing::instrument(skip(state, params))]
|
||||
async fn get_messages(
|
||||
State(state): State<AppState>,
|
||||
Extension(claims): Extension<AccessTokenClaims>,
|
||||
axum::extract::Path(conversation_id): axum::extract::Path<Uuid>, // Extract conversation_id from path
|
||||
Query(params): Query<GetMessagesQuery>,
|
||||
) -> Result<Json<ApiResponse<Vec<Message>>>, StatusCode> {
|
||||
// Validate User ID from token
|
||||
let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
// Check permission to read conversation
|
||||
state.permission_service
|
||||
.can_read_conversation(user_uuid, conversation_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::FORBIDDEN)?;
|
||||
|
||||
// Use Message model
|
||||
let limit = params.limit.unwrap_or(50).min(100);
|
||||
|
||||
|
|
@ -417,12 +430,22 @@ async fn get_messages(
|
|||
#[tracing::instrument(skip(state, payload))]
|
||||
async fn send_message(
|
||||
State(state): State<AppState>,
|
||||
Extension(claims): Extension<AccessTokenClaims>,
|
||||
Json(payload): Json<SendMessageRequest>,
|
||||
) -> Result<Json<ApiResponse<Uuid>>, StatusCode> {
|
||||
// Validate User ID from token
|
||||
let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
// Check permission to send message
|
||||
state.permission_service
|
||||
.can_send_message(user_uuid, payload.conversation_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::FORBIDDEN)?;
|
||||
|
||||
// Return Uuid
|
||||
let message = state
|
||||
.message_repo
|
||||
.create(payload.conversation_id, payload.sender_id, &payload.content) // Use message_repo
|
||||
.create(payload.conversation_id, user_uuid, &payload.content) // Use user_uuid from token
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("Erreur envoi message: {}", e);
|
||||
|
|
@ -437,18 +460,40 @@ async fn send_message(
|
|||
Ok(Json(ApiResponse::success(message.id)))
|
||||
}
|
||||
|
||||
/// Statistiques basiques
|
||||
#[tracing::instrument(skip(_state))]
|
||||
async fn get_stats(State(_state): State<AppState>) -> Json<ApiResponse<HashMap<String, u32>>> {
|
||||
let mut stats = HashMap::new();
|
||||
stats.insert("total_messages".to_string(), 2);
|
||||
stats.insert("active_users".to_string(), 1);
|
||||
stats.insert("rooms".to_string(), 1);
|
||||
stats.insert("websocket_enabled".to_string(), 1);
|
||||
|
||||
Json(ApiResponse::success(stats))
|
||||
}
|
||||
|
||||
/// Middleware d'authentification
|
||||
async fn auth_middleware(
|
||||
State(state): State<AppState>,
|
||||
mut req: axum::extract::Request,
|
||||
next: axum::middleware::Next,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
let auth_header = req.headers()
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|header| header.to_str().ok());
|
||||
|
||||
let auth_header = if let Some(auth_header) = auth_header {
|
||||
auth_header
|
||||
} else {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
};
|
||||
|
||||
if !auth_header.starts_with("Bearer ") {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let token = &auth_header[7..];
|
||||
|
||||
match state.jwt_manager.validate_access_token(token).await {
|
||||
Ok(claims) => {
|
||||
req.extensions_mut().insert(claims);
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
Err(_) => Err(StatusCode::UNAUTHORIZED),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gestionnaire de signal d'arrêt (Graceful Shutdown)
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
|
|
|
|||
Loading…
Reference in a new issue