X-Git-Url: http://these/git/?a=blobdiff_plain;f=crates%2Futils%2Fsrc%2Frate_limit%2Fmod.rs;h=1bb6f1b5f4e1279f3642b8a6b2c246c4569626a2;hb=92568956353f21649ed9aff68b42699c9d036f30;hp=d3e74ed5bc40ad97360d55b9fa5b91618333d2c8;hpb=ddf4a667b1ba15222287b30455a202bd3a336547;p=lemmy.git diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs index d3e74ed5..1bb6f1b5 100644 --- a/crates/utils/src/rate_limit/mod.rs +++ b/crates/utils/src/rate_limit/mod.rs @@ -1,164 +1,210 @@ -use crate::{ - settings::structs::{RateLimitConfig, Settings}, - utils::get_ip, - LemmyError, -}; -use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; +use crate::error::{LemmyError, LemmyErrorType}; +use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; +use enum_map::enum_map; use futures::future::{ok, Ready}; -use rate_limiter::{RateLimitType, RateLimiter}; +use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType}; +use serde::{Deserialize, Serialize}; use std::{ future::Future, + net::{IpAddr, Ipv4Addr, SocketAddr}, pin::Pin, - sync::Arc, + rc::Rc, + str::FromStr, + sync::{Arc, Mutex}, task::{Context, Poll}, + time::Duration, }; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, mpsc::Sender, OnceCell}; +use typed_builder::TypedBuilder; pub mod rate_limiter; +#[derive(Debug, Deserialize, Serialize, Clone, TypedBuilder)] +pub struct RateLimitConfig { + #[builder(default = 180)] + /// Maximum number of messages created in interval + pub message: i32, + #[builder(default = 60)] + /// Interval length for message limit, in seconds + pub message_per_second: i32, + #[builder(default = 6)] + /// Maximum number of posts created in interval + pub post: i32, + #[builder(default = 300)] + /// Interval length for post limit, in seconds + pub post_per_second: i32, + #[builder(default = 3)] + /// Maximum number of registrations in interval + pub register: i32, + #[builder(default = 3600)] + /// Interval length for registration limit, in seconds + pub register_per_second: i32, + #[builder(default = 6)] + /// Maximum number of image uploads in interval + pub image: i32, + #[builder(default = 3600)] + /// Interval length for image uploads, in seconds + pub image_per_second: i32, + #[builder(default = 6)] + /// Maximum number of comments created in interval + pub comment: i32, + #[builder(default = 600)] + /// Interval length for comment limit, in seconds + pub comment_per_second: i32, + #[builder(default = 60)] + /// Maximum number of searches created in interval + pub search: i32, + #[builder(default = 600)] + /// Interval length for search limit, in seconds + pub search_per_second: i32, +} + #[derive(Debug, Clone)] -pub struct RateLimit { - // it might be reasonable to use a std::sync::Mutex here, since we don't need to lock this - // across await points - pub rate_limiter: Arc>, +struct RateLimit { + pub rate_limiter: RateLimitStorage, + pub rate_limit_config: RateLimitConfig, } #[derive(Debug, Clone)] -pub struct RateLimited { - rate_limiter: Arc>, +pub struct RateLimitedGuard { + rate_limit: Arc>, type_: RateLimitType, } -pub struct RateLimitedMiddleware { - rate_limited: RateLimited, - service: S, +/// Single instance of rate limit config and buckets, which is shared across all threads. +#[derive(Clone)] +pub struct RateLimitCell { + tx: Sender, + rate_limit: Arc>, } -impl RateLimit { - pub fn message(&self) -> RateLimited { +impl RateLimitCell { + /// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell. + pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self { + static LOCAL_INSTANCE: OnceCell = OnceCell::const_new(); + LOCAL_INSTANCE + .get_or_init(|| async { + let (tx, mut rx) = mpsc::channel::(4); + let rate_limit = Arc::new(Mutex::new(RateLimit { + rate_limiter: Default::default(), + rate_limit_config, + })); + let rate_limit2 = rate_limit.clone(); + tokio::spawn(async move { + while let Some(r) = rx.recv().await { + rate_limit2 + .lock() + .expect("Failed to lock rate limit mutex for updating") + .rate_limit_config = r; + } + }); + RateLimitCell { tx, rate_limit } + }) + .await + } + + /// Call this when the config was updated, to update all in-memory cells. + pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> { + self.tx.send(config).await?; + Ok(()) + } + + /// Remove buckets older than the given duration + pub fn remove_older_than(&self, mut duration: Duration) { + let mut guard = self + .rate_limit + .lock() + .expect("Failed to lock rate limit mutex for reading"); + let rate_limit = &guard.rate_limit_config; + + // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check. + let max_interval_secs = enum_map! { + RateLimitType::Message => rate_limit.message_per_second, + RateLimitType::Post => rate_limit.post_per_second, + RateLimitType::Register => rate_limit.register_per_second, + RateLimitType::Image => rate_limit.image_per_second, + RateLimitType::Comment => rate_limit.comment_per_second, + RateLimitType::Search => rate_limit.search_per_second, + } + .into_values() + .max() + .and_then(|max| u64::try_from(max).ok()) + .unwrap_or(0); + + duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs)); + + guard + .rate_limiter + .remove_older_than(duration, InstantSecs::now()) + } + + pub fn message(&self) -> RateLimitedGuard { self.kind(RateLimitType::Message) } - pub fn post(&self) -> RateLimited { + pub fn post(&self) -> RateLimitedGuard { self.kind(RateLimitType::Post) } - pub fn register(&self) -> RateLimited { + pub fn register(&self) -> RateLimitedGuard { self.kind(RateLimitType::Register) } - pub fn image(&self) -> RateLimited { + pub fn image(&self) -> RateLimitedGuard { self.kind(RateLimitType::Image) } - fn kind(&self, type_: RateLimitType) -> RateLimited { - RateLimited { - rate_limiter: self.rate_limiter.clone(), + pub fn comment(&self) -> RateLimitedGuard { + self.kind(RateLimitType::Comment) + } + + pub fn search(&self) -> RateLimitedGuard { + self.kind(RateLimitType::Search) + } + + fn kind(&self, type_: RateLimitType) -> RateLimitedGuard { + RateLimitedGuard { + rate_limit: self.rate_limit.clone(), type_, } } } -impl RateLimited { - pub async fn wrap( - self, - ip_addr: String, - fut: impl Future>, - ) -> Result - where - E: From, - { +pub struct RateLimitedMiddleware { + rate_limited: RateLimitedGuard, + service: Rc, +} + +impl RateLimitedGuard { + /// Returns true if the request passed the rate limit, false if it failed and should be rejected. + pub fn check(self, ip_addr: IpAddr) -> bool { // Does not need to be blocking because the RwLock in settings never held across await points, // and the operation here locks only long enough to clone - let rate_limit: RateLimitConfig = Settings::get().rate_limit(); - - // before - { - let mut limiter = self.rate_limiter.lock().await; - - match self.type_ { - RateLimitType::Message => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.message, - rate_limit.message_per_second, - false, - )?; - - drop(limiter); - return fut.await; - } - RateLimitType::Post => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.post, - rate_limit.post_per_second, - true, - )?; - } - RateLimitType::Register => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.register, - rate_limit.register_per_second, - true, - )?; - } - RateLimitType::Image => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.image, - rate_limit.image_per_second, - false, - )?; - } - }; - } + let mut guard = self + .rate_limit + .lock() + .expect("Failed to lock rate limit mutex for reading"); + let rate_limit = &guard.rate_limit_config; - let res = fut.await; - - // after - { - let mut limiter = self.rate_limiter.lock().await; - if res.is_ok() { - match self.type_ { - RateLimitType::Post => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.post, - rate_limit.post_per_second, - false, - )?; - } - RateLimitType::Register => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.register, - rate_limit.register_per_second, - false, - )?; - } - _ => (), - }; - } - } + let (kind, interval) = match self.type_ { + RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second), + RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second), + RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second), + RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second), + RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second), + RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second), + }; + let limiter = &mut guard.rate_limiter; - res + limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now()) } } -impl Transform for RateLimited +impl Transform for RateLimitedGuard where - S: Service, + S: Service + 'static, S::Future: 'static, { - type Request = S::Request; type Response = S::Response; type Error = actix_web::Error; type InitError = (); @@ -168,35 +214,81 @@ where fn new_transform(&self, service: S) -> Self::Future { ok(RateLimitedMiddleware { rate_limited: self.clone(), - service, + service: Rc::new(service), }) } } type FutResult = dyn Future>; -impl Service for RateLimitedMiddleware +impl Service for RateLimitedMiddleware where - S: Service, + S: Service + 'static, S::Future: 'static, { - type Request = S::Request; type Response = S::Response; type Error = actix_web::Error; type Future = Pin>>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } - fn call(&mut self, req: S::Request) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { let ip_addr = get_ip(&req.connection_info()); - let fut = self - .rate_limited - .clone() - .wrap(ip_addr, self.service.call(req)); + let rate_limited = self.rate_limited.clone(); + let service = self.service.clone(); - Box::pin(async move { fut.await.map_err(actix_web::Error::from) }) + Box::pin(async move { + if rate_limited.check(ip_addr) { + service.call(req).await + } else { + let (http_req, _) = req.into_parts(); + Ok(ServiceResponse::from_err( + LemmyError::from(LemmyErrorType::RateLimitError), + http_req, + )) + } + }) + } +} + +fn get_ip(conn_info: &ConnectionInfo) -> IpAddr { + conn_info + .realip_remote_addr() + .and_then(parse_ip) + .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))) +} + +fn parse_ip(addr: &str) -> Option { + if let Some(s) = addr.strip_suffix(']') { + IpAddr::from_str(s.get(1..)?).ok() + } else if let Ok(ip) = IpAddr::from_str(addr) { + Some(ip) + } else if let Ok(socket) = SocketAddr::from_str(addr) { + Some(socket.ip()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + #![allow(clippy::indexing_slicing)] + + #[test] + fn test_parse_ip() { + let ip_addrs = [ + "1.2.3.4", + "1.2.3.4:8000", + "2001:db8::", + "[2001:db8::]", + "[2001:db8::]:8000", + ]; + for addr in ip_addrs { + assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}"); + } } }