//! Module de protection DoS pour WebSocket //! //! Ce module implémente la protection contre les attaques par déni de service //! pour les connexions WebSocket, incluant la limitation des connexions simultanées, //! la taille des messages, et la détection de flood. use crate::error::{ChatError, Result}; use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; use uuid::Uuid; /// Configuration de protection DoS #[derive(Debug, Clone)] pub struct DosProtectionConfig { pub max_connections_per_user: usize, pub max_message_size: usize, pub inactivity_timeout: Duration, pub flood_detection_window: Duration, pub max_messages_per_window: u32, pub connection_rate_limit: u32, pub connection_rate_window: Duration, } impl Default for DosProtectionConfig { fn default() -> Self { Self { max_connections_per_user: 5, max_message_size: 64 * 1024, // 64KB inactivity_timeout: Duration::from_secs(300), // 5 minutes flood_detection_window: Duration::from_secs(10), max_messages_per_window: 10, connection_rate_limit: 10, connection_rate_window: Duration::from_secs(60), } } } /// État de protection DoS pour un utilisateur #[derive(Debug, Clone)] pub struct UserDosState { pub user_id: Uuid, pub active_connections: Vec, pub message_timestamps: Vec, pub connection_timestamps: Vec, pub last_activity: SystemTime, pub is_blocked: bool, pub block_until: Option, } /// Gestionnaire de protection DoS pub struct DosProtectionManager { config: DosProtectionConfig, user_states: Arc>>, ip_states: Arc>>, } /// État de protection DoS pour une IP #[derive(Debug, Clone)] pub struct IpDosState { pub connection_timestamps: Vec, pub is_blocked: bool, pub block_until: Option, } impl DosProtectionManager { /// Crée un nouveau gestionnaire de protection DoS pub fn new(config: DosProtectionConfig) -> Self { Self { config, user_states: Arc::new(RwLock::new(HashMap::new())), ip_states: Arc::new(RwLock::new(HashMap::new())), } } /// Vérifie si une nouvelle connexion est autorisée pour un utilisateur pub async fn check_connection_allowed( &self, user_id: Uuid, connection_id: Uuid, ip_address: &str, ) -> Result { // Vérifier la limitation par IP if !self.check_ip_rate_limit(ip_address).await? { return Ok(false); } // Vérifier la limitation par utilisateur if !self.check_user_connection_limit(user_id, connection_id).await? { return Ok(false); } // Enregistrer la connexion self.record_connection(user_id, connection_id, ip_address).await?; Ok(true) } /// Vérifie la limitation de taux par IP async fn check_ip_rate_limit(&self, ip_address: &str) -> Result { let mut ip_states = self.ip_states.write().await; let now = SystemTime::now(); // Nettoyer les timestamps expirés if let Some(ip_state) = ip_states.get_mut(ip_address) { // Vérifier si l'IP est bloquée if let Some(block_until) = ip_state.block_until { if now < block_until { return Ok(false); } else { ip_state.is_blocked = false; ip_state.block_until = None; } } // Nettoyer les timestamps expirés ip_state.connection_timestamps.retain(|×tamp| { now.duration_since(timestamp) .map(|d| d < self.config.connection_rate_window) .unwrap_or(false) }); // Vérifier la limite de taux if ip_state.connection_timestamps.len() >= self.config.connection_rate_limit as usize { // Bloquer l'IP temporairement ip_state.is_blocked = true; ip_state.block_until = Some(now + Duration::from_secs(300)); // 5 minutes return Ok(false); } ip_state.connection_timestamps.push(now); } else { // Première connexion depuis cette IP let mut ip_state = IpDosState { connection_timestamps: vec![now], is_blocked: false, block_until: None, }; ip_states.insert(ip_address.to_string(), ip_state); } Ok(true) } /// Vérifie la limitation de connexions par utilisateur async fn check_user_connection_limit(&self, user_id: Uuid, connection_id: Uuid) -> Result { let mut user_states = self.user_states.write().await; let now = SystemTime::now(); if let Some(user_state) = user_states.get_mut(&user_id) { // Vérifier si l'utilisateur est bloqué if let Some(block_until) = user_state.block_until { if now < block_until { return Ok(false); } else { user_state.is_blocked = false; user_state.block_until = None; } } // Nettoyer les connexions inactives user_state.active_connections.retain(|&conn_id| { // Dans une implémentation complète, on vérifierait l'état réel de la connexion conn_id != connection_id // Garder toutes les connexions sauf celle qu'on veut ajouter }); // Vérifier la limite de connexions if user_state.active_connections.len() >= self.config.max_connections_per_user { return Ok(false); } user_state.active_connections.push(connection_id); user_state.last_activity = now; } else { // Premier utilisateur let user_state = UserDosState { user_id, active_connections: vec![connection_id], message_timestamps: Vec::new(), connection_timestamps: vec![now], last_activity: now, is_blocked: false, block_until: None, }; user_states.insert(user_id, user_state); } Ok(true) } /// Enregistre une nouvelle connexion async fn record_connection( &self, user_id: Uuid, connection_id: Uuid, ip_address: &str, ) -> Result<()> { // L'enregistrement est déjà fait dans check_user_connection_limit // Cette méthode peut être étendue pour des logs ou métriques supplémentaires Ok(()) } /// Vérifie si un message est autorisé (protection contre le flood) pub async fn check_message_allowed(&self, user_id: Uuid) -> Result { let mut user_states = self.user_states.write().await; let now = SystemTime::now(); if let Some(user_state) = user_states.get_mut(&user_id) { // Vérifier si l'utilisateur est bloqué if let Some(block_until) = user_state.block_until { if now < block_until { return Ok(false); } else { user_state.is_blocked = false; user_state.block_until = None; } } // Nettoyer les timestamps expirés user_state.message_timestamps.retain(|×tamp| { now.duration_since(timestamp) .map(|d| d < self.config.flood_detection_window) .unwrap_or(false) }); // Vérifier la limite de messages if user_state.message_timestamps.len() >= self.config.max_messages_per_window as usize { // Bloquer l'utilisateur temporairement user_state.is_blocked = true; user_state.block_until = Some(now + Duration::from_secs(60)); // 1 minute return Ok(false); } user_state.message_timestamps.push(now); user_state.last_activity = now; } Ok(true) } /// Valide la taille d'un message pub fn validate_message_size(&self, message_size: usize) -> Result { if message_size > self.config.max_message_size { Err(ChatError::rate_limit_error(&format!( "Message too large: {} bytes (max: {} bytes)", message_size, self.config.max_message_size ))) } else { Ok(true) } } /// Déconnecte un utilisateur pub async fn disconnect_user(&self, user_id: Uuid, connection_id: Uuid) -> Result<()> { let mut user_states = self.user_states.write().await; if let Some(user_state) = user_states.get_mut(&user_id) { user_state.active_connections.retain(|&id| id != connection_id); // Supprimer l'état utilisateur s'il n'y a plus de connexions if user_state.active_connections.is_empty() { user_states.remove(&user_id); } } Ok(()) } /// Nettoie les états expirés pub async fn cleanup_expired_states(&self) -> Result<()> { let now = SystemTime::now(); // Nettoyer les états utilisateur { let mut user_states = self.user_states.write().await; user_states.retain(|_, user_state| { // Garder les utilisateurs avec des connexions actives ou récentes !user_state.active_connections.is_empty() || now.duration_since(user_state.last_activity) .map(|d| d < Duration::from_secs(3600)) // 1 heure .unwrap_or(false) }); } // Nettoyer les états IP { let mut ip_states = self.ip_states.write().await; ip_states.retain(|_, ip_state| { // Garder les IPs avec des connexions récentes !ip_state.connection_timestamps.is_empty() || ip_state.is_blocked }); } Ok(()) } /// Obtient les statistiques de protection DoS pub async fn get_dos_stats(&self) -> Result { let user_states = self.user_states.read().await; let ip_states = self.ip_states.read().await; let total_users = user_states.len(); let blocked_users = user_states.values().filter(|s| s.is_blocked).count(); let total_connections: usize = user_states.values().map(|s| s.active_connections.len()).sum(); let blocked_ips = ip_states.values().filter(|s| s.is_blocked).count(); Ok(DosStats { total_users, blocked_users, total_connections, blocked_ips, max_connections_per_user: self.config.max_connections_per_user, max_message_size: self.config.max_message_size, }) } /// Force le déblocage d'un utilisateur (pour les administrateurs) pub async fn unblock_user(&self, user_id: Uuid) -> Result<()> { let mut user_states = self.user_states.write().await; if let Some(user_state) = user_states.get_mut(&user_id) { user_state.is_blocked = false; user_state.block_until = None; } Ok(()) } /// Force le déblocage d'une IP (pour les administrateurs) pub async fn unblock_ip(&self, ip_address: &str) -> Result<()> { let mut ip_states = self.ip_states.write().await; if let Some(ip_state) = ip_states.get_mut(ip_address) { ip_state.is_blocked = false; ip_state.block_until = None; } Ok(()) } } /// Statistiques de protection DoS #[derive(Debug, Serialize)] pub struct DosStats { pub total_users: usize, pub blocked_users: usize, pub total_connections: usize, pub blocked_ips: usize, pub max_connections_per_user: usize, pub max_message_size: usize, } impl Default for DosProtectionManager { fn default() -> Self { Self::new(DosProtectionConfig::default()) } } #[cfg(test)] mod tests { use super::*; use std::time::Duration; #[tokio::test] async fn test_connection_limit() { let config = DosProtectionConfig { max_connections_per_user: 2, max_message_size: 1024, inactivity_timeout: Duration::from_secs(300), flood_detection_window: Duration::from_secs(10), max_messages_per_window: 10, connection_rate_limit: 5, connection_rate_window: Duration::from_secs(60), }; let manager = DosProtectionManager::new(config); let user_id = Uuid::new_v4(); let ip_address = "192.168.1.1"; // Première connexion - devrait être autorisée let conn1 = Uuid::new_v4(); assert!(manager.check_connection_allowed(user_id, conn1, ip_address).await.unwrap()); // Deuxième connexion - devrait être autorisée let conn2 = Uuid::new_v4(); assert!(manager.check_connection_allowed(user_id, conn2, ip_address).await.unwrap()); // Troisième connexion - devrait être refusée let conn3 = Uuid::new_v4(); assert!(!manager.check_connection_allowed(user_id, conn3, ip_address).await.unwrap()); } #[tokio::test] async fn test_message_flood_protection() { let manager = DosProtectionManager::default(); let user_id = Uuid::new_v4(); // Envoyer des messages rapidement for i in 0..12 { let allowed = manager.check_message_allowed(user_id).await.unwrap(); if i < 10 { assert!(allowed, "Message {} should be allowed", i); } else { assert!(!allowed, "Message {} should be blocked", i); } } } #[test] fn test_message_size_validation() { let manager = DosProtectionManager::default(); // Message de taille normale assert!(manager.validate_message_size(512).is_ok()); // Message trop grand assert!(manager.validate_message_size(128 * 1024).is_err()); } }