]> Untitled Git - lemmy.git/commitdiff
Give ratelimit fields explicit names
authorasonix <asonix@asonix.dog>
Mon, 20 Apr 2020 17:51:42 +0000 (12:51 -0500)
committerasonix <asonix@asonix.dog>
Mon, 20 Apr 2020 17:51:42 +0000 (12:51 -0500)
server/src/main.rs
server/src/rate_limit/mod.rs

index c92770f2a6147540c60c49a049c78fa574d02397..4e773ee576e5e4d36a509af30f51744e77df8081 100644 (file)
@@ -34,7 +34,9 @@ async fn main() -> io::Result<()> {
   embedded_migrations::run(&conn).unwrap();
 
   // Set up the rate limiter
-  let rate_limiter = RateLimit(Arc::new(Mutex::new(RateLimiter::default())));
+  let rate_limiter = RateLimit {
+    rate_limiter: Arc::new(Mutex::new(RateLimiter::default())),
+  };
 
   // Set up websocket server
   let server = ChatServer::startup(pool.clone(), rate_limiter.clone()).start();
index 9aeb11718e62dda1ee090e77ecfef38a7ea22a64..bb77db29c05ca25a01f8735afe9feeb4ec23b1de 100644 (file)
@@ -18,12 +18,20 @@ use strum::IntoEnumIterator;
 use tokio::sync::Mutex;
 
 #[derive(Debug, Clone)]
-pub struct RateLimit(pub Arc<Mutex<RateLimiter>>);
+pub struct RateLimit {
+  pub rate_limiter: Arc<Mutex<RateLimiter>>,
+}
 
 #[derive(Debug, Clone)]
-pub struct RateLimited(Arc<Mutex<RateLimiter>>, RateLimitType);
+pub struct RateLimited {
+  rate_limiter: Arc<Mutex<RateLimiter>>,
+  type_: RateLimitType,
+}
 
-pub struct RateLimitedMiddleware<S>(RateLimited, S);
+pub struct RateLimitedMiddleware<S> {
+  rate_limited: RateLimited,
+  service: S,
+}
 
 impl RateLimit {
   pub fn message(&self) -> RateLimited {
@@ -39,7 +47,10 @@ impl RateLimit {
   }
 
   fn kind(&self, type_: RateLimitType) -> RateLimited {
-    RateLimited(self.0.clone(), type_)
+    RateLimited {
+      rate_limiter: self.rate_limiter.clone(),
+      type_,
+    }
   }
 }
 
@@ -64,12 +75,12 @@ impl RateLimited {
 
     // before
     {
-      let mut limiter = self.0.lock().await;
+      let mut limiter = self.rate_limiter.lock().await;
 
-      match self.1 {
+      match self.type_ {
         RateLimitType::Message => {
           limiter.check_rate_limit_full(
-            self.1,
+            self.type_,
             &ip_addr,
             rate_limit.message,
             rate_limit.message_per_second,
@@ -80,7 +91,7 @@ impl RateLimited {
         }
         RateLimitType::Post => {
           limiter.check_rate_limit_full(
-            self.1.clone(),
+            self.type_.clone(),
             &ip_addr,
             rate_limit.post,
             rate_limit.post_per_second,
@@ -89,7 +100,7 @@ impl RateLimited {
         }
         RateLimitType::Register => {
           limiter.check_rate_limit_full(
-            self.1,
+            self.type_,
             &ip_addr,
             rate_limit.register,
             rate_limit.register_per_second,
@@ -103,12 +114,12 @@ impl RateLimited {
 
     // after
     {
-      let mut limiter = self.0.lock().await;
+      let mut limiter = self.rate_limiter.lock().await;
       if res.is_ok() {
-        match self.1 {
+        match self.type_ {
           RateLimitType::Post => {
             limiter.check_rate_limit_full(
-              self.1,
+              self.type_,
               &ip_addr,
               rate_limit.post,
               rate_limit.post_per_second,
@@ -117,7 +128,7 @@ impl RateLimited {
           }
           RateLimitType::Register => {
             limiter.check_rate_limit_full(
-              self.1,
+              self.type_,
               &ip_addr,
               rate_limit.register,
               rate_limit.register_per_second,
@@ -146,7 +157,10 @@ where
   type Future = Ready<Result<Self::Transform, Self::InitError>>;
 
   fn new_transform(&self, service: S) -> Self::Future {
-    ok(RateLimitedMiddleware(self.clone(), service))
+    ok(RateLimitedMiddleware {
+      rate_limited: self.clone(),
+      service,
+    })
   }
 }
 
@@ -163,7 +177,7 @@ where
   type Future = Pin<Box<FutResult<Self::Response, Self::Error>>>;
 
   fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
-    self.1.poll_ready(cx)
+    self.service.poll_ready(cx)
   }
 
   fn call(&mut self, req: S::Request) -> Self::Future {
@@ -176,7 +190,10 @@ where
       .unwrap_or("127.0.0.1")
       .to_string();
 
-    let fut = self.0.clone().wrap(ip_addr, self.1.call(req));
+    let fut = self
+      .rate_limited
+      .clone()
+      .wrap(ip_addr, self.service.call(req));
 
     Box::pin(async move { fut.await.map_err(actix_web::Error::from) })
   }