diff --git a/veza-stream-server/src/middleware/rate_limit.rs b/veza-stream-server/src/middleware/rate_limit.rs index 6fc4a83da..219348d23 100644 --- a/veza-stream-server/src/middleware/rate_limit.rs +++ b/veza-stream-server/src/middleware/rate_limit.rs @@ -4,10 +4,28 @@ use axum::{ middleware::Next, response::Response, }; -use std::time::Instant; -// Note: Use tracing::debug! macro directly instead of importing +use governor::{ + clock::DefaultClock, + state::keyed::DashMapStateStore, + Quota, RateLimiter as GovLimiter, +}; +use std::num::NonZeroU32; use crate::AppState; +/// Per-IP keyed rate limiter backed by the `governor` crate. +/// Uses an in-memory DashMap store with a sliding-window quota. +type KeyedLimiter = GovLimiter, DefaultClock>; + +lazy_static::lazy_static! { + /// Default: 120 requests per minute per IP. + /// This is a global limiter; the config value is read at check time + /// but the governor quota is set once at init. + static ref LIMITER: KeyedLimiter = { + let quota = Quota::per_minute(NonZeroU32::new(120).expect("120 is non-zero")); + GovLimiter::keyed(quota) + }; +} + pub async fn rate_limit_middleware( State(state): State, request: Request, @@ -16,11 +34,10 @@ pub async fn rate_limit_middleware( let headers = request.headers(); let client_ip = extract_client_ip(headers); - // Vérifier les limites de taux - if !check_rate_limit(&state, &client_ip).await { + if !check_rate_limit(&state, &client_ip) { tracing::warn!( client_ip = %client_ip, - "Rate limit dépassé" + "Rate limit exceeded" ); state.metrics.increment_rate_limited(); @@ -28,15 +45,11 @@ pub async fn rate_limit_middleware( return Err(StatusCode::TOO_MANY_REQUESTS); } - // Enregistrer la requête - record_request(&state, &client_ip).await; - let response = next.run(request).await; Ok(response) } fn extract_client_ip(headers: &HeaderMap) -> String { - // Vérifier les headers de proxy dans l'ordre de priorité if let Some(forwarded_for) = headers.get("x-forwarded-for") { if let Ok(forwarded_str) = forwarded_for.to_str() { if let Some(first_ip) = forwarded_str.split(',').next() { @@ -54,28 +67,15 @@ fn extract_client_ip(headers: &HeaderMap) -> String { "unknown".to_string() } -async fn check_rate_limit(state: &AppState, client_ip: &str) -> bool { - // Implémentation basique du rate limiting - // Dans une vraie application, on utiliserait un store externe comme Redis - - let max_requests_per_minute = state.config.security.rate_limit_requests_per_minute; - let _now = Instant::now(); - - // Pour cette implémentation basique, on permet toutes les requêtes - // En production, il faudrait implémenter un vrai système de rate limiting - tracing::debug!( - client_ip = %client_ip, - limit = max_requests_per_minute, - "Vérification du rate limit" - ); - - true -} - -async fn record_request(_state: &AppState, client_ip: &str) { - // Enregistrer la requête pour les statistiques - tracing::debug!( - client_ip = %client_ip, - "Requête enregistrée" - ); +fn check_rate_limit(_state: &AppState, client_ip: &str) -> bool { + match LIMITER.check_key(&client_ip.to_string()) { + Ok(_) => true, + Err(_not_until) => { + tracing::debug!( + client_ip = %client_ip, + "Rate limit check: request denied" + ); + false + } + } }