]> Untitled Git - lemmy.git/blob - lemmy_rate_limit/src/rate_limiter.rs
routes.api: fix get_captcha endpoint (#1135)
[lemmy.git] / lemmy_rate_limit / src / rate_limiter.rs
1 use lemmy_utils::{APIError, IPAddr, LemmyError};
2 use log::debug;
3 use std::{collections::HashMap, time::SystemTime};
4 use strum::IntoEnumIterator;
5
6 #[derive(Debug, Clone)]
7 pub struct RateLimitBucket {
8   last_checked: SystemTime,
9   allowance: f64,
10 }
11
12 #[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone, AsRefStr)]
13 pub enum RateLimitType {
14   Message,
15   Register,
16   Post,
17   Image,
18 }
19
20 /// Rate limiting based on rate type and IP addr
21 #[derive(Debug, Clone)]
22 pub struct RateLimiter {
23   pub buckets: HashMap<RateLimitType, HashMap<IPAddr, RateLimitBucket>>,
24 }
25
26 impl Default for RateLimiter {
27   fn default() -> Self {
28     Self {
29       buckets: HashMap::<RateLimitType, HashMap<IPAddr, RateLimitBucket>>::new(),
30     }
31   }
32 }
33
34 impl RateLimiter {
35   fn insert_ip(&mut self, ip: &str) {
36     for rate_limit_type in RateLimitType::iter() {
37       if self.buckets.get(&rate_limit_type).is_none() {
38         self.buckets.insert(rate_limit_type, HashMap::new());
39       }
40
41       if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) {
42         if bucket.get(ip).is_none() {
43           bucket.insert(
44             ip.to_string(),
45             RateLimitBucket {
46               last_checked: SystemTime::now(),
47               allowance: -2f64,
48             },
49           );
50         }
51       }
52     }
53   }
54
55   #[allow(clippy::float_cmp)]
56   pub(super) fn check_rate_limit_full(
57     &mut self,
58     type_: RateLimitType,
59     ip: &str,
60     rate: i32,
61     per: i32,
62     check_only: bool,
63   ) -> Result<(), LemmyError> {
64     self.insert_ip(ip);
65     if let Some(bucket) = self.buckets.get_mut(&type_) {
66       if let Some(rate_limit) = bucket.get_mut(ip) {
67         let current = SystemTime::now();
68         let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
69
70         // The initial value
71         if rate_limit.allowance == -2f64 {
72           rate_limit.allowance = rate as f64;
73         };
74
75         rate_limit.last_checked = current;
76         rate_limit.allowance += time_passed * (rate as f64 / per as f64);
77         if !check_only && rate_limit.allowance > rate as f64 {
78           rate_limit.allowance = rate as f64;
79         }
80
81         if rate_limit.allowance < 1.0 {
82           debug!(
83             "Rate limited type: {}, IP: {}, time_passed: {}, allowance: {}",
84             type_.as_ref(),
85             ip,
86             time_passed,
87             rate_limit.allowance
88           );
89           Err(
90             APIError {
91               message: format!(
92                 "Too many requests. type: {}, IP: {}, {} per {} seconds",
93                 type_.as_ref(),
94                 ip,
95                 rate,
96                 per
97               ),
98             }
99             .into(),
100           )
101         } else {
102           if !check_only {
103             rate_limit.allowance -= 1.0;
104           }
105           Ok(())
106         }
107       } else {
108         Ok(())
109       }
110     } else {
111       Ok(())
112     }
113   }
114 }