use clap::Parser; use futures_util::{SinkExt, StreamExt}; use std::sync::{ atomic::{AtomicU64, Ordering}, Arc, }; use std::time::{Duration, SystemTime}; use stream_server::streaming::websocket::{WebSocketCommand, WebSocketEvent}; use tokio::net::TcpStream; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use uuid::Uuid; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// WebSocket URL to connect to #[arg(short, long, default_value = "ws://localhost:8082/ws")] url: String, /// Number of simulated clients #[arg(short, long, default_value_t = 10)] clients: usize, /// Duration of the test in seconds #[arg(short, long, default_value_t = 30)] duration: u64, } #[derive(Debug, Default)] struct LoadStats { total_messages_sent: AtomicU64, total_messages_received: AtomicU64, sync_adjustments_received: AtomicU64, errors: AtomicU64, active_clients: AtomicU64, } struct SimulatedClient { id: Uuid, url: String, stats: Arc, } impl SimulatedClient { fn new(url: String, stats: Arc) -> Self { Self { id: Uuid::new_v4(), url, stats, } } async fn run(self, duration: Duration) { let connect_url = format!("{}?user_id=load_test_{}", self.url, self.id); match connect_async(&connect_url).await { Ok((ws_stream, _)) => { self.stats.active_clients.fetch_add(1, Ordering::SeqCst); if let Err(e) = self.handle_connection(ws_stream, duration).await { // Ignore connection closed errors at shutdown if expected if !e.to_string().contains("Connection reset without closing handshake") { eprintln!("Client {} error: {}", self.id, e); self.stats.errors.fetch_add(1, Ordering::SeqCst); } } self.stats.active_clients.fetch_sub(1, Ordering::SeqCst); } Err(e) => { eprintln!("Failed to connect client {}: {}", self.id, e); self.stats.errors.fetch_add(1, Ordering::SeqCst); } } } async fn handle_connection( &self, ws_stream: WebSocketStream>, duration: Duration, ) -> Result<(), Box> { let (mut write, mut read) = ws_stream.split(); let mut interval = tokio::time::interval(Duration::from_secs(5)); let timeout = tokio::time::sleep(duration); tokio::pin!(timeout); // Subscribe command let subscribe_cmd = WebSocketCommand::Subscribe { command_id: Uuid::new_v4().to_string(), events: vec!["*".to_string()], filters: None, }; let msg = tokio_tungstenite::tungstenite::Message::Text(serde_json::to_string(&subscribe_cmd)?); write.send(msg).await?; self.stats.total_messages_sent.fetch_add(1, Ordering::Relaxed); loop { tokio::select! { _ = &mut timeout => { break; } _ = interval.tick() => { // Send Keepalive / Ping let status_cmd = WebSocketCommand::Ping { command_id: Uuid::new_v4().to_string(), }; if let Ok(json) = serde_json::to_string(&status_cmd) { write.send(tokio_tungstenite::tungstenite::Message::Text(json)).await?; self.stats.total_messages_sent.fetch_add(1, Ordering::Relaxed); } } msg = read.next() => { match msg { Some(Ok(message)) => { self.stats.total_messages_received.fetch_add(1, Ordering::Relaxed); if let tokio_tungstenite::tungstenite::Message::Text(text) = message { if let Ok(event) = serde_json::from_str::(&text) { match event { WebSocketEvent::SyncPing { ping_id, server_timestamp: _ } => { // Reply with SyncPong let pong_cmd = WebSocketCommand::SyncPong { ping_id, client_timestamp: SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH)? .as_millis() as u64, }; let json = serde_json::to_string(&pong_cmd)?; write.send(tokio_tungstenite::tungstenite::Message::Text(json)).await?; self.stats.total_messages_sent.fetch_add(1, Ordering::Relaxed); } WebSocketEvent::SyncAdjustment { .. } => { self.stats.sync_adjustments_received.fetch_add(1, Ordering::Relaxed); } _ => {} } } } } Some(Err(e)) => { return Err(Box::new(e)); } None => break, } } } } let _ = write.close().await; Ok(()) } } #[tokio::main] async fn main() { let args = Args::parse(); let stats = Arc::new(LoadStats::default()); println!("šŸš€ Starting Stream Server Load Test"); println!(" Target: {}", args.url); println!(" Clients: {}", args.clients); println!(" Duration: {}s", args.duration); let start_time = SystemTime::now(); let mut handles = Vec::new(); for _ in 0..args.clients { let client = SimulatedClient::new(args.url.clone(), stats.clone()); let duration = Duration::from_secs(args.duration); handles.push(tokio::spawn(async move { client.run(duration).await; })); // Stagger connections slightly tokio::time::sleep(Duration::from_millis(50)).await; } // Wait for all clients futures::future::join_all(handles).await; let elapsed = start_time.elapsed().unwrap_or_default(); println!("\nšŸ“Š Load Test Report"); println!("===================="); println!("Duration: {:.2?}", elapsed); println!("Total Messages Sent: {}", stats.total_messages_sent.load(Ordering::SeqCst)); println!("Total Messages Received: {}", stats.total_messages_received.load(Ordering::SeqCst)); println!("Sync Adjustments Received: {}", stats.sync_adjustments_received.load(Ordering::SeqCst)); println!("Errors: {}", stats.errors.load(Ordering::SeqCst)); // Validate results - allow small error margin for connection teardown let errors = stats.errors.load(Ordering::SeqCst); if errors > (args.clients as u64 / 2) { println!("āŒ Test Failed with high error rate: {}", errors); std::process::exit(1); } else { println!("āœ… Test Passed!"); } }