]> Untitled Git - lemmy.git/blob - crates/utils/src/rate_limit/mod.rs
Lowering search rate limit. Fixes #2153 (#2154)
[lemmy.git] / crates / utils / src / rate_limit / mod.rs
1 use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr};
2 use actix_web::{
3   dev::{Service, ServiceRequest, ServiceResponse, Transform},
4   HttpResponse,
5 };
6 use futures::future::{ok, Ready};
7 use parking_lot::Mutex;
8 use rate_limiter::{RateLimitType, RateLimiter};
9 use std::{
10   future::Future,
11   pin::Pin,
12   rc::Rc,
13   sync::Arc,
14   task::{Context, Poll},
15 };
16
17 pub mod rate_limiter;
18
19 #[derive(Debug, Clone)]
20 pub struct RateLimit {
21   // it might be reasonable to use a std::sync::Mutex here, since we don't need to lock this
22   // across await points
23   pub rate_limiter: Arc<Mutex<RateLimiter>>,
24   pub rate_limit_config: RateLimitConfig,
25 }
26
27 #[derive(Debug, Clone)]
28 pub struct RateLimited {
29   rate_limiter: Arc<Mutex<RateLimiter>>,
30   rate_limit_config: RateLimitConfig,
31   type_: RateLimitType,
32 }
33
34 pub struct RateLimitedMiddleware<S> {
35   rate_limited: RateLimited,
36   service: Rc<S>,
37 }
38
39 impl RateLimit {
40   pub fn message(&self) -> RateLimited {
41     self.kind(RateLimitType::Message)
42   }
43
44   pub fn post(&self) -> RateLimited {
45     self.kind(RateLimitType::Post)
46   }
47
48   pub fn register(&self) -> RateLimited {
49     self.kind(RateLimitType::Register)
50   }
51
52   pub fn image(&self) -> RateLimited {
53     self.kind(RateLimitType::Image)
54   }
55
56   pub fn comment(&self) -> RateLimited {
57     self.kind(RateLimitType::Comment)
58   }
59
60   pub fn search(&self) -> RateLimited {
61     self.kind(RateLimitType::Search)
62   }
63
64   fn kind(&self, type_: RateLimitType) -> RateLimited {
65     RateLimited {
66       rate_limiter: self.rate_limiter.clone(),
67       rate_limit_config: self.rate_limit_config.clone(),
68       type_,
69     }
70   }
71 }
72
73 impl RateLimited {
74   /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
75   pub fn check(self, ip_addr: IpAddr) -> bool {
76     // Does not need to be blocking because the RwLock in settings never held across await points,
77     // and the operation here locks only long enough to clone
78     let rate_limit = self.rate_limit_config;
79
80     let (kind, interval) = match self.type_ {
81       RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
82       RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
83       RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second),
84       RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
85       RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
86       RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second),
87     };
88     let mut limiter = self.rate_limiter.lock();
89
90     limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
91   }
92 }
93
94 impl<S> Transform<S, ServiceRequest> for RateLimited
95 where
96   S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
97   S::Future: 'static,
98 {
99   type Response = S::Response;
100   type Error = actix_web::Error;
101   type InitError = ();
102   type Transform = RateLimitedMiddleware<S>;
103   type Future = Ready<Result<Self::Transform, Self::InitError>>;
104
105   fn new_transform(&self, service: S) -> Self::Future {
106     ok(RateLimitedMiddleware {
107       rate_limited: self.clone(),
108       service: Rc::new(service),
109     })
110   }
111 }
112
113 type FutResult<T, E> = dyn Future<Output = Result<T, E>>;
114
115 impl<S> Service<ServiceRequest> for RateLimitedMiddleware<S>
116 where
117   S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
118   S::Future: 'static,
119 {
120   type Response = S::Response;
121   type Error = actix_web::Error;
122   type Future = Pin<Box<FutResult<Self::Response, Self::Error>>>;
123
124   fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125     self.service.poll_ready(cx)
126   }
127
128   fn call(&self, req: ServiceRequest) -> Self::Future {
129     let ip_addr = get_ip(&req.connection_info());
130
131     let rate_limited = self.rate_limited.clone();
132     let service = self.service.clone();
133
134     Box::pin(async move {
135       if rate_limited.check(ip_addr) {
136         service.call(req).await
137       } else {
138         let (http_req, _) = req.into_parts();
139         // if rate limit was hit, respond with http 400
140         Ok(ServiceResponse::new(
141           http_req,
142           HttpResponse::BadRequest().finish(),
143         ))
144       }
145     })
146   }
147 }