]> Untitled Git - lemmy.git/blob - crates/utils/src/rate_limit/mod.rs
79bdbb1d8fd5dc393a386f33f5744dc95ae361d6
[lemmy.git] / crates / utils / src / rate_limit / mod.rs
1 use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr};
2 use actix_web::{
3   dev::{Service, ServiceRequest, ServiceResponse, Transform},
4   HttpResponse,
5 };
6 use futures::future::{ok, Ready};
7 use rate_limiter::{RateLimitType, RateLimiter};
8 use std::{
9   future::Future,
10   pin::Pin,
11   rc::Rc,
12   sync::{Arc, Mutex},
13   task::{Context, Poll},
14 };
15
16 pub mod rate_limiter;
17
18 #[derive(Debug, Clone)]
19 pub struct RateLimit {
20   // it might be reasonable to use a std::sync::Mutex here, since we don't need to lock this
21   // across await points
22   pub rate_limiter: Arc<Mutex<RateLimiter>>,
23   pub rate_limit_config: RateLimitConfig,
24 }
25
26 #[derive(Debug, Clone)]
27 pub struct RateLimited {
28   rate_limiter: Arc<Mutex<RateLimiter>>,
29   rate_limit_config: RateLimitConfig,
30   type_: RateLimitType,
31 }
32
33 pub struct RateLimitedMiddleware<S> {
34   rate_limited: RateLimited,
35   service: Rc<S>,
36 }
37
38 impl RateLimit {
39   pub fn message(&self) -> RateLimited {
40     self.kind(RateLimitType::Message)
41   }
42
43   pub fn post(&self) -> RateLimited {
44     self.kind(RateLimitType::Post)
45   }
46
47   pub fn register(&self) -> RateLimited {
48     self.kind(RateLimitType::Register)
49   }
50
51   pub fn image(&self) -> RateLimited {
52     self.kind(RateLimitType::Image)
53   }
54
55   pub fn comment(&self) -> RateLimited {
56     self.kind(RateLimitType::Comment)
57   }
58
59   pub fn search(&self) -> RateLimited {
60     self.kind(RateLimitType::Search)
61   }
62
63   fn kind(&self, type_: RateLimitType) -> RateLimited {
64     RateLimited {
65       rate_limiter: self.rate_limiter.clone(),
66       rate_limit_config: self.rate_limit_config.clone(),
67       type_,
68     }
69   }
70 }
71
72 impl RateLimited {
73   /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
74   pub fn check(self, ip_addr: IpAddr) -> bool {
75     // Does not need to be blocking because the RwLock in settings never held across await points,
76     // and the operation here locks only long enough to clone
77     let rate_limit = self.rate_limit_config;
78
79     let (kind, interval) = match self.type_ {
80       RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
81       RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
82       RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second),
83       RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
84       RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
85       RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second),
86     };
87     let mut limiter = self.rate_limiter.lock().expect("mutex poison error");
88
89     limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
90   }
91 }
92
93 impl<S> Transform<S, ServiceRequest> for RateLimited
94 where
95   S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
96   S::Future: 'static,
97 {
98   type Response = S::Response;
99   type Error = actix_web::Error;
100   type InitError = ();
101   type Transform = RateLimitedMiddleware<S>;
102   type Future = Ready<Result<Self::Transform, Self::InitError>>;
103
104   fn new_transform(&self, service: S) -> Self::Future {
105     ok(RateLimitedMiddleware {
106       rate_limited: self.clone(),
107       service: Rc::new(service),
108     })
109   }
110 }
111
112 type FutResult<T, E> = dyn Future<Output = Result<T, E>>;
113
114 impl<S> Service<ServiceRequest> for RateLimitedMiddleware<S>
115 where
116   S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
117   S::Future: 'static,
118 {
119   type Response = S::Response;
120   type Error = actix_web::Error;
121   type Future = Pin<Box<FutResult<Self::Response, Self::Error>>>;
122
123   fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
124     self.service.poll_ready(cx)
125   }
126
127   fn call(&self, req: ServiceRequest) -> Self::Future {
128     let ip_addr = get_ip(&req.connection_info());
129
130     let rate_limited = self.rate_limited.clone();
131     let service = self.service.clone();
132
133     Box::pin(async move {
134       if rate_limited.check(ip_addr) {
135         service.call(req).await
136       } else {
137         let (http_req, _) = req.into_parts();
138         // if rate limit was hit, respond with http 400
139         Ok(ServiceResponse::new(
140           http_req,
141           HttpResponse::BadRequest().finish(),
142         ))
143       }
144     })
145   }
146 }