+
+ IpAddr::V6(ipv6) => {
+ let (key_48, key_56, key_64) = split_ipv6(ipv6);
+
+ // Contains all addresses with the same first 48 bits. These addresses might be part of the same network.
+ let group_48 = self
+ .ipv6_buckets
+ .entry(key_48)
+ .or_insert(RateLimitedGroup::new(now));
+ result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill);
+
+ // Contains all addresses with the same first 56 bits. These addresses might be part of the same network.
+ let group_56 = group_48
+ .children
+ .entry(key_56)
+ .or_insert(RateLimitedGroup::new(now));
+ result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill);
+
+ // A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network.
+ let group_64 = group_56
+ .children
+ .entry(key_64)
+ .or_insert(RateLimitedGroup::new(now));
+
+ result &= group_64.check_total(type_, now, capacity, secs_to_refill);
+ }
+ };
+
+ if !result {
+ debug!("Rate limited IP: {ip}");
+ }
+
+ result
+ }
+
+ /// Remove buckets older than the given duration
+ pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) {
+ // Only retain buckets that were last used after `instant`
+ let Some(instant) = now.to_instant().checked_sub(duration) else {
+ return;
+ };
+
+ let is_recently_used = |group: &RateLimitedGroup<_>| {
+ group
+ .total
+ .values()
+ .all(|bucket| bucket.last_checked.to_instant() > instant)
+ };
+
+ retain_and_shrink(&mut self.ipv4_buckets, |_, group| is_recently_used(group));
+
+ retain_and_shrink(&mut self.ipv6_buckets, |_, group_48| {
+ retain_and_shrink(&mut group_48.children, |_, group_56| {
+ retain_and_shrink(&mut group_56.children, |_, group_64| {
+ is_recently_used(group_64)
+ });
+ !group_56.children.is_empty()
+ });
+ !group_48.children.is_empty()
+ })
+ }
+}
+
+fn retain_and_shrink<K, V, F>(map: &mut HashMap<K, V>, f: F)
+where
+ K: Eq + Hash,
+ F: FnMut(&K, &mut V) -> bool,
+{
+ map.retain(f);
+ map.shrink_to_fit();
+}
+
+fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) {
+ let [a0, a1, a2, a3, a4, a5, b, c, ..] = ip.octets();
+ ([a0, a1, a2, a3, a4, a5], b, c)
+}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+ #![allow(clippy::indexing_slicing)]
+
+ #[test]
+ fn test_split_ipv6() {
+ let ip = std::net::Ipv6Addr::new(
+ 0x0011, 0x2233, 0x4455, 0x6677, 0x8899, 0xAABB, 0xCCDD, 0xEEFF,
+ );
+ assert_eq!(
+ super::split_ipv6(ip),
+ ([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], 0x66, 0x77)
+ );
+ }
+
+ #[test]
+ fn test_rate_limiter() {
+ let mut rate_limiter = super::RateLimitStorage::default();
+ let mut now = super::InstantSecs::now();
+
+ let ips = [
+ "123.123.123.123",
+ "1:2:3::",
+ "1:2:3:0400::",
+ "1:2:3:0405::",
+ "1:2:3:0405:6::",
+ ];
+ for ip in ips {
+ let ip = ip.parse().unwrap();
+ let message_passed =
+ rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now);
+ let post_passed =
+ rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now);
+ assert!(message_passed);
+ assert!(post_passed);