veza/veza-common/src/config/redis.rs
2025-12-03 22:24:14 +01:00

312 lines
9.1 KiB
Rust

//! Redis configuration types
//!
//! This module defines configuration structures for Redis connections.
use serde::{Deserialize, Serialize};
use std::time::Duration;
use crate::error::{CommonError, CommonResult};
/// Redis configuration
///
/// Configuration for Redis connections including connection pooling settings.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RedisConfig {
/// Redis connection URL
pub url: String,
/// Maximum number of connections in the pool
#[serde(default = "default_max_connections")]
pub max_connections: u32,
/// Connection timeout
#[serde(with = "duration_secs", default = "default_connection_timeout")]
pub connection_timeout: Duration,
/// Command timeout
#[serde(with = "duration_secs", default = "default_command_timeout")]
pub command_timeout: Duration,
/// Enable connection keepalive
#[serde(default = "default_true")]
pub keepalive: bool,
/// Database number (0-15)
#[serde(default = "default_db")]
pub db: u8,
/// Password for authentication (optional)
#[serde(skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
/// Username for authentication (optional)
#[serde(skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
}
fn default_max_connections() -> u32 {
10
}
fn default_connection_timeout() -> Duration {
Duration::from_secs(5)
}
fn default_command_timeout() -> Duration {
Duration::from_secs(3)
}
fn default_true() -> bool {
true
}
fn default_db() -> u8 {
0
}
/// Serialize/Deserialize Duration as seconds
mod duration_secs {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u64(duration.as_secs())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let secs = u64::deserialize(deserializer)?;
Ok(Duration::from_secs(secs))
}
}
impl Default for RedisConfig {
fn default() -> Self {
Self {
url: "redis://localhost:6379".to_string(),
max_connections: 10,
connection_timeout: Duration::from_secs(5),
command_timeout: Duration::from_secs(3),
keepalive: true,
db: 0,
password: None,
username: None,
}
}
}
impl RedisConfig {
/// Create a new Redis configuration
pub fn new(url: String) -> Self {
Self {
url,
max_connections: 10,
connection_timeout: Duration::from_secs(5),
command_timeout: Duration::from_secs(3),
keepalive: true,
db: 0,
password: None,
username: None,
}
}
/// Validate the Redis configuration
pub fn validate(&self) -> CommonResult<()> {
if self.url.is_empty() {
return Err(CommonError::ValidationError(
"Redis URL cannot be empty".to_string()
));
}
if !self.url.starts_with("redis://") && !self.url.starts_with("rediss://") {
return Err(CommonError::ValidationError(
"Redis URL must start with redis:// or rediss://".to_string()
));
}
if self.max_connections == 0 {
return Err(CommonError::ValidationError(
"Max connections must be greater than 0".to_string()
));
}
if self.db > 15 {
return Err(CommonError::ValidationError(
"Database number must be between 0 and 15".to_string()
));
}
if self.connection_timeout.as_secs() == 0 {
return Err(CommonError::ValidationError(
"Connection timeout must be greater than 0".to_string()
));
}
Ok(())
}
/// Get the host from the URL
pub fn host(&self) -> Option<String> {
self.url
.strip_prefix("redis://")
.or_else(|| self.url.strip_prefix("rediss://"))
.and_then(|s| {
// Remove user:password@ if present
let s = if s.contains('@') {
s.split('@').nth(1).unwrap_or(s)
} else {
s
};
// Extract host:port
s.split('/').next().map(|s| {
if s.contains(':') {
s.split(':').next().unwrap_or(s).to_string()
} else {
s.to_string()
}
})
})
}
/// Get the port from the URL
pub fn port(&self) -> Option<u16> {
self.url
.strip_prefix("redis://")
.or_else(|| self.url.strip_prefix("rediss://"))
.and_then(|s| {
let s = if s.contains('@') {
s.split('@').nth(1).unwrap_or(s)
} else {
s
};
s.split('/').next().and_then(|s| {
if s.contains(':') {
s.split(':').nth(1).and_then(|p| p.parse().ok())
} else {
None
}
})
})
.or(Some(6379)) // Default Redis port
}
/// Check if SSL/TLS is enabled
pub fn is_ssl(&self) -> bool {
self.url.starts_with("rediss://")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_redis_config_default() {
let config = RedisConfig::default();
assert!(!config.url.is_empty());
assert_eq!(config.max_connections, 10);
assert_eq!(config.db, 0);
assert!(config.keepalive);
}
#[test]
fn test_redis_config_new() {
let config = RedisConfig::new("redis://localhost:6379".to_string());
assert_eq!(config.url, "redis://localhost:6379");
assert_eq!(config.max_connections, 10);
assert_eq!(config.db, 0);
}
#[test]
fn test_redis_config_validate_success() {
let config = RedisConfig::new("redis://localhost:6379".to_string());
assert!(config.validate().is_ok());
}
#[test]
fn test_redis_config_validate_empty_url() {
let mut config = RedisConfig::default();
config.url = "".to_string();
assert!(config.validate().is_err());
}
#[test]
fn test_redis_config_validate_invalid_url() {
let config = RedisConfig::new("http://localhost:6379".to_string());
assert!(config.validate().is_err());
}
#[test]
fn test_redis_config_validate_zero_max_connections() {
let mut config = RedisConfig::new("redis://localhost:6379".to_string());
config.max_connections = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_redis_config_validate_db_too_large() {
let mut config = RedisConfig::new("redis://localhost:6379".to_string());
config.db = 16;
assert!(config.validate().is_err());
}
#[test]
fn test_redis_config_host() {
let config = RedisConfig::new("redis://localhost:6379".to_string());
assert_eq!(config.host(), Some("localhost".to_string()));
}
#[test]
fn test_redis_config_host_with_auth() {
let config = RedisConfig::new("redis://user:pass@localhost:6379".to_string());
assert_eq!(config.host(), Some("localhost".to_string()));
}
#[test]
fn test_redis_config_port() {
let config = RedisConfig::new("redis://localhost:6380".to_string());
assert_eq!(config.port(), Some(6380));
}
#[test]
fn test_redis_config_port_default() {
let config = RedisConfig::new("redis://localhost".to_string());
assert_eq!(config.port(), Some(6379));
}
#[test]
fn test_redis_config_is_ssl() {
let config_ssl = RedisConfig::new("rediss://localhost:6379".to_string());
assert!(config_ssl.is_ssl());
let config_no_ssl = RedisConfig::new("redis://localhost:6379".to_string());
assert!(!config_no_ssl.is_ssl());
}
#[test]
fn test_redis_config_serialize() {
let config = RedisConfig::default();
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("url"));
assert!(json.contains("max_connections"));
assert!(!json.contains("password")); // Should be skipped if None
}
#[test]
fn test_redis_config_deserialize() {
let json = r#"{
"url": "redis://localhost:6379",
"max_connections": 20,
"db": 1
}"#;
let config: RedisConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.url, "redis://localhost:6379");
assert_eq!(config.max_connections, 20);
assert_eq!(config.db, 1);
}
#[test]
fn test_redis_config_with_password() {
let mut config = RedisConfig::new("redis://localhost:6379".to_string());
config.password = Some("secret".to_string());
assert_eq!(config.password, Some("secret".to_string()));
}
}