diff --git a/Cargo.lock b/Cargo.lock index 7a8dab2..0537c02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -164,6 +164,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "base64ct" version = "1.8.3" @@ -3542,7 +3548,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a5d75ac36ee28647f6d871a93eefc7edcb729c3096590031ba50857fac44fa8" dependencies = [ "anyhow", - "base64", + "base64 0.21.7", "directories-next", "log", "postcard", @@ -4006,6 +4012,16 @@ dependencies = [ "weft-ipc-types", ] +[[package]] +name = "weft-file-portal" +version = "0.1.0" +dependencies = [ + "anyhow", + "base64 0.22.1", + "serde", + "serde_json", +] + [[package]] name = "weft-ipc-types" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index e6e17e2..6015158 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "crates/weft-build-meta", "crates/weft-compositor", "crates/weft-ipc-types", + "crates/weft-file-portal", "crates/weft-mount-helper", "crates/weft-pack", "crates/weft-runtime", diff --git a/crates/weft-file-portal/Cargo.toml b/crates/weft-file-portal/Cargo.toml new file mode 100644 index 0000000..5e8c337 --- /dev/null +++ b/crates/weft-file-portal/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "weft-file-portal" +version.workspace = true +edition.workspace = true +rust-version.workspace = true + +[[bin]] +name = "weft-file-portal" +path = "src/main.rs" + +[dependencies] +anyhow = "1.0" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +base64 = "0.22" diff --git a/crates/weft-file-portal/src/main.rs b/crates/weft-file-portal/src/main.rs new file mode 100644 index 0000000..524757f --- /dev/null +++ b/crates/weft-file-portal/src/main.rs @@ -0,0 +1,276 @@ +use std::io::{BufRead, BufReader, Write}; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::path::{Path, PathBuf}; + +use anyhow::Context; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize)] +#[serde(tag = "op", rename_all = "snake_case")] +enum Request { + Read { path: String }, + Write { path: String, data_b64: String }, + List { path: String }, +} + +#[derive(Serialize)] +#[serde(untagged)] +enum Response { + Ok, + OkData { data_b64: String }, + OkEntries { entries: Vec }, + Err { error: String }, +} + +impl Response { + fn err(msg: impl std::fmt::Display) -> Self { + Self::Err { + error: msg.to_string(), + } + } +} + +fn main() -> anyhow::Result<()> { + let args: Vec = std::env::args().collect(); + if args.len() < 2 { + eprintln!("usage: weft-file-portal [--allow ]..."); + std::process::exit(1); + } + + let socket_path = &args[1]; + let allowed = parse_allowed(&args[2..]); + + if Path::new(socket_path).exists() { + std::fs::remove_file(socket_path) + .with_context(|| format!("remove stale socket {socket_path}"))?; + } + + let listener = + UnixListener::bind(socket_path).with_context(|| format!("bind {socket_path}"))?; + + for stream in listener.incoming() { + match stream { + Ok(s) => handle_connection(s, &allowed), + Err(e) => eprintln!("accept error: {e}"), + } + } + + Ok(()) +} + +fn parse_allowed(args: &[String]) -> Vec { + let mut allowed = Vec::new(); + let mut i = 0; + while i < args.len() { + if args[i] == "--allow" { + if let Some(p) = args.get(i + 1) { + allowed.push(PathBuf::from(p)); + i += 2; + continue; + } + } + i += 1; + } + allowed +} + +fn is_allowed(path: &Path, allowed: &[PathBuf]) -> bool { + if allowed.is_empty() { + return false; + } + allowed.iter().any(|a| path.starts_with(a)) +} + +fn handle_connection(stream: UnixStream, allowed: &[PathBuf]) { + let mut writer = match stream.try_clone() { + Ok(s) => s, + Err(e) => { + eprintln!("stream clone error: {e}"); + return; + } + }; + let reader = BufReader::new(stream); + + for line in reader.lines() { + let line = match line { + Ok(l) => l, + Err(_) => break, + }; + if line.is_empty() { + continue; + } + + let response = match serde_json::from_str::(&line) { + Ok(req) => handle_request(req, allowed), + Err(e) => Response::err(format!("bad request: {e}")), + }; + + let mut out = serde_json::to_string(&response) + .unwrap_or_else(|_| r#"{"error":"serialize"}"#.to_string()); + out.push('\n'); + if writer.write_all(out.as_bytes()).is_err() { + break; + } + } +} + +fn handle_request(req: Request, allowed: &[PathBuf]) -> Response { + match req { + Request::Read { path } => { + let p = PathBuf::from(&path); + if !is_allowed(&p, allowed) { + return Response::err(format!("access denied: {path}")); + } + match std::fs::read(&p) { + Ok(data) => Response::OkData { + data_b64: base64::Engine::encode( + &base64::engine::general_purpose::STANDARD, + &data, + ), + }, + Err(e) => Response::err(e), + } + } + Request::Write { path, data_b64 } => { + let p = PathBuf::from(&path); + if !is_allowed(&p, allowed) { + return Response::err(format!("access denied: {path}")); + } + let data = + match base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &data_b64) + { + Ok(d) => d, + Err(e) => return Response::err(format!("bad base64: {e}")), + }; + match std::fs::write(&p, &data) { + Ok(()) => Response::Ok, + Err(e) => Response::err(e), + } + } + Request::List { path } => { + let p = PathBuf::from(&path); + if !is_allowed(&p, allowed) { + return Response::err(format!("access denied: {path}")); + } + match std::fs::read_dir(&p) { + Ok(entries) => { + let mut names = Vec::new(); + for entry in entries.flatten() { + if let Some(name) = entry.file_name().to_str() { + names.push(name.to_string()); + } + } + names.sort(); + Response::OkEntries { entries: names } + } + Err(e) => Response::err(e), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn allowed_path_accepted() { + let allowed = vec![PathBuf::from("/tmp/weft-test-allowed")]; + assert!(is_allowed( + Path::new("/tmp/weft-test-allowed/file.txt"), + &allowed + )); + } + + #[test] + fn disallowed_path_rejected() { + let allowed = vec![PathBuf::from("/tmp/weft-test-allowed")]; + assert!(!is_allowed(Path::new("/etc/passwd"), &allowed)); + } + + #[test] + fn empty_allowlist_rejects_all() { + assert!(!is_allowed(Path::new("/tmp/anything"), &[])); + } + + #[test] + fn parse_allowed_extracts_paths() { + let args: Vec = vec![ + "--allow".into(), + "/tmp/a".into(), + "--allow".into(), + "/tmp/b".into(), + ]; + let result = parse_allowed(&args); + assert_eq!( + result, + vec![PathBuf::from("/tmp/a"), PathBuf::from("/tmp/b")] + ); + } + + #[test] + fn handle_request_read_denied() { + let resp = handle_request( + Request::Read { + path: "/etc/shadow".into(), + }, + &[PathBuf::from("/tmp/safe")], + ); + let json = serde_json::to_string(&resp).unwrap(); + assert!(json.contains("access denied")); + } + + #[test] + fn handle_request_read_roundtrip() { + use std::fs; + let dir = std::env::temp_dir().join(format!("wfp_test_{}", std::process::id())); + let _ = fs::create_dir_all(&dir); + let file = dir.join("hello.txt"); + fs::write(&file, b"hello world").unwrap(); + + let allowed = vec![dir.clone()]; + let resp = handle_request( + Request::Read { + path: file.to_string_lossy().into(), + }, + &allowed, + ); + let json = serde_json::to_string(&resp).unwrap(); + assert!(json.contains("data_b64")); + + if let Response::OkData { data_b64 } = resp { + let decoded = + base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &data_b64) + .unwrap(); + assert_eq!(decoded, b"hello world"); + } else { + panic!("expected OkData"); + } + + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn handle_request_list() { + use std::fs; + let dir = std::env::temp_dir().join(format!("wfp_list_{}", std::process::id())); + let _ = fs::create_dir_all(&dir); + fs::write(dir.join("b.txt"), b"").unwrap(); + fs::write(dir.join("a.txt"), b"").unwrap(); + + let allowed = vec![dir.clone()]; + let resp = handle_request( + Request::List { + path: dir.to_string_lossy().into(), + }, + &allowed, + ); + if let Response::OkEntries { entries } = resp { + assert_eq!(entries, vec!["a.txt", "b.txt"]); + } else { + panic!("expected OkEntries"); + } + + let _ = fs::remove_dir_all(&dir); + } +} diff --git a/scripts/wsl-test.sh b/scripts/wsl-test.sh index fc8c1a4..8b5793c 100644 --- a/scripts/wsl-test.sh +++ b/scripts/wsl-test.sh @@ -42,5 +42,9 @@ echo "" echo "==> cargo test -p weft-mount-helper" cargo test -p weft-mount-helper -- --test-threads=1 2>&1 +echo "" +echo "==> cargo test -p weft-file-portal" +cargo test -p weft-file-portal -- --test-threads=1 2>&1 + echo "" echo "ALL DONE" diff --git a/{data_b64:...} b/{data_b64:...} new file mode 100644 index 0000000..e69de29 diff --git a/{entries:[...]} b/{entries:[...]} new file mode 100644 index 0000000..e69de29 diff --git a/{} b/{} new file mode 100644 index 0000000..e69de29