veza/veza-stream-server/src/compression/adaptive.rs
2025-12-03 20:36:56 +01:00

394 lines
14 KiB
Rust

use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
// Note: Use tracing::info! macro directly instead of importing
use uuid::Uuid;
/// Qualités audio supportées avec leurs paramètres
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum AudioQuality {
Low, // 128kbps
Medium, // 256kbps
High, // 320kbps
}
impl AudioQuality {
pub fn bitrate(&self) -> u32 {
match self {
AudioQuality::Low => 128000,
AudioQuality::Medium => 256000,
AudioQuality::High => 320000,
}
}
pub fn compression_level(&self) -> u8 {
match self {
AudioQuality::Low => 6, // Compression plus agressive
AudioQuality::Medium => 4, // Compression modérée
AudioQuality::High => 2, // Compression légère
}
}
pub fn sample_rate(&self) -> u32 {
match self {
AudioQuality::Low => 44100,
AudioQuality::Medium => 44100,
AudioQuality::High => 48000,
}
}
}
/// Métriques de bande passante pour un client
#[derive(Debug, Clone)]
pub struct BandwidthMetrics {
pub client_id: Uuid,
pub measured_bandwidth: f64, // bps
pub packet_loss_rate: f64, // 0.0 à 1.0
pub latency_ms: u32,
pub last_measurement: std::time::Instant,
pub measurement_count: u32,
}
/// Détection automatique de la bande passante
#[derive(Debug)]
pub struct BandwidthDetector {
client_metrics: Arc<RwLock<HashMap<Uuid, BandwidthMetrics>>>,
measurement_window: std::time::Duration,
min_measurements: u32,
}
impl BandwidthDetector {
pub fn new() -> Self {
Self {
client_metrics: Arc::new(RwLock::new(HashMap::new())),
measurement_window: std::time::Duration::from_secs(30),
min_measurements: 5,
}
}
/// Enregistre une mesure de bande passante pour un client
pub async fn record_measurement(
&self,
client_id: Uuid,
bytes_transferred: usize,
duration: std::time::Duration,
packet_loss: f64,
latency_ms: u32,
) {
let bandwidth_bps = (bytes_transferred as f64 * 8.0) / duration.as_secs_f64();
let mut metrics_map = self.client_metrics.write().await;
let metrics = metrics_map.entry(client_id).or_insert_with(|| BandwidthMetrics {
client_id,
measured_bandwidth: bandwidth_bps,
packet_loss_rate: packet_loss,
latency_ms,
last_measurement: std::time::Instant::now(),
measurement_count: 0,
});
// Mise à jour des métriques avec moyenne mobile
let alpha = 0.3; // Facteur de lissage
metrics.measured_bandwidth = alpha * bandwidth_bps + (1.0 - alpha) * metrics.measured_bandwidth;
metrics.packet_loss_rate = alpha * packet_loss + (1.0 - alpha) * metrics.packet_loss_rate;
metrics.latency_ms = ((alpha * latency_ms as f64) + ((1.0 - alpha) * metrics.latency_ms as f64)) as u32;
metrics.last_measurement = std::time::Instant::now();
metrics.measurement_count += 1;
tracing::debug!(
"Bandwidth measurement for client {}: {:.2} bps, loss: {:.2}%, latency: {}ms",
client_id, bandwidth_bps, packet_loss * 100.0, latency_ms
);
}
/// Détermine la qualité audio recommandée pour un client
pub async fn get_recommended_quality(&self, client_id: Uuid) -> AudioQuality {
let metrics_map = self.client_metrics.read().await;
if let Some(metrics) = metrics_map.get(&client_id) {
// Vérifier si on a assez de mesures
if metrics.measurement_count < self.min_measurements {
return AudioQuality::Medium; // Qualité par défaut
}
// Vérifier si les mesures sont récentes
if metrics.last_measurement.elapsed() > self.measurement_window {
return AudioQuality::Medium; // Qualité par défaut si mesures obsolètes
}
// Déterminer la qualité basée sur la bande passante et la perte de paquets
let bandwidth_mbps = metrics.measured_bandwidth / 1_000_000.0;
let packet_loss_percent = metrics.packet_loss_rate * 100.0;
if bandwidth_mbps >= 2.0 && packet_loss_percent < 1.0 {
AudioQuality::High
} else if bandwidth_mbps >= 1.0 && packet_loss_percent < 3.0 {
AudioQuality::Medium
} else {
AudioQuality::Low
}
} else {
AudioQuality::Medium // Qualité par défaut pour nouveaux clients
}
}
/// Nettoie les métriques obsolètes
pub async fn cleanup_old_metrics(&self) {
let mut metrics_map = self.client_metrics.write().await;
let cutoff_time = std::time::Instant::now() - self.measurement_window * 2;
metrics_map.retain(|_, metrics| metrics.last_measurement > cutoff_time);
tracing::debug!("Cleaned up old bandwidth metrics, {} clients remaining", metrics_map.len());
}
/// Obtient les métriques d'un client
pub async fn get_client_metrics(&self, client_id: Uuid) -> Option<BandwidthMetrics> {
let metrics_map = self.client_metrics.read().await;
metrics_map.get(&client_id).cloned()
}
}
/// Gestionnaire de compression adaptative
#[derive(Debug)]
pub struct AdaptiveCompression {
bandwidth_detector: BandwidthDetector,
quality_overrides: Arc<RwLock<HashMap<Uuid, AudioQuality>>>, // Overrides manuels
global_quality_limit: Arc<RwLock<Option<AudioQuality>>>, // Limite globale
}
impl AdaptiveCompression {
pub fn new() -> Self {
Self {
bandwidth_detector: BandwidthDetector::new(),
quality_overrides: Arc::new(RwLock::new(HashMap::new())),
global_quality_limit: Arc::new(RwLock::new(None)),
}
}
/// Détermine la qualité audio pour un client
pub async fn get_audio_quality(&self, client_id: Uuid) -> AudioQuality {
// Vérifier les overrides manuels
{
let overrides = self.quality_overrides.read().await;
if let Some(quality) = overrides.get(&client_id) {
return quality.clone();
}
}
// Vérifier la limite globale
{
let global_limit = self.global_quality_limit.read().await;
if let Some(limit) = global_limit.as_ref() {
let recommended = self.bandwidth_detector.get_recommended_quality(client_id).await;
return self.apply_quality_limit(recommended, limit.clone());
}
}
// Utiliser la détection automatique
self.bandwidth_detector.get_recommended_quality(client_id).await
}
/// Applique une limite de qualité
fn apply_quality_limit(&self, recommended: AudioQuality, limit: AudioQuality) -> AudioQuality {
match (&recommended, &limit) {
(AudioQuality::High, AudioQuality::High) => AudioQuality::High,
(AudioQuality::High, AudioQuality::Medium) => AudioQuality::Medium,
(AudioQuality::High, AudioQuality::Low) => AudioQuality::Low,
(AudioQuality::Medium, AudioQuality::High) => AudioQuality::Medium,
(AudioQuality::Medium, AudioQuality::Medium) => AudioQuality::Medium,
(AudioQuality::Medium, AudioQuality::Low) => AudioQuality::Low,
(AudioQuality::Low, _) => AudioQuality::Low,
}
}
/// Enregistre une mesure de transfert
pub async fn record_transfer(
&self,
client_id: Uuid,
bytes_transferred: usize,
duration: std::time::Duration,
packet_loss: f64,
latency_ms: u32,
) {
self.bandwidth_detector
.record_measurement(client_id, bytes_transferred, duration, packet_loss, latency_ms)
.await;
}
/// Définit un override de qualité pour un client
pub async fn set_client_quality_override(&self, client_id: Uuid, quality: AudioQuality) {
let mut overrides = self.quality_overrides.write().await;
overrides.insert(client_id, quality);
tracing::info!("Set quality override for client {}: {:?}", client_id, quality);
}
/// Supprime l'override de qualité pour un client
pub async fn remove_client_quality_override(&self, client_id: Uuid) {
let mut overrides = self.quality_overrides.write().await;
overrides.remove(&client_id);
tracing::info!("Removed quality override for client {}", client_id);
}
/// Définit une limite de qualité globale
pub async fn set_global_quality_limit(&self, quality: Option<AudioQuality>) {
let mut global_limit = self.global_quality_limit.write().await;
*global_limit = quality;
match &*global_limit {
Some(q) => tracing::info!("Set global quality limit: {:?}", q),
None => tracing::info!("Removed global quality limit"),
}
}
/// Nettoie les données obsolètes
pub async fn cleanup(&self) {
self.bandwidth_detector.cleanup_old_metrics().await;
// Nettoyer les overrides pour les clients inactifs
let mut overrides = self.quality_overrides.write().await;
let cutoff_time = std::time::Instant::now() - std::time::Duration::from_secs(3600); // 1 heure
overrides.retain(|client_id, _| {
// Ici on pourrait vérifier si le client est encore actif
// Pour simplifier, on garde tous les overrides
true
});
}
/// Obtient les statistiques de compression
pub async fn get_compression_stats(&self) -> CompressionStats {
let metrics_map = self.bandwidth_detector.client_metrics.read().await;
let overrides = self.quality_overrides.read().await;
let global_limit = self.global_quality_limit.read().await;
let mut quality_distribution = HashMap::new();
let mut total_clients = 0;
for (client_id, metrics) in metrics_map.iter() {
let quality = if overrides.contains_key(client_id) {
overrides[client_id].clone()
} else {
self.bandwidth_detector.get_recommended_quality(*client_id).await
};
let final_quality = if let Some(limit) = global_limit.as_ref() {
self.apply_quality_limit(quality, limit.clone())
} else {
quality
};
*quality_distribution.entry(final_quality).or_insert(0) += 1;
total_clients += 1;
}
CompressionStats {
total_clients,
quality_distribution,
has_global_limit: global_limit.is_some(),
override_count: overrides.len(),
}
}
}
/// Statistiques de compression
#[derive(Debug, Clone)]
pub struct CompressionStats {
pub total_clients: usize,
pub quality_distribution: HashMap<AudioQuality, usize>,
pub has_global_limit: bool,
pub override_count: usize,
}
/// Utilitaire pour la compression Brotli des assets
pub struct BrotliCompressor {
compression_level: u32,
}
impl BrotliCompressor {
pub fn new(level: u32) -> Self {
Self {
compression_level: level.clamp(0, 11),
}
}
/// Compresse des données avec Brotli
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
use brotli::enc::BrotliEncoderParams;
use std::io::Write;
let mut params = BrotliEncoderParams::default();
params.quality = self.compression_level as i32;
let mut compressed = Vec::new();
{
let mut writer = brotli::CompressorWriter::new(&mut compressed, 4096, &params);
writer.write_all(data)?;
writer.flush()?;
}
Ok(compressed)
}
/// Décompresse des données Brotli
pub fn decompress(&self, compressed_data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
use std::io::Read;
let mut decompressed = Vec::new();
{
let mut reader = brotli::Decompressor::new(compressed_data, 4096);
reader.read_to_end(&mut decompressed)?;
}
Ok(decompressed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_bandwidth_detection() {
let detector = BandwidthDetector::new();
let client_id = Uuid::new_v4();
// Simuler des mesures de bande passante
detector.record_measurement(client_id, 1024 * 1024, Duration::from_secs(1), 0.01, 50).await;
detector.record_measurement(client_id, 2 * 1024 * 1024, Duration::from_secs(1), 0.005, 45).await;
let quality = detector.get_recommended_quality(client_id).await;
assert!(matches!(quality, AudioQuality::High | AudioQuality::Medium));
}
#[tokio::test]
async fn test_adaptive_compression() {
let compression = AdaptiveCompression::new();
let client_id = Uuid::new_v4();
// Test avec override
compression.set_client_quality_override(client_id, AudioQuality::Low).await;
let quality = compression.get_audio_quality(client_id).await;
assert_eq!(quality, AudioQuality::Low);
// Test sans override
compression.remove_client_quality_override(client_id).await;
let quality = compression.get_audio_quality(client_id).await;
assert_eq!(quality, AudioQuality::Medium); // Qualité par défaut
}
#[test]
fn test_brotli_compression() {
let compressor = BrotliCompressor::new(6);
let data = b"Hello, world! This is a test string for compression.";
let compressed = compressor.compress(data).unwrap();
assert!(compressed.len() < data.len());
let decompressed = compressor.decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
}