use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; /// Thread-safe sliding-window rate limiter backed by an in-memory HashMap. /// Each key (e.g. `"join:{ip}"` or `"upload:{user_id}"`) tracks timestamps /// of recent requests and rejects new ones once the window is full. #[derive(Clone)] pub struct RateLimiter { windows: Arc>>>, } impl RateLimiter { pub fn new() -> Self { Self { windows: Arc::new(Mutex::new(HashMap::new())), } } /// Returns `true` if the request is allowed, `false` if rate-limited. pub fn check(&self, key: impl Into, max: usize, window: Duration) -> bool { self.check_with_retry(key, max, window).is_ok() } /// Returns `Ok(())` if allowed, `Err(retry_after_secs)` if rate-limited. /// `retry_after_secs` is how long until the oldest slot in the window expires. pub fn check_with_retry(&self, key: impl Into, max: usize, window: Duration) -> Result<(), u64> { let now = Instant::now(); let key = key.into(); let mut map = self.windows.lock().unwrap(); let timestamps = map.entry(key).or_default(); timestamps.retain(|&t| now.duration_since(t) < window); if timestamps.len() < max { timestamps.push(now); Ok(()) } else { // The oldest timestamp expires at oldest + window; compute remaining seconds let oldest = timestamps[0]; let elapsed = now.duration_since(oldest); let remaining = window.saturating_sub(elapsed); Err(remaining.as_secs().max(1)) } } /// Wipe every tracked window. Used by the test-mode truncate route so a previous /// test's accumulated counters don't bleed into the next test's rate-limit checks. pub fn clear(&self) { self.windows.lock().unwrap().clear(); } /// Drop keys whose windows are empty after expiring old timestamps. Called from a /// background task (see [`crate::services::maintenance`]) so that long-lived /// processes don't accumulate one HashMap entry per IP that ever connected. /// /// Uses a conservative 24h ceiling — anything older than that is gone regardless /// of which endpoint's window it was tracked under (the longest window today is /// 24h for export downloads). If we ever add longer windows, raise this constant. pub fn prune(&self) { let now = Instant::now(); let ceiling = Duration::from_secs(24 * 60 * 60); let mut map = self.windows.lock().unwrap(); let before = map.len(); map.retain(|_, ts| { ts.retain(|&t| now.duration_since(t) < ceiling); !ts.is_empty() }); let dropped = before.saturating_sub(map.len()); if dropped > 0 { tracing::debug!("rate limiter pruned {dropped} idle keys"); } } } /// Extract the client IP from X-Forwarded-For (Caddy sets this) or fall back /// to a provided socket address string. pub fn client_ip(headers: &axum::http::HeaderMap, fallback: &str) -> String { headers .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) .and_then(|s| s.split(',').next()) .map(|s| s.trim().to_owned()) .unwrap_or_else(|| fallback.to_owned()) }