use axum::body::Body; use axum::extract::State; use axum::http::{header, HeaderMap, Request, StatusCode}; use axum::middleware::Next; use axum::response::{IntoResponse, Response}; use serde_json::json; use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::Mutex; use uuid::Uuid; use crate::auth::jwt::verify_token; use crate::config::Config; #[derive(Clone)] pub struct RateLimitState { config: Arc, ip: Arc, user: Arc, auth: Arc, } impl RateLimitState { pub fn new(config: Arc) -> Self { let window = Duration::from_secs(60); Self { ip: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_ip_rpm)), user: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_user_rpm)), auth: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_auth_rpm)), config, } } } struct WindowEntry { window_start: Instant, count: u32, last_seen: Instant, } struct FixedWindowLimiter { window: Duration, limit: u32, entries: Mutex>, } impl FixedWindowLimiter { fn new(window: Duration, limit: u32) -> Self { Self { window, limit, entries: Mutex::new(HashMap::new()), } } async fn check(&self, key: &str) -> Result<(), u64> { if self.limit == 0 { return Ok(()); } let now = Instant::now(); let mut map = self.entries.lock().await; if map.len() > 10_000 { let stale_before = now .checked_sub(self.window.saturating_mul(2)) .unwrap_or(now); map.retain(|_, v| v.last_seen >= stale_before); } let entry = map.entry(key.to_string()).or_insert_with(|| WindowEntry { window_start: now, count: 0, last_seen: now, }); entry.last_seen = now; if now.duration_since(entry.window_start) >= self.window { entry.window_start = now; entry.count = 0; } if entry.count >= self.limit { let elapsed = now.duration_since(entry.window_start); let retry_after = self .window .checked_sub(elapsed) .unwrap_or_else(|| Duration::from_secs(0)); return Err(retry_after.as_secs().max(1)); } entry.count += 1; Ok(()) } } fn parse_ip_from_headers(headers: &HeaderMap) -> Option { if let Some(forwarded) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) { let first = forwarded.split(',').next().map(|s| s.trim()).unwrap_or(""); if let Ok(ip) = first.parse::() { return Some(ip); } } if let Some(real_ip) = headers.get("x-real-ip").and_then(|v| v.to_str().ok()) { if let Ok(ip) = real_ip.trim().parse::() { return Some(ip); } } None } fn parse_ip_from_connect_info(request: &Request) -> Option { request .extensions() .get::>() .map(|ci| ci.0.ip()) } fn parse_user_from_auth(headers: &HeaderMap, secret: &str) -> Option { let auth = headers.get(header::AUTHORIZATION)?.to_str().ok()?; let token = auth.strip_prefix("Bearer ")?; verify_token(token, secret).ok().map(|c| c.sub) } fn is_bypassed_path(path: &str) -> bool { path == "/" || path == "/health" } fn is_auth_path(path: &str) -> bool { path == "/api/auth/login" || path == "/api/auth/register" } pub async fn rate_limit_middleware( State(state): State, request: Request, next: Next, ) -> Response { if !state.config.rate_limit_enabled { return next.run(request).await; } if request.method() == axum::http::Method::OPTIONS { return next.run(request).await; } let path = request.uri().path(); if is_bypassed_path(path) { return next.run(request).await; } let headers = request.headers(); let ip = parse_ip_from_headers(headers).or_else(|| parse_ip_from_connect_info(&request)); let ip_key = ip .map(|v| format!("ip:{}", v)) .unwrap_or_else(|| "ip:unknown".to_string()); if is_auth_path(path) { match state.auth.check(&ip_key).await { Ok(_) => return next.run(request).await, Err(retry_after) => { return ( StatusCode::TOO_MANY_REQUESTS, [(header::RETRY_AFTER, retry_after.to_string())], axum::Json(json!({"error": "Rate limit exceeded"})), ) .into_response(); } } } if let Err(retry_after) = state.ip.check(&ip_key).await { return ( StatusCode::TOO_MANY_REQUESTS, [(header::RETRY_AFTER, retry_after.to_string())], axum::Json(json!({"error": "Rate limit exceeded"})), ) .into_response(); } if state.config.rate_limit_user_rpm > 0 { if let Some(user_id) = parse_user_from_auth(headers, &state.config.jwt_secret) { let user_key = format!("user:{}", user_id); if let Err(retry_after) = state.user.check(&user_key).await { return ( StatusCode::TOO_MANY_REQUESTS, [(header::RETRY_AFTER, retry_after.to_string())], axum::Json(json!({"error": "Rate limit exceeded"})), ) .into_response(); } } } next.run(request).await } #[cfg(test)] mod tests { use super::{rate_limit_middleware, FixedWindowLimiter, RateLimitState}; use axum::body::Body; use axum::http::{header, Request, StatusCode}; use axum::routing::get; use axum::Router; use std::sync::Arc; use tower::ServiceExt; use uuid::Uuid; use crate::auth::jwt::create_token; use crate::config::Config; use std::time::Duration; #[tokio::test] async fn fixed_window_allows_up_to_limit() { let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 3); assert!(limiter.check("k").await.is_ok()); assert!(limiter.check("k").await.is_ok()); assert!(limiter.check("k").await.is_ok()); } #[tokio::test] async fn fixed_window_denies_after_limit() { let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 2); assert!(limiter.check("k").await.is_ok()); assert!(limiter.check("k").await.is_ok()); let retry_after = limiter.check("k").await.expect_err("expected denial"); assert!(retry_after >= 1); } #[tokio::test] async fn fixed_window_zero_limit_disables() { let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 0); for _ in 0..10 { assert!(limiter.check("k").await.is_ok()); } } #[tokio::test] async fn middleware_bypasses_health() { let cfg = Arc::new(Config { rate_limit_enabled: true, rate_limit_ip_rpm: 1, rate_limit_user_rpm: 0, rate_limit_auth_rpm: 0, ..Default::default() }); let app = Router::new() .route("/health", get(|| async { "ok" })) .layer(axum::middleware::from_fn_with_state( RateLimitState::new(cfg.clone()), rate_limit_middleware, )); for _ in 0..5 { let req = Request::builder() .method("GET") .uri("/health") .header("x-forwarded-for", "1.2.3.4") .body(Body::empty()) .unwrap(); let res = app.clone().oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } } #[tokio::test] async fn middleware_enforces_ip_rate_limit() { let cfg = Arc::new(Config { rate_limit_enabled: true, rate_limit_ip_rpm: 2, rate_limit_user_rpm: 0, rate_limit_auth_rpm: 0, ..Default::default() }); let app = Router::new() .route("/api/ping", get(|| async { "ok" })) .layer(axum::middleware::from_fn_with_state( RateLimitState::new(cfg.clone()), rate_limit_middleware, )); for _ in 0..2 { let req = Request::builder() .method("GET") .uri("/api/ping") .header("x-forwarded-for", "1.2.3.4") .body(Body::empty()) .unwrap(); let res = app.clone().oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } let req = Request::builder() .method("GET") .uri("/api/ping") .header("x-forwarded-for", "1.2.3.4") .body(Body::empty()) .unwrap(); let res = app.clone().oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS); assert!(res.headers().get(header::RETRY_AFTER).is_some()); } #[tokio::test] async fn middleware_enforces_auth_rate_limit_on_login() { let cfg = Arc::new(Config { rate_limit_enabled: true, rate_limit_ip_rpm: 0, rate_limit_user_rpm: 0, rate_limit_auth_rpm: 1, ..Default::default() }); let app = Router::new() .route("/api/auth/login", get(|| async { "ok" })) .layer(axum::middleware::from_fn_with_state( RateLimitState::new(cfg.clone()), rate_limit_middleware, )); let req = Request::builder() .method("GET") .uri("/api/auth/login") .header("x-forwarded-for", "1.2.3.4") .body(Body::empty()) .unwrap(); let res = app.clone().oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let req = Request::builder() .method("GET") .uri("/api/auth/login") .header("x-forwarded-for", "1.2.3.4") .body(Body::empty()) .unwrap(); let res = app.clone().oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS); assert!(res.headers().get(header::RETRY_AFTER).is_some()); } #[tokio::test] async fn middleware_enforces_user_rate_limit_when_token_is_valid() { let cfg = Arc::new(Config { rate_limit_enabled: true, rate_limit_ip_rpm: 0, rate_limit_user_rpm: 1, rate_limit_auth_rpm: 0, jwt_secret: "testsecret".to_string(), ..Default::default() }); let user_id = Uuid::new_v4(); let token = create_token(user_id, "u", &cfg.jwt_secret).unwrap(); let app = Router::new() .route("/api/ping", get(|| async { "ok" })) .layer(axum::middleware::from_fn_with_state( RateLimitState::new(cfg.clone()), rate_limit_middleware, )); let req = Request::builder() .method("GET") .uri("/api/ping") .header("x-forwarded-for", "1.2.3.4") .header(header::AUTHORIZATION, format!("Bearer {token}")) .body(Body::empty()) .unwrap(); let res = app.clone().oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let req = Request::builder() .method("GET") .uri("/api/ping") .header("x-forwarded-for", "1.2.3.4") .header(header::AUTHORIZATION, format!("Bearer {token}")) .body(Body::empty()) .unwrap(); let res = app.clone().oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS); assert!(res.headers().get(header::RETRY_AFTER).is_some()); } }