backend: add configurable rate limiting

This commit is contained in:
Marco Allegretti 2026-02-02 18:51:14 +01:00
parent 49579e9286
commit ed728979b6
5 changed files with 337 additions and 65 deletions

View file

@ -1,57 +1,85 @@
use serde::Deserialize; use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct Config { pub struct Config {
#[serde(default = "default_database_url")] #[serde(default = "default_database_url")]
pub database_url: String, pub database_url: String,
#[serde(default = "default_server_host")] #[serde(default = "default_server_host")]
pub server_host: String, pub server_host: String,
#[serde(default = "default_server_port")] #[serde(default = "default_server_port")]
pub server_port: u16, pub server_port: u16,
/// Enable demo mode - restricts destructive actions and enables demo accounts /// Enable demo mode - restricts destructive actions and enables demo accounts
#[serde(default)] #[serde(default)]
pub demo_mode: bool, pub demo_mode: bool,
/// Secret key for JWT tokens /// Secret key for JWT tokens
#[serde(default = "default_jwt_secret")] #[serde(default = "default_jwt_secret")]
pub jwt_secret: String, pub jwt_secret: String,
} #[serde(default = "default_rate_limit_enabled")]
pub rate_limit_enabled: bool,
fn default_database_url() -> String { #[serde(default = "default_rate_limit_ip_rpm")]
"postgres://likwid:likwid@localhost:5432/likwid".to_string() pub rate_limit_ip_rpm: u32,
} #[serde(default = "default_rate_limit_user_rpm")]
pub rate_limit_user_rpm: u32,
fn default_server_host() -> String { #[serde(default = "default_rate_limit_auth_rpm")]
"127.0.0.1".to_string() pub rate_limit_auth_rpm: u32,
} }
fn default_server_port() -> u16 { fn default_database_url() -> String {
3000 "postgres://likwid:likwid@localhost:5432/likwid".to_string()
} }
fn default_jwt_secret() -> String { fn default_server_host() -> String {
"".to_string() "127.0.0.1".to_string()
} }
impl Config { fn default_server_port() -> u16 {
pub fn from_env() -> Result<Self, envy::Error> { 3000
dotenvy::dotenv().ok(); }
envy::from_env::<Config>()
} fn default_jwt_secret() -> String {
"".to_string()
/// Check if demo mode is enabled }
pub fn is_demo(&self) -> bool {
self.demo_mode fn default_rate_limit_enabled() -> bool {
} true
} }
impl Default for Config { fn default_rate_limit_ip_rpm() -> u32 {
fn default() -> Self { 300
Self { }
database_url: "postgres://likwid:likwid@localhost:5432/likwid".to_string(),
server_host: "127.0.0.1".to_string(), fn default_rate_limit_user_rpm() -> u32 {
server_port: 3000, 1200
demo_mode: false, }
jwt_secret: default_jwt_secret(),
} fn default_rate_limit_auth_rpm() -> u32 {
} 30
} }
impl Config {
pub fn from_env() -> Result<Self, envy::Error> {
dotenvy::dotenv().ok();
envy::from_env::<Config>()
}
/// Check if demo mode is enabled
pub fn is_demo(&self) -> bool {
self.demo_mode
}
}
impl Default for Config {
fn default() -> Self {
Self {
database_url: "postgres://likwid:likwid@localhost:5432/likwid".to_string(),
server_host: "127.0.0.1".to_string(),
server_port: 3000,
demo_mode: false,
jwt_secret: default_jwt_secret(),
rate_limit_enabled: default_rate_limit_enabled(),
rate_limit_ip_rpm: default_rate_limit_ip_rpm(),
rate_limit_user_rpm: default_rate_limit_user_rpm(),
rate_limit_auth_rpm: default_rate_limit_auth_rpm(),
}
}
}

View file

@ -5,6 +5,7 @@ mod db;
mod demo; mod demo;
mod models; mod models;
mod plugins; mod plugins;
mod rate_limit;
mod voting; mod voting;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -207,6 +208,10 @@ async fn run() -> Result<(), StartupError> {
.layer(Extension(plugins)) .layer(Extension(plugins))
.layer(Extension(config.clone())) .layer(Extension(config.clone()))
.layer(cors) .layer(cors)
.layer(axum::middleware::from_fn_with_state(
rate_limit::RateLimitState::new(config.clone()),
rate_limit::rate_limit_middleware,
))
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
.layer(middleware::map_response(add_security_headers)); .layer(middleware::map_response(add_security_headers));
@ -216,7 +221,7 @@ async fn run() -> Result<(), StartupError> {
tracing::info!("Likwid backend listening on http://{}", addr); tracing::info!("Likwid backend listening on http://{}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?; let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app) axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>())
.await .await
.map_err(|e| StartupError::Serve(e.to_string()))?; .map_err(|e| StartupError::Serve(e.to_string()))?;

238
backend/src/rate_limit.rs Normal file
View file

@ -0,0 +1,238 @@
use axum::body::Body;
use axum::extract::State;
use axum::http::{header, HeaderMap, Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use serde_json::json;
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::auth::jwt::verify_token;
use crate::config::Config;
#[derive(Clone)]
pub struct RateLimitState {
config: Arc<Config>,
ip: Arc<FixedWindowLimiter>,
user: Arc<FixedWindowLimiter>,
auth: Arc<FixedWindowLimiter>,
}
impl RateLimitState {
pub fn new(config: Arc<Config>) -> Self {
let window = Duration::from_secs(60);
Self {
ip: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_ip_rpm)),
user: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_user_rpm)),
auth: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_auth_rpm)),
config,
}
}
}
struct WindowEntry {
window_start: Instant,
count: u32,
last_seen: Instant,
}
struct FixedWindowLimiter {
window: Duration,
limit: u32,
entries: Mutex<HashMap<String, WindowEntry>>,
}
impl FixedWindowLimiter {
fn new(window: Duration, limit: u32) -> Self {
Self {
window,
limit,
entries: Mutex::new(HashMap::new()),
}
}
async fn check(&self, key: &str) -> Result<(), u64> {
if self.limit == 0 {
return Ok(());
}
let now = Instant::now();
let mut map = self.entries.lock().await;
if map.len() > 10_000 {
let stale_before = now
.checked_sub(self.window.saturating_mul(2))
.unwrap_or(now);
map.retain(|_, v| v.last_seen >= stale_before);
}
let entry = map.entry(key.to_string()).or_insert_with(|| WindowEntry {
window_start: now,
count: 0,
last_seen: now,
});
entry.last_seen = now;
if now.duration_since(entry.window_start) >= self.window {
entry.window_start = now;
entry.count = 0;
}
if entry.count >= self.limit {
let elapsed = now.duration_since(entry.window_start);
let retry_after = self
.window
.checked_sub(elapsed)
.unwrap_or_else(|| Duration::from_secs(0));
return Err(retry_after.as_secs().max(1));
}
entry.count += 1;
Ok(())
}
}
fn parse_ip_from_headers(headers: &HeaderMap) -> Option<IpAddr> {
if let Some(forwarded) = headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
{
let first = forwarded.split(',').next().map(|s| s.trim()).unwrap_or("");
if let Ok(ip) = first.parse::<IpAddr>() {
return Some(ip);
}
}
if let Some(real_ip) = headers
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
{
if let Ok(ip) = real_ip.trim().parse::<IpAddr>() {
return Some(ip);
}
}
None
}
fn parse_ip_from_connect_info<B>(request: &Request<B>) -> Option<IpAddr> {
request
.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.0.ip())
}
fn parse_user_from_auth(headers: &HeaderMap, secret: &str) -> Option<Uuid> {
let auth = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
let token = auth.strip_prefix("Bearer ")?;
verify_token(token, secret).ok().map(|c| c.sub)
}
fn is_bypassed_path(path: &str) -> bool {
path == "/" || path == "/health"
}
fn is_auth_path(path: &str) -> bool {
path == "/api/auth/login" || path == "/api/auth/register"
}
pub async fn rate_limit_middleware(
State(state): State<RateLimitState>,
request: Request<Body>,
next: Next,
) -> Response {
if !state.config.rate_limit_enabled {
return next.run(request).await;
}
if request.method() == axum::http::Method::OPTIONS {
return next.run(request).await;
}
let path = request.uri().path();
if is_bypassed_path(path) {
return next.run(request).await;
}
let headers = request.headers();
let ip = parse_ip_from_headers(headers).or_else(|| parse_ip_from_connect_info(&request));
let ip_key = ip
.map(|v| format!("ip:{}", v))
.unwrap_or_else(|| "ip:unknown".to_string());
if is_auth_path(path) {
match state.auth.check(&ip_key).await {
Ok(_) => return next.run(request).await,
Err(retry_after) => {
return (
StatusCode::TOO_MANY_REQUESTS,
[(header::RETRY_AFTER, retry_after.to_string())],
axum::Json(json!({"error": "Rate limit exceeded"})),
)
.into_response();
}
}
}
if let Err(retry_after) = state.ip.check(&ip_key).await {
return (
StatusCode::TOO_MANY_REQUESTS,
[(header::RETRY_AFTER, retry_after.to_string())],
axum::Json(json!({"error": "Rate limit exceeded"})),
)
.into_response();
}
if state.config.rate_limit_user_rpm > 0 {
if let Some(user_id) = parse_user_from_auth(headers, &state.config.jwt_secret) {
let user_key = format!("user:{}", user_id);
if let Err(retry_after) = state.user.check(&user_key).await {
return (
StatusCode::TOO_MANY_REQUESTS,
[(header::RETRY_AFTER, retry_after.to_string())],
axum::Json(json!({"error": "Rate limit exceeded"})),
)
.into_response();
}
}
}
next.run(request).await
}
#[cfg(test)]
mod tests {
use super::FixedWindowLimiter;
use std::time::Duration;
#[tokio::test]
async fn fixed_window_allows_up_to_limit() {
let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 3);
assert!(limiter.check("k").await.is_ok());
assert!(limiter.check("k").await.is_ok());
assert!(limiter.check("k").await.is_ok());
}
#[tokio::test]
async fn fixed_window_denies_after_limit() {
let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 2);
assert!(limiter.check("k").await.is_ok());
assert!(limiter.check("k").await.is_ok());
let retry_after = limiter.check("k").await.err().expect("expected denial");
assert!(retry_after >= 1);
}
#[tokio::test]
async fn fixed_window_zero_limit_disables() {
let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 0);
for _ in 0..10 {
assert!(limiter.check("k").await.is_ok());
}
}
}

View file

@ -13,6 +13,10 @@ Likwid is configured through environment variables and database settings.
| `SERVER_HOST` | No | `127.0.0.1` | Bind address | | `SERVER_HOST` | No | `127.0.0.1` | Bind address |
| `SERVER_PORT` | No | `3000` | HTTP port | | `SERVER_PORT` | No | `3000` | HTTP port |
| `DEMO_MODE` | No | `false` | Enable demo features | | `DEMO_MODE` | No | `false` | Enable demo features |
| `RATE_LIMIT_ENABLED` | No | `true` | Enable API rate limiting |
| `RATE_LIMIT_IP_RPM` | No | `300` | Requests per minute per IP |
| `RATE_LIMIT_USER_RPM` | No | `1200` | Requests per minute per authenticated user |
| `RATE_LIMIT_AUTH_RPM` | No | `30` | Requests per minute per IP for auth endpoints (`/api/auth/login`, `/api/auth/register`) |
| `RUST_LOG` | No | `info` | Log level (trace, debug, info, warn, error) | | `RUST_LOG` | No | `info` | Log level (trace, debug, info, warn, error) |
### Frontend ### Frontend
@ -73,10 +77,7 @@ Each community can configure:
## API Configuration ## API Configuration
### Rate Limiting ### Rate Limiting
Configure in backend settings: Rate limiting is configured via backend environment variables.
- Requests per minute per IP
- Requests per minute per user
- Burst allowance
### CORS ### CORS
By default, CORS allows all origins in development. For production: By default, CORS allows all origins in development. For production:

View file

@ -35,9 +35,9 @@ CORS_ALLOWED_ORIGINS=https://likwid.example.org
### Rate Limiting ### Rate Limiting
Protect against abuse: Protect against abuse:
- 100 requests/minute per IP (default) - 300 requests/minute per IP (default)
- 1000 requests/minute per authenticated user - 1200 requests/minute per authenticated user
- Configurable per endpoint - 30 requests/minute per IP for auth endpoints
## Database Security ## Database Security