veza/veza-stream-server/src/bin/stream_load_test.rs

196 lines
7.4 KiB
Rust
Raw Normal View History

use clap::Parser;
use futures_util::{SinkExt, StreamExt};
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use std::time::{Duration, SystemTime};
use stream_server::streaming::websocket::{WebSocketCommand, WebSocketEvent};
use tokio::net::TcpStream;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use uuid::Uuid;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// WebSocket URL to connect to
#[arg(short, long, default_value = "ws://localhost:8082/ws")]
url: String,
/// Number of simulated clients
#[arg(short, long, default_value_t = 10)]
clients: usize,
/// Duration of the test in seconds
#[arg(short, long, default_value_t = 30)]
duration: u64,
}
#[derive(Debug, Default)]
struct LoadStats {
total_messages_sent: AtomicU64,
total_messages_received: AtomicU64,
sync_adjustments_received: AtomicU64,
errors: AtomicU64,
active_clients: AtomicU64,
}
struct SimulatedClient {
id: Uuid,
url: String,
stats: Arc<LoadStats>,
}
impl SimulatedClient {
fn new(url: String, stats: Arc<LoadStats>) -> Self {
Self {
id: Uuid::new_v4(),
url,
stats,
}
}
async fn run(self, duration: Duration) {
let connect_url = format!("{}?user_id=load_test_{}", self.url, self.id);
match connect_async(&connect_url).await {
Ok((ws_stream, _)) => {
self.stats.active_clients.fetch_add(1, Ordering::SeqCst);
if let Err(e) = self.handle_connection(ws_stream, duration).await {
// Ignore connection closed errors at shutdown if expected
if !e.to_string().contains("Connection reset without closing handshake") {
eprintln!("Client {} error: {}", self.id, e);
self.stats.errors.fetch_add(1, Ordering::SeqCst);
}
}
self.stats.active_clients.fetch_sub(1, Ordering::SeqCst);
}
Err(e) => {
eprintln!("Failed to connect client {}: {}", self.id, e);
self.stats.errors.fetch_add(1, Ordering::SeqCst);
}
}
}
async fn handle_connection(
&self,
ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
duration: Duration,
) -> Result<(), Box<dyn std::error::Error>> {
let (mut write, mut read) = ws_stream.split();
let mut interval = tokio::time::interval(Duration::from_secs(5));
let timeout = tokio::time::sleep(duration);
tokio::pin!(timeout);
// Subscribe command
let subscribe_cmd = WebSocketCommand::Subscribe {
command_id: Uuid::new_v4().to_string(),
events: vec!["*".to_string()],
filters: None,
};
let msg = tokio_tungstenite::tungstenite::Message::Text(serde_json::to_string(&subscribe_cmd)?);
write.send(msg).await?;
self.stats.total_messages_sent.fetch_add(1, Ordering::Relaxed);
loop {
tokio::select! {
_ = &mut timeout => {
break;
}
_ = interval.tick() => {
// Send Keepalive / Ping
let status_cmd = WebSocketCommand::Ping {
command_id: Uuid::new_v4().to_string(),
};
if let Ok(json) = serde_json::to_string(&status_cmd) {
write.send(tokio_tungstenite::tungstenite::Message::Text(json)).await?;
self.stats.total_messages_sent.fetch_add(1, Ordering::Relaxed);
}
}
msg = read.next() => {
match msg {
Some(Ok(message)) => {
self.stats.total_messages_received.fetch_add(1, Ordering::Relaxed);
if let tokio_tungstenite::tungstenite::Message::Text(text) = message {
if let Ok(event) = serde_json::from_str::<WebSocketEvent>(&text) {
match event {
WebSocketEvent::SyncPing { ping_id, server_timestamp: _ } => {
// Reply with SyncPong
let pong_cmd = WebSocketCommand::SyncPong {
ping_id,
client_timestamp: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_millis() as u64,
};
let json = serde_json::to_string(&pong_cmd)?;
write.send(tokio_tungstenite::tungstenite::Message::Text(json)).await?;
self.stats.total_messages_sent.fetch_add(1, Ordering::Relaxed);
}
WebSocketEvent::SyncAdjustment { .. } => {
self.stats.sync_adjustments_received.fetch_add(1, Ordering::Relaxed);
}
_ => {}
}
}
}
}
Some(Err(e)) => {
return Err(Box::new(e));
}
None => break,
}
}
}
}
let _ = write.close().await;
Ok(())
}
}
#[tokio::main]
async fn main() {
let args = Args::parse();
let stats = Arc::new(LoadStats::default());
println!("🚀 Starting Stream Server Load Test");
println!(" Target: {}", args.url);
println!(" Clients: {}", args.clients);
println!(" Duration: {}s", args.duration);
let start_time = SystemTime::now();
let mut handles = Vec::new();
for _ in 0..args.clients {
let client = SimulatedClient::new(args.url.clone(), stats.clone());
let duration = Duration::from_secs(args.duration);
handles.push(tokio::spawn(async move {
client.run(duration).await;
}));
// Stagger connections slightly
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Wait for all clients
futures::future::join_all(handles).await;
let elapsed = start_time.elapsed().unwrap_or_default();
println!("\n📊 Load Test Report");
println!("====================");
println!("Duration: {:.2?}", elapsed);
println!("Total Messages Sent: {}", stats.total_messages_sent.load(Ordering::SeqCst));
println!("Total Messages Received: {}", stats.total_messages_received.load(Ordering::SeqCst));
println!("Sync Adjustments Received: {}", stats.sync_adjustments_received.load(Ordering::SeqCst));
println!("Errors: {}", stats.errors.load(Ordering::SeqCst));
// Validate results - allow small error margin for connection teardown
let errors = stats.errors.load(Ordering::SeqCst);
if errors > (args.clients as u64 / 2) {
println!("❌ Test Failed with high error rate: {}", errors);
std::process::exit(1);
} else {
println!("✅ Test Passed!");
}
}