veza/veza-chat-server/src/jwt_manager.rs

641 lines
21 KiB
Rust

//! Gestionnaire JWT avancé avec refresh tokens et rotation
//!
//! Ce module fournit une gestion complète des tokens JWT avec:
//! - Access tokens (courte durée)
//! - Refresh tokens (longue durée)
//! - Rotation automatique des tokens
//! - Blacklist des tokens révoqués
//! - Validation robuste avec métriques
use crate::config::SecurityConfig;
use crate::error::{ChatError, Result};
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
/// Claims pour les access tokens
/// MIGRATION UUID: user_id est maintenant String (UUID serialisé)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessTokenClaims {
/// ID de l'utilisateur (UUID en string)
#[serde(rename = "sub")]
pub user_id: String,
/// Nom d'utilisateur
pub username: String,
/// Rôle de l'utilisateur
pub role: String,
/// Type de token
pub token_type: String,
/// Audience
#[serde(deserialize_with = "deserialize_audience")]
pub aud: Vec<String>,
/// Issuer
pub iss: String,
/// Expiration
pub exp: usize,
/// Émis à
pub iat: usize,
/// JTI (JWT ID) pour la révocation
pub jti: String,
}
/// Claims pour les refresh tokens
/// MIGRATION UUID: user_id est maintenant String (UUID serialisé)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefreshTokenClaims {
/// ID de l'utilisateur (UUID en string)
#[serde(rename = "sub")]
pub user_id: String,
/// Type de token
pub token_type: String,
/// Audience
#[serde(deserialize_with = "deserialize_audience")]
pub aud: Vec<String>,
/// Issuer
pub iss: String,
/// Expiration
pub exp: usize,
/// Émis à
pub iat: usize,
/// JTI (JWT ID) pour la révocation
pub jti: String,
/// Version de la famille de tokens
pub token_family: String,
}
fn deserialize_audience<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
struct AudienceVisitor;
impl<'de> serde::de::Visitor<'de> for AudienceVisitor {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string or an array of strings")
}
fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(vec![v.to_owned()])
}
fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut res = Vec::new();
while let Some(el) = seq.next_element()? {
res.push(el);
}
Ok(res)
}
}
deserializer.deserialize_any(AudienceVisitor)
}
/// Paire de tokens (access + refresh)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenPair {
pub access_token: String,
pub refresh_token: String,
pub expires_in: u64,
pub token_type: String,
}
/// Informations sur un token révoqué
/// MIGRATION UUID: user_id est maintenant String
#[derive(Debug, Clone)]
pub struct RevokedToken {
pub jti: String,
pub user_id: String, // UUID as string
pub revoked_at: DateTime<Utc>,
pub reason: RevocationReason,
}
/// Raison de révocation d'un token
#[derive(Debug, Clone)]
pub enum RevocationReason {
UserLogout,
TokenRefresh,
SecurityViolation,
AdminRevocation,
Expired,
}
/// Gestionnaire JWT avec rotation et blacklist
pub struct JwtManager {
config: SecurityConfig,
encoding_key: EncodingKey,
decoding_key: DecodingKey,
validation: Validation,
/// Blacklist des tokens révoqués
revoked_tokens: Arc<RwLock<HashSet<String>>>,
/// Cache des familles de tokens actives
active_token_families: Arc<RwLock<HashSet<String>>>,
/// Pool de base de données optionnel pour récupérer les infos utilisateur
db_pool: Option<PgPool>,
}
impl JwtManager {
/// Crée un nouveau gestionnaire JWT
pub fn new(config: SecurityConfig) -> Result<Self> {
let algorithm = match config.jwt_algorithm.as_str() {
"HS256" => Algorithm::HS256,
"HS384" => Algorithm::HS384,
"HS512" => Algorithm::HS512,
"RS256" => Algorithm::RS256,
"RS384" => Algorithm::RS384,
"RS512" => Algorithm::RS512,
_ => return Err(ChatError::configuration_error("Algorithme JWT invalide")),
};
let encoding_key = EncodingKey::from_secret(config.jwt_secret.as_bytes());
let decoding_key = DecodingKey::from_secret(config.jwt_secret.as_bytes());
let mut validation = Validation::new(algorithm);
validation.set_audience(&[&config.jwt_audience]);
validation.set_issuer(&[&config.jwt_issuer]);
validation.set_required_spec_claims(&["exp", "iat", "sub", "aud", "iss", "jti"]);
Ok(Self {
config,
encoding_key,
decoding_key,
validation,
revoked_tokens: Arc::new(RwLock::new(HashSet::new())),
active_token_families: Arc::new(RwLock::new(HashSet::new())),
db_pool: None,
})
}
/// Crée un nouveau gestionnaire JWT avec un pool de base de données
pub fn with_pool(config: SecurityConfig, pool: PgPool) -> Result<Self> {
let mut manager = Self::new(config)?;
manager.db_pool = Some(pool);
Ok(manager)
}
/// Génère une paire de tokens (access + refresh)
/// MIGRATION UUID: user_id est maintenant String (UUID)
pub async fn generate_token_pair(
&self,
user_id: String, // UUID as string
username: String,
role: String,
) -> Result<TokenPair> {
let now = Utc::now();
let access_exp = now + Duration::seconds(self.config.jwt_access_duration.as_secs() as i64);
let refresh_exp =
now + Duration::seconds(self.config.jwt_refresh_duration.as_secs() as i64);
// Générer des JTI uniques
let access_jti = Uuid::new_v4().to_string();
let refresh_jti = Uuid::new_v4().to_string();
let token_family = Uuid::new_v4().to_string();
// Claims pour access token
let access_claims = AccessTokenClaims {
user_id: user_id.clone(),
username: username.clone(),
role: role.clone(),
token_type: "access".to_string(),
aud: vec![self.config.jwt_audience.clone()],
iss: self.config.jwt_issuer.clone(),
exp: access_exp.timestamp() as usize,
iat: now.timestamp() as usize,
jti: access_jti.clone(),
};
// Claims pour refresh token
let refresh_claims = RefreshTokenClaims {
user_id: user_id.clone(),
token_type: "refresh".to_string(),
aud: vec![self.config.jwt_audience.clone()],
iss: self.config.jwt_issuer.clone(),
exp: refresh_exp.timestamp() as usize,
iat: now.timestamp() as usize,
jti: refresh_jti.clone(),
token_family: token_family.clone(),
};
// Encoder les tokens
let access_token =
encode(&Header::default(), &access_claims, &self.encoding_key).map_err(|e| {
ChatError::validation_error(&format!("Erreur encodage access token: {e}"))
})?;
let refresh_token = encode(&Header::default(), &refresh_claims, &self.encoding_key)
.map_err(|e| {
ChatError::validation_error(&format!("Erreur encodage refresh token: {e}"))
})?;
// Enregistrer la famille de tokens comme active
{
let mut families = self.active_token_families.write().await;
families.insert(token_family);
}
tracing::info!(
user_id = %user_id,
username = %username,
role = %role,
access_jti = %access_jti,
refresh_jti = %refresh_jti,
"🔐 Paire de tokens générée"
);
Ok(TokenPair {
access_token,
refresh_token,
expires_in: self.config.jwt_access_duration.as_secs(),
token_type: "Bearer".to_string(),
})
}
/// Valide un access token
pub async fn validate_access_token(&self, token: &str) -> Result<AccessTokenClaims> {
// Vérifier si le token est dans la blacklist
{
let revoked = self.revoked_tokens.read().await;
if revoked.contains(token) {
return Err(ChatError::unauthorized("Token révoqué"));
}
}
// Décoder et valider le token
let token_data = decode::<AccessTokenClaims>(token, &self.decoding_key, &self.validation)
.map_err(|e| {
tracing::warn!(error = %e, "❌ Échec validation access token");
ChatError::unauthorized("Token invalide")
})?;
let claims = token_data.claims;
// Vérifier le type de token
if claims.token_type != "access" {
return Err(ChatError::unauthorized("Type de token invalide"));
}
// Vérifier l'expiration
let now = Utc::now().timestamp() as usize;
if claims.exp < now {
return Err(ChatError::unauthorized("Token expiré"));
}
tracing::debug!(
user_id = %claims.user_id,
username = %claims.username,
jti = %claims.jti,
"✅ Access token validé"
);
Ok(claims)
}
/// Valide un refresh token et génère une nouvelle paire
pub async fn refresh_tokens(&self, refresh_token: &str) -> Result<TokenPair> {
// Vérifier si le token est dans la blacklist
{
let revoked = self.revoked_tokens.read().await;
if revoked.contains(refresh_token) {
return Err(ChatError::unauthorized("Refresh token révoqué"));
}
}
// Décoder et valider le refresh token
let token_data =
decode::<RefreshTokenClaims>(refresh_token, &self.decoding_key, &self.validation)
.map_err(|e| {
tracing::warn!(error = %e, "❌ Échec validation refresh token");
ChatError::unauthorized("Refresh token invalide")
})?;
let claims = token_data.claims;
// Vérifier le type de token
if claims.token_type != "refresh" {
return Err(ChatError::unauthorized("Type de token invalide"));
}
// Vérifier l'expiration
let now = Utc::now().timestamp() as usize;
if claims.exp < now {
return Err(ChatError::unauthorized("Refresh token expiré"));
}
// Vérifier que la famille de tokens est toujours active
{
let families = self.active_token_families.read().await;
if !families.contains(&claims.token_family) {
return Err(ChatError::unauthorized("Famille de tokens révoquée"));
}
}
// Révocation de l'ancien refresh token
self.revoke_token(refresh_token, RevocationReason::TokenRefresh)
.await?;
// Récupérer les informations utilisateur depuis la DB
let (username, role) = if let Some(ref pool) = self.db_pool {
// Parser user_id depuis String vers Uuid
let user_uuid = Uuid::parse_str(&claims.user_id).map_err(|e| {
ChatError::validation_error(&format!("Invalid user UUID in token: {}", e))
})?;
// Récupérer username et role depuis la DB
let user_info: Option<(String, Option<String>)> = sqlx::query_as(
r#"
SELECT username, role FROM users
WHERE id = $1
"#,
)
.bind(user_uuid)
.fetch_optional(pool)
.await
.map_err(|e| ChatError::from_sqlx_error("get_user_info_for_refresh", e))?
.map(|row: (String, Option<String>)| row);
match user_info {
Some((username, role_opt)) => {
let role = role_opt.unwrap_or_else(|| "user".to_string());
(username, role)
}
None => {
tracing::warn!(
user_id = %claims.user_id,
"Utilisateur non trouvé dans la DB lors du refresh token, utilisation de valeurs par défaut"
);
// Fallback si utilisateur non trouvé (ne devrait pas arriver en production)
("user".to_string(), "user".to_string())
}
}
} else {
// Fallback si pas de pool DB (mode dégradé)
tracing::warn!(
user_id = %claims.user_id,
"Pas de pool DB disponible, utilisation de valeurs par défaut pour refresh token"
);
("user".to_string(), "user".to_string())
};
// MIGRATION UUID: Cloner user_id avant de le move
let user_id_clone = claims.user_id.clone();
// Générer une nouvelle paire de tokens
let new_tokens = self
.generate_token_pair(claims.user_id, username, role)
.await?;
tracing::info!(
user_id = %user_id_clone,
old_jti = %claims.jti,
"🔄 Tokens rafraîchis"
);
Ok(new_tokens)
}
/// Révoque un token
pub async fn revoke_token(&self, token: &str, reason: RevocationReason) -> Result<()> {
// Extraire le JTI du token (sans validation complète pour la révocation)
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(ChatError::validation_error("Format de token invalide"));
}
// Décoder le payload pour obtenir le JTI
let payload = parts[1];
let decoded =
base64::Engine::decode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, payload)
.map_err(|e| {
ChatError::validation_error(&format!("Erreur décodage payload: {e}"))
})?;
let claims: serde_json::Value = serde_json::from_slice(&decoded)
.map_err(|e| ChatError::validation_error(&format!("Erreur parsing claims: {e}")))?;
let jti = claims["jti"]
.as_str()
.ok_or_else(|| ChatError::validation_error("JTI manquant"))?;
let user_id = claims["sub"].as_str().unwrap_or("unknown").to_string();
// Ajouter à la blacklist
{
let mut revoked = self.revoked_tokens.write().await;
revoked.insert(token.to_string());
}
// Si c'est un refresh token, révoquer toute la famille
if let Some(token_type) = claims["token_type"].as_str() {
if token_type == "refresh" {
if let Some(family) = claims["token_family"].as_str() {
let mut families = self.active_token_families.write().await;
families.remove(family);
}
}
}
tracing::info!(
jti = %jti,
user_id = %user_id,
reason = ?reason,
"🚫 Token révoqué"
);
Ok(())
}
/// Révoque tous les tokens d'un utilisateur
/// MIGRATION UUID: user_id est String
pub async fn revoke_user_tokens(&self, user_id: String) -> Result<()> {
// En production, on devrait maintenir une liste des familles de tokens par utilisateur
// Pour l'instant, on nettoie toutes les familles actives
let mut families = self.active_token_families.write().await;
families.clear();
tracing::info!(user_id = %user_id, "🚫 Tous les tokens de l'utilisateur révoqués");
Ok(())
}
/// Nettoie les tokens expirés de la blacklist
pub async fn cleanup_expired_tokens(&self) -> Result<usize> {
// En production, on devrait vérifier l'expiration de chaque token
// Pour l'instant, on limite la taille de la blacklist
let mut revoked = self.revoked_tokens.write().await;
let initial_size = revoked.len();
if revoked.len() > 10000 {
// Garder seulement les 5000 plus récents (simulation)
let tokens: Vec<String> = revoked.iter().take(5000).cloned().collect();
revoked.clear();
revoked.extend(tokens);
}
let cleaned = initial_size - revoked.len();
if cleaned > 0 {
tracing::debug!(cleaned = %cleaned, "🧹 Tokens expirés nettoyés de la blacklist");
}
Ok(cleaned)
}
/// Vérifie si un token est révoqué
pub async fn is_token_revoked(&self, token: &str) -> bool {
let revoked = self.revoked_tokens.read().await;
revoked.contains(token)
}
/// Obtient les statistiques des tokens
pub async fn get_token_stats(&self) -> TokenStats {
let revoked_count = self.revoked_tokens.read().await.len();
let active_families = self.active_token_families.read().await.len();
TokenStats {
revoked_tokens: revoked_count,
active_token_families: active_families,
}
}
}
/// Statistiques des tokens
#[derive(Debug, Clone, Serialize)]
pub struct TokenStats {
pub revoked_tokens: usize,
pub active_token_families: usize,
}
/// Fonction utilitaire pour extraire le token du header Authorization
pub fn extract_token_from_header(auth_header: &str) -> Result<&str> {
if !auth_header.starts_with("Bearer ") {
return Err(ChatError::unauthorized("Format d'autorisation invalide"));
}
Ok(&auth_header[7..]) // Retirer "Bearer "
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn create_test_config() -> SecurityConfig {
SecurityConfig {
jwt_secret: "test_secret_key_32_chars_minimum_required".to_string(),
jwt_access_duration: Duration::from_secs(3600), // 1 heure
jwt_refresh_duration: Duration::from_secs(86400), // 24 heures
jwt_algorithm: "HS256".to_string(),
jwt_audience: "test".to_string(),
jwt_issuer: "test".to_string(),
enable_2fa: false,
totp_window: 1,
content_filtering: false,
password_min_length: 8,
bcrypt_cost: 12,
}
}
#[tokio::test]
async fn test_generate_and_validate_tokens() {
let config = create_test_config();
let manager = JwtManager::new(config).unwrap();
// UUID de test
let test_user_id = Uuid::new_v4().to_string();
// Générer une paire de tokens
let tokens = manager
.generate_token_pair(
test_user_id.clone(),
"testuser".to_string(),
"user".to_string(),
)
.await
.unwrap();
// Valider l'access token
let claims = manager
.validate_access_token(&tokens.access_token)
.await
.unwrap();
assert_eq!(claims.user_id, test_user_id);
assert_eq!(claims.username, "testuser");
assert_eq!(claims.role, "user");
assert_eq!(claims.token_type, "access");
}
#[tokio::test]
async fn test_token_revocation() {
let config = create_test_config();
let manager = JwtManager::new(config).unwrap();
let test_user_id = Uuid::new_v4().to_string();
// Générer des tokens
let tokens = manager
.generate_token_pair(test_user_id, "testuser".to_string(), "user".to_string())
.await
.unwrap();
// Valider avant révocation
assert!(manager
.validate_access_token(&tokens.access_token)
.await
.is_ok());
// Révoquer le token
manager
.revoke_token(&tokens.access_token, RevocationReason::UserLogout)
.await
.unwrap();
// Vérifier que le token est révoqué
assert!(manager
.validate_access_token(&tokens.access_token)
.await
.is_err());
}
#[tokio::test]
async fn test_token_refresh() {
let config = create_test_config();
let manager = JwtManager::new(config).unwrap();
let test_user_id = Uuid::new_v4().to_string();
// Générer des tokens
let tokens = manager
.generate_token_pair(
test_user_id.clone(),
"testuser".to_string(),
"user".to_string(),
)
.await
.unwrap();
// Rafraîchir les tokens
let new_tokens = manager.refresh_tokens(&tokens.refresh_token).await.unwrap();
// Vérifier que les nouveaux tokens fonctionnent
let claims = manager
.validate_access_token(&new_tokens.access_token)
.await
.unwrap();
assert_eq!(claims.user_id, test_user_id);
// Vérifier que l'ancien refresh token est révoqué
assert!(manager.refresh_tokens(&tokens.refresh_token).await.is_err());
}
}