]> Untitled Git - lemmy.git/blob - crates/utils/src/rate_limit/mod.rs
Live reload settings (fixes #2508) (#2543)
[lemmy.git] / crates / utils / src / rate_limit / mod.rs
1 use crate::{error::LemmyError, utils::get_ip, IpAddr};
2 use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
3 use futures::future::{ok, Ready};
4 use rate_limiter::{RateLimitStorage, RateLimitType};
5 use serde::{Deserialize, Serialize};
6 use std::{
7   future::Future,
8   pin::Pin,
9   rc::Rc,
10   sync::{Arc, Mutex},
11   task::{Context, Poll},
12 };
13 use tokio::sync::{mpsc, mpsc::Sender, OnceCell};
14 use typed_builder::TypedBuilder;
15
16 pub mod rate_limiter;
17
18 #[derive(Debug, Deserialize, Serialize, Clone, TypedBuilder)]
19 pub struct RateLimitConfig {
20   #[builder(default = 180)]
21   /// Maximum number of messages created in interval
22   pub message: i32,
23   #[builder(default = 60)]
24   /// Interval length for message limit, in seconds
25   pub message_per_second: i32,
26   #[builder(default = 6)]
27   /// Maximum number of posts created in interval
28   pub post: i32,
29   #[builder(default = 300)]
30   /// Interval length for post limit, in seconds
31   pub post_per_second: i32,
32   #[builder(default = 3)]
33   /// Maximum number of registrations in interval
34   pub register: i32,
35   #[builder(default = 3600)]
36   /// Interval length for registration limit, in seconds
37   pub register_per_second: i32,
38   #[builder(default = 6)]
39   /// Maximum number of image uploads in interval
40   pub image: i32,
41   #[builder(default = 3600)]
42   /// Interval length for image uploads, in seconds
43   pub image_per_second: i32,
44   #[builder(default = 6)]
45   /// Maximum number of comments created in interval
46   pub comment: i32,
47   #[builder(default = 600)]
48   /// Interval length for comment limit, in seconds
49   pub comment_per_second: i32,
50   #[builder(default = 60)]
51   /// Maximum number of searches created in interval
52   pub search: i32,
53   #[builder(default = 600)]
54   /// Interval length for search limit, in seconds
55   pub search_per_second: i32,
56 }
57
58 #[derive(Debug, Clone)]
59 struct RateLimit {
60   pub rate_limiter: RateLimitStorage,
61   pub rate_limit_config: RateLimitConfig,
62 }
63
64 #[derive(Debug, Clone)]
65 pub struct RateLimitedGuard {
66   rate_limit: Arc<Mutex<RateLimit>>,
67   type_: RateLimitType,
68 }
69
70 /// Single instance of rate limit config and buckets, which is shared across all threads.
71 #[derive(Clone)]
72 pub struct RateLimitCell {
73   tx: Sender<RateLimitConfig>,
74   rate_limit: Arc<Mutex<RateLimit>>,
75 }
76
77 impl RateLimitCell {
78   /// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell.
79   pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self {
80     static LOCAL_INSTANCE: OnceCell<RateLimitCell> = OnceCell::const_new();
81     LOCAL_INSTANCE
82       .get_or_init(|| async {
83         let (tx, mut rx) = mpsc::channel::<RateLimitConfig>(4);
84         let rate_limit = Arc::new(Mutex::new(RateLimit {
85           rate_limiter: Default::default(),
86           rate_limit_config,
87         }));
88         let rate_limit2 = rate_limit.clone();
89         tokio::spawn(async move {
90           while let Some(r) = rx.recv().await {
91             rate_limit2
92               .lock()
93               .expect("Failed to lock rate limit mutex for updating")
94               .rate_limit_config = r;
95           }
96         });
97         RateLimitCell { tx, rate_limit }
98       })
99       .await
100   }
101
102   /// Call this when the config was updated, to update all in-memory cells.
103   pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> {
104     self.tx.send(config).await?;
105     Ok(())
106   }
107
108   pub fn message(&self) -> RateLimitedGuard {
109     self.kind(RateLimitType::Message)
110   }
111
112   pub fn post(&self) -> RateLimitedGuard {
113     self.kind(RateLimitType::Post)
114   }
115
116   pub fn register(&self) -> RateLimitedGuard {
117     self.kind(RateLimitType::Register)
118   }
119
120   pub fn image(&self) -> RateLimitedGuard {
121     self.kind(RateLimitType::Image)
122   }
123
124   pub fn comment(&self) -> RateLimitedGuard {
125     self.kind(RateLimitType::Comment)
126   }
127
128   pub fn search(&self) -> RateLimitedGuard {
129     self.kind(RateLimitType::Search)
130   }
131
132   fn kind(&self, type_: RateLimitType) -> RateLimitedGuard {
133     RateLimitedGuard {
134       rate_limit: self.rate_limit.clone(),
135       type_,
136     }
137   }
138 }
139
140 pub struct RateLimitedMiddleware<S> {
141   rate_limited: RateLimitedGuard,
142   service: Rc<S>,
143 }
144
145 impl RateLimitedGuard {
146   /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
147   pub fn check(self, ip_addr: IpAddr) -> bool {
148     // Does not need to be blocking because the RwLock in settings never held across await points,
149     // and the operation here locks only long enough to clone
150     let mut guard = self
151       .rate_limit
152       .lock()
153       .expect("Failed to lock rate limit mutex for reading");
154     let rate_limit = &guard.rate_limit_config;
155
156     let (kind, interval) = match self.type_ {
157       RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
158       RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
159       RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second),
160       RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
161       RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
162       RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second),
163     };
164     let limiter = &mut guard.rate_limiter;
165
166     limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
167   }
168 }
169
170 impl<S> Transform<S, ServiceRequest> for RateLimitedGuard
171 where
172   S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
173   S::Future: 'static,
174 {
175   type Response = S::Response;
176   type Error = actix_web::Error;
177   type InitError = ();
178   type Transform = RateLimitedMiddleware<S>;
179   type Future = Ready<Result<Self::Transform, Self::InitError>>;
180
181   fn new_transform(&self, service: S) -> Self::Future {
182     ok(RateLimitedMiddleware {
183       rate_limited: self.clone(),
184       service: Rc::new(service),
185     })
186   }
187 }
188
189 type FutResult<T, E> = dyn Future<Output = Result<T, E>>;
190
191 impl<S> Service<ServiceRequest> for RateLimitedMiddleware<S>
192 where
193   S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
194   S::Future: 'static,
195 {
196   type Response = S::Response;
197   type Error = actix_web::Error;
198   type Future = Pin<Box<FutResult<Self::Response, Self::Error>>>;
199
200   fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
201     self.service.poll_ready(cx)
202   }
203
204   fn call(&self, req: ServiceRequest) -> Self::Future {
205     let ip_addr = get_ip(&req.connection_info());
206
207     let rate_limited = self.rate_limited.clone();
208     let service = self.service.clone();
209
210     Box::pin(async move {
211       if rate_limited.check(ip_addr) {
212         service.call(req).await
213       } else {
214         let (http_req, _) = req.into_parts();
215         Ok(ServiceResponse::from_err(
216           LemmyError::from_message("rate_limit_error"),
217           http_req,
218         ))
219       }
220     })
221   }
222 }