1 use enum_map::{enum_map, EnumMap};
2 use once_cell::sync::Lazy;
6 net::{IpAddr, Ipv4Addr, Ipv6Addr},
7 time::{Duration, Instant},
11 const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0;
13 static START_TIME: Lazy<Instant> = Lazy::new(Instant::now);
15 /// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't
17 #[derive(PartialEq, Debug, Clone, Copy)]
18 pub struct InstantSecs {
23 pub fn now() -> Self {
25 secs: u32::try_from(START_TIME.elapsed().as_secs())
26 .expect("server has been running for over 136 years"),
30 fn secs_since(self, earlier: Self) -> u32 {
31 self.secs.saturating_sub(earlier.secs)
34 fn to_instant(self) -> Instant {
35 *START_TIME + Duration::from_secs(self.secs.into())
39 #[derive(PartialEq, Debug, Clone)]
40 struct RateLimitBucket {
41 last_checked: InstantSecs,
42 /// This field stores the amount of tokens that were present at `last_checked`.
43 /// The amount of tokens steadily increases until it reaches the bucket's capacity.
44 /// Performing the rate-limited action consumes 1 token.
48 #[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)]
49 pub(crate) enum RateLimitType {
58 type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
60 #[derive(PartialEq, Debug, Clone)]
61 struct RateLimitedGroup<C> {
62 total: EnumMap<RateLimitType, RateLimitBucket>,
66 impl<C: Default> RateLimitedGroup<C> {
67 fn new(now: InstantSecs) -> Self {
70 _ => RateLimitBucket {
72 tokens: UNINITIALIZED_TOKEN_AMOUNT,
75 children: Default::default(),
86 let capacity = capacity as f32;
87 let secs_to_refill = secs_to_refill as f32;
89 #[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton
90 let bucket = &mut self.total[type_];
92 if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT {
93 bucket.tokens = capacity;
96 let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32;
97 bucket.last_checked = now;
99 // For `secs_since_last_checked` seconds, increase `bucket.tokens`
100 // by `capacity` every `secs_to_refill` seconds
102 let tokens_per_sec = capacity / secs_to_refill;
103 secs_since_last_checked * tokens_per_sec
106 // Prevent `bucket.tokens` from exceeding `capacity`
107 if bucket.tokens > capacity {
108 bucket.tokens = capacity;
111 if bucket.tokens < 1.0 {
112 // Not enough tokens yet
114 "Rate limited type: {}, time_passed: {}, allowance: {}",
116 secs_since_last_checked,
122 bucket.tokens -= 1.0;
128 /// Rate limiting based on rate type and IP addr
129 #[derive(PartialEq, Debug, Clone, Default)]
130 pub struct RateLimitStorage {
131 /// One bucket per individual IPv4 address
132 ipv4_buckets: Map<Ipv4Addr, ()>,
133 /// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses
134 ipv6_buckets: Map<[u8; 6], Map<u8, Map<u8, ()>>>,
137 impl RateLimitStorage {
138 /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
140 /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
141 pub(super) fn check_rate_limit_full(
143 type_: RateLimitType,
149 let mut result = true;
152 IpAddr::V4(ipv4) => {
153 // Only used by one address.
157 .or_insert(RateLimitedGroup::new(now));
159 result &= group.check_total(type_, now, capacity, secs_to_refill);
162 IpAddr::V6(ipv6) => {
163 let (key_48, key_56, key_64) = split_ipv6(ipv6);
165 // Contains all addresses with the same first 48 bits. These addresses might be part of the same network.
169 .or_insert(RateLimitedGroup::new(now));
170 result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill);
172 // Contains all addresses with the same first 56 bits. These addresses might be part of the same network.
173 let group_56 = group_48
176 .or_insert(RateLimitedGroup::new(now));
177 result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill);
179 // A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network.
180 let group_64 = group_56
183 .or_insert(RateLimitedGroup::new(now));
185 result &= group_64.check_total(type_, now, capacity, secs_to_refill);
190 debug!("Rate limited IP: {ip}");
196 /// Remove buckets older than the given duration
197 pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) {
198 // Only retain buckets that were last used after `instant`
199 let Some(instant) = now.to_instant().checked_sub(duration) else {
203 let is_recently_used = |group: &RateLimitedGroup<_>| {
207 .all(|bucket| bucket.last_checked.to_instant() > instant)
210 retain_and_shrink(&mut self.ipv4_buckets, |_, group| is_recently_used(group));
212 retain_and_shrink(&mut self.ipv6_buckets, |_, group_48| {
213 retain_and_shrink(&mut group_48.children, |_, group_56| {
214 retain_and_shrink(&mut group_56.children, |_, group_64| {
215 is_recently_used(group_64)
217 !group_56.children.is_empty()
219 !group_48.children.is_empty()
224 fn retain_and_shrink<K, V, F>(map: &mut HashMap<K, V>, f: F)
227 F: FnMut(&K, &mut V) -> bool,
233 fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) {
234 let [a0, a1, a2, a3, a4, a5, b, c, ..] = ip.octets();
235 ([a0, a1, a2, a3, a4, a5], b, c)
241 fn test_split_ipv6() {
242 let ip = std::net::Ipv6Addr::new(
243 0x0011, 0x2233, 0x4455, 0x6677, 0x8899, 0xAABB, 0xCCDD, 0xEEFF,
246 super::split_ipv6(ip),
247 ([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], 0x66, 0x77)
252 fn test_rate_limiter() {
253 let mut rate_limiter = super::RateLimitStorage::default();
254 let mut now = super::InstantSecs::now();
264 let ip = ip.parse().unwrap();
266 rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now);
268 rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now);
269 assert!(message_passed);
270 assert!(post_passed);
273 #[allow(clippy::indexing_slicing)]
274 let expected_buckets = |factor: f32, tokens_consumed: f32| {
275 let mut buckets = super::RateLimitedGroup::<()>::new(now).total;
276 buckets[super::RateLimitType::Message] = super::RateLimitBucket {
278 tokens: (2.0 * factor) - tokens_consumed,
280 buckets[super::RateLimitType::Post] = super::RateLimitBucket {
282 tokens: (3.0 * factor) - tokens_consumed,
287 let bottom_group = |tokens_consumed| super::RateLimitedGroup {
288 total: expected_buckets(1.0, tokens_consumed),
294 super::RateLimitStorage {
295 ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(),
298 super::RateLimitedGroup {
299 total: expected_buckets(16.0, 4.0),
303 super::RateLimitedGroup {
304 total: expected_buckets(4.0, 1.0),
305 children: [(0, bottom_group(1.0)),].into(),
310 super::RateLimitedGroup {
311 total: expected_buckets(4.0, 3.0),
312 children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(),
324 rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now);
325 assert!(rate_limiter.ipv4_buckets.is_empty());
326 assert!(rate_limiter.ipv6_buckets.is_empty());