X-Git-Url: http://these/git/?a=blobdiff_plain;f=crates%2Futils%2Fsrc%2Frate_limit%2Fmod.rs;h=1bb6f1b5f4e1279f3642b8a6b2c246c4569626a2;hb=92568956353f21649ed9aff68b42699c9d036f30;hp=6027520f0ac1038e7a7d6c938a14b9bc6c728848;hpb=a5ff629b2492ff21ca24c9de79e41ad341e204db;p=lemmy.git diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs index 6027520f..1bb6f1b5 100644 --- a/crates/utils/src/rate_limit/mod.rs +++ b/crates/utils/src/rate_limit/mod.rs @@ -1,79 +1,190 @@ -use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr}; -use actix_web::{ - dev::{Service, ServiceRequest, ServiceResponse, Transform}, - HttpResponse, -}; +use crate::error::{LemmyError, LemmyErrorType}; +use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; +use enum_map::enum_map; use futures::future::{ok, Ready}; -use rate_limiter::{RateLimitType, RateLimiter}; +use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType}; +use serde::{Deserialize, Serialize}; use std::{ future::Future, + net::{IpAddr, Ipv4Addr, SocketAddr}, pin::Pin, rc::Rc, - sync::Arc, + str::FromStr, + sync::{Arc, Mutex}, task::{Context, Poll}, + time::Duration, }; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, mpsc::Sender, OnceCell}; +use typed_builder::TypedBuilder; pub mod rate_limiter; +#[derive(Debug, Deserialize, Serialize, Clone, TypedBuilder)] +pub struct RateLimitConfig { + #[builder(default = 180)] + /// Maximum number of messages created in interval + pub message: i32, + #[builder(default = 60)] + /// Interval length for message limit, in seconds + pub message_per_second: i32, + #[builder(default = 6)] + /// Maximum number of posts created in interval + pub post: i32, + #[builder(default = 300)] + /// Interval length for post limit, in seconds + pub post_per_second: i32, + #[builder(default = 3)] + /// Maximum number of registrations in interval + pub register: i32, + #[builder(default = 3600)] + /// Interval length for registration limit, in seconds + pub register_per_second: i32, + #[builder(default = 6)] + /// Maximum number of image uploads in interval + pub image: i32, + #[builder(default = 3600)] + /// Interval length for image uploads, in seconds + pub image_per_second: i32, + #[builder(default = 6)] + /// Maximum number of comments created in interval + pub comment: i32, + #[builder(default = 600)] + /// Interval length for comment limit, in seconds + pub comment_per_second: i32, + #[builder(default = 60)] + /// Maximum number of searches created in interval + pub search: i32, + #[builder(default = 600)] + /// Interval length for search limit, in seconds + pub search_per_second: i32, +} + #[derive(Debug, Clone)] -pub struct RateLimit { - // it might be reasonable to use a std::sync::Mutex here, since we don't need to lock this - // across await points - pub rate_limiter: Arc>, +struct RateLimit { + pub rate_limiter: RateLimitStorage, pub rate_limit_config: RateLimitConfig, } #[derive(Debug, Clone)] -pub struct RateLimited { - rate_limiter: Arc>, - rate_limit_config: RateLimitConfig, +pub struct RateLimitedGuard { + rate_limit: Arc>, type_: RateLimitType, } -pub struct RateLimitedMiddleware { - rate_limited: RateLimited, - service: Rc, +/// Single instance of rate limit config and buckets, which is shared across all threads. +#[derive(Clone)] +pub struct RateLimitCell { + tx: Sender, + rate_limit: Arc>, } -impl RateLimit { - pub fn message(&self) -> RateLimited { +impl RateLimitCell { + /// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell. + pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self { + static LOCAL_INSTANCE: OnceCell = OnceCell::const_new(); + LOCAL_INSTANCE + .get_or_init(|| async { + let (tx, mut rx) = mpsc::channel::(4); + let rate_limit = Arc::new(Mutex::new(RateLimit { + rate_limiter: Default::default(), + rate_limit_config, + })); + let rate_limit2 = rate_limit.clone(); + tokio::spawn(async move { + while let Some(r) = rx.recv().await { + rate_limit2 + .lock() + .expect("Failed to lock rate limit mutex for updating") + .rate_limit_config = r; + } + }); + RateLimitCell { tx, rate_limit } + }) + .await + } + + /// Call this when the config was updated, to update all in-memory cells. + pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> { + self.tx.send(config).await?; + Ok(()) + } + + /// Remove buckets older than the given duration + pub fn remove_older_than(&self, mut duration: Duration) { + let mut guard = self + .rate_limit + .lock() + .expect("Failed to lock rate limit mutex for reading"); + let rate_limit = &guard.rate_limit_config; + + // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check. + let max_interval_secs = enum_map! { + RateLimitType::Message => rate_limit.message_per_second, + RateLimitType::Post => rate_limit.post_per_second, + RateLimitType::Register => rate_limit.register_per_second, + RateLimitType::Image => rate_limit.image_per_second, + RateLimitType::Comment => rate_limit.comment_per_second, + RateLimitType::Search => rate_limit.search_per_second, + } + .into_values() + .max() + .and_then(|max| u64::try_from(max).ok()) + .unwrap_or(0); + + duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs)); + + guard + .rate_limiter + .remove_older_than(duration, InstantSecs::now()) + } + + pub fn message(&self) -> RateLimitedGuard { self.kind(RateLimitType::Message) } - pub fn post(&self) -> RateLimited { + pub fn post(&self) -> RateLimitedGuard { self.kind(RateLimitType::Post) } - pub fn register(&self) -> RateLimited { + pub fn register(&self) -> RateLimitedGuard { self.kind(RateLimitType::Register) } - pub fn image(&self) -> RateLimited { + pub fn image(&self) -> RateLimitedGuard { self.kind(RateLimitType::Image) } - pub fn comment(&self) -> RateLimited { + pub fn comment(&self) -> RateLimitedGuard { self.kind(RateLimitType::Comment) } - fn kind(&self, type_: RateLimitType) -> RateLimited { - RateLimited { - rate_limiter: self.rate_limiter.clone(), - rate_limit_config: self.rate_limit_config.clone(), + pub fn search(&self) -> RateLimitedGuard { + self.kind(RateLimitType::Search) + } + + fn kind(&self, type_: RateLimitType) -> RateLimitedGuard { + RateLimitedGuard { + rate_limit: self.rate_limit.clone(), type_, } } } -impl RateLimited { +pub struct RateLimitedMiddleware { + rate_limited: RateLimitedGuard, + service: Rc, +} + +impl RateLimitedGuard { /// Returns true if the request passed the rate limit, false if it failed and should be rejected. - pub async fn check(self, ip_addr: IpAddr) -> bool { + pub fn check(self, ip_addr: IpAddr) -> bool { // Does not need to be blocking because the RwLock in settings never held across await points, // and the operation here locks only long enough to clone - let rate_limit = self.rate_limit_config; - - let mut limiter = self.rate_limiter.lock().await; + let mut guard = self + .rate_limit + .lock() + .expect("Failed to lock rate limit mutex for reading"); + let rate_limit = &guard.rate_limit_config; let (kind, interval) = match self.type_ { RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second), @@ -81,12 +192,15 @@ impl RateLimited { RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second), RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second), RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second), + RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second), }; - limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) + let limiter = &mut guard.rate_limiter; + + limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now()) } } -impl Transform for RateLimited +impl Transform for RateLimitedGuard where S: Service + 'static, S::Future: 'static, @@ -127,16 +241,54 @@ where let service = self.service.clone(); Box::pin(async move { - if rate_limited.check(ip_addr).await { + if rate_limited.check(ip_addr) { service.call(req).await } else { let (http_req, _) = req.into_parts(); - // if rate limit was hit, respond with http 400 - Ok(ServiceResponse::new( + Ok(ServiceResponse::from_err( + LemmyError::from(LemmyErrorType::RateLimitError), http_req, - HttpResponse::BadRequest().finish(), )) } }) } } + +fn get_ip(conn_info: &ConnectionInfo) -> IpAddr { + conn_info + .realip_remote_addr() + .and_then(parse_ip) + .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))) +} + +fn parse_ip(addr: &str) -> Option { + if let Some(s) = addr.strip_suffix(']') { + IpAddr::from_str(s.get(1..)?).ok() + } else if let Ok(ip) = IpAddr::from_str(addr) { + Some(ip) + } else if let Ok(socket) = SocketAddr::from_str(addr) { + Some(socket.ip()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + #![allow(clippy::indexing_slicing)] + + #[test] + fn test_parse_ip() { + let ip_addrs = [ + "1.2.3.4", + "1.2.3.4:8000", + "2001:db8::", + "[2001:db8::]", + "[2001:db8::]:8000", + ]; + for addr in ip_addrs { + assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}"); + } + } +}