]> Untitled Git - lemmy.git/blob - server/src/rate_limit/rate_limiter.rs
Proxy pictrs requests through Lemmy (fixes #371) (#77)
[lemmy.git] / server / src / rate_limit / rate_limiter.rs
1 use super::IPAddr;
2 use crate::{api::APIError, LemmyError};
3 use log::debug;
4 use std::{collections::HashMap, time::SystemTime};
5 use strum::IntoEnumIterator;
6
7 #[derive(Debug, Clone)]
8 pub struct RateLimitBucket {
9   last_checked: SystemTime,
10   allowance: f64,
11 }
12
13 #[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone, AsRefStr)]
14 pub enum RateLimitType {
15   Message,
16   Register,
17   Post,
18   Image,
19 }
20
21 /// Rate limiting based on rate type and IP addr
22 #[derive(Debug, Clone)]
23 pub struct RateLimiter {
24   pub buckets: HashMap<RateLimitType, HashMap<IPAddr, RateLimitBucket>>,
25 }
26
27 impl Default for RateLimiter {
28   fn default() -> Self {
29     Self {
30       buckets: HashMap::new(),
31     }
32   }
33 }
34
35 impl RateLimiter {
36   fn insert_ip(&mut self, ip: &str) {
37     for rate_limit_type in RateLimitType::iter() {
38       if self.buckets.get(&rate_limit_type).is_none() {
39         self.buckets.insert(rate_limit_type, HashMap::new());
40       }
41
42       if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) {
43         if bucket.get(ip).is_none() {
44           bucket.insert(
45             ip.to_string(),
46             RateLimitBucket {
47               last_checked: SystemTime::now(),
48               allowance: -2f64,
49             },
50           );
51         }
52       }
53     }
54   }
55
56   #[allow(clippy::float_cmp)]
57   pub(super) fn check_rate_limit_full(
58     &mut self,
59     type_: RateLimitType,
60     ip: &str,
61     rate: i32,
62     per: i32,
63     check_only: bool,
64   ) -> Result<(), LemmyError> {
65     self.insert_ip(ip);
66     if let Some(bucket) = self.buckets.get_mut(&type_) {
67       if let Some(rate_limit) = bucket.get_mut(ip) {
68         let current = SystemTime::now();
69         let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
70
71         // The initial value
72         if rate_limit.allowance == -2f64 {
73           rate_limit.allowance = rate as f64;
74         };
75
76         rate_limit.last_checked = current;
77         rate_limit.allowance += time_passed * (rate as f64 / per as f64);
78         if !check_only && rate_limit.allowance > rate as f64 {
79           rate_limit.allowance = rate as f64;
80         }
81
82         if rate_limit.allowance < 1.0 {
83           debug!(
84             "Rate limited type: {}, IP: {}, time_passed: {}, allowance: {}",
85             type_.as_ref(),
86             ip,
87             time_passed,
88             rate_limit.allowance
89           );
90           Err(
91             APIError {
92               message: format!(
93                 "Too many requests. type: {}, IP: {}, {} per {} seconds",
94                 type_.as_ref(),
95                 ip,
96                 rate,
97                 per
98               ),
99             }
100             .into(),
101           )
102         } else {
103           if !check_only {
104             rate_limit.allowance -= 1.0;
105           }
106           Ok(())
107         }
108       } else {
109         Ok(())
110       }
111     } else {
112       Ok(())
113     }
114   }
115 }