X-Git-Url: http://these/git/?a=blobdiff_plain;f=crates%2Futils%2Fsrc%2Frate_limit%2Frate_limiter.rs;h=3acf23ba42835dffafac7fc805075ea92ff6cbab;hb=92568956353f21649ed9aff68b42699c9d036f30;hp=5bb02f596996fc54d0fed01c35546ee8cb0fbaf3;hpb=3b64c58198ddc47305353432026312d1957a4848;p=lemmy.git diff --git a/crates/utils/src/rate_limit/rate_limiter.rs b/crates/utils/src/rate_limit/rate_limiter.rs index 5bb02f59..3acf23ba 100644 --- a/crates/utils/src/rate_limit/rate_limiter.rs +++ b/crates/utils/src/rate_limit/rate_limiter.rs @@ -1,114 +1,331 @@ -use crate::{APIError, IPAddr, LemmyError}; -use log::debug; -use std::{collections::HashMap, time::SystemTime}; -use strum::IntoEnumIterator; +use enum_map::{enum_map, EnumMap}; +use once_cell::sync::Lazy; +use std::{ + collections::HashMap, + hash::Hash, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + time::{Duration, Instant}, +}; +use tracing::debug; -#[derive(Debug, Clone)] +const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0; + +static START_TIME: Lazy = Lazy::new(Instant::now); + +/// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't +/// store nanoseconds +#[derive(PartialEq, Debug, Clone, Copy)] +pub struct InstantSecs { + secs: u32, +} + +impl InstantSecs { + pub fn now() -> Self { + InstantSecs { + secs: u32::try_from(START_TIME.elapsed().as_secs()) + .expect("server has been running for over 136 years"), + } + } + + fn secs_since(self, earlier: Self) -> u32 { + self.secs.saturating_sub(earlier.secs) + } + + fn to_instant(self) -> Instant { + *START_TIME + Duration::from_secs(self.secs.into()) + } +} + +#[derive(PartialEq, Debug, Clone)] struct RateLimitBucket { - last_checked: SystemTime, - allowance: f64, + last_checked: InstantSecs, + /// This field stores the amount of tokens that were present at `last_checked`. + /// The amount of tokens steadily increases until it reaches the bucket's capacity. + /// Performing the rate-limited action consumes 1 token. + tokens: f32, } -#[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone, AsRefStr)] +#[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)] pub(crate) enum RateLimitType { Message, Register, Post, Image, + Comment, + Search, } -/// Rate limiting based on rate type and IP addr -#[derive(Debug, Clone)] -pub struct RateLimiter { - buckets: HashMap>, +type Map = HashMap>; + +#[derive(PartialEq, Debug, Clone)] +struct RateLimitedGroup { + total: EnumMap, + children: C, } -impl Default for RateLimiter { - fn default() -> Self { - Self { - buckets: HashMap::>::new(), +impl RateLimitedGroup { + fn new(now: InstantSecs) -> Self { + RateLimitedGroup { + total: enum_map! { + _ => RateLimitBucket { + last_checked: now, + tokens: UNINITIALIZED_TOKEN_AMOUNT, + }, + }, + children: Default::default(), } } -} -impl RateLimiter { - fn insert_ip(&mut self, ip: &str) { - for rate_limit_type in RateLimitType::iter() { - if self.buckets.get(&rate_limit_type).is_none() { - self.buckets.insert(rate_limit_type, HashMap::new()); - } + fn check_total( + &mut self, + type_: RateLimitType, + now: InstantSecs, + capacity: i32, + secs_to_refill: i32, + ) -> bool { + let capacity = capacity as f32; + let secs_to_refill = secs_to_refill as f32; - if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) { - if bucket.get(ip).is_none() { - bucket.insert( - ip.to_string(), - RateLimitBucket { - last_checked: SystemTime::now(), - allowance: -2f64, - }, - ); - } - } + #[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton + let bucket = &mut self.total[type_]; + + if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT { + bucket.tokens = capacity; + } + + let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32; + bucket.last_checked = now; + + // For `secs_since_last_checked` seconds, increase `bucket.tokens` + // by `capacity` every `secs_to_refill` seconds + bucket.tokens += { + let tokens_per_sec = capacity / secs_to_refill; + secs_since_last_checked * tokens_per_sec + }; + + // Prevent `bucket.tokens` from exceeding `capacity` + if bucket.tokens > capacity { + bucket.tokens = capacity; + } + + if bucket.tokens < 1.0 { + // Not enough tokens yet + debug!( + "Rate limited type: {}, time_passed: {}, allowance: {}", + type_.as_ref(), + secs_since_last_checked, + bucket.tokens + ); + false + } else { + // Consume 1 token + bucket.tokens -= 1.0; + true } } +} + +/// Rate limiting based on rate type and IP addr +#[derive(PartialEq, Debug, Clone, Default)] +pub struct RateLimitStorage { + /// One bucket per individual IPv4 address + ipv4_buckets: Map, + /// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses + ipv6_buckets: Map<[u8; 6], Map>>, +} - #[allow(clippy::float_cmp)] +impl RateLimitStorage { + /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478 + /// + /// Returns true if the request passed the rate limit, false if it failed and should be rejected. pub(super) fn check_rate_limit_full( &mut self, type_: RateLimitType, - ip: &str, - rate: i32, - per: i32, - check_only: bool, - ) -> Result<(), LemmyError> { - self.insert_ip(ip); - if let Some(bucket) = self.buckets.get_mut(&type_) { - if let Some(rate_limit) = bucket.get_mut(ip) { - let current = SystemTime::now(); - let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64; - - // The initial value - if rate_limit.allowance == -2f64 { - rate_limit.allowance = rate as f64; - }; - - rate_limit.last_checked = current; - rate_limit.allowance += time_passed * (rate as f64 / per as f64); - if !check_only && rate_limit.allowance > rate as f64 { - rate_limit.allowance = rate as f64; - } - - if rate_limit.allowance < 1.0 { - debug!( - "Rate limited type: {}, IP: {}, time_passed: {}, allowance: {}", - type_.as_ref(), - ip, - time_passed, - rate_limit.allowance - ); - Err( - APIError { - message: format!( - "Too many requests. type: {}, IP: {}, {} per {} seconds", - type_.as_ref(), - ip, - rate, - per + ip: IpAddr, + capacity: i32, + secs_to_refill: i32, + now: InstantSecs, + ) -> bool { + let mut result = true; + + match ip { + IpAddr::V4(ipv4) => { + // Only used by one address. + let group = self + .ipv4_buckets + .entry(ipv4) + .or_insert(RateLimitedGroup::new(now)); + + result &= group.check_total(type_, now, capacity, secs_to_refill); + } + + IpAddr::V6(ipv6) => { + let (key_48, key_56, key_64) = split_ipv6(ipv6); + + // Contains all addresses with the same first 48 bits. These addresses might be part of the same network. + let group_48 = self + .ipv6_buckets + .entry(key_48) + .or_insert(RateLimitedGroup::new(now)); + result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill); + + // Contains all addresses with the same first 56 bits. These addresses might be part of the same network. + let group_56 = group_48 + .children + .entry(key_56) + .or_insert(RateLimitedGroup::new(now)); + result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill); + + // A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network. + let group_64 = group_56 + .children + .entry(key_64) + .or_insert(RateLimitedGroup::new(now)); + + result &= group_64.check_total(type_, now, capacity, secs_to_refill); + } + }; + + if !result { + debug!("Rate limited IP: {ip}"); + } + + result + } + + /// Remove buckets older than the given duration + pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) { + // Only retain buckets that were last used after `instant` + let Some(instant) = now.to_instant().checked_sub(duration) else { + return; + }; + + let is_recently_used = |group: &RateLimitedGroup<_>| { + group + .total + .values() + .all(|bucket| bucket.last_checked.to_instant() > instant) + }; + + retain_and_shrink(&mut self.ipv4_buckets, |_, group| is_recently_used(group)); + + retain_and_shrink(&mut self.ipv6_buckets, |_, group_48| { + retain_and_shrink(&mut group_48.children, |_, group_56| { + retain_and_shrink(&mut group_56.children, |_, group_64| { + is_recently_used(group_64) + }); + !group_56.children.is_empty() + }); + !group_48.children.is_empty() + }) + } +} + +fn retain_and_shrink(map: &mut HashMap, f: F) +where + K: Eq + Hash, + F: FnMut(&K, &mut V) -> bool, +{ + map.retain(f); + map.shrink_to_fit(); +} + +fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) { + let [a0, a1, a2, a3, a4, a5, b, c, ..] = ip.octets(); + ([a0, a1, a2, a3, a4, a5], b, c) +} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + #![allow(clippy::indexing_slicing)] + + #[test] + fn test_split_ipv6() { + let ip = std::net::Ipv6Addr::new( + 0x0011, 0x2233, 0x4455, 0x6677, 0x8899, 0xAABB, 0xCCDD, 0xEEFF, + ); + assert_eq!( + super::split_ipv6(ip), + ([0x00, 0x11, 0x22, 0x33, 0x44, 0x55], 0x66, 0x77) + ); + } + + #[test] + fn test_rate_limiter() { + let mut rate_limiter = super::RateLimitStorage::default(); + let mut now = super::InstantSecs::now(); + + let ips = [ + "123.123.123.123", + "1:2:3::", + "1:2:3:0400::", + "1:2:3:0405::", + "1:2:3:0405:6::", + ]; + for ip in ips { + let ip = ip.parse().unwrap(); + let message_passed = + rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now); + let post_passed = + rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now); + assert!(message_passed); + assert!(post_passed); + } + + #[allow(clippy::indexing_slicing)] + let expected_buckets = |factor: f32, tokens_consumed: f32| { + let mut buckets = super::RateLimitedGroup::<()>::new(now).total; + buckets[super::RateLimitType::Message] = super::RateLimitBucket { + last_checked: now, + tokens: (2.0 * factor) - tokens_consumed, + }; + buckets[super::RateLimitType::Post] = super::RateLimitBucket { + last_checked: now, + tokens: (3.0 * factor) - tokens_consumed, + }; + buckets + }; + + let bottom_group = |tokens_consumed| super::RateLimitedGroup { + total: expected_buckets(1.0, tokens_consumed), + children: (), + }; + + assert_eq!( + rate_limiter, + super::RateLimitStorage { + ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(), + ipv6_buckets: [( + [0, 1, 0, 2, 0, 3], + super::RateLimitedGroup { + total: expected_buckets(16.0, 4.0), + children: [ + ( + 0, + super::RateLimitedGroup { + total: expected_buckets(4.0, 1.0), + children: [(0, bottom_group(1.0)),].into(), + } ), - } + ( + 4, + super::RateLimitedGroup { + total: expected_buckets(4.0, 3.0), + children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(), + } + ), + ] .into(), - ) - } else { - if !check_only { - rate_limit.allowance -= 1.0; } - Ok(()) - } - } else { - Ok(()) + ),] + .into(), } - } else { - Ok(()) - } + ); + + now.secs += 2; + rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now); + assert!(rate_limiter.ipv4_buckets.is_empty()); + assert!(rate_limiter.ipv6_buckets.is_empty()); } }