mirror of
https://codeberg.org/likwid/likwid.git
synced 2026-02-09 21:13:09 +00:00
239 lines
6.7 KiB
Rust
239 lines
6.7 KiB
Rust
|
|
use axum::body::Body;
|
||
|
|
use axum::extract::State;
|
||
|
|
use axum::http::{header, HeaderMap, Request, StatusCode};
|
||
|
|
use axum::middleware::Next;
|
||
|
|
use axum::response::{IntoResponse, Response};
|
||
|
|
use serde_json::json;
|
||
|
|
use std::collections::HashMap;
|
||
|
|
use std::net::IpAddr;
|
||
|
|
use std::sync::Arc;
|
||
|
|
use std::time::{Duration, Instant};
|
||
|
|
use tokio::sync::Mutex;
|
||
|
|
use uuid::Uuid;
|
||
|
|
|
||
|
|
use crate::auth::jwt::verify_token;
|
||
|
|
use crate::config::Config;
|
||
|
|
|
||
|
|
#[derive(Clone)]
|
||
|
|
pub struct RateLimitState {
|
||
|
|
config: Arc<Config>,
|
||
|
|
ip: Arc<FixedWindowLimiter>,
|
||
|
|
user: Arc<FixedWindowLimiter>,
|
||
|
|
auth: Arc<FixedWindowLimiter>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl RateLimitState {
|
||
|
|
pub fn new(config: Arc<Config>) -> Self {
|
||
|
|
let window = Duration::from_secs(60);
|
||
|
|
Self {
|
||
|
|
ip: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_ip_rpm)),
|
||
|
|
user: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_user_rpm)),
|
||
|
|
auth: Arc::new(FixedWindowLimiter::new(window, config.rate_limit_auth_rpm)),
|
||
|
|
config,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
struct WindowEntry {
|
||
|
|
window_start: Instant,
|
||
|
|
count: u32,
|
||
|
|
last_seen: Instant,
|
||
|
|
}
|
||
|
|
|
||
|
|
struct FixedWindowLimiter {
|
||
|
|
window: Duration,
|
||
|
|
limit: u32,
|
||
|
|
entries: Mutex<HashMap<String, WindowEntry>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl FixedWindowLimiter {
|
||
|
|
fn new(window: Duration, limit: u32) -> Self {
|
||
|
|
Self {
|
||
|
|
window,
|
||
|
|
limit,
|
||
|
|
entries: Mutex::new(HashMap::new()),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn check(&self, key: &str) -> Result<(), u64> {
|
||
|
|
if self.limit == 0 {
|
||
|
|
return Ok(());
|
||
|
|
}
|
||
|
|
|
||
|
|
let now = Instant::now();
|
||
|
|
let mut map = self.entries.lock().await;
|
||
|
|
|
||
|
|
if map.len() > 10_000 {
|
||
|
|
let stale_before = now
|
||
|
|
.checked_sub(self.window.saturating_mul(2))
|
||
|
|
.unwrap_or(now);
|
||
|
|
map.retain(|_, v| v.last_seen >= stale_before);
|
||
|
|
}
|
||
|
|
|
||
|
|
let entry = map.entry(key.to_string()).or_insert_with(|| WindowEntry {
|
||
|
|
window_start: now,
|
||
|
|
count: 0,
|
||
|
|
last_seen: now,
|
||
|
|
});
|
||
|
|
|
||
|
|
entry.last_seen = now;
|
||
|
|
|
||
|
|
if now.duration_since(entry.window_start) >= self.window {
|
||
|
|
entry.window_start = now;
|
||
|
|
entry.count = 0;
|
||
|
|
}
|
||
|
|
|
||
|
|
if entry.count >= self.limit {
|
||
|
|
let elapsed = now.duration_since(entry.window_start);
|
||
|
|
let retry_after = self
|
||
|
|
.window
|
||
|
|
.checked_sub(elapsed)
|
||
|
|
.unwrap_or_else(|| Duration::from_secs(0));
|
||
|
|
return Err(retry_after.as_secs().max(1));
|
||
|
|
}
|
||
|
|
|
||
|
|
entry.count += 1;
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn parse_ip_from_headers(headers: &HeaderMap) -> Option<IpAddr> {
|
||
|
|
if let Some(forwarded) = headers
|
||
|
|
.get("x-forwarded-for")
|
||
|
|
.and_then(|v| v.to_str().ok())
|
||
|
|
{
|
||
|
|
let first = forwarded.split(',').next().map(|s| s.trim()).unwrap_or("");
|
||
|
|
if let Ok(ip) = first.parse::<IpAddr>() {
|
||
|
|
return Some(ip);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if let Some(real_ip) = headers
|
||
|
|
.get("x-real-ip")
|
||
|
|
.and_then(|v| v.to_str().ok())
|
||
|
|
{
|
||
|
|
if let Ok(ip) = real_ip.trim().parse::<IpAddr>() {
|
||
|
|
return Some(ip);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
None
|
||
|
|
}
|
||
|
|
|
||
|
|
fn parse_ip_from_connect_info<B>(request: &Request<B>) -> Option<IpAddr> {
|
||
|
|
request
|
||
|
|
.extensions()
|
||
|
|
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
|
||
|
|
.map(|ci| ci.0.ip())
|
||
|
|
}
|
||
|
|
|
||
|
|
fn parse_user_from_auth(headers: &HeaderMap, secret: &str) -> Option<Uuid> {
|
||
|
|
let auth = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
|
||
|
|
let token = auth.strip_prefix("Bearer ")?;
|
||
|
|
verify_token(token, secret).ok().map(|c| c.sub)
|
||
|
|
}
|
||
|
|
|
||
|
|
fn is_bypassed_path(path: &str) -> bool {
|
||
|
|
path == "/" || path == "/health"
|
||
|
|
}
|
||
|
|
|
||
|
|
fn is_auth_path(path: &str) -> bool {
|
||
|
|
path == "/api/auth/login" || path == "/api/auth/register"
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn rate_limit_middleware(
|
||
|
|
State(state): State<RateLimitState>,
|
||
|
|
request: Request<Body>,
|
||
|
|
next: Next,
|
||
|
|
) -> Response {
|
||
|
|
if !state.config.rate_limit_enabled {
|
||
|
|
return next.run(request).await;
|
||
|
|
}
|
||
|
|
|
||
|
|
if request.method() == axum::http::Method::OPTIONS {
|
||
|
|
return next.run(request).await;
|
||
|
|
}
|
||
|
|
|
||
|
|
let path = request.uri().path();
|
||
|
|
if is_bypassed_path(path) {
|
||
|
|
return next.run(request).await;
|
||
|
|
}
|
||
|
|
|
||
|
|
let headers = request.headers();
|
||
|
|
let ip = parse_ip_from_headers(headers).or_else(|| parse_ip_from_connect_info(&request));
|
||
|
|
|
||
|
|
let ip_key = ip
|
||
|
|
.map(|v| format!("ip:{}", v))
|
||
|
|
.unwrap_or_else(|| "ip:unknown".to_string());
|
||
|
|
|
||
|
|
if is_auth_path(path) {
|
||
|
|
match state.auth.check(&ip_key).await {
|
||
|
|
Ok(_) => return next.run(request).await,
|
||
|
|
Err(retry_after) => {
|
||
|
|
return (
|
||
|
|
StatusCode::TOO_MANY_REQUESTS,
|
||
|
|
[(header::RETRY_AFTER, retry_after.to_string())],
|
||
|
|
axum::Json(json!({"error": "Rate limit exceeded"})),
|
||
|
|
)
|
||
|
|
.into_response();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if let Err(retry_after) = state.ip.check(&ip_key).await {
|
||
|
|
return (
|
||
|
|
StatusCode::TOO_MANY_REQUESTS,
|
||
|
|
[(header::RETRY_AFTER, retry_after.to_string())],
|
||
|
|
axum::Json(json!({"error": "Rate limit exceeded"})),
|
||
|
|
)
|
||
|
|
.into_response();
|
||
|
|
}
|
||
|
|
|
||
|
|
if state.config.rate_limit_user_rpm > 0 {
|
||
|
|
if let Some(user_id) = parse_user_from_auth(headers, &state.config.jwt_secret) {
|
||
|
|
let user_key = format!("user:{}", user_id);
|
||
|
|
if let Err(retry_after) = state.user.check(&user_key).await {
|
||
|
|
return (
|
||
|
|
StatusCode::TOO_MANY_REQUESTS,
|
||
|
|
[(header::RETRY_AFTER, retry_after.to_string())],
|
||
|
|
axum::Json(json!({"error": "Rate limit exceeded"})),
|
||
|
|
)
|
||
|
|
.into_response();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
next.run(request).await
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::FixedWindowLimiter;
|
||
|
|
use std::time::Duration;
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn fixed_window_allows_up_to_limit() {
|
||
|
|
let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 3);
|
||
|
|
assert!(limiter.check("k").await.is_ok());
|
||
|
|
assert!(limiter.check("k").await.is_ok());
|
||
|
|
assert!(limiter.check("k").await.is_ok());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn fixed_window_denies_after_limit() {
|
||
|
|
let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 2);
|
||
|
|
assert!(limiter.check("k").await.is_ok());
|
||
|
|
assert!(limiter.check("k").await.is_ok());
|
||
|
|
let retry_after = limiter.check("k").await.err().expect("expected denial");
|
||
|
|
assert!(retry_after >= 1);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn fixed_window_zero_limit_disables() {
|
||
|
|
let limiter = FixedWindowLimiter::new(Duration::from_secs(60), 0);
|
||
|
|
for _ in 0..10 {
|
||
|
|
assert!(limiter.check("k").await.is_ok());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|