2 use std::{collections::HashMap, time::Instant};
3 use strum::IntoEnumIterator;
6 #[derive(Debug, Clone)]
7 struct RateLimitBucket {
12 #[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone, AsRefStr)]
13 pub(crate) enum RateLimitType {
21 /// Rate limiting based on rate type and IP addr
22 #[derive(Debug, Clone, Default)]
23 pub struct RateLimiter {
24 buckets: HashMap<RateLimitType, HashMap<IpAddr, RateLimitBucket>>,
28 fn insert_ip(&mut self, ip: &IpAddr) {
29 for rate_limit_type in RateLimitType::iter() {
30 if self.buckets.get(&rate_limit_type).is_none() {
31 self.buckets.insert(rate_limit_type, HashMap::new());
34 if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) {
35 if bucket.get(ip).is_none() {
39 last_checked: Instant::now(),
48 /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
50 /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
51 #[allow(clippy::float_cmp)]
52 pub(super) fn check_rate_limit_full(
60 if let Some(bucket) = self.buckets.get_mut(&type_) {
61 if let Some(rate_limit) = bucket.get_mut(ip) {
62 let current = Instant::now();
63 let time_passed = current.duration_since(rate_limit.last_checked).as_secs() as f64;
66 if rate_limit.allowance == -2f64 {
67 rate_limit.allowance = rate as f64;
70 rate_limit.last_checked = current;
71 rate_limit.allowance += time_passed * (rate as f64 / per as f64);
72 if rate_limit.allowance > rate as f64 {
73 rate_limit.allowance = rate as f64;
76 if rate_limit.allowance < 1.0 {
78 "Rate limited type: {}, IP: {}, time_passed: {}, allowance: {}",
86 rate_limit.allowance -= 1.0;