118 lines
3.4 KiB
Rust
118 lines
3.4 KiB
Rust
use axum::{
|
|
extract::{Request, State},
|
|
http::{HeaderMap, StatusCode},
|
|
middleware::Next,
|
|
response::Response,
|
|
};
|
|
use governor::{
|
|
clock::DefaultClock,
|
|
state::keyed::DashMapStateStore,
|
|
Quota, RateLimiter as GovLimiter,
|
|
};
|
|
use std::num::NonZeroU32;
|
|
use crate::AppState;
|
|
|
|
/// Per-IP keyed rate limiter backed by the `governor` crate.
|
|
/// Uses an in-memory DashMap store with a sliding-window quota.
|
|
type KeyedLimiter = GovLimiter<String, DashMapStateStore<String>, DefaultClock>;
|
|
|
|
lazy_static::lazy_static! {
|
|
/// Default: 120 requests per minute per IP.
|
|
/// This is a global limiter; the config value is read at check time
|
|
/// but the governor quota is set once at init.
|
|
static ref LIMITER: KeyedLimiter = {
|
|
let quota = Quota::per_minute(NonZeroU32::new(120).expect("120 is non-zero"));
|
|
GovLimiter::keyed(quota)
|
|
};
|
|
}
|
|
|
|
pub async fn rate_limit_middleware(
|
|
State(state): State<AppState>,
|
|
request: Request,
|
|
next: Next,
|
|
) -> Result<Response, StatusCode> {
|
|
let headers = request.headers();
|
|
let client_ip = extract_client_ip(headers);
|
|
|
|
if !check_rate_limit(&state, &client_ip) {
|
|
tracing::warn!(
|
|
client_ip = %client_ip,
|
|
"Rate limit exceeded"
|
|
);
|
|
|
|
state.metrics.increment_rate_limited();
|
|
|
|
return Err(StatusCode::TOO_MANY_REQUESTS);
|
|
}
|
|
|
|
let response = next.run(request).await;
|
|
Ok(response)
|
|
}
|
|
|
|
pub(crate) fn extract_client_ip(headers: &HeaderMap) -> String {
|
|
if let Some(forwarded_for) = headers.get("x-forwarded-for") {
|
|
if let Ok(forwarded_str) = forwarded_for.to_str() {
|
|
if let Some(first_ip) = forwarded_str.split(',').next() {
|
|
return first_ip.trim().to_string();
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some(real_ip) = headers.get("x-real-ip") {
|
|
if let Ok(ip_str) = real_ip.to_str() {
|
|
return ip_str.to_string();
|
|
}
|
|
}
|
|
|
|
"unknown".to_string()
|
|
}
|
|
|
|
fn check_rate_limit(_state: &AppState, client_ip: &str) -> bool {
|
|
match LIMITER.check_key(&client_ip.to_string()) {
|
|
Ok(_) => true,
|
|
Err(_not_until) => {
|
|
tracing::debug!(
|
|
client_ip = %client_ip,
|
|
"Rate limit check: request denied"
|
|
);
|
|
false
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use axum::http::HeaderValue;
|
|
|
|
#[test]
|
|
fn test_extract_client_ip_x_forwarded_for() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(
|
|
"x-forwarded-for",
|
|
HeaderValue::from_static("10.0.0.1, 192.168.1.1"),
|
|
);
|
|
assert_eq!(extract_client_ip(&headers), "10.0.0.1");
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_client_ip_x_real_ip() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert("x-real-ip", HeaderValue::from_static("203.0.113.50"));
|
|
assert_eq!(extract_client_ip(&headers), "203.0.113.50");
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_client_ip_fallback_unknown() {
|
|
let headers = HeaderMap::new();
|
|
assert_eq!(extract_client_ip(&headers), "unknown");
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_client_ip_forwarded_takes_precedence() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert("x-forwarded-for", HeaderValue::from_static("1.2.3.4"));
|
|
headers.insert("x-real-ip", HeaderValue::from_static("5.6.7.8"));
|
|
assert_eq!(extract_client_ip(&headers), "1.2.3.4");
|
|
}
|
|
}
|