]> Untitled Git - lemmy.git/blob - crates/utils/src/rate_limit/rate_limiter.rs
Show rate limit algorithm. Fixes #2136
[lemmy.git] / crates / utils / src / rate_limit / rate_limiter.rs
1 use crate::{IpAddr, LemmyError};
2 use std::{collections::HashMap, time::SystemTime};
3 use strum::IntoEnumIterator;
4 use tracing::debug;
5
6 #[derive(Debug, Clone)]
7 struct RateLimitBucket {
8   last_checked: SystemTime,
9   allowance: f64,
10 }
11
12 #[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone, AsRefStr)]
13 pub(crate) enum RateLimitType {
14   Message,
15   Register,
16   Post,
17   Image,
18   Comment,
19 }
20
21 /// Rate limiting based on rate type and IP addr
22 #[derive(Debug, Clone, Default)]
23 pub struct RateLimiter {
24   buckets: HashMap<RateLimitType, HashMap<IpAddr, RateLimitBucket>>,
25 }
26
27 impl RateLimiter {
28   fn insert_ip(&mut self, ip: &IpAddr) {
29     for rate_limit_type in RateLimitType::iter() {
30       if self.buckets.get(&rate_limit_type).is_none() {
31         self.buckets.insert(rate_limit_type, HashMap::new());
32       }
33
34       if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) {
35         if bucket.get(ip).is_none() {
36           bucket.insert(
37             ip.clone(),
38             RateLimitBucket {
39               last_checked: SystemTime::now(),
40               allowance: -2f64,
41             },
42           );
43         }
44       }
45     }
46   }
47
48   /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
49   #[allow(clippy::float_cmp)]
50   pub(super) fn check_rate_limit_full(
51     &mut self,
52     type_: RateLimitType,
53     ip: &IpAddr,
54     rate: i32,
55     per: i32,
56     check_only: bool,
57   ) -> Result<(), LemmyError> {
58     self.insert_ip(ip);
59     if let Some(bucket) = self.buckets.get_mut(&type_) {
60       if let Some(rate_limit) = bucket.get_mut(ip) {
61         let current = SystemTime::now();
62         let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
63
64         // The initial value
65         if rate_limit.allowance == -2f64 {
66           rate_limit.allowance = rate as f64;
67         };
68
69         rate_limit.last_checked = current;
70         rate_limit.allowance += time_passed * (rate as f64 / per as f64);
71         if !check_only && rate_limit.allowance > rate as f64 {
72           rate_limit.allowance = rate as f64;
73         }
74
75         if rate_limit.allowance < 1.0 {
76           debug!(
77             "Rate limited type: {}, IP: {}, time_passed: {}, allowance: {}",
78             type_.as_ref(),
79             ip,
80             time_passed,
81             rate_limit.allowance
82           );
83           Err(LemmyError::from_error_message(
84             anyhow::anyhow!(
85               "Too many requests. type: {}, IP: {}, {} per {} seconds",
86               type_.as_ref(),
87               ip,
88               rate,
89               per
90             ),
91             "too_many_requests",
92           ))
93         } else {
94           if !check_only {
95             rate_limit.allowance -= 1.0;
96           }
97           Ok(())
98         }
99       } else {
100         Ok(())
101       }
102     } else {
103       Ok(())
104     }
105   }
106 }