]> Untitled Git - lemmy.git/blob - crates/utils/src/rate_limit/mod.rs
7a5c1ec685759ee4a89a90562c92309d2a725edd
[lemmy.git] / crates / utils / src / rate_limit / mod.rs
1 use crate::error::{LemmyError, LemmyErrorType};
2 use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform};
3 use enum_map::enum_map;
4 use futures::future::{ok, Ready};
5 use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType};
6 use serde::{Deserialize, Serialize};
7 use std::{
8   future::Future,
9   net::{IpAddr, Ipv4Addr, SocketAddr},
10   pin::Pin,
11   rc::Rc,
12   str::FromStr,
13   sync::{Arc, Mutex},
14   task::{Context, Poll},
15   time::Duration,
16 };
17 use tokio::sync::{mpsc, mpsc::Sender, OnceCell};
18 use typed_builder::TypedBuilder;
19
20 pub mod rate_limiter;
21
22 #[derive(Debug, Deserialize, Serialize, Clone, TypedBuilder)]
23 pub struct RateLimitConfig {
24   #[builder(default = 180)]
25   /// Maximum number of messages created in interval
26   pub message: i32,
27   #[builder(default = 60)]
28   /// Interval length for message limit, in seconds
29   pub message_per_second: i32,
30   #[builder(default = 6)]
31   /// Maximum number of posts created in interval
32   pub post: i32,
33   #[builder(default = 300)]
34   /// Interval length for post limit, in seconds
35   pub post_per_second: i32,
36   #[builder(default = 3)]
37   /// Maximum number of registrations in interval
38   pub register: i32,
39   #[builder(default = 3600)]
40   /// Interval length for registration limit, in seconds
41   pub register_per_second: i32,
42   #[builder(default = 6)]
43   /// Maximum number of image uploads in interval
44   pub image: i32,
45   #[builder(default = 3600)]
46   /// Interval length for image uploads, in seconds
47   pub image_per_second: i32,
48   #[builder(default = 6)]
49   /// Maximum number of comments created in interval
50   pub comment: i32,
51   #[builder(default = 600)]
52   /// Interval length for comment limit, in seconds
53   pub comment_per_second: i32,
54   #[builder(default = 60)]
55   /// Maximum number of searches created in interval
56   pub search: i32,
57   #[builder(default = 600)]
58   /// Interval length for search limit, in seconds
59   pub search_per_second: i32,
60 }
61
62 #[derive(Debug, Clone)]
63 struct RateLimit {
64   pub rate_limiter: RateLimitStorage,
65   pub rate_limit_config: RateLimitConfig,
66 }
67
68 #[derive(Debug, Clone)]
69 pub struct RateLimitedGuard {
70   rate_limit: Arc<Mutex<RateLimit>>,
71   type_: RateLimitType,
72 }
73
74 /// Single instance of rate limit config and buckets, which is shared across all threads.
75 #[derive(Clone)]
76 pub struct RateLimitCell {
77   tx: Sender<RateLimitConfig>,
78   rate_limit: Arc<Mutex<RateLimit>>,
79 }
80
81 impl RateLimitCell {
82   /// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell.
83   pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self {
84     static LOCAL_INSTANCE: OnceCell<RateLimitCell> = OnceCell::const_new();
85     LOCAL_INSTANCE
86       .get_or_init(|| async {
87         let (tx, mut rx) = mpsc::channel::<RateLimitConfig>(4);
88         let rate_limit = Arc::new(Mutex::new(RateLimit {
89           rate_limiter: Default::default(),
90           rate_limit_config,
91         }));
92         let rate_limit2 = rate_limit.clone();
93         tokio::spawn(async move {
94           while let Some(r) = rx.recv().await {
95             rate_limit2
96               .lock()
97               .expect("Failed to lock rate limit mutex for updating")
98               .rate_limit_config = r;
99           }
100         });
101         RateLimitCell { tx, rate_limit }
102       })
103       .await
104   }
105
106   /// Call this when the config was updated, to update all in-memory cells.
107   pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> {
108     self.tx.send(config).await?;
109     Ok(())
110   }
111
112   /// Remove buckets older than the given duration
113   pub fn remove_older_than(&self, mut duration: Duration) {
114     let mut guard = self
115       .rate_limit
116       .lock()
117       .expect("Failed to lock rate limit mutex for reading");
118     let rate_limit = &guard.rate_limit_config;
119
120     // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check.
121     let max_interval_secs = enum_map! {
122       RateLimitType::Message => rate_limit.message_per_second,
123       RateLimitType::Post => rate_limit.post_per_second,
124       RateLimitType::Register => rate_limit.register_per_second,
125       RateLimitType::Image => rate_limit.image_per_second,
126       RateLimitType::Comment => rate_limit.comment_per_second,
127       RateLimitType::Search => rate_limit.search_per_second,
128     }
129     .into_values()
130     .max()
131     .and_then(|max| u64::try_from(max).ok())
132     .unwrap_or(0);
133
134     duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs));
135
136     guard
137       .rate_limiter
138       .remove_older_than(duration, InstantSecs::now())
139   }
140
141   pub fn message(&self) -> RateLimitedGuard {
142     self.kind(RateLimitType::Message)
143   }
144
145   pub fn post(&self) -> RateLimitedGuard {
146     self.kind(RateLimitType::Post)
147   }
148
149   pub fn register(&self) -> RateLimitedGuard {
150     self.kind(RateLimitType::Register)
151   }
152
153   pub fn image(&self) -> RateLimitedGuard {
154     self.kind(RateLimitType::Image)
155   }
156
157   pub fn comment(&self) -> RateLimitedGuard {
158     self.kind(RateLimitType::Comment)
159   }
160
161   pub fn search(&self) -> RateLimitedGuard {
162     self.kind(RateLimitType::Search)
163   }
164
165   fn kind(&self, type_: RateLimitType) -> RateLimitedGuard {
166     RateLimitedGuard {
167       rate_limit: self.rate_limit.clone(),
168       type_,
169     }
170   }
171 }
172
173 pub struct RateLimitedMiddleware<S> {
174   rate_limited: RateLimitedGuard,
175   service: Rc<S>,
176 }
177
178 impl RateLimitedGuard {
179   /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
180   pub fn check(self, ip_addr: IpAddr) -> bool {
181     // Does not need to be blocking because the RwLock in settings never held across await points,
182     // and the operation here locks only long enough to clone
183     let mut guard = self
184       .rate_limit
185       .lock()
186       .expect("Failed to lock rate limit mutex for reading");
187     let rate_limit = &guard.rate_limit_config;
188
189     let (kind, interval) = match self.type_ {
190       RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
191       RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
192       RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second),
193       RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
194       RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
195       RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second),
196     };
197     let limiter = &mut guard.rate_limiter;
198
199     limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now())
200   }
201 }
202
203 impl<S> Transform<S, ServiceRequest> for RateLimitedGuard
204 where
205   S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
206   S::Future: 'static,
207 {
208   type Response = S::Response;
209   type Error = actix_web::Error;
210   type InitError = ();
211   type Transform = RateLimitedMiddleware<S>;
212   type Future = Ready<Result<Self::Transform, Self::InitError>>;
213
214   fn new_transform(&self, service: S) -> Self::Future {
215     ok(RateLimitedMiddleware {
216       rate_limited: self.clone(),
217       service: Rc::new(service),
218     })
219   }
220 }
221
222 type FutResult<T, E> = dyn Future<Output = Result<T, E>>;
223
224 impl<S> Service<ServiceRequest> for RateLimitedMiddleware<S>
225 where
226   S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
227   S::Future: 'static,
228 {
229   type Response = S::Response;
230   type Error = actix_web::Error;
231   type Future = Pin<Box<FutResult<Self::Response, Self::Error>>>;
232
233   fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
234     self.service.poll_ready(cx)
235   }
236
237   fn call(&self, req: ServiceRequest) -> Self::Future {
238     let ip_addr = get_ip(&req.connection_info());
239
240     let rate_limited = self.rate_limited.clone();
241     let service = self.service.clone();
242
243     Box::pin(async move {
244       if rate_limited.check(ip_addr) {
245         service.call(req).await
246       } else {
247         let (http_req, _) = req.into_parts();
248         Ok(ServiceResponse::from_err(
249           LemmyError::from(LemmyErrorType::RateLimitError),
250           http_req,
251         ))
252       }
253     })
254   }
255 }
256
257 fn get_ip(conn_info: &ConnectionInfo) -> IpAddr {
258   conn_info
259     .realip_remote_addr()
260     .and_then(parse_ip)
261     .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
262 }
263
264 fn parse_ip(addr: &str) -> Option<IpAddr> {
265   if let Some(s) = addr.strip_suffix(']') {
266     IpAddr::from_str(s.get(1..)?).ok()
267   } else if let Ok(ip) = IpAddr::from_str(addr) {
268     Some(ip)
269   } else if let Ok(socket) = SocketAddr::from_str(addr) {
270     Some(socket.ip())
271   } else {
272     None
273   }
274 }
275
276 #[cfg(test)]
277 mod tests {
278   #[test]
279   fn test_parse_ip() {
280     let ip_addrs = [
281       "1.2.3.4",
282       "1.2.3.4:8000",
283       "2001:db8::",
284       "[2001:db8::]",
285       "[2001:db8::]:8000",
286     ];
287     for addr in ip_addrs {
288       assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}");
289     }
290   }
291 }