]> Untitled Git - lemmy.git/blob - lemmy_rate_limit/src/lib.rs
routes.api: fix get_captcha endpoint (#1135)
[lemmy.git] / lemmy_rate_limit / src / lib.rs
1 #[macro_use]
2 extern crate strum_macros;
3 extern crate actix_web;
4 extern crate futures;
5 extern crate log;
6 extern crate tokio;
7
8 use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
9 use futures::future::{ok, Ready};
10 use lemmy_utils::{
11   settings::{RateLimitConfig, Settings},
12   utils::get_ip,
13   LemmyError,
14 };
15 use rate_limiter::{RateLimitType, RateLimiter};
16 use std::{
17   future::Future,
18   pin::Pin,
19   sync::Arc,
20   task::{Context, Poll},
21 };
22 use tokio::sync::Mutex;
23
24 pub mod rate_limiter;
25
26 #[derive(Debug, Clone)]
27 pub struct RateLimit {
28   // it might be reasonable to use a std::sync::Mutex here, since we don't need to lock this
29   // across await points
30   pub rate_limiter: Arc<Mutex<RateLimiter>>,
31 }
32
33 #[derive(Debug, Clone)]
34 pub struct RateLimited {
35   rate_limiter: Arc<Mutex<RateLimiter>>,
36   type_: RateLimitType,
37 }
38
39 pub struct RateLimitedMiddleware<S> {
40   rate_limited: RateLimited,
41   service: S,
42 }
43
44 impl RateLimit {
45   pub fn message(&self) -> RateLimited {
46     self.kind(RateLimitType::Message)
47   }
48
49   pub fn post(&self) -> RateLimited {
50     self.kind(RateLimitType::Post)
51   }
52
53   pub fn register(&self) -> RateLimited {
54     self.kind(RateLimitType::Register)
55   }
56
57   pub fn image(&self) -> RateLimited {
58     self.kind(RateLimitType::Image)
59   }
60
61   fn kind(&self, type_: RateLimitType) -> RateLimited {
62     RateLimited {
63       rate_limiter: self.rate_limiter.clone(),
64       type_,
65     }
66   }
67 }
68
69 impl RateLimited {
70   pub async fn wrap<T, E>(
71     self,
72     ip_addr: String,
73     fut: impl Future<Output = Result<T, E>>,
74   ) -> Result<T, E>
75   where
76     E: From<LemmyError>,
77   {
78     // Does not need to be blocking because the RwLock in settings never held across await points,
79     // and the operation here locks only long enough to clone
80     let rate_limit: RateLimitConfig = Settings::get().rate_limit;
81
82     // before
83     {
84       let mut limiter = self.rate_limiter.lock().await;
85
86       match self.type_ {
87         RateLimitType::Message => {
88           limiter.check_rate_limit_full(
89             self.type_,
90             &ip_addr,
91             rate_limit.message,
92             rate_limit.message_per_second,
93             false,
94           )?;
95
96           drop(limiter);
97           return fut.await;
98         }
99         RateLimitType::Post => {
100           limiter.check_rate_limit_full(
101             self.type_,
102             &ip_addr,
103             rate_limit.post,
104             rate_limit.post_per_second,
105             true,
106           )?;
107         }
108         RateLimitType::Register => {
109           limiter.check_rate_limit_full(
110             self.type_,
111             &ip_addr,
112             rate_limit.register,
113             rate_limit.register_per_second,
114             true,
115           )?;
116         }
117         RateLimitType::Image => {
118           limiter.check_rate_limit_full(
119             self.type_,
120             &ip_addr,
121             rate_limit.image,
122             rate_limit.image_per_second,
123             false,
124           )?;
125         }
126       };
127     }
128
129     let res = fut.await;
130
131     // after
132     {
133       let mut limiter = self.rate_limiter.lock().await;
134       if res.is_ok() {
135         match self.type_ {
136           RateLimitType::Post => {
137             limiter.check_rate_limit_full(
138               self.type_,
139               &ip_addr,
140               rate_limit.post,
141               rate_limit.post_per_second,
142               false,
143             )?;
144           }
145           RateLimitType::Register => {
146             limiter.check_rate_limit_full(
147               self.type_,
148               &ip_addr,
149               rate_limit.register,
150               rate_limit.register_per_second,
151               false,
152             )?;
153           }
154           _ => (),
155         };
156       }
157     }
158
159     res
160   }
161 }
162
163 impl<S> Transform<S> for RateLimited
164 where
165   S: Service<Request = ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
166   S::Future: 'static,
167 {
168   type Request = S::Request;
169   type Response = S::Response;
170   type Error = actix_web::Error;
171   type InitError = ();
172   type Transform = RateLimitedMiddleware<S>;
173   type Future = Ready<Result<Self::Transform, Self::InitError>>;
174
175   fn new_transform(&self, service: S) -> Self::Future {
176     ok(RateLimitedMiddleware {
177       rate_limited: self.clone(),
178       service,
179     })
180   }
181 }
182
183 type FutResult<T, E> = dyn Future<Output = Result<T, E>>;
184
185 impl<S> Service for RateLimitedMiddleware<S>
186 where
187   S: Service<Request = ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
188   S::Future: 'static,
189 {
190   type Request = S::Request;
191   type Response = S::Response;
192   type Error = actix_web::Error;
193   type Future = Pin<Box<FutResult<Self::Response, Self::Error>>>;
194
195   fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
196     self.service.poll_ready(cx)
197   }
198
199   fn call(&mut self, req: S::Request) -> Self::Future {
200     let ip_addr = get_ip(&req.connection_info());
201
202     let fut = self
203       .rate_limited
204       .clone()
205       .wrap(ip_addr, self.service.call(req));
206
207     Box::pin(async move { fut.await.map_err(actix_web::Error::from) })
208   }
209 }