From 33311c51c8a7d23bbedb1d3cfc4bc36723eb2e44 Mon Sep 17 00:00:00 2001 From: Marco Allegretti Date: Thu, 12 Feb 2026 18:15:39 +0100 Subject: [PATCH] backend: harden auth token validation --- backend/src/auth/jwt.rs | 9 ++- backend/src/auth/middleware.rs | 112 ++++++++++++++++++++++++++++++--- backend/src/rate_limit.rs | 10 ++- 3 files changed, 121 insertions(+), 10 deletions(-) diff --git a/backend/src/auth/jwt.rs b/backend/src/auth/jwt.rs index 1387972..524632e 100644 --- a/backend/src/auth/jwt.rs +++ b/backend/src/auth/jwt.rs @@ -1,6 +1,7 @@ use chrono::{Duration, Utc}; -use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; +use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; +use std::collections::HashSet; use uuid::Uuid; #[derive(Debug, Serialize, Deserialize)] @@ -34,10 +35,14 @@ pub fn create_token( } pub fn verify_token(token: &str, secret: &str) -> Result { + let mut validation = Validation::new(Algorithm::HS256); + validation.leeway = 30; + validation.required_spec_claims = HashSet::from(["exp".to_string(), "sub".to_string()]); + let token_data = decode::( token, &DecodingKey::from_secret(secret.as_bytes()), - &Validation::default(), + &validation, )?; Ok(token_data.claims) } diff --git a/backend/src/auth/middleware.rs b/backend/src/auth/middleware.rs index 9e08160..08507df 100644 --- a/backend/src/auth/middleware.rs +++ b/backend/src/auth/middleware.rs @@ -1,5 +1,6 @@ use axum::{ extract::FromRequestParts, + http::header, http::{request::Parts, StatusCode}, }; use std::sync::Arc; @@ -8,6 +9,7 @@ use uuid::Uuid; use super::jwt::{verify_token, Claims}; use crate::config::Config; +#[derive(Debug)] pub struct AuthUser { pub user_id: Uuid, pub username: String, @@ -22,23 +24,39 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let auth_header = parts .headers - .get("Authorization") + .get(header::AUTHORIZATION) .and_then(|value| value.to_str().ok()) .ok_or((StatusCode::UNAUTHORIZED, "Missing authorization header"))?; - let token = auth_header.strip_prefix("Bearer ").ok_or(( + let mut pieces = auth_header.split_whitespace(); + let scheme = pieces.next().ok_or(( StatusCode::UNAUTHORIZED, "Invalid authorization header format", ))?; + let token = pieces.next().ok_or(( + StatusCode::UNAUTHORIZED, + "Invalid authorization header format", + ))?; + if pieces.next().is_some() || !scheme.eq_ignore_ascii_case("bearer") { + return Err(( + StatusCode::UNAUTHORIZED, + "Invalid authorization header format", + )); + } - let secret = parts + let config = parts .extensions .get::>() - .map(|c| c.jwt_secret.clone()) - .or_else(|| std::env::var("JWT_SECRET").ok()) - .unwrap_or_else(|| "dev-secret-change-in-production".to_string()); + .ok_or((StatusCode::INTERNAL_SERVER_ERROR, "Auth config missing"))?; - let claims: Claims = verify_token(token, &secret) + if config.jwt_secret.trim().is_empty() { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "JWT secret not configured", + )); + } + + let claims: Claims = verify_token(token, &config.jwt_secret) .map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token"))?; Ok(AuthUser { @@ -47,3 +65,83 @@ where }) } } + +#[cfg(test)] +mod tests { + use super::AuthUser; + use crate::auth::create_token; + use crate::config::Config; + use axum::body::Body; + use axum::extract::FromRequestParts; + use axum::http::Request; + use std::sync::Arc; + use uuid::Uuid; + + fn parts_with_auth( + auth: Option<&str>, + config: Option>, + ) -> axum::http::request::Parts { + let mut req = Request::builder().uri("/").body(Body::empty()).unwrap(); + if let Some(auth) = auth { + req.headers_mut() + .insert(axum::http::header::AUTHORIZATION, auth.parse().unwrap()); + } + if let Some(config) = config { + req.extensions_mut().insert(config); + } + let (parts, _) = req.into_parts(); + parts + } + + #[tokio::test] + async fn rejects_missing_auth_header() { + let config = Arc::new(Config { + jwt_secret: "secret".to_string(), + ..Config::default() + }); + let mut parts = parts_with_auth(None, Some(config)); + let err = AuthUser::from_request_parts(&mut parts, &()) + .await + .unwrap_err(); + assert_eq!(err.0, axum::http::StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn accepts_bearer_case_insensitive_and_whitespace() { + let config = Arc::new(Config { + jwt_secret: "secret".to_string(), + ..Config::default() + }); + let token = create_token(Uuid::new_v4(), "alice", &config.jwt_secret).unwrap(); + let auth = format!(" bEaReR {token} "); + let mut parts = parts_with_auth(Some(&auth), Some(config)); + let user = AuthUser::from_request_parts(&mut parts, &()).await.unwrap(); + assert_eq!(user.username, "alice"); + } + + #[tokio::test] + async fn rejects_non_bearer_scheme() { + let config = Arc::new(Config { + jwt_secret: "secret".to_string(), + ..Config::default() + }); + let token = create_token(Uuid::new_v4(), "alice", &config.jwt_secret).unwrap(); + let auth = format!("Token {token}"); + let mut parts = parts_with_auth(Some(&auth), Some(config)); + let err = AuthUser::from_request_parts(&mut parts, &()) + .await + .unwrap_err(); + assert_eq!(err.0, axum::http::StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn errors_when_config_missing() { + let token = create_token(Uuid::new_v4(), "alice", "secret").unwrap(); + let auth = format!("Bearer {token}"); + let mut parts = parts_with_auth(Some(&auth), None); + let err = AuthUser::from_request_parts(&mut parts, &()) + .await + .unwrap_err(); + assert_eq!(err.0, axum::http::StatusCode::INTERNAL_SERVER_ERROR); + } +} diff --git a/backend/src/rate_limit.rs b/backend/src/rate_limit.rs index b2bd097..82fbee4 100644 --- a/backend/src/rate_limit.rs +++ b/backend/src/rate_limit.rs @@ -123,7 +123,15 @@ fn parse_ip_from_connect_info(request: &Request) -> Option { fn parse_user_from_auth(headers: &HeaderMap, secret: &str) -> Option { let auth = headers.get(header::AUTHORIZATION)?.to_str().ok()?; - let token = auth.strip_prefix("Bearer ")?; + let mut pieces = auth.split_whitespace(); + let scheme = pieces.next()?; + let token = pieces.next()?; + if pieces.next().is_some() { + return None; + } + if !scheme.eq_ignore_ascii_case("bearer") { + return None; + } verify_token(token, secret).ok().map(|c| c.sub) }