X-Git-Url: http://these/git/?a=blobdiff_plain;f=crates%2Futils%2Fsrc%2Frate_limit%2Fmod.rs;h=d1c51265d5f1d59d8ccec10a52781cfb61bfa288;hb=45818fb4c50dc06f6909312bd034f68af92c3146;hp=b7000e3e6589690f7fd42ee9493dec0682780fec;hpb=b214d3dc00c269d7987ace7f5522e2ff406eec03;p=lemmy.git diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs index b7000e3e..d1c51265 100644 --- a/crates/utils/src/rate_limit/mod.rs +++ b/crates/utils/src/rate_limit/mod.rs @@ -1,14 +1,18 @@ -use crate::{error::LemmyError, IpAddr}; +use crate::error::LemmyError; use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; +use enum_map::enum_map; use futures::future::{ok, Ready}; -use rate_limiter::{RateLimitStorage, RateLimitType}; +use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType}; use serde::{Deserialize, Serialize}; use std::{ future::Future, + net::{IpAddr, Ipv4Addr, SocketAddr}, pin::Pin, rc::Rc, + str::FromStr, sync::{Arc, Mutex}, task::{Context, Poll}, + time::Duration, }; use tokio::sync::{mpsc, mpsc::Sender, OnceCell}; use typed_builder::TypedBuilder; @@ -105,6 +109,35 @@ impl RateLimitCell { Ok(()) } + /// Remove buckets older than the given duration + pub fn remove_older_than(&self, mut duration: Duration) { + let mut guard = self + .rate_limit + .lock() + .expect("Failed to lock rate limit mutex for reading"); + let rate_limit = &guard.rate_limit_config; + + // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check. + let max_interval_secs = enum_map! { + RateLimitType::Message => rate_limit.message_per_second, + RateLimitType::Post => rate_limit.post_per_second, + RateLimitType::Register => rate_limit.register_per_second, + RateLimitType::Image => rate_limit.image_per_second, + RateLimitType::Comment => rate_limit.comment_per_second, + RateLimitType::Search => rate_limit.search_per_second, + } + .into_values() + .max() + .and_then(|max| u64::try_from(max).ok()) + .unwrap_or(0); + + duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs)); + + guard + .rate_limiter + .remove_older_than(duration, InstantSecs::now()) + } + pub fn message(&self) -> RateLimitedGuard { self.kind(RateLimitType::Message) } @@ -163,7 +196,7 @@ impl RateLimitedGuard { }; let limiter = &mut guard.rate_limiter; - limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) + limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now()) } } @@ -222,13 +255,37 @@ where } fn get_ip(conn_info: &ConnectionInfo) -> IpAddr { - IpAddr( - conn_info - .realip_remote_addr() - .unwrap_or("127.0.0.1:12345") - .split(':') - .next() - .unwrap_or("127.0.0.1") - .to_string(), - ) + conn_info + .realip_remote_addr() + .and_then(parse_ip) + .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))) +} + +fn parse_ip(addr: &str) -> Option { + if let Some(s) = addr.strip_suffix(']') { + IpAddr::from_str(s.get(1..)?).ok() + } else if let Ok(ip) = IpAddr::from_str(addr) { + Some(ip) + } else if let Ok(socket) = SocketAddr::from_str(addr) { + Some(socket.ip()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + #[test] + fn test_parse_ip() { + let ip_addrs = [ + "1.2.3.4", + "1.2.3.4:8000", + "2001:db8::", + "[2001:db8::]", + "[2001:db8::]:8000", + ]; + for addr in ip_addrs { + assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}"); + } + } }