]> Untitled Git - lemmy.git/blobdiff - crates/utils/src/rate_limit/mod.rs
Reduce memory usage of rate limiting (#3111)
[lemmy.git] / crates / utils / src / rate_limit / mod.rs
index b7000e3e6589690f7fd42ee9493dec0682780fec..d1c51265d5f1d59d8ccec10a52781cfb61bfa288 100644 (file)
@@ -1,14 +1,18 @@
-use crate::{error::LemmyError, IpAddr};
+use crate::error::LemmyError;
 use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform};
+use enum_map::enum_map;
 use futures::future::{ok, Ready};
-use rate_limiter::{RateLimitStorage, RateLimitType};
+use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType};
 use serde::{Deserialize, Serialize};
 use std::{
   future::Future,
+  net::{IpAddr, Ipv4Addr, SocketAddr},
   pin::Pin,
   rc::Rc,
+  str::FromStr,
   sync::{Arc, Mutex},
   task::{Context, Poll},
+  time::Duration,
 };
 use tokio::sync::{mpsc, mpsc::Sender, OnceCell};
 use typed_builder::TypedBuilder;
@@ -105,6 +109,35 @@ impl RateLimitCell {
     Ok(())
   }
 
+  /// Remove buckets older than the given duration
+  pub fn remove_older_than(&self, mut duration: Duration) {
+    let mut guard = self
+      .rate_limit
+      .lock()
+      .expect("Failed to lock rate limit mutex for reading");
+    let rate_limit = &guard.rate_limit_config;
+
+    // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check.
+    let max_interval_secs = enum_map! {
+      RateLimitType::Message => rate_limit.message_per_second,
+      RateLimitType::Post => rate_limit.post_per_second,
+      RateLimitType::Register => rate_limit.register_per_second,
+      RateLimitType::Image => rate_limit.image_per_second,
+      RateLimitType::Comment => rate_limit.comment_per_second,
+      RateLimitType::Search => rate_limit.search_per_second,
+    }
+    .into_values()
+    .max()
+    .and_then(|max| u64::try_from(max).ok())
+    .unwrap_or(0);
+
+    duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs));
+
+    guard
+      .rate_limiter
+      .remove_older_than(duration, InstantSecs::now())
+  }
+
   pub fn message(&self) -> RateLimitedGuard {
     self.kind(RateLimitType::Message)
   }
@@ -163,7 +196,7 @@ impl RateLimitedGuard {
     };
     let limiter = &mut guard.rate_limiter;
 
-    limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
+    limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now())
   }
 }
 
@@ -222,13 +255,37 @@ where
 }
 
 fn get_ip(conn_info: &ConnectionInfo) -> IpAddr {
-  IpAddr(
-    conn_info
-      .realip_remote_addr()
-      .unwrap_or("127.0.0.1:12345")
-      .split(':')
-      .next()
-      .unwrap_or("127.0.0.1")
-      .to_string(),
-  )
+  conn_info
+    .realip_remote_addr()
+    .and_then(parse_ip)
+    .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
+}
+
+fn parse_ip(addr: &str) -> Option<IpAddr> {
+  if let Some(s) = addr.strip_suffix(']') {
+    IpAddr::from_str(s.get(1..)?).ok()
+  } else if let Ok(ip) = IpAddr::from_str(addr) {
+    Some(ip)
+  } else if let Ok(socket) = SocketAddr::from_str(addr) {
+    Some(socket.ip())
+  } else {
+    None
+  }
+}
+
+#[cfg(test)]
+mod tests {
+  #[test]
+  fn test_parse_ip() {
+    let ip_addrs = [
+      "1.2.3.4",
+      "1.2.3.4:8000",
+      "2001:db8::",
+      "[2001:db8::]",
+      "[2001:db8::]:8000",
+    ];
+    for addr in ip_addrs {
+      assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}");
+    }
+  }
 }