From a5ff629b2492ff21ca24c9de79e41ad341e204db Mon Sep 17 00:00:00 2001
From: Nutomic <me@nutomic.com>
Date: Fri, 25 Mar 2022 16:41:38 +0100
Subject: [PATCH] Dont log errors when rate limit is hit (fixes #2157) (#2161)

* Dont log errors when rate limit is hit (fixes #2157)

* Clone service rather than http request

* some cleanup/refactoring

Co-authored-by: Aode (Lion) <asonix@asonix.dog>
---
 crates/utils/src/rate_limit/mod.rs          | 109 +++++++-------------
 crates/utils/src/rate_limit/rate_limiter.rs |  33 +++---
 crates/websocket/src/chat_server.rs         |  39 ++++---
 3 files changed, 74 insertions(+), 107 deletions(-)

diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs
index e2a155eb..6027520f 100644
--- a/crates/utils/src/rate_limit/mod.rs
+++ b/crates/utils/src/rate_limit/mod.rs
@@ -1,10 +1,14 @@
-use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr, LemmyError};
-use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
+use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr};
+use actix_web::{
+  dev::{Service, ServiceRequest, ServiceResponse, Transform},
+  HttpResponse,
+};
 use futures::future::{ok, Ready};
 use rate_limiter::{RateLimitType, RateLimiter};
 use std::{
   future::Future,
   pin::Pin,
+  rc::Rc,
   sync::Arc,
   task::{Context, Poll},
 };
@@ -29,7 +33,7 @@ pub struct RateLimited {
 
 pub struct RateLimitedMiddleware<S> {
   rate_limited: RateLimited,
-  service: S,
+  service: Rc<S>,
 }
 
 impl RateLimit {
@@ -63,78 +67,28 @@ impl RateLimit {
 }
 
 impl RateLimited {
-  pub async fn wrap<T, E>(
-    self,
-    ip_addr: IpAddr,
-    fut: impl Future<Output = Result<T, E>>,
-  ) -> Result<T, E>
-  where
-    E: From<LemmyError>,
-  {
+  /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
+  pub async fn check(self, ip_addr: IpAddr) -> bool {
     // Does not need to be blocking because the RwLock in settings never held across await points,
     // and the operation here locks only long enough to clone
     let rate_limit = self.rate_limit_config;
 
-    // before
-    {
-      let mut limiter = self.rate_limiter.lock().await;
-
-      match self.type_ {
-        RateLimitType::Message => {
-          limiter.check_rate_limit_full(
-            self.type_,
-            &ip_addr,
-            rate_limit.message,
-            rate_limit.message_per_second,
-          )?;
-
-          drop(limiter);
-          return fut.await;
-        }
-        RateLimitType::Post => {
-          limiter.check_rate_limit_full(
-            self.type_,
-            &ip_addr,
-            rate_limit.post,
-            rate_limit.post_per_second,
-          )?;
-        }
-        RateLimitType::Register => {
-          limiter.check_rate_limit_full(
-            self.type_,
-            &ip_addr,
-            rate_limit.register,
-            rate_limit.register_per_second,
-          )?;
-        }
-        RateLimitType::Image => {
-          limiter.check_rate_limit_full(
-            self.type_,
-            &ip_addr,
-            rate_limit.image,
-            rate_limit.image_per_second,
-          )?;
-        }
-        RateLimitType::Comment => {
-          limiter.check_rate_limit_full(
-            self.type_,
-            &ip_addr,
-            rate_limit.comment,
-            rate_limit.comment_per_second,
-          )?;
-        }
-      };
-    }
-
-    let res = fut.await;
+    let mut limiter = self.rate_limiter.lock().await;
 
-    res
+    let (kind, interval) = match self.type_ {
+      RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
+      RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
+      RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second),
+      RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
+      RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
+    };
+    limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
   }
 }
 
 impl<S> Transform<S, ServiceRequest> for RateLimited
 where
-  S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
+  S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
   S::Future: 'static,
 {
   type Response = S::Response;
@@ -146,7 +100,7 @@ where
   fn new_transform(&self, service: S) -> Self::Future {
     ok(RateLimitedMiddleware {
       rate_limited: self.clone(),
-      service,
+      service: Rc::new(service),
     })
   }
 }
@@ -155,7 +109,7 @@ type FutResult<T, E> = dyn Future<Output = Result<T, E>>;
 
 impl<S> Service<ServiceRequest> for RateLimitedMiddleware<S>
 where
-  S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
+  S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
   S::Future: 'static,
 {
   type Response = S::Response;
@@ -169,11 +123,20 @@ where
   fn call(&self, req: ServiceRequest) -> Self::Future {
     let ip_addr = get_ip(&req.connection_info());
 
-    let fut = self
-      .rate_limited
-      .clone()
-      .wrap(ip_addr, self.service.call(req));
-
-    Box::pin(async move { fut.await.map_err(actix_web::Error::from) })
+    let rate_limited = self.rate_limited.clone();
+    let service = self.service.clone();
+
+    Box::pin(async move {
+      if rate_limited.check(ip_addr).await {
+        service.call(req).await
+      } else {
+        let (http_req, _) = req.into_parts();
+        // if rate limit was hit, respond with http 400
+        Ok(ServiceResponse::new(
+          http_req,
+          HttpResponse::BadRequest().finish(),
+        ))
+      }
+    })
   }
 }
diff --git a/crates/utils/src/rate_limit/rate_limiter.rs b/crates/utils/src/rate_limit/rate_limiter.rs
index ccc483ed..31d91036 100644
--- a/crates/utils/src/rate_limit/rate_limiter.rs
+++ b/crates/utils/src/rate_limit/rate_limiter.rs
@@ -1,11 +1,11 @@
-use crate::{IpAddr, LemmyError};
-use std::{collections::HashMap, time::SystemTime};
+use crate::IpAddr;
+use std::{collections::HashMap, time::Instant};
 use strum::IntoEnumIterator;
 use tracing::debug;
 
 #[derive(Debug, Clone)]
 struct RateLimitBucket {
-  last_checked: SystemTime,
+  last_checked: Instant,
   allowance: f64,
 }
 
@@ -36,7 +36,7 @@ impl RateLimiter {
           bucket.insert(
             ip.clone(),
             RateLimitBucket {
-              last_checked: SystemTime::now(),
+              last_checked: Instant::now(),
               allowance: -2f64,
             },
           );
@@ -46,6 +46,8 @@ impl RateLimiter {
   }
 
   /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
+  ///
+  /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
   #[allow(clippy::float_cmp)]
   pub(super) fn check_rate_limit_full(
     &mut self,
@@ -53,12 +55,12 @@ impl RateLimiter {
     ip: &IpAddr,
     rate: i32,
     per: i32,
-  ) -> Result<(), LemmyError> {
+  ) -> bool {
     self.insert_ip(ip);
     if let Some(bucket) = self.buckets.get_mut(&type_) {
       if let Some(rate_limit) = bucket.get_mut(ip) {
-        let current = SystemTime::now();
-        let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
+        let current = Instant::now();
+        let time_passed = current.duration_since(rate_limit.last_checked).as_secs() as f64;
 
         // The initial value
         if rate_limit.allowance == -2f64 {
@@ -79,25 +81,16 @@ impl RateLimiter {
             time_passed,
             rate_limit.allowance
           );
-          Err(LemmyError::from_error_message(
-            anyhow::anyhow!(
-              "Too many requests. type: {}, IP: {}, {} per {} seconds",
-              type_.as_ref(),
-              ip,
-              rate,
-              per
-            ),
-            "too_many_requests",
-          ))
+          false
         } else {
           rate_limit.allowance -= 1.0;
-          Ok(())
+          true
         }
       } else {
-        Ok(())
+        true
       }
     } else {
-      Ok(())
+      true
     }
   }
 }
diff --git a/crates/websocket/src/chat_server.rs b/crates/websocket/src/chat_server.rs
index 53274e38..1e95344b 100644
--- a/crates/websocket/src/chat_server.rs
+++ b/crates/websocket/src/chat_server.rs
@@ -478,22 +478,33 @@ impl ChatServer {
         .as_str()
         .ok_or_else(|| LemmyError::from_message("missing op"))?;
 
-      if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
-        let fut = (message_handler_crud)(context, msg.id, user_operation_crud.clone(), data);
-        match user_operation_crud {
-          UserOperationCrud::Register => rate_limiter.register().wrap(ip, fut).await,
-          UserOperationCrud::CreatePost => rate_limiter.post().wrap(ip, fut).await,
-          UserOperationCrud::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
-          UserOperationCrud::CreateComment => rate_limiter.comment().wrap(ip, fut).await,
-          _ => rate_limiter.message().wrap(ip, fut).await,
-        }
+      // check if api call passes the rate limit, and generate future for later execution
+      let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
+        let passed = match user_operation_crud {
+          UserOperationCrud::Register => rate_limiter.register().check(ip).await,
+          UserOperationCrud::CreatePost => rate_limiter.post().check(ip).await,
+          UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip).await,
+          UserOperationCrud::CreateComment => rate_limiter.comment().check(ip).await,
+          _ => rate_limiter.message().check(ip).await,
+        };
+        let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
+        (passed, fut)
       } else {
         let user_operation = UserOperation::from_str(op)?;
-        let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
-        match user_operation {
-          UserOperation::GetCaptcha => rate_limiter.post().wrap(ip, fut).await,
-          _ => rate_limiter.message().wrap(ip, fut).await,
-        }
+        let passed = match user_operation {
+          UserOperation::GetCaptcha => rate_limiter.post().check(ip).await,
+          _ => rate_limiter.message().check(ip).await,
+        };
+        let fut = (message_handler)(context, msg.id, user_operation, data);
+        (passed, fut)
+      };
+
+      // if rate limit passed, execute api call future
+      if passed {
+        fut.await
+      } else {
+        // if rate limit was hit, respond with empty message
+        Ok("".to_string())
       }
     }
   }
-- 
2.44.1