X-Git-Url: http://these/git/?a=blobdiff_plain;f=crates%2Futils%2Fsrc%2Frate_limit%2Fmod.rs;h=1bb6f1b5f4e1279f3642b8a6b2c246c4569626a2;hb=92568956353f21649ed9aff68b42699c9d036f30;hp=6dc9dcbef69d997539f5df8a552097669f0560a3;hpb=24756af84b8cdce9cd21dd63713b5d44474ad1f8;p=lemmy.git diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs index 6dc9dcbe..1bb6f1b5 100644 --- a/crates/utils/src/rate_limit/mod.rs +++ b/crates/utils/src/rate_limit/mod.rs @@ -1,14 +1,18 @@ -use crate::{error::LemmyError, utils::get_ip, IpAddr}; -use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; +use crate::error::{LemmyError, LemmyErrorType}; +use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; +use enum_map::enum_map; use futures::future::{ok, Ready}; -use rate_limiter::{RateLimitStorage, RateLimitType}; +use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType}; use serde::{Deserialize, Serialize}; use std::{ future::Future, + net::{IpAddr, Ipv4Addr, SocketAddr}, pin::Pin, rc::Rc, + str::FromStr, sync::{Arc, Mutex}, task::{Context, Poll}, + time::Duration, }; use tokio::sync::{mpsc, mpsc::Sender, OnceCell}; use typed_builder::TypedBuilder; @@ -105,6 +109,35 @@ impl RateLimitCell { Ok(()) } + /// Remove buckets older than the given duration + pub fn remove_older_than(&self, mut duration: Duration) { + let mut guard = self + .rate_limit + .lock() + .expect("Failed to lock rate limit mutex for reading"); + let rate_limit = &guard.rate_limit_config; + + // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check. + let max_interval_secs = enum_map! { + RateLimitType::Message => rate_limit.message_per_second, + RateLimitType::Post => rate_limit.post_per_second, + RateLimitType::Register => rate_limit.register_per_second, + RateLimitType::Image => rate_limit.image_per_second, + RateLimitType::Comment => rate_limit.comment_per_second, + RateLimitType::Search => rate_limit.search_per_second, + } + .into_values() + .max() + .and_then(|max| u64::try_from(max).ok()) + .unwrap_or(0); + + duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs)); + + guard + .rate_limiter + .remove_older_than(duration, InstantSecs::now()) + } + pub fn message(&self) -> RateLimitedGuard { self.kind(RateLimitType::Message) } @@ -163,7 +196,7 @@ impl RateLimitedGuard { }; let limiter = &mut guard.rate_limiter; - limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) + limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now()) } } @@ -213,10 +246,49 @@ where } else { let (http_req, _) = req.into_parts(); Ok(ServiceResponse::from_err( - LemmyError::from_message("rate_limit_error"), + LemmyError::from(LemmyErrorType::RateLimitError), http_req, )) } }) } } + +fn get_ip(conn_info: &ConnectionInfo) -> IpAddr { + conn_info + .realip_remote_addr() + .and_then(parse_ip) + .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))) +} + +fn parse_ip(addr: &str) -> Option { + if let Some(s) = addr.strip_suffix(']') { + IpAddr::from_str(s.get(1..)?).ok() + } else if let Ok(ip) = IpAddr::from_str(addr) { + Some(ip) + } else if let Ok(socket) = SocketAddr::from_str(addr) { + Some(socket.ip()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + #![allow(clippy::indexing_slicing)] + + #[test] + fn test_parse_ip() { + let ip_addrs = [ + "1.2.3.4", + "1.2.3.4:8000", + "2001:db8::", + "[2001:db8::]", + "[2001:db8::]:8000", + ]; + for addr in ip_addrs { + assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}"); + } + } +}