use std::sync::Arc; use axum::extract::DefaultBodyLimit; use axum::http::{HeaderName, HeaderValue, Method}; use axum::Router; use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::trace::TraceLayer; use crate::config::{AuthConfig, Config, UploadConfig}; use crate::storage::{LocalStorage, Storage}; #[derive(Clone)] pub struct AppState { pub db: PgPool, pub storage: Arc, pub auth: AuthConfig, pub upload: UploadConfig, } pub async fn build(config: Config) -> anyhow::Result { let db = PgPoolOptions::new() .max_connections(10) .connect(&config.database_url) .await?; sqlx::migrate!("./migrations").run(&db).await?; let storage: Arc = Arc::new(LocalStorage::new(config.storage_dir.clone())); let state = AppState { db, storage, auth: config.auth.clone(), upload: config.upload.clone(), }; Ok(router(state).layer(cors_layer(&config.cors_allowed_origins))) } /// Build a router from a pre-assembled state. Used by integration tests /// so they can swap in a test DB pool and a `tempfile`-backed storage. pub fn router(state: AppState) -> Router { let max_request_bytes = state.upload.max_request_bytes; Router::new() .nest("/api/v1", crate::api::routes()) .layer(DefaultBodyLimit::max(max_request_bytes)) .with_state(state) .layer(TraceLayer::new_for_http()) } pub(crate) fn cors_layer(allowed_origins: &[String]) -> CorsLayer { if allowed_origins.is_empty() { // Same-origin only — no CORS headers emitted. return CorsLayer::new(); } let origins: Vec = allowed_origins .iter() .filter_map(|o| HeaderValue::from_str(o).ok()) .collect(); CorsLayer::new() .allow_origin(AllowOrigin::list(origins)) .allow_credentials(true) .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]) .allow_headers([ HeaderName::from_static("content-type"), HeaderName::from_static("authorization"), ]) } #[cfg(test)] mod tests { use super::*; use axum::body::Body; use axum::http::Request; use axum::routing::get; use tower::ServiceExt; fn test_router() -> Router { Router::new().route("/", get(|| async { "ok" })) } #[tokio::test] async fn allowlist_preflight_emits_credentialed_headers() { let app = test_router().layer(cors_layer(&["https://app.example.com".to_string()])); let resp = app .oneshot( Request::builder() .method(Method::OPTIONS) .uri("/") .header("origin", "https://app.example.com") .header("access-control-request-method", "POST") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!( resp.headers().get("access-control-allow-origin").unwrap(), "https://app.example.com" ); assert_eq!( resp.headers().get("access-control-allow-credentials").unwrap(), "true" ); } #[tokio::test] async fn allowlist_rejects_unlisted_origin() { let app = test_router().layer(cors_layer(&["https://app.example.com".to_string()])); let resp = app .oneshot( Request::builder() .method(Method::OPTIONS) .uri("/") .header("origin", "https://evil.example.org") .header("access-control-request-method", "POST") .body(Body::empty()) .unwrap(), ) .await .unwrap(); // Browsers will refuse the response when the allow-origin header // is absent (or doesn't echo the requesting origin). assert!(resp.headers().get("access-control-allow-origin").is_none()); } #[tokio::test] async fn empty_allowlist_is_same_origin_only() { let app = test_router().layer(cors_layer(&[])); let resp = app .oneshot( Request::builder() .method(Method::OPTIONS) .uri("/") .header("origin", "https://app.example.com") .header("access-control-request-method", "POST") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert!(resp.headers().get("access-control-allow-origin").is_none()); assert!(resp.headers().get("access-control-allow-credentials").is_none()); } }