]> Untitled Git - lemmy.git/blob - crates/utils/src/rate_limit/rate_limiter.rs
Lowering search rate limit. Fixes #2153 (#2154)
[lemmy.git] / crates / utils / src / rate_limit / rate_limiter.rs
1 use crate::IpAddr;
2 use std::{collections::HashMap, time::Instant};
3 use strum::IntoEnumIterator;
4 use tracing::debug;
5
6 #[derive(Debug, Clone)]
7 struct RateLimitBucket {
8   last_checked: Instant,
9   allowance: f64,
10 }
11
12 #[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone, AsRefStr)]
13 pub(crate) enum RateLimitType {
14   Message,
15   Register,
16   Post,
17   Image,
18   Comment,
19   Search,
20 }
21
22 /// Rate limiting based on rate type and IP addr
23 #[derive(Debug, Clone, Default)]
24 pub struct RateLimiter {
25   buckets: HashMap<RateLimitType, HashMap<IpAddr, RateLimitBucket>>,
26 }
27
28 impl RateLimiter {
29   fn insert_ip(&mut self, ip: &IpAddr) {
30     for rate_limit_type in RateLimitType::iter() {
31       if self.buckets.get(&rate_limit_type).is_none() {
32         self.buckets.insert(rate_limit_type, HashMap::new());
33       }
34
35       if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) {
36         if bucket.get(ip).is_none() {
37           bucket.insert(
38             ip.clone(),
39             RateLimitBucket {
40               last_checked: Instant::now(),
41               allowance: -2f64,
42             },
43           );
44         }
45       }
46     }
47   }
48
49   /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
50   ///
51   /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
52   #[allow(clippy::float_cmp)]
53   pub(super) fn check_rate_limit_full(
54     &mut self,
55     type_: RateLimitType,
56     ip: &IpAddr,
57     rate: i32,
58     per: i32,
59   ) -> bool {
60     self.insert_ip(ip);
61     if let Some(bucket) = self.buckets.get_mut(&type_) {
62       if let Some(rate_limit) = bucket.get_mut(ip) {
63         let current = Instant::now();
64         let time_passed = current.duration_since(rate_limit.last_checked).as_secs() as f64;
65
66         // The initial value
67         if rate_limit.allowance == -2f64 {
68           rate_limit.allowance = rate as f64;
69         };
70
71         rate_limit.last_checked = current;
72         rate_limit.allowance += time_passed * (rate as f64 / per as f64);
73         if rate_limit.allowance > rate as f64 {
74           rate_limit.allowance = rate as f64;
75         }
76
77         if rate_limit.allowance < 1.0 {
78           debug!(
79             "Rate limited type: {}, IP: {}, time_passed: {}, allowance: {}",
80             type_.as_ref(),
81             ip,
82             time_passed,
83             rate_limit.allowance
84           );
85           false
86         } else {
87           rate_limit.allowance -= 1.0;
88           true
89         }
90       } else {
91         true
92       }
93     } else {
94       true
95     }
96   }
97 }