322 lines
10 KiB
Rust
322 lines
10 KiB
Rust
|
|
//! Module d'authentification WebSocket pour le serveur de chat
|
||
|
|
//!
|
||
|
|
//! Ce module implémente l'authentification JWT pour les connexions WebSocket,
|
||
|
|
//! la validation des permissions par conversation, et le rate limiting.
|
||
|
|
|
||
|
|
use crate::error::{ChatError, Result};
|
||
|
|
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use std::collections::HashMap;
|
||
|
|
use std::sync::Arc;
|
||
|
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||
|
|
use tokio::sync::RwLock;
|
||
|
|
use uuid::Uuid;
|
||
|
|
|
||
|
|
/// Claims JWT pour l'authentification
|
||
|
|
#[derive(Debug, Serialize, Deserialize)]
|
||
|
|
pub struct JwtClaims {
|
||
|
|
pub user_id: Uuid,
|
||
|
|
pub username: String,
|
||
|
|
pub exp: u64,
|
||
|
|
pub iat: u64,
|
||
|
|
pub permissions: Vec<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Gestionnaire d'authentification WebSocket
|
||
|
|
pub struct WebSocketAuthManager {
|
||
|
|
jwt_secret: String,
|
||
|
|
active_sessions: Arc<RwLock<HashMap<Uuid, UserSession>>>,
|
||
|
|
rate_limits: Arc<RwLock<HashMap<Uuid, RateLimitState>>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Session utilisateur active
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct UserSession {
|
||
|
|
pub user_id: Uuid,
|
||
|
|
pub username: String,
|
||
|
|
pub connected_at: SystemTime,
|
||
|
|
pub last_activity: SystemTime,
|
||
|
|
pub permissions: Vec<String>,
|
||
|
|
pub conversation_access: Vec<Uuid>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// État du rate limiting par utilisateur
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct RateLimitState {
|
||
|
|
pub message_count: u32,
|
||
|
|
pub window_start: SystemTime,
|
||
|
|
pub last_message_time: SystemTime,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl WebSocketAuthManager {
|
||
|
|
/// Crée un nouveau gestionnaire d'authentification
|
||
|
|
pub fn new(jwt_secret: String) -> Self {
|
||
|
|
Self {
|
||
|
|
jwt_secret,
|
||
|
|
active_sessions: Arc::new(RwLock::new(HashMap::new())),
|
||
|
|
rate_limits: Arc::new(RwLock::new(HashMap::new())),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Valide un token JWT et retourne les claims
|
||
|
|
pub fn validate_jwt_token(&self, token: &str) -> Result<JwtClaims> {
|
||
|
|
let decoding_key = DecodingKey::from_secret(self.jwt_secret.as_ref());
|
||
|
|
let validation = Validation::default();
|
||
|
|
|
||
|
|
match decode::<JwtClaims>(token, &decoding_key, &validation) {
|
||
|
|
Ok(token_data) => {
|
||
|
|
// Vérifier l'expiration
|
||
|
|
let now = SystemTime::now()
|
||
|
|
.duration_since(UNIX_EPOCH)
|
||
|
|
.map_err(|e| ChatError::authentication_error(&format!("Time error: {}", e)))?
|
||
|
|
.as_secs();
|
||
|
|
|
||
|
|
if token_data.claims.exp < now {
|
||
|
|
return Err(ChatError::authentication_error("Token expired"));
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(token_data.claims)
|
||
|
|
}
|
||
|
|
Err(e) => Err(ChatError::authentication_error(&format!("Invalid token: {}", e))),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Authentifie un utilisateur WebSocket
|
||
|
|
pub async fn authenticate_websocket_user(
|
||
|
|
&self,
|
||
|
|
token: &str,
|
||
|
|
connection_id: Uuid,
|
||
|
|
) -> Result<UserSession> {
|
||
|
|
let claims = self.validate_jwt_token(token)?;
|
||
|
|
|
||
|
|
// Créer la session utilisateur
|
||
|
|
let session = UserSession {
|
||
|
|
user_id: claims.user_id,
|
||
|
|
username: claims.username,
|
||
|
|
connected_at: SystemTime::now(),
|
||
|
|
last_activity: SystemTime::now(),
|
||
|
|
permissions: claims.permissions,
|
||
|
|
conversation_access: Vec::new(), // Sera rempli lors de la jointure aux conversations
|
||
|
|
};
|
||
|
|
|
||
|
|
// Enregistrer la session active
|
||
|
|
let mut sessions = self.active_sessions.write().await;
|
||
|
|
sessions.insert(connection_id, session.clone());
|
||
|
|
|
||
|
|
// Initialiser le rate limiting
|
||
|
|
let mut rate_limits = self.rate_limits.write().await;
|
||
|
|
rate_limits.insert(connection_id, RateLimitState {
|
||
|
|
message_count: 0,
|
||
|
|
window_start: SystemTime::now(),
|
||
|
|
last_message_time: SystemTime::now(),
|
||
|
|
});
|
||
|
|
|
||
|
|
Ok(session)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Vérifie les permissions pour une conversation
|
||
|
|
pub async fn check_conversation_permission(
|
||
|
|
&self,
|
||
|
|
connection_id: Uuid,
|
||
|
|
conversation_id: Uuid,
|
||
|
|
) -> Result<bool> {
|
||
|
|
let sessions = self.active_sessions.read().await;
|
||
|
|
|
||
|
|
if let Some(session) = sessions.get(&connection_id) {
|
||
|
|
// Vérifier si l'utilisateur a accès à cette conversation
|
||
|
|
// Pour l'instant, on autorise tous les utilisateurs authentifiés
|
||
|
|
// Dans une implémentation complète, on vérifierait les permissions spécifiques
|
||
|
|
Ok(session.conversation_access.contains(&conversation_id) ||
|
||
|
|
session.permissions.contains(&"chat:all".to_string()))
|
||
|
|
} else {
|
||
|
|
Err(ChatError::authentication_error("Session not found"))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Vérifie le rate limiting pour les messages
|
||
|
|
pub async fn check_message_rate_limit(&self, connection_id: Uuid) -> Result<bool> {
|
||
|
|
const MAX_MESSAGES_PER_MINUTE: u32 = 60;
|
||
|
|
const WINDOW_DURATION_SECONDS: u64 = 60;
|
||
|
|
|
||
|
|
let mut rate_limits = self.rate_limits.write().await;
|
||
|
|
|
||
|
|
if let Some(rate_limit) = rate_limits.get_mut(&connection_id) {
|
||
|
|
let now = SystemTime::now();
|
||
|
|
|
||
|
|
// Vérifier si la fenêtre de temps a expiré
|
||
|
|
if now.duration_since(rate_limit.window_start)
|
||
|
|
.map_err(|e| ChatError::rate_limit_error(&format!("Time error: {}", e)))?
|
||
|
|
.as_secs() >= WINDOW_DURATION_SECONDS {
|
||
|
|
// Réinitialiser le compteur
|
||
|
|
rate_limit.message_count = 0;
|
||
|
|
rate_limit.window_start = now;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Vérifier la limite
|
||
|
|
if rate_limit.message_count >= MAX_MESSAGES_PER_MINUTE {
|
||
|
|
return Ok(false);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Incrémenter le compteur
|
||
|
|
rate_limit.message_count += 1;
|
||
|
|
rate_limit.last_message_time = now;
|
||
|
|
|
||
|
|
Ok(true)
|
||
|
|
} else {
|
||
|
|
Err(ChatError::authentication_error("Rate limit state not found"))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Met à jour l'activité d'un utilisateur
|
||
|
|
pub async fn update_user_activity(&self, connection_id: Uuid) -> Result<()> {
|
||
|
|
let mut sessions = self.active_sessions.write().await;
|
||
|
|
|
||
|
|
if let Some(session) = sessions.get_mut(&connection_id) {
|
||
|
|
session.last_activity = SystemTime::now();
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Ajoute l'accès à une conversation pour un utilisateur
|
||
|
|
pub async fn grant_conversation_access(
|
||
|
|
&self,
|
||
|
|
connection_id: Uuid,
|
||
|
|
conversation_id: Uuid,
|
||
|
|
) -> Result<()> {
|
||
|
|
let mut sessions = self.active_sessions.write().await;
|
||
|
|
|
||
|
|
if let Some(session) = sessions.get_mut(&connection_id) {
|
||
|
|
if !session.conversation_access.contains(&conversation_id) {
|
||
|
|
session.conversation_access.push(conversation_id);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Retire l'accès à une conversation pour un utilisateur
|
||
|
|
pub async fn revoke_conversation_access(
|
||
|
|
&self,
|
||
|
|
connection_id: Uuid,
|
||
|
|
conversation_id: Uuid,
|
||
|
|
) -> Result<()> {
|
||
|
|
let mut sessions = self.active_sessions.write().await;
|
||
|
|
|
||
|
|
if let Some(session) = sessions.get_mut(&connection_id) {
|
||
|
|
session.conversation_access.retain(|&id| id != conversation_id);
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Déconnecte un utilisateur
|
||
|
|
pub async fn disconnect_user(&self, connection_id: Uuid) -> Result<()> {
|
||
|
|
let mut sessions = self.active_sessions.write().await;
|
||
|
|
sessions.remove(&connection_id);
|
||
|
|
|
||
|
|
let mut rate_limits = self.rate_limits.write().await;
|
||
|
|
rate_limits.remove(&connection_id);
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Nettoie les sessions expirées
|
||
|
|
pub async fn cleanup_expired_sessions(&self, max_idle_duration: Duration) -> Result<()> {
|
||
|
|
let now = SystemTime::now();
|
||
|
|
let mut sessions = self.active_sessions.write().await;
|
||
|
|
let mut rate_limits = self.rate_limits.write().await;
|
||
|
|
|
||
|
|
let expired_connections: Vec<Uuid> = sessions
|
||
|
|
.iter()
|
||
|
|
.filter(|(_, session)| {
|
||
|
|
now.duration_since(session.last_activity)
|
||
|
|
.map(|d| d > max_idle_duration)
|
||
|
|
.unwrap_or(true)
|
||
|
|
})
|
||
|
|
.map(|(id, _)| *id)
|
||
|
|
.collect();
|
||
|
|
|
||
|
|
for connection_id in expired_connections {
|
||
|
|
sessions.remove(&connection_id);
|
||
|
|
rate_limits.remove(&connection_id);
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Obtient les statistiques des sessions actives
|
||
|
|
pub async fn get_session_stats(&self) -> Result<SessionStats> {
|
||
|
|
let sessions = self.active_sessions.read().await;
|
||
|
|
|
||
|
|
let total_sessions = sessions.len();
|
||
|
|
let now = SystemTime::now();
|
||
|
|
|
||
|
|
let active_last_hour = sessions
|
||
|
|
.values()
|
||
|
|
.filter(|session| {
|
||
|
|
now.duration_since(session.last_activity)
|
||
|
|
.map(|d| d.as_secs() < 3600)
|
||
|
|
.unwrap_or(false)
|
||
|
|
})
|
||
|
|
.count();
|
||
|
|
|
||
|
|
Ok(SessionStats {
|
||
|
|
total_sessions,
|
||
|
|
active_last_hour,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Statistiques des sessions
|
||
|
|
#[derive(Debug, Serialize)]
|
||
|
|
pub struct SessionStats {
|
||
|
|
pub total_sessions: usize,
|
||
|
|
pub active_last_hour: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for WebSocketAuthManager {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self::new("default_secret_key".to_string())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
use std::time::Duration;
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_jwt_validation() {
|
||
|
|
let auth_manager = WebSocketAuthManager::new("test_secret".to_string());
|
||
|
|
|
||
|
|
// Test avec un token invalide
|
||
|
|
let result = auth_manager.validate_jwt_token("invalid_token");
|
||
|
|
assert!(result.is_err());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_rate_limiting() {
|
||
|
|
let auth_manager = WebSocketAuthManager::new("test_secret".to_string());
|
||
|
|
let connection_id = Uuid::new_v4();
|
||
|
|
|
||
|
|
// Simuler l'authentification
|
||
|
|
let claims = JwtClaims {
|
||
|
|
user_id: Uuid::new_v4(),
|
||
|
|
username: "test_user".to_string(),
|
||
|
|
exp: (SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs()) + 3600,
|
||
|
|
iat: SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
|
||
|
|
permissions: vec!["chat:all".to_string()],
|
||
|
|
};
|
||
|
|
|
||
|
|
// Test du rate limiting
|
||
|
|
for _ in 0..65 {
|
||
|
|
let result = auth_manager.check_message_rate_limit(connection_id).await;
|
||
|
|
if let Ok(allowed) = result {
|
||
|
|
if !allowed {
|
||
|
|
break; // Rate limit atteint
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|