likwid/backend/src/rate_limit.rs

239 lines
6.7 KiB
Rust
Raw Normal View History

use axum::body::Body;
use axum::extract::State;
use axum::http::{header, HeaderMap, Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use serde_json::json;
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::auth::jwt::verify_token;
use crate::config::Config;
#[derive(Clone)]
pub struct RateLimitState {
config: Arc<Config>,
ip: Arc<FixedWindowLimiter>,
user: Arc<FixedWindowLimiter>,
auth: Arc<FixedWindowLimiter>,
}
impl RateLimitState {
pub fn new(config: Arc<Config>) -> Self {
let window = Duration::from_secs(60);
Self {
ip: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_ip_rpm)),
user: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_user_rpm)),
auth: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_auth_rpm)),
config,
}
}
}
struct WindowEntry {
window_start: Instant,
count: u32,
last_seen: Instant,
}
struct FixedWindowLimiter {
window: Duration,
limit: u32,
entries: Mutex<HashMap<String, WindowEntry>>,
}
impl FixedWindowLimiter {
fn new(window: Duration, limit: u32) -> Self {
Self {
window,
limit,
entries: Mutex::new(HashMap::new()),
}
}
async fn check(&self, key: &str) -> Result<(), u64> {
if self.limit == 0 {
return Ok(());
}
let now = Instant::now();
let mut map = self.entries.lock().await;
if map.len() > 10_000 {
let stale_before = now
.checked_sub(self.window.saturating_mul(2))
.unwrap_or(now);
map.retain(|_, v| v.last_seen >= stale_before);
}
let entry = map.entry(key.to_string()).or_insert_with(|| WindowEntry {
window_start: now,
count: 0,
last_seen: now,
});
entry.last_seen = now;
if now.duration_since(entry.window_start) >= self.window {
entry.window_start = now;
entry.count = 0;
}
if entry.count >= self.limit {
let elapsed = now.duration_since(entry.window_start);
let retry_after = self
.window
.checked_sub(elapsed)
.unwrap_or_else(|| Duration::from_secs(0));
return Err(retry_after.as_secs().max(1));
}
entry.count += 1;
Ok(())
}
}
fn parse_ip_from_headers(headers: &HeaderMap) -> Option<IpAddr> {
if let Some(forwarded) = headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
{
let first = forwarded.split(',').next().map(|s| s.trim()).unwrap_or("");
if let Ok(ip) = first.parse::<IpAddr>() {
return Some(ip);
}
}
if let Some(real_ip) = headers
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
{
if let Ok(ip) = real_ip.trim().parse::<IpAddr>() {
return Some(ip);
}
}
None
}
fn parse_ip_from_connect_info<B>(request: &Request<B>) -> Option<IpAddr> {
request
.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.0.ip())
}
fn parse_user_from_auth(headers: &HeaderMap, secret: &str) -> Option<Uuid> {
let auth = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
let token = auth.strip_prefix("Bearer ")?;
verify_token(token, secret).ok().map(|c| c.sub)
}
fn is_bypassed_path(path: &str) -> bool {
path == "/" || path == "/health"
}
fn is_auth_path(path: &str) -> bool {
path == "/api/auth/login" || path == "/api/auth/register"
}
pub async fn rate_limit_middleware(
State(state): State<RateLimitState>,
request: Request<Body>,
next: Next,
) -> Response {
if !state.config.rate_limit_enabled {
return next.run(request).await;
}
if request.method() == axum::http::Method::OPTIONS {
return next.run(request).await;
}
let path = request.uri().path();
if is_bypassed_path(path) {
return next.run(request).await;
}
let headers = request.headers();
let ip = parse_ip_from_headers(headers).or_else(|| parse_ip_from_connect_info(&request));
let ip_key = ip
.map(|v| format!("ip:{}", v))
.unwrap_or_else(|| "ip:unknown".to_string());
if is_auth_path(path) {
match state.auth.check(&ip_key).await {
Ok(_) => return next.run(request).await,
Err(retry_after) => {
return (
StatusCode::TOO_MANY_REQUESTS,
[(header::RETRY_AFTER, retry_after.to_string())],
axum::Json(json!({"error": "Rate limit exceeded"})),
)
.into_response();
}
}
}
if let Err(retry_after) = state.ip.check(&ip_key).await {
return (
StatusCode::TOO_MANY_REQUESTS,
[(header::RETRY_AFTER, retry_after.to_string())],
axum::Json(json!({"error": "Rate limit exceeded"})),
)
.into_response();
}
if state.config.rate_limit_user_rpm > 0 {
if let Some(user_id) = parse_user_from_auth(headers, &state.config.jwt_secret) {
let user_key = format!("user:{}", user_id);
if let Err(retry_after) = state.user.check(&user_key).await {
return (
StatusCode::TOO_MANY_REQUESTS,
[(header::RETRY_AFTER, retry_after.to_string())],
axum::Json(json!({"error": "Rate limit exceeded"})),
)
.into_response();
}
}
}
next.run(request).await
}
#[cfg(test)]
mod tests {
use super::FixedWindowLimiter;
use std::time::Duration;
#[tokio::test]
async fn fixed_window_allows_up_to_limit() {
let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 3);
assert!(limiter.check("k").await.is_ok());
assert!(limiter.check("k").await.is_ok());
assert!(limiter.check("k").await.is_ok());
}
#[tokio::test]
async fn fixed_window_denies_after_limit() {
let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 2);
assert!(limiter.check("k").await.is_ok());
assert!(limiter.check("k").await.is_ok());
let retry_after = limiter.check("k").await.err().expect("expected denial");
assert!(retry_after >= 1);
}
#[tokio::test]
async fn fixed_window_zero_limit_disables() {
let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 0);
for _ in 0..10 {
assert!(limiter.check("k").await.is_ok());
}
}
}