]> Untitled Git - lemmy.git/blob - crates/utils/src/rate_limit/rate_limiter.rs
ed3dc569ef797eabbee27ebd808bea0e39252ce4
[lemmy.git] / crates / utils / src / rate_limit / rate_limiter.rs
1 use enum_map::{enum_map, EnumMap};
2 use once_cell::sync::Lazy;
3 use std::{
4   collections::HashMap,
5   hash::Hash,
6   net::{IpAddr, Ipv4Addr, Ipv6Addr},
7   time::{Duration, Instant},
8 };
9 use tracing::debug;
10
11 const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0;
12
13 static START_TIME: Lazy<Instant> = Lazy::new(Instant::now);
14
15 /// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't
16 /// store nanoseconds
17 #[derive(PartialEq, Debug, Clone, Copy)]
18 pub struct InstantSecs {
19   secs: u32,
20 }
21
22 impl InstantSecs {
23   pub fn now() -> Self {
24     InstantSecs {
25       secs: u32::try_from(START_TIME.elapsed().as_secs())
26         .expect("server has been running for over 136 years"),
27     }
28   }
29
30   fn secs_since(self, earlier: Self) -> u32 {
31     self.secs.saturating_sub(earlier.secs)
32   }
33
34   fn to_instant(self) -> Instant {
35     *START_TIME + Duration::from_secs(self.secs.into())
36   }
37 }
38
39 #[derive(PartialEq, Debug, Clone)]
40 struct RateLimitBucket {
41   last_checked: InstantSecs,
42   /// This field stores the amount of tokens that were present at `last_checked`.
43   /// The amount of tokens steadily increases until it reaches the bucket's capacity.
44   /// Performing the rate-limited action consumes 1 token.
45   tokens: f32,
46 }
47
48 #[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)]
49 pub(crate) enum RateLimitType {
50   Message,
51   Register,
52   Post,
53   Image,
54   Comment,
55   Search,
56 }
57
58 type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
59
60 #[derive(PartialEq, Debug, Clone)]
61 struct RateLimitedGroup<C> {
62   total: EnumMap<RateLimitType, RateLimitBucket>,
63   children: C,
64 }
65
66 impl<C: Default> RateLimitedGroup<C> {
67   fn new(now: InstantSecs) -> Self {
68     RateLimitedGroup {
69       total: enum_map! {
70         _ => RateLimitBucket {
71           last_checked: now,
72           tokens: UNINITIALIZED_TOKEN_AMOUNT,
73         },
74       },
75       children: Default::default(),
76     }
77   }
78
79   fn check_total(
80     &mut self,
81     type_: RateLimitType,
82     now: InstantSecs,
83     capacity: i32,
84     secs_to_refill: i32,
85   ) -> bool {
86     let capacity = capacity as f32;
87     let secs_to_refill = secs_to_refill as f32;
88
89     #[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton
90     let bucket = &mut self.total[type_];
91
92     if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT {
93       bucket.tokens = capacity;
94     }
95
96     let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32;
97     bucket.last_checked = now;
98
99     // For `secs_since_last_checked` seconds, increase `bucket.tokens`
100     // by `capacity` every `secs_to_refill` seconds
101     bucket.tokens += {
102       let tokens_per_sec = capacity / secs_to_refill;
103       secs_since_last_checked * tokens_per_sec
104     };
105
106     // Prevent `bucket.tokens` from exceeding `capacity`
107     if bucket.tokens > capacity {
108       bucket.tokens = capacity;
109     }
110
111     if bucket.tokens < 1.0 {
112       // Not enough tokens yet
113       debug!(
114         "Rate limited type: {}, time_passed: {}, allowance: {}",
115         type_.as_ref(),
116         secs_since_last_checked,
117         bucket.tokens
118       );
119       false
120     } else {
121       // Consume 1 token
122       bucket.tokens -= 1.0;
123       true
124     }
125   }
126 }
127
128 /// Rate limiting based on rate type and IP addr
129 #[derive(PartialEq, Debug, Clone, Default)]
130 pub struct RateLimitStorage {
131   /// One bucket per individual IPv4 address
132   ipv4_buckets: Map<Ipv4Addr, ()>,
133   /// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses
134   ipv6_buckets: Map<[u8; 6], Map<u8, Map<u8, ()>>>,
135 }
136
137 impl RateLimitStorage {
138   /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
139   ///
140   /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
141   pub(super) fn check_rate_limit_full(
142     &mut self,
143     type_: RateLimitType,
144     ip: IpAddr,
145     capacity: i32,
146     secs_to_refill: i32,
147     now: InstantSecs,
148   ) -> bool {
149     let mut result = true;
150
151     match ip {
152       IpAddr::V4(ipv4) => {
153         // Only used by one address.
154         let group = self
155           .ipv4_buckets
156           .entry(ipv4)
157           .or_insert(RateLimitedGroup::new(now));
158
159         result &= group.check_total(type_, now, capacity, secs_to_refill);
160       }
161
162       IpAddr::V6(ipv6) => {
163         let (key_48, key_56, key_64) = split_ipv6(ipv6);
164
165         // Contains all addresses with the same first 48 bits. These addresses might be part of the same network.
166         let group_48 = self
167           .ipv6_buckets
168           .entry(key_48)
169           .or_insert(RateLimitedGroup::new(now));
170         result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill);
171
172         // Contains all addresses with the same first 56 bits. These addresses might be part of the same network.
173         let group_56 = group_48
174           .children
175           .entry(key_56)
176           .or_insert(RateLimitedGroup::new(now));
177         result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill);
178
179         // A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network.
180         let group_64 = group_56
181           .children
182           .entry(key_64)
183           .or_insert(RateLimitedGroup::new(now));
184
185         result &= group_64.check_total(type_, now, capacity, secs_to_refill);
186       }
187     };
188
189     if !result {
190       debug!("Rate limited IP: {ip}");
191     }
192
193     result
194   }
195
196   /// Remove buckets older than the given duration
197   pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) {
198     // Only retain buckets that were last used after `instant`
199     let Some(instant) = now.to_instant().checked_sub(duration) else {
200       return;
201     };
202
203     let is_recently_used = |group: &RateLimitedGroup<_>| {
204       group
205         .total
206         .values()
207         .all(|bucket| bucket.last_checked.to_instant() > instant)
208     };
209
210     retain_and_shrink(&mut self.ipv4_buckets, |_, group| is_recently_used(group));
211
212     retain_and_shrink(&mut self.ipv6_buckets, |_, group_48| {
213       retain_and_shrink(&mut group_48.children, |_, group_56| {
214         retain_and_shrink(&mut group_56.children, |_, group_64| {
215           is_recently_used(group_64)
216         });
217         !group_56.children.is_empty()
218       });
219       !group_48.children.is_empty()
220     })
221   }
222 }
223
224 fn retain_and_shrink<K, V, F>(map: &mut HashMap<K, V>, f: F)
225 where
226   K: Eq + Hash,
227   F: FnMut(&K, &mut V) -> bool,
228 {
229   map.retain(f);
230   map.shrink_to_fit();
231 }
232
233 fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) {
234   let [a0, a1, a2, a3, a4, a5, b, c, ..] = ip.octets();
235   ([a0, a1, a2, a3, a4, a5], b, c)
236 }
237
238 #[cfg(test)]
239 mod tests {
240   #[test]
241   fn test_split_ipv6() {
242     let ip = std::net::Ipv6Addr::new(
243       0x0011, 0x2233, 0x4455, 0x6677, 0x8899, 0xAABB, 0xCCDD, 0xEEFF,
244     );
245     assert_eq!(
246       super::split_ipv6(ip),
247       ([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], 0x66, 0x77)
248     );
249   }
250
251   #[test]
252   fn test_rate_limiter() {
253     let mut rate_limiter = super::RateLimitStorage::default();
254     let mut now = super::InstantSecs::now();
255
256     let ips = [
257       "123.123.123.123",
258       "1:2:3::",
259       "1:2:3:0400::",
260       "1:2:3:0405::",
261       "1:2:3:0405:6::",
262     ];
263     for ip in ips {
264       let ip = ip.parse().unwrap();
265       let message_passed =
266         rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now);
267       let post_passed =
268         rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now);
269       assert!(message_passed);
270       assert!(post_passed);
271     }
272
273     #[allow(clippy::indexing_slicing)]
274     let expected_buckets = |factor: f32, tokens_consumed: f32| {
275       let mut buckets = super::RateLimitedGroup::<()>::new(now).total;
276       buckets[super::RateLimitType::Message] = super::RateLimitBucket {
277         last_checked: now,
278         tokens: (2.0 * factor) - tokens_consumed,
279       };
280       buckets[super::RateLimitType::Post] = super::RateLimitBucket {
281         last_checked: now,
282         tokens: (3.0 * factor) - tokens_consumed,
283       };
284       buckets
285     };
286
287     let bottom_group = |tokens_consumed| super::RateLimitedGroup {
288       total: expected_buckets(1.0, tokens_consumed),
289       children: (),
290     };
291
292     assert_eq!(
293       rate_limiter,
294       super::RateLimitStorage {
295         ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(),
296         ipv6_buckets: [(
297           [0, 1, 0, 2, 0, 3],
298           super::RateLimitedGroup {
299             total: expected_buckets(16.0, 4.0),
300             children: [
301               (
302                 0,
303                 super::RateLimitedGroup {
304                   total: expected_buckets(4.0, 1.0),
305                   children: [(0, bottom_group(1.0)),].into(),
306                 }
307               ),
308               (
309                 4,
310                 super::RateLimitedGroup {
311                   total: expected_buckets(4.0, 3.0),
312                   children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(),
313                 }
314               ),
315             ]
316             .into(),
317           }
318         ),]
319         .into(),
320       }
321     );
322
323     now.secs += 2;
324     rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now);
325     assert!(rate_limiter.ipv4_buckets.is_empty());
326     assert!(rate_limiter.ipv6_buckets.is_empty());
327   }
328 }