From ed728979b6584c68fbae3002c48a496e2951bc77 Mon Sep 17 00:00:00 2001 From: Marco Allegretti Date: Mon, 2 Feb 2026 18:51:14 +0100 Subject: [PATCH] backend: add configurable rate limiting --- backend/src/config/mod.rs | 142 ++++++++++++--------- backend/src/main.rs | 7 +- backend/src/rate_limit.rs | 238 ++++++++++++++++++++++++++++++++++++ docs/admin/configuration.md | 9 +- docs/admin/security.md | 6 +- 5 files changed, 337 insertions(+), 65 deletions(-) create mode 100644 backend/src/rate_limit.rs diff --git a/backend/src/config/mod.rs b/backend/src/config/mod.rs index 46534d0..0bd44dc 100644 --- a/backend/src/config/mod.rs +++ b/backend/src/config/mod.rs @@ -1,57 +1,85 @@ -use serde::Deserialize; - -#[derive(Debug, Clone, Deserialize)] -pub struct Config { - #[serde(default = "default_database_url")] - pub database_url: String, - #[serde(default = "default_server_host")] - pub server_host: String, - #[serde(default = "default_server_port")] - pub server_port: u16, - /// Enable demo mode - restricts destructive actions and enables demo accounts - #[serde(default)] - pub demo_mode: bool, - /// Secret key for JWT tokens - #[serde(default = "default_jwt_secret")] - pub jwt_secret: String, -} - -fn default_database_url() -> String { - "postgres://likwid:likwid@localhost:5432/likwid".to_string() -} - -fn default_server_host() -> String { - "127.0.0.1".to_string() -} - -fn default_server_port() -> u16 { - 3000 -} - -fn default_jwt_secret() -> String { - "".to_string() -} - -impl Config { - pub fn from_env() -> Result { - dotenvy::dotenv().ok(); - envy::from_env::() - } - - /// Check if demo mode is enabled - pub fn is_demo(&self) -> bool { - self.demo_mode - } -} - -impl Default for Config { - fn default() -> Self { - Self { - database_url: "postgres://likwid:likwid@localhost:5432/likwid".to_string(), - server_host: "127.0.0.1".to_string(), - server_port: 3000, - demo_mode: false, - jwt_secret: default_jwt_secret(), - } - } -} +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + #[serde(default = "default_database_url")] + pub database_url: String, + #[serde(default = "default_server_host")] + pub server_host: String, + #[serde(default = "default_server_port")] + pub server_port: u16, + /// Enable demo mode - restricts destructive actions and enables demo accounts + #[serde(default)] + pub demo_mode: bool, + /// Secret key for JWT tokens + #[serde(default = "default_jwt_secret")] + pub jwt_secret: String, + #[serde(default = "default_rate_limit_enabled")] + pub rate_limit_enabled: bool, + #[serde(default = "default_rate_limit_ip_rpm")] + pub rate_limit_ip_rpm: u32, + #[serde(default = "default_rate_limit_user_rpm")] + pub rate_limit_user_rpm: u32, + #[serde(default = "default_rate_limit_auth_rpm")] + pub rate_limit_auth_rpm: u32, +} + +fn default_database_url() -> String { + "postgres://likwid:likwid@localhost:5432/likwid".to_string() +} + +fn default_server_host() -> String { + "127.0.0.1".to_string() +} + +fn default_server_port() -> u16 { + 3000 +} + +fn default_jwt_secret() -> String { + "".to_string() +} + +fn default_rate_limit_enabled() -> bool { + true +} + +fn default_rate_limit_ip_rpm() -> u32 { + 300 +} + +fn default_rate_limit_user_rpm() -> u32 { + 1200 +} + +fn default_rate_limit_auth_rpm() -> u32 { + 30 +} + +impl Config { + pub fn from_env() -> Result { + dotenvy::dotenv().ok(); + envy::from_env::() + } + + /// Check if demo mode is enabled + pub fn is_demo(&self) -> bool { + self.demo_mode + } +} + +impl Default for Config { + fn default() -> Self { + Self { + database_url: "postgres://likwid:likwid@localhost:5432/likwid".to_string(), + server_host: "127.0.0.1".to_string(), + server_port: 3000, + demo_mode: false, + jwt_secret: default_jwt_secret(), + rate_limit_enabled: default_rate_limit_enabled(), + rate_limit_ip_rpm: default_rate_limit_ip_rpm(), + rate_limit_user_rpm: default_rate_limit_user_rpm(), + rate_limit_auth_rpm: default_rate_limit_auth_rpm(), + } + } +} diff --git a/backend/src/main.rs b/backend/src/main.rs index a2706a0..046ae85 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -5,6 +5,7 @@ mod db; mod demo; mod models; mod plugins; +mod rate_limit; mod voting; use std::net::SocketAddr; @@ -207,6 +208,10 @@ async fn run() -> Result<(), StartupError> { .layer(Extension(plugins)) .layer(Extension(config.clone())) .layer(cors) + .layer(axum::middleware::from_fn_with_state( + rate_limit::RateLimitState::new(config.clone()), + rate_limit::rate_limit_middleware, + )) .layer(TraceLayer::new_for_http()) .layer(middleware::map_response(add_security_headers)); @@ -216,7 +221,7 @@ async fn run() -> Result<(), StartupError> { tracing::info!("Likwid backend listening on http://{}", addr); let listener = tokio::net::TcpListener::bind(addr).await?; - axum::serve(listener, app) + axum::serve(listener, app.into_make_service_with_connect_info::()) .await .map_err(|e| StartupError::Serve(e.to_string()))?; diff --git a/backend/src/rate_limit.rs b/backend/src/rate_limit.rs new file mode 100644 index 0000000..0bf23ad --- /dev/null +++ b/backend/src/rate_limit.rs @@ -0,0 +1,238 @@ +use axum::body::Body; +use axum::extract::State; +use axum::http::{header, HeaderMap, Request, StatusCode}; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use serde_json::json; +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::auth::jwt::verify_token; +use crate::config::Config; + +#[derive(Clone)] +pub struct RateLimitState { + config: Arc, + ip: Arc, + user: Arc, + auth: Arc, +} + +impl RateLimitState { + pub fn new(config: Arc) -> Self { + let window = Duration::from_secs(60); + Self { + ip: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_ip_rpm)), + user: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_user_rpm)), + auth: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_auth_rpm)), + config, + } + } +} + +struct WindowEntry { + window_start: Instant, + count: u32, + last_seen: Instant, +} + +struct FixedWindowLimiter { + window: Duration, + limit: u32, + entries: Mutex>, +} + +impl FixedWindowLimiter { + fn new(window: Duration, limit: u32) -> Self { + Self { + window, + limit, + entries: Mutex::new(HashMap::new()), + } + } + + async fn check(&self, key: &str) -> Result<(), u64> { + if self.limit == 0 { + return Ok(()); + } + + let now = Instant::now(); + let mut map = self.entries.lock().await; + + if map.len() > 10_000 { + let stale_before = now + .checked_sub(self.window.saturating_mul(2)) + .unwrap_or(now); + map.retain(|_, v| v.last_seen >= stale_before); + } + + let entry = map.entry(key.to_string()).or_insert_with(|| WindowEntry { + window_start: now, + count: 0, + last_seen: now, + }); + + entry.last_seen = now; + + if now.duration_since(entry.window_start) >= self.window { + entry.window_start = now; + entry.count = 0; + } + + if entry.count >= self.limit { + let elapsed = now.duration_since(entry.window_start); + let retry_after = self + .window + .checked_sub(elapsed) + .unwrap_or_else(|| Duration::from_secs(0)); + return Err(retry_after.as_secs().max(1)); + } + + entry.count += 1; + Ok(()) + } +} + +fn parse_ip_from_headers(headers: &HeaderMap) -> Option { + if let Some(forwarded) = headers + .get("x-forwarded-for") + .and_then(|v| v.to_str().ok()) + { + let first = forwarded.split(',').next().map(|s| s.trim()).unwrap_or(""); + if let Ok(ip) = first.parse::() { + return Some(ip); + } + } + + if let Some(real_ip) = headers + .get("x-real-ip") + .and_then(|v| v.to_str().ok()) + { + if let Ok(ip) = real_ip.trim().parse::() { + return Some(ip); + } + } + + None +} + +fn parse_ip_from_connect_info(request: &Request) -> Option { + request + .extensions() + .get::>() + .map(|ci| ci.0.ip()) +} + +fn parse_user_from_auth(headers: &HeaderMap, secret: &str) -> Option { + let auth = headers.get(header::AUTHORIZATION)?.to_str().ok()?; + let token = auth.strip_prefix("Bearer ")?; + verify_token(token, secret).ok().map(|c| c.sub) +} + +fn is_bypassed_path(path: &str) -> bool { + path == "/" || path == "/health" +} + +fn is_auth_path(path: &str) -> bool { + path == "/api/auth/login" || path == "/api/auth/register" +} + +pub async fn rate_limit_middleware( + State(state): State, + request: Request, + next: Next, +) -> Response { + if !state.config.rate_limit_enabled { + return next.run(request).await; + } + + if request.method() == axum::http::Method::OPTIONS { + return next.run(request).await; + } + + let path = request.uri().path(); + if is_bypassed_path(path) { + return next.run(request).await; + } + + let headers = request.headers(); + let ip = parse_ip_from_headers(headers).or_else(|| parse_ip_from_connect_info(&request)); + + let ip_key = ip + .map(|v| format!("ip:{}", v)) + .unwrap_or_else(|| "ip:unknown".to_string()); + + if is_auth_path(path) { + match state.auth.check(&ip_key).await { + Ok(_) => return next.run(request).await, + Err(retry_after) => { + return ( + StatusCode::TOO_MANY_REQUESTS, + [(header::RETRY_AFTER, retry_after.to_string())], + axum::Json(json!({"error": "Rate limit exceeded"})), + ) + .into_response(); + } + } + } + + if let Err(retry_after) = state.ip.check(&ip_key).await { + return ( + StatusCode::TOO_MANY_REQUESTS, + [(header::RETRY_AFTER, retry_after.to_string())], + axum::Json(json!({"error": "Rate limit exceeded"})), + ) + .into_response(); + } + + if state.config.rate_limit_user_rpm > 0 { + if let Some(user_id) = parse_user_from_auth(headers, &state.config.jwt_secret) { + let user_key = format!("user:{}", user_id); + if let Err(retry_after) = state.user.check(&user_key).await { + return ( + StatusCode::TOO_MANY_REQUESTS, + [(header::RETRY_AFTER, retry_after.to_string())], + axum::Json(json!({"error": "Rate limit exceeded"})), + ) + .into_response(); + } + } + } + + next.run(request).await +} + +#[cfg(test)] +mod tests { + use super::FixedWindowLimiter; + use std::time::Duration; + + #[tokio::test] + async fn fixed_window_allows_up_to_limit() { + let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 3); + assert!(limiter.check("k").await.is_ok()); + assert!(limiter.check("k").await.is_ok()); + assert!(limiter.check("k").await.is_ok()); + } + + #[tokio::test] + async fn fixed_window_denies_after_limit() { + let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 2); + assert!(limiter.check("k").await.is_ok()); + assert!(limiter.check("k").await.is_ok()); + let retry_after = limiter.check("k").await.err().expect("expected denial"); + assert!(retry_after >= 1); + } + + #[tokio::test] + async fn fixed_window_zero_limit_disables() { + let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 0); + for _ in 0..10 { + assert!(limiter.check("k").await.is_ok()); + } + } +} diff --git a/docs/admin/configuration.md b/docs/admin/configuration.md index d4d621c..042fd4a 100644 --- a/docs/admin/configuration.md +++ b/docs/admin/configuration.md @@ -13,6 +13,10 @@ Likwid is configured through environment variables and database settings. | `SERVER_HOST` | No | `127.0.0.1` | Bind address | | `SERVER_PORT` | No | `3000` | HTTP port | | `DEMO_MODE` | No | `false` | Enable demo features | +| `RATE_LIMIT_ENABLED` | No | `true` | Enable API rate limiting | +| `RATE_LIMIT_IP_RPM` | No | `300` | Requests per minute per IP | +| `RATE_LIMIT_USER_RPM` | No | `1200` | Requests per minute per authenticated user | +| `RATE_LIMIT_AUTH_RPM` | No | `30` | Requests per minute per IP for auth endpoints (`/api/auth/login`, `/api/auth/register`) | | `RUST_LOG` | No | `info` | Log level (trace, debug, info, warn, error) | ### Frontend @@ -73,10 +77,7 @@ Each community can configure: ## API Configuration ### Rate Limiting -Configure in backend settings: -- Requests per minute per IP -- Requests per minute per user -- Burst allowance +Rate limiting is configured via backend environment variables. ### CORS By default, CORS allows all origins in development. For production: diff --git a/docs/admin/security.md b/docs/admin/security.md index 703188c..3a536cb 100644 --- a/docs/admin/security.md +++ b/docs/admin/security.md @@ -35,9 +35,9 @@ CORS_ALLOWED_ORIGINS=https://likwid.example.org ### Rate Limiting Protect against abuse: -- 100 requests/minute per IP (default) -- 1000 requests/minute per authenticated user -- Configurable per endpoint +- 300 requests/minute per IP (default) +- 1200 requests/minute per authenticated user +- 30 requests/minute per IP for auth endpoints ## Database Security