use axum::body::Body; use axum::extract::State; use axum::http::{header, HeaderMap, Request, StatusCode}; use axum::middleware::Next; use axum::response::{IntoResponse, Response}; use serde_json::json; use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::Mutex; use uuid::Uuid; use crate::auth::jwt::verify_token; use crate::config::Config; #[derive(Clone)] pub struct RateLimitState { config: Arc, ip: Arc, user: Arc, auth: Arc, } impl RateLimitState { pub fn new(config: Arc) -> Self { let window = Duration::from_secs(60); Self { ip: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_ip_rpm)), user: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_user_rpm)), auth: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_auth_rpm)), config, } } } struct WindowEntry { window_start: Instant, count: u32, last_seen: Instant, } struct FixedWindowLimiter { window: Duration, limit: u32, entries: Mutex>, } impl FixedWindowLimiter { fn new(window: Duration, limit: u32) -> Self { Self { window, limit, entries: Mutex::new(HashMap::new()), } } async fn check(&self, key: &str) -> Result<(), u64> { if self.limit == 0 { return Ok(()); } let now = Instant::now(); let mut map = self.entries.lock().await; if map.len() > 10_000 { let stale_before = now .checked_sub(self.window.saturating_mul(2)) .unwrap_or(now); map.retain(|_, v| v.last_seen >= stale_before); } let entry = map.entry(key.to_string()).or_insert_with(|| WindowEntry { window_start: now, count: 0, last_seen: now, }); entry.last_seen = now; if now.duration_since(entry.window_start) >= self.window { entry.window_start = now; entry.count = 0; } if entry.count >= self.limit { let elapsed = now.duration_since(entry.window_start); let retry_after = self .window .checked_sub(elapsed) .unwrap_or_else(|| Duration::from_secs(0)); return Err(retry_after.as_secs().max(1)); } entry.count += 1; Ok(()) } } fn parse_ip_from_headers(headers: &HeaderMap) -> Option { if let Some(forwarded) = headers .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) { let first = forwarded.split(',').next().map(|s| s.trim()).unwrap_or(""); if let Ok(ip) = first.parse::() { return Some(ip); } } if let Some(real_ip) = headers .get("x-real-ip") .and_then(|v| v.to_str().ok()) { if let Ok(ip) = real_ip.trim().parse::() { return Some(ip); } } None } fn parse_ip_from_connect_info(request: &Request) -> Option { request .extensions() .get::>() .map(|ci| ci.0.ip()) } fn parse_user_from_auth(headers: &HeaderMap, secret: &str) -> Option { let auth = headers.get(header::AUTHORIZATION)?.to_str().ok()?; let token = auth.strip_prefix("Bearer ")?; verify_token(token, secret).ok().map(|c| c.sub) } fn is_bypassed_path(path: &str) -> bool { path == "/" || path == "/health" } fn is_auth_path(path: &str) -> bool { path == "/api/auth/login" || path == "/api/auth/register" } pub async fn rate_limit_middleware( State(state): State, request: Request, next: Next, ) -> Response { if !state.config.rate_limit_enabled { return next.run(request).await; } if request.method() == axum::http::Method::OPTIONS { return next.run(request).await; } let path = request.uri().path(); if is_bypassed_path(path) { return next.run(request).await; } let headers = request.headers(); let ip = parse_ip_from_headers(headers).or_else(|| parse_ip_from_connect_info(&request)); let ip_key = ip .map(|v| format!("ip:{}", v)) .unwrap_or_else(|| "ip:unknown".to_string()); if is_auth_path(path) { match state.auth.check(&ip_key).await { Ok(_) => return next.run(request).await, Err(retry_after) => { return ( StatusCode::TOO_MANY_REQUESTS, [(header::RETRY_AFTER, retry_after.to_string())], axum::Json(json!({"error": "Rate limit exceeded"})), ) .into_response(); } } } if let Err(retry_after) = state.ip.check(&ip_key).await { return ( StatusCode::TOO_MANY_REQUESTS, [(header::RETRY_AFTER, retry_after.to_string())], axum::Json(json!({"error": "Rate limit exceeded"})), ) .into_response(); } if state.config.rate_limit_user_rpm > 0 { if let Some(user_id) = parse_user_from_auth(headers, &state.config.jwt_secret) { let user_key = format!("user:{}", user_id); if let Err(retry_after) = state.user.check(&user_key).await { return ( StatusCode::TOO_MANY_REQUESTS, [(header::RETRY_AFTER, retry_after.to_string())], axum::Json(json!({"error": "Rate limit exceeded"})), ) .into_response(); } } } next.run(request).await } #[cfg(test)] mod tests { use super::FixedWindowLimiter; use std::time::Duration; #[tokio::test] async fn fixed_window_allows_up_to_limit() { let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 3); assert!(limiter.check("k").await.is_ok()); assert!(limiter.check("k").await.is_ok()); assert!(limiter.check("k").await.is_ok()); } #[tokio::test] async fn fixed_window_denies_after_limit() { let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 2); assert!(limiter.check("k").await.is_ok()); assert!(limiter.check("k").await.is_ok()); let retry_after = limiter.check("k").await.err().expect("expected denial"); assert!(retry_after >= 1); } #[tokio::test] async fn fixed_window_zero_limit_disables() { let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 0); for _ in 0..10 { assert!(limiter.check("k").await.is_ok()); } } }