-use crate::{error::LemmyError, utils::get_ip, IpAddr};
-use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
+use crate::error::LemmyError;
+use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform};
+use enum_map::enum_map;
use futures::future::{ok, Ready};
-use rate_limiter::{RateLimitType, RateLimiter};
+use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType};
use serde::{Deserialize, Serialize};
use std::{
future::Future,
+ net::{IpAddr, Ipv4Addr, SocketAddr},
pin::Pin,
rc::Rc,
+ str::FromStr,
sync::{Arc, Mutex},
task::{Context, Poll},
+ time::Duration,
};
+use tokio::sync::{mpsc, mpsc::Sender, OnceCell};
use typed_builder::TypedBuilder;
pub mod rate_limiter;
}
#[derive(Debug, Clone)]
-pub struct RateLimit {
- // it might be reasonable to use a std::sync::Mutex here, since we don't need to lock this
- // across await points
- pub rate_limiter: Arc<Mutex<RateLimiter>>,
+struct RateLimit {
+ pub rate_limiter: RateLimitStorage,
pub rate_limit_config: RateLimitConfig,
}
#[derive(Debug, Clone)]
-pub struct RateLimited {
- rate_limiter: Arc<Mutex<RateLimiter>>,
- rate_limit_config: RateLimitConfig,
+pub struct RateLimitedGuard {
+ rate_limit: Arc<Mutex<RateLimit>>,
type_: RateLimitType,
}
-pub struct RateLimitedMiddleware<S> {
- rate_limited: RateLimited,
- service: Rc<S>,
+/// Single instance of rate limit config and buckets, which is shared across all threads.
+#[derive(Clone)]
+pub struct RateLimitCell {
+ tx: Sender<RateLimitConfig>,
+ rate_limit: Arc<Mutex<RateLimit>>,
}
-impl RateLimit {
- pub fn message(&self) -> RateLimited {
+impl RateLimitCell {
+ /// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell.
+ pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self {
+ static LOCAL_INSTANCE: OnceCell<RateLimitCell> = OnceCell::const_new();
+ LOCAL_INSTANCE
+ .get_or_init(|| async {
+ let (tx, mut rx) = mpsc::channel::<RateLimitConfig>(4);
+ let rate_limit = Arc::new(Mutex::new(RateLimit {
+ rate_limiter: Default::default(),
+ rate_limit_config,
+ }));
+ let rate_limit2 = rate_limit.clone();
+ tokio::spawn(async move {
+ while let Some(r) = rx.recv().await {
+ rate_limit2
+ .lock()
+ .expect("Failed to lock rate limit mutex for updating")
+ .rate_limit_config = r;
+ }
+ });
+ RateLimitCell { tx, rate_limit }
+ })
+ .await
+ }
+
+ /// Call this when the config was updated, to update all in-memory cells.
+ pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> {
+ self.tx.send(config).await?;
+ Ok(())
+ }
+
+ /// Remove buckets older than the given duration
+ pub fn remove_older_than(&self, mut duration: Duration) {
+ let mut guard = self
+ .rate_limit
+ .lock()
+ .expect("Failed to lock rate limit mutex for reading");
+ let rate_limit = &guard.rate_limit_config;
+
+ // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check.
+ let max_interval_secs = enum_map! {
+ RateLimitType::Message => rate_limit.message_per_second,
+ RateLimitType::Post => rate_limit.post_per_second,
+ RateLimitType::Register => rate_limit.register_per_second,
+ RateLimitType::Image => rate_limit.image_per_second,
+ RateLimitType::Comment => rate_limit.comment_per_second,
+ RateLimitType::Search => rate_limit.search_per_second,
+ }
+ .into_values()
+ .max()
+ .and_then(|max| u64::try_from(max).ok())
+ .unwrap_or(0);
+
+ duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs));
+
+ guard
+ .rate_limiter
+ .remove_older_than(duration, InstantSecs::now())
+ }
+
+ pub fn message(&self) -> RateLimitedGuard {
self.kind(RateLimitType::Message)
}
- pub fn post(&self) -> RateLimited {
+ pub fn post(&self) -> RateLimitedGuard {
self.kind(RateLimitType::Post)
}
- pub fn register(&self) -> RateLimited {
+ pub fn register(&self) -> RateLimitedGuard {
self.kind(RateLimitType::Register)
}
- pub fn image(&self) -> RateLimited {
+ pub fn image(&self) -> RateLimitedGuard {
self.kind(RateLimitType::Image)
}
- pub fn comment(&self) -> RateLimited {
+ pub fn comment(&self) -> RateLimitedGuard {
self.kind(RateLimitType::Comment)
}
- pub fn search(&self) -> RateLimited {
+ pub fn search(&self) -> RateLimitedGuard {
self.kind(RateLimitType::Search)
}
- fn kind(&self, type_: RateLimitType) -> RateLimited {
- RateLimited {
- rate_limiter: self.rate_limiter.clone(),
- rate_limit_config: self.rate_limit_config.clone(),
+ fn kind(&self, type_: RateLimitType) -> RateLimitedGuard {
+ RateLimitedGuard {
+ rate_limit: self.rate_limit.clone(),
type_,
}
}
}
-impl RateLimited {
+pub struct RateLimitedMiddleware<S> {
+ rate_limited: RateLimitedGuard,
+ service: Rc<S>,
+}
+
+impl RateLimitedGuard {
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
pub fn check(self, ip_addr: IpAddr) -> bool {
// Does not need to be blocking because the RwLock in settings never held across await points,
// and the operation here locks only long enough to clone
- let rate_limit = self.rate_limit_config;
+ let mut guard = self
+ .rate_limit
+ .lock()
+ .expect("Failed to lock rate limit mutex for reading");
+ let rate_limit = &guard.rate_limit_config;
let (kind, interval) = match self.type_ {
RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second),
};
- let mut limiter = self.rate_limiter.lock().expect("mutex poison error");
+ let limiter = &mut guard.rate_limiter;
- limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
+ limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now())
}
}
-impl<S> Transform<S, ServiceRequest> for RateLimited
+impl<S> Transform<S, ServiceRequest> for RateLimitedGuard
where
S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
S::Future: 'static,
})
}
}
+
+fn get_ip(conn_info: &ConnectionInfo) -> IpAddr {
+ conn_info
+ .realip_remote_addr()
+ .and_then(parse_ip)
+ .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
+}
+
+fn parse_ip(addr: &str) -> Option<IpAddr> {
+ if let Some(s) = addr.strip_suffix(']') {
+ IpAddr::from_str(s.get(1..)?).ok()
+ } else if let Ok(ip) = IpAddr::from_str(addr) {
+ Some(ip)
+ } else if let Ok(socket) = SocketAddr::from_str(addr) {
+ Some(socket.ip())
+ } else {
+ None
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ #[test]
+ fn test_parse_ip() {
+ let ip_addrs = [
+ "1.2.3.4",
+ "1.2.3.4:8000",
+ "2001:db8::",
+ "[2001:db8::]",
+ "[2001:db8::]:8000",
+ ];
+ for addr in ip_addrs {
+ assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}");
+ }
+ }
+}