//! Per-host request pacing. //! //! `RateLimiter` is a single-token bucket: each `wait().await` returns //! immediately when at least `interval` has elapsed since the last call, //! otherwise sleeps just enough to satisfy it. Uses //! `tokio::time::Instant` so tests can run under `start_paused` virtual //! time without sleeping for real. //! //! `HostRateLimiters` is the multi-host wrapper actually used by the //! crawler — concurrent workers issuing requests to different origins //! (catalog vs. CDN) don't contend on a shared budget; each host gets //! its own bucket. `wait_for(url)` extracts the host, lazily creates a //! limiter for it, and serializes only against other callers hitting //! the same host. use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; use tokio::time::Instant; #[derive(Debug)] pub struct RateLimiter { interval: Duration, last: Option, } impl RateLimiter { pub fn new(interval: Duration) -> Self { Self { interval, last: None, } } pub async fn wait(&mut self) { if let Some(last) = self.last { let elapsed = last.elapsed(); if elapsed < self.interval { tokio::time::sleep(self.interval - elapsed).await; } } self.last = Some(Instant::now()); } } /// Per-host rate limiter map. The outer `Mutex` is held only /// during the entry-or-insert + Arc clone; the per-host `Mutex` /// is held during the actual `wait().await`. So N workers calling /// `wait_for(url)` on N different hosts contend nowhere except the brief /// HashMap lookup; workers hitting the same host serialize on that /// host's bucket. #[derive(Debug)] pub struct HostRateLimiters { default_interval: Duration, overrides: HashMap, map: Mutex>>>, } impl HostRateLimiters { pub fn new(default_interval: Duration) -> Self { Self { default_interval, overrides: HashMap::new(), map: Mutex::new(HashMap::new()), } } /// Set a per-host interval that overrides `default_interval`. Calls /// after a host's limiter has been instantiated do *not* re-create /// it — set all overrides before the first `wait_for` to that host. pub fn with_override(mut self, host: impl Into, interval: Duration) -> Self { self.overrides.insert(host.into(), interval); self } /// Block until the per-host budget allows the next request to /// `url`'s host. Returns an error only when the URL has no host /// (malformed input). pub async fn wait_for(&self, url: &str) -> anyhow::Result<()> { let host = host_of(url) .ok_or_else(|| anyhow::anyhow!("no host in url: {url}"))?; let limiter = { let mut map = self.map.lock().await; map.entry(host.clone()) .or_insert_with(|| { let interval = self .overrides .get(&host) .copied() .unwrap_or(self.default_interval); Arc::new(Mutex::new(RateLimiter::new(interval))) }) .clone() }; limiter.lock().await.wait().await; Ok(()) } } /// Extract the host (no port) from a URL string. Returns `None` for /// inputs without a `scheme://host` shape — those would never have /// reached the network layer anyway. fn host_of(url: &str) -> Option { let after_scheme = url.split_once("://")?.1; let host_with_port = after_scheme.split('/').next()?; let host = host_with_port.rsplit_once(':').map_or(host_with_port, |(h, _)| h); (!host.is_empty()).then(|| host.to_ascii_lowercase()) } #[cfg(test)] mod tests { use super::*; #[tokio::test(start_paused = true)] async fn first_call_does_not_sleep() { let mut rl = RateLimiter::new(Duration::from_millis(100)); let t0 = Instant::now(); rl.wait().await; assert_eq!(Instant::now() - t0, Duration::ZERO); } #[tokio::test(start_paused = true)] async fn second_call_sleeps_to_fill_interval() { let mut rl = RateLimiter::new(Duration::from_millis(100)); let t0 = Instant::now(); rl.wait().await; rl.wait().await; // Second call had to wait the full 100ms after the (instant) // first call. assert_eq!(Instant::now() - t0, Duration::from_millis(100)); } #[tokio::test(start_paused = true)] async fn no_sleep_if_interval_already_elapsed() { let mut rl = RateLimiter::new(Duration::from_millis(100)); rl.wait().await; tokio::time::sleep(Duration::from_millis(250)).await; let t0 = Instant::now(); rl.wait().await; // Already 250ms past — no further wait needed. assert_eq!(Instant::now() - t0, Duration::ZERO); } #[test] fn host_of_parses_scheme_path_and_port() { assert_eq!(host_of("https://Example.com/path").as_deref(), Some("example.com")); assert_eq!(host_of("http://cdn.foo.bar/img.jpg").as_deref(), Some("cdn.foo.bar")); assert_eq!(host_of("http://localhost:8080/x").as_deref(), Some("localhost")); assert!(host_of("not a url").is_none()); } #[tokio::test(start_paused = true)] async fn host_rate_limiters_pace_per_host() { // Two hosts at 100ms each. Two consecutive calls to the SAME // host wait 100ms total. Two consecutive calls to DIFFERENT // hosts both fire immediately. let rl = HostRateLimiters::new(Duration::from_millis(100)); let t0 = Instant::now(); rl.wait_for("https://a.example/x").await.unwrap(); rl.wait_for("https://b.example/y").await.unwrap(); assert_eq!(Instant::now() - t0, Duration::ZERO, "different hosts don't contend"); let t1 = Instant::now(); rl.wait_for("https://a.example/x").await.unwrap(); assert_eq!( Instant::now() - t1, Duration::from_millis(100), "second call to same host waits a full interval" ); } #[tokio::test(start_paused = true)] async fn host_rate_limiters_honor_overrides() { let rl = HostRateLimiters::new(Duration::from_millis(1000)) .with_override("fast.example", Duration::from_millis(100)); rl.wait_for("https://fast.example/a").await.unwrap(); let t0 = Instant::now(); rl.wait_for("https://fast.example/b").await.unwrap(); assert_eq!(Instant::now() - t0, Duration::from_millis(100)); } }