likwid/backend/src/main.rs

598 lines
19 KiB
Rust

mod api;
mod auth;
mod config;
mod db;
mod demo;
mod models;
mod plugins;
mod rate_limit;
mod voting;
use axum::http::{HeaderName, HeaderValue, Request};
use axum::response::Response;
use axum::{middleware, Extension};
use chrono::{Datelike, Timelike, Utc, Weekday};
use serde_json::json;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tower_http::cors::{Any, CorsLayer};
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::trace::TraceLayer;
use tracing_subscriber::EnvFilter;
use uuid::Uuid;
use crate::config::Config;
use crate::plugins::HookContext;
#[derive(Debug, Error)]
enum StartupError {
#[error("Failed to load configuration: {0}")]
Config(#[from] envy::Error),
#[error("JWT_SECRET must be set")]
MissingJwtSecret,
#[error("Failed to create database pool: {0}")]
Db(#[from] sqlx::Error),
#[error("Failed to run database migrations: {0}")]
Migrations(#[from] sqlx::migrate::MigrateError),
#[error("Failed to initialize plugins: {0}")]
Plugins(#[from] crate::plugins::PluginError),
#[error("Failed to bind server listener: {0}")]
Bind(#[from] std::io::Error),
#[error("Server error: {0}")]
Serve(String),
}
#[tokio::main]
async fn main() {
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::fmt().with_env_filter(filter).init();
if let Err(e) = run().await {
tracing::error!("{e}");
std::process::exit(1);
}
}
#[cfg(test)]
mod security_headers_tests {
use super::add_security_headers;
use axum::body::Body;
use axum::http::{header, Request, StatusCode};
use axum::middleware;
use axum::response::Response;
use axum::routing::get;
use axum::Router;
use tower::ServiceExt;
#[tokio::test]
async fn security_headers_are_added_by_default() {
let app = Router::new()
.route("/api/ping", get(|| async { "ok" }))
.layer(middleware::map_response(add_security_headers));
let req = Request::builder()
.method("GET")
.uri("/api/ping")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.headers()
.get("x-content-type-options")
.and_then(|v| v.to_str().ok()),
Some("nosniff")
);
assert_eq!(
res.headers()
.get("x-frame-options")
.and_then(|v| v.to_str().ok()),
Some("DENY")
);
assert_eq!(
res.headers()
.get("referrer-policy")
.and_then(|v| v.to_str().ok()),
Some("no-referrer")
);
assert_eq!(
res.headers()
.get("permissions-policy")
.and_then(|v| v.to_str().ok()),
Some("camera=(), microphone=(), geolocation=()")
);
assert_eq!(
res.headers()
.get("x-permitted-cross-domain-policies")
.and_then(|v| v.to_str().ok()),
Some("none")
);
}
#[tokio::test]
async fn security_headers_do_not_override_explicit_headers() {
let app = Router::new()
.route(
"/api/ping",
get(|| async {
Response::builder()
.status(StatusCode::OK)
.header("x-frame-options", "SAMEORIGIN")
.header(header::REFERRER_POLICY, "strict-origin")
.body(Body::from("ok"))
.unwrap()
}),
)
.layer(middleware::map_response(add_security_headers));
let req = Request::builder()
.method("GET")
.uri("/api/ping")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(
res.headers()
.get("x-frame-options")
.and_then(|v| v.to_str().ok()),
Some("SAMEORIGIN")
);
assert_eq!(
res.headers()
.get(header::REFERRER_POLICY)
.and_then(|v| v.to_str().ok()),
Some("strict-origin")
);
}
}
async fn run() -> Result<(), StartupError> {
dotenvy::dotenv().ok();
// Load configuration
let config = Arc::new(Config::from_env()?);
if config.jwt_secret.trim().is_empty() {
return Err(StartupError::MissingJwtSecret);
}
if config.is_demo() {
tracing::info!("🎭 DEMO MODE ENABLED - Some actions are restricted");
}
let database_url =
std::env::var("DATABASE_URL").unwrap_or_else(|_| config.database_url.clone());
let pool = {
let start = Instant::now();
let mut attempt: u32 = 1;
loop {
match db::create_pool(&database_url).await {
Ok(pool) => {
tracing::info!(
elapsed_ms = start.elapsed().as_millis(),
"Connected to database"
);
break pool;
}
Err(e) => {
if attempt >= 30 {
return Err(StartupError::Db(e));
}
tracing::warn!(attempt, error = %e, "Failed to connect to database; retrying");
tokio::time::sleep(Duration::from_secs(1)).await;
attempt += 1;
}
}
}
};
let mut migrator = sqlx::migrate!("./migrations");
if config.is_demo() {
migrator.set_ignore_missing(true);
}
{
let start = Instant::now();
let mut attempt: u32 = 1;
loop {
match migrator.run(&pool).await {
Ok(()) => {
tracing::info!(
elapsed_ms = start.elapsed().as_millis(),
"Database migrations applied"
);
break;
}
Err(e) => {
if attempt >= 30 {
return Err(StartupError::Migrations(e));
}
tracing::warn!(attempt, error = %e, "Database migrations failed; retrying");
tokio::time::sleep(Duration::from_secs(1)).await;
attempt += 1;
}
}
}
}
if config.is_demo() {
let mut demo_migrator = sqlx::migrate!("./migrations_demo");
demo_migrator.set_ignore_missing(true);
let start = Instant::now();
let mut attempt: u32 = 1;
loop {
match demo_migrator.run(&pool).await {
Ok(()) => {
tracing::info!(
elapsed_ms = start.elapsed().as_millis(),
"Demo database migrations applied"
);
break;
}
Err(e) => {
if attempt >= 30 {
return Err(StartupError::Migrations(e));
}
tracing::warn!(attempt, error = %e, "Demo migrations failed; retrying");
tokio::time::sleep(Duration::from_secs(1)).await;
attempt += 1;
}
}
}
}
let cors = build_cors_layer(config.as_ref());
let plugins = plugins::PluginManager::new(pool.clone())
.register_builtin_plugins()
.initialize()
.await?;
{
let cron_plugins = plugins.clone();
let cron_pool = pool.clone();
tokio::spawn(async move {
let mut last_minute_key: i64 = -1;
let mut last_hour_key: i64 = -1;
let mut last_day_key: i64 = -1;
let mut last_week_key: i64 = -1;
let mut last_15min_key: i64 = -1;
let mut interval = tokio::time::interval(Duration::from_secs(5));
loop {
interval.tick().await;
let now = Utc::now();
let minute_key = now.timestamp() / 60;
if minute_key == last_minute_key {
continue;
}
last_minute_key = minute_key;
let ctx = HookContext {
pool: cron_pool.clone(),
community_id: None,
actor_user_id: None,
};
let payload = json!({"ts": now.to_rfc3339()});
cron_plugins
.do_action("cron.minute", ctx.clone(), payload.clone())
.await;
cron_plugins
.do_action("cron.minutely", ctx.clone(), payload.clone())
.await;
let min15_key = now.timestamp() / 900;
let is_15min = min15_key != last_15min_key;
if is_15min {
last_15min_key = min15_key;
cron_plugins
.do_action("cron.every_15_minutes", ctx.clone(), payload.clone())
.await;
}
let hour_key = now.timestamp() / 3600;
let is_hour = hour_key != last_hour_key;
if is_hour {
last_hour_key = hour_key;
if now.minute() == 0 {
cron_plugins
.do_action("cron.hourly", ctx.clone(), payload.clone())
.await;
}
}
let day_key = now.timestamp() / 86_400;
let is_day = day_key != last_day_key;
if is_day {
last_day_key = day_key;
if now.hour() == 0 && now.minute() == 0 {
cron_plugins
.do_action("cron.daily", ctx.clone(), payload.clone())
.await;
}
}
let iso_week = now.iso_week();
let week_key = (iso_week.year() as i64) * 100 + (iso_week.week() as i64);
let is_week = week_key != last_week_key;
if is_week {
last_week_key = week_key;
if now.weekday() == Weekday::Mon && now.hour() == 0 && now.minute() == 0 {
cron_plugins
.do_action("cron.weekly", ctx.clone(), payload.clone())
.await;
}
}
// WASM plugins need per-community context.
let community_ids: Vec<Uuid> =
match sqlx::query_scalar("SELECT id FROM communities WHERE is_active = true")
.fetch_all(&cron_pool)
.await
{
Ok(ids) => ids,
Err(e) => {
tracing::error!("cron: failed to list communities: {}", e);
continue;
}
};
let mut wasm_hooks: Vec<&'static str> = vec!["cron.minute", "cron.minutely"];
if is_15min {
wasm_hooks.push("cron.every_15_minutes");
}
if is_hour && now.minute() == 0 {
wasm_hooks.push("cron.hourly");
}
if is_day && now.hour() == 0 && now.minute() == 0 {
wasm_hooks.push("cron.daily");
if is_week && now.weekday() == Weekday::Mon {
wasm_hooks.push("cron.weekly");
}
}
for cid in community_ids {
for hook in &wasm_hooks {
cron_plugins
.do_wasm_action_for_community(hook, cid, payload.clone())
.await;
}
}
}
});
}
let request_id_header = HeaderName::from_static("x-request-id");
let trace_layer = {
let request_id_header = request_id_header.clone();
TraceLayer::new_for_http().make_span_with(move |request: &Request<_>| {
let request_id = request
.headers()
.get(request_id_header.as_str())
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
tracing::info_span!(
"http.request",
method = %request.method(),
uri = %request.uri(),
request_id = %request_id,
)
})
};
let app = api::create_router(pool.clone(), config.clone())
.layer(Extension(plugins))
.layer(Extension(config.clone()))
.layer(cors)
.layer(axum::middleware::from_fn_with_state(
rate_limit::RateLimitState::new(config.clone()),
rate_limit::rate_limit_middleware,
))
.layer(trace_layer)
.layer(PropagateRequestIdLayer::new(request_id_header.clone()))
.layer(SetRequestIdLayer::new(request_id_header, MakeRequestUuid))
.layer(middleware::map_response(add_security_headers));
let host: std::net::IpAddr = config
.server_host
.parse()
.unwrap_or_else(|_| std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)));
let addr = SocketAddr::from((host, config.server_port));
tracing::info!("Likwid backend listening on http://{}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.map_err(|e| StartupError::Serve(e.to_string()))?;
Ok(())
}
async fn add_security_headers(mut res: Response) -> Response {
let headers = res.headers_mut();
if !headers.contains_key("x-content-type-options") {
headers.insert(
HeaderName::from_static("x-content-type-options"),
HeaderValue::from_static("nosniff"),
);
}
if !headers.contains_key("x-frame-options") {
headers.insert(
HeaderName::from_static("x-frame-options"),
HeaderValue::from_static("DENY"),
);
}
if !headers.contains_key("referrer-policy") {
headers.insert(
HeaderName::from_static("referrer-policy"),
HeaderValue::from_static("no-referrer"),
);
}
if !headers.contains_key("permissions-policy") {
headers.insert(
HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("camera=(), microphone=(), geolocation=()"),
);
}
if !headers.contains_key("x-permitted-cross-domain-policies") {
headers.insert(
HeaderName::from_static("x-permitted-cross-domain-policies"),
HeaderValue::from_static("none"),
);
}
res
}
fn build_cors_layer(config: &Config) -> CorsLayer {
let layer = CorsLayer::new().allow_methods(Any).allow_headers(Any);
let allowed = config
.cors_allowed_origins
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty());
let Some(allowed) = allowed else {
return layer.allow_origin(Any);
};
let origins: Vec<HeaderValue> = allowed
.split(',')
.map(|v| v.trim())
.filter(|v| !v.is_empty())
.filter_map(|v| match HeaderValue::from_str(v) {
Ok(hv) => Some(hv),
Err(e) => {
tracing::warn!(origin = v, error = %e, "Invalid CORS origin; ignoring");
None
}
})
.collect();
if origins.is_empty() {
return layer.allow_origin(Any);
}
layer.allow_origin(origins)
}
#[cfg(test)]
mod cors_tests {
use super::build_cors_layer;
use axum::body::Body;
use axum::http::{header, Request, StatusCode};
use axum::routing::get;
use axum::Router;
use tower::ServiceExt;
use crate::config::Config;
#[tokio::test]
async fn cors_default_allows_any_origin() {
let cfg = Config::default();
let app = Router::new()
.route("/api/ping", get(|| async { "ok" }))
.layer(build_cors_layer(&cfg));
let req = Request::builder()
.method("GET")
.uri("/api/ping")
.header(header::ORIGIN, "https://example.com")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.and_then(|v| v.to_str().ok()),
Some("*")
);
}
#[tokio::test]
async fn cors_allowlist_only_allows_listed_origins() {
let cfg = Config {
cors_allowed_origins: Some("https://a.example, https://b.example".to_string()),
..Default::default()
};
let app = Router::new()
.route("/api/ping", get(|| async { "ok" }))
.layer(build_cors_layer(&cfg));
let allowed_req = Request::builder()
.method("GET")
.uri("/api/ping")
.header(header::ORIGIN, "https://a.example")
.body(Body::empty())
.unwrap();
let allowed_res = app.clone().oneshot(allowed_req).await.unwrap();
assert_eq!(allowed_res.status(), StatusCode::OK);
assert_eq!(
allowed_res
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.and_then(|v| v.to_str().ok()),
Some("https://a.example")
);
let denied_req = Request::builder()
.method("GET")
.uri("/api/ping")
.header(header::ORIGIN, "https://c.example")
.body(Body::empty())
.unwrap();
let denied_res = app.clone().oneshot(denied_req).await.unwrap();
assert_eq!(denied_res.status(), StatusCode::OK);
assert!(denied_res
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none());
}
#[tokio::test]
async fn cors_preflight_responds_for_allowed_origin() {
let cfg = Config {
cors_allowed_origins: Some("https://a.example".to_string()),
..Default::default()
};
let app = Router::new()
.route("/api/ping", get(|| async { "ok" }))
.layer(build_cors_layer(&cfg));
let req = Request::builder()
.method("OPTIONS")
.uri("/api/ping")
.header(header::ORIGIN, "https://a.example")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert!(res.status().is_success());
assert_eq!(
res.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.and_then(|v| v.to_str().ok()),
Some("https://a.example")
);
}
}