]> Untitled Git - lemmy.git/blob - crates/utils/src/rate_limit/rate_limiter.rs
12f264faebb8ac28761a1d63dc4d7ac851174a82
[lemmy.git] / crates / utils / src / rate_limit / rate_limiter.rs
1 use enum_map::{enum_map, EnumMap};
2 use once_cell::sync::Lazy;
3 use std::{
4   collections::HashMap,
5   net::{IpAddr, Ipv4Addr, Ipv6Addr},
6   time::{Duration, Instant},
7 };
8 use tracing::debug;
9
10 const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0;
11
12 static START_TIME: Lazy<Instant> = Lazy::new(Instant::now);
13
14 /// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't
15 /// store nanoseconds
16 #[derive(PartialEq, Debug, Clone, Copy)]
17 pub struct InstantSecs {
18   secs: u32,
19 }
20
21 impl InstantSecs {
22   pub fn now() -> Self {
23     InstantSecs {
24       secs: u32::try_from(START_TIME.elapsed().as_secs())
25         .expect("server has been running for over 136 years"),
26     }
27   }
28
29   fn secs_since(self, earlier: Self) -> u32 {
30     self.secs.saturating_sub(earlier.secs)
31   }
32
33   fn to_instant(self) -> Instant {
34     *START_TIME + Duration::from_secs(self.secs.into())
35   }
36 }
37
38 #[derive(PartialEq, Debug, Clone)]
39 struct RateLimitBucket {
40   last_checked: InstantSecs,
41   /// This field stores the amount of tokens that were present at `last_checked`.
42   /// The amount of tokens steadily increases until it reaches the bucket's capacity.
43   /// Performing the rate-limited action consumes 1 token.
44   tokens: f32,
45 }
46
47 #[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)]
48 pub(crate) enum RateLimitType {
49   Message,
50   Register,
51   Post,
52   Image,
53   Comment,
54   Search,
55 }
56
57 type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
58
59 #[derive(PartialEq, Debug, Clone)]
60 struct RateLimitedGroup<C> {
61   total: EnumMap<RateLimitType, RateLimitBucket>,
62   children: C,
63 }
64
65 impl<C: Default> RateLimitedGroup<C> {
66   fn new(now: InstantSecs) -> Self {
67     RateLimitedGroup {
68       total: enum_map! {
69         _ => RateLimitBucket {
70           last_checked: now,
71           tokens: UNINITIALIZED_TOKEN_AMOUNT,
72         },
73       },
74       children: Default::default(),
75     }
76   }
77
78   fn check_total(
79     &mut self,
80     type_: RateLimitType,
81     now: InstantSecs,
82     capacity: i32,
83     secs_to_refill: i32,
84   ) -> bool {
85     let capacity = capacity as f32;
86     let secs_to_refill = secs_to_refill as f32;
87
88     #[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton
89     let bucket = &mut self.total[type_];
90
91     if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT {
92       bucket.tokens = capacity;
93     }
94
95     let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32;
96     bucket.last_checked = now;
97
98     // For `secs_since_last_checked` seconds, increase `bucket.tokens`
99     // by `capacity` every `secs_to_refill` seconds
100     bucket.tokens += {
101       let tokens_per_sec = capacity / secs_to_refill;
102       secs_since_last_checked * tokens_per_sec
103     };
104
105     // Prevent `bucket.tokens` from exceeding `capacity`
106     if bucket.tokens > capacity {
107       bucket.tokens = capacity;
108     }
109
110     if bucket.tokens < 1.0 {
111       // Not enough tokens yet
112       debug!(
113         "Rate limited type: {}, time_passed: {}, allowance: {}",
114         type_.as_ref(),
115         secs_since_last_checked,
116         bucket.tokens
117       );
118       false
119     } else {
120       // Consume 1 token
121       bucket.tokens -= 1.0;
122       true
123     }
124   }
125 }
126
127 /// Rate limiting based on rate type and IP addr
128 #[derive(PartialEq, Debug, Clone, Default)]
129 pub struct RateLimitStorage {
130   /// One bucket per individual IPv4 address
131   ipv4_buckets: Map<Ipv4Addr, ()>,
132   /// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses
133   ipv6_buckets: Map<[u8; 6], Map<u8, Map<u8, ()>>>,
134 }
135
136 impl RateLimitStorage {
137   /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
138   ///
139   /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
140   pub(super) fn check_rate_limit_full(
141     &mut self,
142     type_: RateLimitType,
143     ip: IpAddr,
144     capacity: i32,
145     secs_to_refill: i32,
146     now: InstantSecs,
147   ) -> bool {
148     let mut result = true;
149
150     match ip {
151       IpAddr::V4(ipv4) => {
152         // Only used by one address.
153         let group = self
154           .ipv4_buckets
155           .entry(ipv4)
156           .or_insert(RateLimitedGroup::new(now));
157
158         result &= group.check_total(type_, now, capacity, secs_to_refill);
159       }
160
161       IpAddr::V6(ipv6) => {
162         let (key_48, key_56, key_64) = split_ipv6(ipv6);
163
164         // Contains all addresses with the same first 48 bits. These addresses might be part of the same network.
165         let group_48 = self
166           .ipv6_buckets
167           .entry(key_48)
168           .or_insert(RateLimitedGroup::new(now));
169         result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill);
170
171         // Contains all addresses with the same first 56 bits. These addresses might be part of the same network.
172         let group_56 = group_48
173           .children
174           .entry(key_56)
175           .or_insert(RateLimitedGroup::new(now));
176         result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill);
177
178         // A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network.
179         let group_64 = group_56
180           .children
181           .entry(key_64)
182           .or_insert(RateLimitedGroup::new(now));
183
184         result &= group_64.check_total(type_, now, capacity, secs_to_refill);
185       }
186     };
187
188     if !result {
189       debug!("Rate limited IP: {ip}");
190     }
191
192     result
193   }
194
195   /// Remove buckets older than the given duration
196   pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) {
197     // Only retain buckets that were last used after `instant`
198     let Some(instant) = now.to_instant().checked_sub(duration) else { return };
199
200     let is_recently_used = |group: &RateLimitedGroup<_>| {
201       group
202         .total
203         .values()
204         .all(|bucket| bucket.last_checked.to_instant() > instant)
205     };
206
207     self.ipv4_buckets.retain(|_, group| is_recently_used(group));
208
209     self.ipv6_buckets.retain(|_, group_48| {
210       group_48.children.retain(|_, group_56| {
211         group_56
212           .children
213           .retain(|_, group_64| is_recently_used(group_64));
214         !group_56.children.is_empty()
215       });
216       !group_48.children.is_empty()
217     })
218   }
219 }
220
221 fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) {
222   let [a0, a1, a2, a3, a4, a5, b, c, ..] = ip.octets();
223   ([a0, a1, a2, a3, a4, a5], b, c)
224 }
225
226 #[cfg(test)]
227 mod tests {
228   #[test]
229   fn test_split_ipv6() {
230     let ip = std::net::Ipv6Addr::new(
231       0x0011, 0x2233, 0x4455, 0x6677, 0x8899, 0xAABB, 0xCCDD, 0xEEFF,
232     );
233     assert_eq!(
234       super::split_ipv6(ip),
235       ([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], 0x66, 0x77)
236     );
237   }
238
239   #[test]
240   fn test_rate_limiter() {
241     let mut rate_limiter = super::RateLimitStorage::default();
242     let mut now = super::InstantSecs::now();
243
244     let ips = [
245       "123.123.123.123",
246       "1:2:3::",
247       "1:2:3:0400::",
248       "1:2:3:0405::",
249       "1:2:3:0405:6::",
250     ];
251     for ip in ips {
252       let ip = ip.parse().unwrap();
253       let message_passed =
254         rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now);
255       let post_passed =
256         rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now);
257       assert!(message_passed);
258       assert!(post_passed);
259     }
260
261     #[allow(clippy::indexing_slicing)]
262     let expected_buckets = |factor: f32, tokens_consumed: f32| {
263       let mut buckets = super::RateLimitedGroup::<()>::new(now).total;
264       buckets[super::RateLimitType::Message] = super::RateLimitBucket {
265         last_checked: now,
266         tokens: (2.0 * factor) - tokens_consumed,
267       };
268       buckets[super::RateLimitType::Post] = super::RateLimitBucket {
269         last_checked: now,
270         tokens: (3.0 * factor) - tokens_consumed,
271       };
272       buckets
273     };
274
275     let bottom_group = |tokens_consumed| super::RateLimitedGroup {
276       total: expected_buckets(1.0, tokens_consumed),
277       children: (),
278     };
279
280     assert_eq!(
281       rate_limiter,
282       super::RateLimitStorage {
283         ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(),
284         ipv6_buckets: [(
285           [0, 1, 0, 2, 0, 3],
286           super::RateLimitedGroup {
287             total: expected_buckets(16.0, 4.0),
288             children: [
289               (
290                 0,
291                 super::RateLimitedGroup {
292                   total: expected_buckets(4.0, 1.0),
293                   children: [(0, bottom_group(1.0)),].into(),
294                 }
295               ),
296               (
297                 4,
298                 super::RateLimitedGroup {
299                   total: expected_buckets(4.0, 3.0),
300                   children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(),
301                 }
302               ),
303             ]
304             .into(),
305           }
306         ),]
307         .into(),
308       }
309     );
310
311     now.secs += 2;
312     rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now);
313     assert!(rate_limiter.ipv4_buckets.is_empty());
314     assert!(rate_limiter.ipv6_buckets.is_empty());
315   }
316 }