-use crate::{error::LemmyError, utils::get_ip, IpAddr};
-use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
+use crate::error::LemmyError;
+use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform};
+use enum_map::enum_map;
use futures::future::{ok, Ready};
-use rate_limiter::{RateLimitStorage, RateLimitType};
+use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType};
use serde::{Deserialize, Serialize};
use std::{
future::Future,
+ net::{IpAddr, Ipv4Addr, SocketAddr},
pin::Pin,
rc::Rc,
+ str::FromStr,
sync::{Arc, Mutex},
task::{Context, Poll},
+ time::Duration,
};
use tokio::sync::{mpsc, mpsc::Sender, OnceCell};
use typed_builder::TypedBuilder;
Ok(())
}
+ /// Remove buckets older than the given duration
+ pub fn remove_older_than(&self, mut duration: Duration) {
+ let mut guard = self
+ .rate_limit
+ .lock()
+ .expect("Failed to lock rate limit mutex for reading");
+ let rate_limit = &guard.rate_limit_config;
+
+ // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check.
+ let max_interval_secs = enum_map! {
+ RateLimitType::Message => rate_limit.message_per_second,
+ RateLimitType::Post => rate_limit.post_per_second,
+ RateLimitType::Register => rate_limit.register_per_second,
+ RateLimitType::Image => rate_limit.image_per_second,
+ RateLimitType::Comment => rate_limit.comment_per_second,
+ RateLimitType::Search => rate_limit.search_per_second,
+ }
+ .into_values()
+ .max()
+ .and_then(|max| u64::try_from(max).ok())
+ .unwrap_or(0);
+
+ duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs));
+
+ guard
+ .rate_limiter
+ .remove_older_than(duration, InstantSecs::now())
+ }
+
pub fn message(&self) -> RateLimitedGuard {
self.kind(RateLimitType::Message)
}
};
let limiter = &mut guard.rate_limiter;
- limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
+ limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now())
}
}
})
}
}
+
+fn get_ip(conn_info: &ConnectionInfo) -> IpAddr {
+ conn_info
+ .realip_remote_addr()
+ .and_then(parse_ip)
+ .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
+}
+
+fn parse_ip(addr: &str) -> Option<IpAddr> {
+ if let Some(s) = addr.strip_suffix(']') {
+ IpAddr::from_str(s.get(1..)?).ok()
+ } else if let Ok(ip) = IpAddr::from_str(addr) {
+ Some(ip)
+ } else if let Ok(socket) = SocketAddr::from_str(addr) {
+ Some(socket.ip())
+ } else {
+ None
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ #[test]
+ fn test_parse_ip() {
+ let ip_addrs = [
+ "1.2.3.4",
+ "1.2.3.4:8000",
+ "2001:db8::",
+ "[2001:db8::]",
+ "[2001:db8::]:8000",
+ ];
+ for addr in ip_addrs {
+ assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}");
+ }
+ }
+}