X-Git-Url: http://these/git/?a=blobdiff_plain;f=crates%2Futils%2Fsrc%2Frate_limit%2Fmod.rs;h=1bb6f1b5f4e1279f3642b8a6b2c246c4569626a2;hb=92568956353f21649ed9aff68b42699c9d036f30;hp=ed019255f8026262f6977080c5667c4ee867cfb1;hpb=93931958277714fd738a8dd8ab965c39b5016f63;p=lemmy.git diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs index ed019255..1bb6f1b5 100644 --- a/crates/utils/src/rate_limit/mod.rs +++ b/crates/utils/src/rate_limit/mod.rs @@ -1,15 +1,20 @@ -use crate::{error::LemmyError, utils::get_ip, IpAddr}; -use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; +use crate::error::{LemmyError, LemmyErrorType}; +use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; +use enum_map::enum_map; use futures::future::{ok, Ready}; -use rate_limiter::{RateLimitType, RateLimiter}; +use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType}; use serde::{Deserialize, Serialize}; use std::{ future::Future, + net::{IpAddr, Ipv4Addr, SocketAddr}, pin::Pin, rc::Rc, + str::FromStr, sync::{Arc, Mutex}, task::{Context, Poll}, + time::Duration, }; +use tokio::sync::{mpsc, mpsc::Sender, OnceCell}; use typed_builder::TypedBuilder; pub mod rate_limiter; @@ -55,65 +60,131 @@ pub struct RateLimitConfig { } #[derive(Debug, Clone)] -pub struct RateLimit { - // it might be reasonable to use a std::sync::Mutex here, since we don't need to lock this - // across await points - pub rate_limiter: Arc>, +struct RateLimit { + pub rate_limiter: RateLimitStorage, pub rate_limit_config: RateLimitConfig, } #[derive(Debug, Clone)] -pub struct RateLimited { - rate_limiter: Arc>, - rate_limit_config: RateLimitConfig, +pub struct RateLimitedGuard { + rate_limit: Arc>, type_: RateLimitType, } -pub struct RateLimitedMiddleware { - rate_limited: RateLimited, - service: Rc, +/// Single instance of rate limit config and buckets, which is shared across all threads. +#[derive(Clone)] +pub struct RateLimitCell { + tx: Sender, + rate_limit: Arc>, } -impl RateLimit { - pub fn message(&self) -> RateLimited { +impl RateLimitCell { + /// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell. + pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self { + static LOCAL_INSTANCE: OnceCell = OnceCell::const_new(); + LOCAL_INSTANCE + .get_or_init(|| async { + let (tx, mut rx) = mpsc::channel::(4); + let rate_limit = Arc::new(Mutex::new(RateLimit { + rate_limiter: Default::default(), + rate_limit_config, + })); + let rate_limit2 = rate_limit.clone(); + tokio::spawn(async move { + while let Some(r) = rx.recv().await { + rate_limit2 + .lock() + .expect("Failed to lock rate limit mutex for updating") + .rate_limit_config = r; + } + }); + RateLimitCell { tx, rate_limit } + }) + .await + } + + /// Call this when the config was updated, to update all in-memory cells. + pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> { + self.tx.send(config).await?; + Ok(()) + } + + /// Remove buckets older than the given duration + pub fn remove_older_than(&self, mut duration: Duration) { + let mut guard = self + .rate_limit + .lock() + .expect("Failed to lock rate limit mutex for reading"); + let rate_limit = &guard.rate_limit_config; + + // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check. + let max_interval_secs = enum_map! { + RateLimitType::Message => rate_limit.message_per_second, + RateLimitType::Post => rate_limit.post_per_second, + RateLimitType::Register => rate_limit.register_per_second, + RateLimitType::Image => rate_limit.image_per_second, + RateLimitType::Comment => rate_limit.comment_per_second, + RateLimitType::Search => rate_limit.search_per_second, + } + .into_values() + .max() + .and_then(|max| u64::try_from(max).ok()) + .unwrap_or(0); + + duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs)); + + guard + .rate_limiter + .remove_older_than(duration, InstantSecs::now()) + } + + pub fn message(&self) -> RateLimitedGuard { self.kind(RateLimitType::Message) } - pub fn post(&self) -> RateLimited { + pub fn post(&self) -> RateLimitedGuard { self.kind(RateLimitType::Post) } - pub fn register(&self) -> RateLimited { + pub fn register(&self) -> RateLimitedGuard { self.kind(RateLimitType::Register) } - pub fn image(&self) -> RateLimited { + pub fn image(&self) -> RateLimitedGuard { self.kind(RateLimitType::Image) } - pub fn comment(&self) -> RateLimited { + pub fn comment(&self) -> RateLimitedGuard { self.kind(RateLimitType::Comment) } - pub fn search(&self) -> RateLimited { + pub fn search(&self) -> RateLimitedGuard { self.kind(RateLimitType::Search) } - fn kind(&self, type_: RateLimitType) -> RateLimited { - RateLimited { - rate_limiter: self.rate_limiter.clone(), - rate_limit_config: self.rate_limit_config.clone(), + fn kind(&self, type_: RateLimitType) -> RateLimitedGuard { + RateLimitedGuard { + rate_limit: self.rate_limit.clone(), type_, } } } -impl RateLimited { +pub struct RateLimitedMiddleware { + rate_limited: RateLimitedGuard, + service: Rc, +} + +impl RateLimitedGuard { /// Returns true if the request passed the rate limit, false if it failed and should be rejected. pub fn check(self, ip_addr: IpAddr) -> bool { // Does not need to be blocking because the RwLock in settings never held across await points, // and the operation here locks only long enough to clone - let rate_limit = self.rate_limit_config; + let mut guard = self + .rate_limit + .lock() + .expect("Failed to lock rate limit mutex for reading"); + let rate_limit = &guard.rate_limit_config; let (kind, interval) = match self.type_ { RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second), @@ -123,13 +194,13 @@ impl RateLimited { RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second), RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second), }; - let mut limiter = self.rate_limiter.lock().expect("mutex poison error"); + let limiter = &mut guard.rate_limiter; - limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) + limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now()) } } -impl Transform for RateLimited +impl Transform for RateLimitedGuard where S: Service + 'static, S::Future: 'static, @@ -175,10 +246,49 @@ where } else { let (http_req, _) = req.into_parts(); Ok(ServiceResponse::from_err( - LemmyError::from_message("rate_limit_error"), + LemmyError::from(LemmyErrorType::RateLimitError), http_req, )) } }) } } + +fn get_ip(conn_info: &ConnectionInfo) -> IpAddr { + conn_info + .realip_remote_addr() + .and_then(parse_ip) + .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))) +} + +fn parse_ip(addr: &str) -> Option { + if let Some(s) = addr.strip_suffix(']') { + IpAddr::from_str(s.get(1..)?).ok() + } else if let Ok(ip) = IpAddr::from_str(addr) { + Some(ip) + } else if let Ok(socket) = SocketAddr::from_str(addr) { + Some(socket.ip()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + #![allow(clippy::indexing_slicing)] + + #[test] + fn test_parse_ip() { + let ip_addrs = [ + "1.2.3.4", + "1.2.3.4:8000", + "2001:db8::", + "[2001:db8::]", + "[2001:db8::]:8000", + ]; + for addr in ip_addrs { + assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}"); + } + } +}