Rate limit websocket joins. (#2165)
authorDessalines <dessalines@users.noreply.github.com>
Sun, 27 Mar 2022 00:29:05 +0000 (00:29 +0000)
committerGitHub <noreply@github.com>
Sun, 27 Mar 2022 00:29:05 +0000 (00:29 +0000)
* Rate limit websocket joins.

* Removing async on mutex lock fn.

* Removing redundant ip

* Return early if check fails.

Cargo.lock
Cargo.toml
crates/apub/Cargo.toml
crates/apub/src/objects/mod.rs
crates/utils/Cargo.toml
crates/utils/src/rate_limit/mod.rs
crates/websocket/Cargo.toml
crates/websocket/src/chat_server.rs
crates/websocket/src/routes.rs
src/main.rs

index 430f18d137f75143abb82c5a35eec2e3891e2611..74b66e88980d0726ccdd6700fa89e47bb034f11a 100644 (file)
@@ -1962,6 +1962,7 @@ dependencies = [
  "lemmy_utils",
  "lemmy_websocket",
  "once_cell",
+ "parking_lot 0.12.0",
  "percent-encoding",
  "rand 0.8.4",
  "reqwest",
@@ -2129,6 +2130,7 @@ dependencies = [
  "openssl",
  "opentelemetry",
  "opentelemetry-otlp",
+ "parking_lot 0.12.0",
  "reqwest",
  "reqwest-middleware",
  "reqwest-tracing",
@@ -2166,6 +2168,7 @@ dependencies = [
  "lettre",
  "once_cell",
  "openssl",
+ "parking_lot 0.12.0",
  "percent-encoding",
  "rand 0.8.4",
  "regex",
@@ -2204,6 +2207,7 @@ dependencies = [
  "lemmy_db_views_actor",
  "lemmy_utils",
  "opentelemetry",
+ "parking_lot 0.12.0",
  "rand 0.8.4",
  "reqwest",
  "reqwest-middleware",
index 90fe4658ff4096a2f312017cc6ef789a505479b6..e3a294e080234ac3a82230db2756326685473599 100644 (file)
@@ -75,3 +75,4 @@ doku = "0.10.2"
 opentelemetry = { version = "0.16", features = ["rt-tokio"] }
 opentelemetry-otlp = "0.9"
 tracing-opentelemetry = "0.16"
+parking_lot = "0.12"
index cc640c512178a196dac478edfe22d9b36c292770..fce4f406c77d81168341836257b265ea1c5b0a1a 100644 (file)
@@ -50,6 +50,7 @@ background-jobs = "0.11.0"
 reqwest = { version = "0.11.7", features = ["json"] }
 html2md = "0.2.13"
 once_cell = "1.8.0"
+parking_lot = "0.12"
 
 [dev-dependencies]
 serial_test = "0.5.1"
index 9f939aaa86636aa60dbb9687ec6fd735b2bc6eea..b387564a78abd57bc38fc34ccac9d507b537cfa2 100644 (file)
@@ -58,10 +58,10 @@ pub(crate) mod tests {
     LemmyError,
   };
   use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
+  use parking_lot::Mutex;
   use reqwest::Client;
   use reqwest_middleware::ClientBuilder;
   use std::sync::Arc;
-  use tokio::sync::Mutex;
 
   // TODO: would be nice if we didnt have to use a full context for tests.
   //       or at least write a helper function so this code is shared with main.rs
index 0a2519517664c513ac9c5286552acb0b3b6d0096..475d397cc71b05ab790811d071a19ce0b514fe18 100644 (file)
@@ -48,6 +48,7 @@ uuid = { version = "0.8.2", features = ["serde", "v4"] }
 encoding = "0.2.33"
 html2text = "0.2.1"
 rosetta-i18n = "0.1"
+parking_lot = "0.12"
 
 [build-dependencies]
 rosetta-build = "0.1"
index 6027520f0ac1038e7a7d6c938a14b9bc6c728848..69bcbcecdacd7371430431d7fb97fe1fbcac0478 100644 (file)
@@ -4,6 +4,7 @@ use actix_web::{
   HttpResponse,
 };
 use futures::future::{ok, Ready};
+use parking_lot::Mutex;
 use rate_limiter::{RateLimitType, RateLimiter};
 use std::{
   future::Future,
@@ -12,7 +13,6 @@ use std::{
   sync::Arc,
   task::{Context, Poll},
 };
-use tokio::sync::Mutex;
 
 pub mod rate_limiter;
 
@@ -68,13 +68,11 @@ impl RateLimit {
 
 impl RateLimited {
   /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
-  pub async fn check(self, ip_addr: IpAddr) -> bool {
+  pub fn check(self, ip_addr: IpAddr) -> bool {
     // Does not need to be blocking because the RwLock in settings never held across await points,
     // and the operation here locks only long enough to clone
     let rate_limit = self.rate_limit_config;
 
-    let mut limiter = self.rate_limiter.lock().await;
-
     let (kind, interval) = match self.type_ {
       RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
       RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
@@ -82,6 +80,8 @@ impl RateLimited {
       RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
       RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
     };
+    let mut limiter = self.rate_limiter.lock();
+
     limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
   }
 }
@@ -127,7 +127,7 @@ where
     let service = self.service.clone();
 
     Box::pin(async move {
-      if rate_limited.check(ip_addr).await {
+      if rate_limited.check(ip_addr) {
         service.call(req).await
       } else {
         let (http_req, _) = req.into_parts();
index fd10a746285f8b1fa55bfaef046ae4d3d001afc0..0dd30149dc165f9430a7bcfb60b91dee16783182 100644 (file)
@@ -36,3 +36,4 @@ actix-web = { version = "4.0.0", default-features = false, features = ["rustls"]
 actix-web-actors = { version = "4.1.0", default-features = false }
 opentelemetry = "0.16"
 tracing-opentelemetry = "0.16"
+parking_lot = "0.12"
index 1e95344b0c226bed760b4246e7e63486ccc36afb..d9de90dbef8c54d700f6651713cd169a3c850028 100644 (file)
@@ -481,19 +481,19 @@ impl ChatServer {
       // check if api call passes the rate limit, and generate future for later execution
       let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
         let passed = match user_operation_crud {
-          UserOperationCrud::Register => rate_limiter.register().check(ip).await,
-          UserOperationCrud::CreatePost => rate_limiter.post().check(ip).await,
-          UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip).await,
-          UserOperationCrud::CreateComment => rate_limiter.comment().check(ip).await,
-          _ => rate_limiter.message().check(ip).await,
+          UserOperationCrud::Register => rate_limiter.register().check(ip),
+          UserOperationCrud::CreatePost => rate_limiter.post().check(ip),
+          UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip),
+          UserOperationCrud::CreateComment => rate_limiter.comment().check(ip),
+          _ => rate_limiter.message().check(ip),
         };
         let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
         (passed, fut)
       } else {
         let user_operation = UserOperation::from_str(op)?;
         let passed = match user_operation {
-          UserOperation::GetCaptcha => rate_limiter.post().check(ip).await,
-          _ => rate_limiter.message().check(ip).await,
+          UserOperation::GetCaptcha => rate_limiter.post().check(ip),
+          _ => rate_limiter.message().check(ip),
         };
         let fut = (message_handler)(context, msg.id, user_operation, data);
         (passed, fut)
index 41b1782466bffc1a3f89e499e51b98a39e176e87..e99e683eb623683fc41bd34e56f6aa8dc7ac2b56 100644 (file)
@@ -6,7 +6,7 @@ use crate::{
 use actix::prelude::*;
 use actix_web::{web, Error, HttpRequest, HttpResponse};
 use actix_web_actors::ws;
-use lemmy_utils::{utils::get_ip, ConnectionId, IpAddr};
+use lemmy_utils::{rate_limit::RateLimit, utils::get_ip, ConnectionId, IpAddr};
 use std::time::{Duration, Instant};
 use tracing::{debug, error, info};
 
@@ -20,6 +20,7 @@ pub async fn chat_route(
   req: HttpRequest,
   stream: web::Payload,
   context: web::Data<LemmyContext>,
+  rate_limiter: web::Data<RateLimit>,
 ) -> Result<HttpResponse, Error> {
   ws::start(
     WsSession {
@@ -27,6 +28,7 @@ pub async fn chat_route(
       id: 0,
       hb: Instant::now(),
       ip: get_ip(&req.connection_info()),
+      rate_limiter: rate_limiter.as_ref().to_owned(),
     },
     &req,
     stream,
@@ -41,6 +43,8 @@ struct WsSession {
   /// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
   /// otherwise we drop connection.
   hb: Instant,
+  /// A rate limiter for websocket joins
+  rate_limiter: RateLimit,
 }
 
 impl Actor for WsSession {
@@ -57,6 +61,11 @@ impl Actor for WsSession {
     // before processing any other events.
     // across all routes within application
     let addr = ctx.address();
+
+    if !self.rate_limit_check(ctx) {
+      return;
+    }
+
     self
       .cs_addr
       .send(Connect {
@@ -98,6 +107,10 @@ impl Handler<WsMessage> for WsSession {
 /// WebSocket message handler
 impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsSession {
   fn handle(&mut self, result: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
+    if !self.rate_limit_check(ctx) {
+      return;
+    }
+
     let message = match result {
       Ok(m) => m,
       Err(e) => {
@@ -169,4 +182,14 @@ impl WsSession {
       ctx.ping(b"");
     });
   }
+
+  /// Check the rate limit, and stop the ctx if it fails
+  fn rate_limit_check(&mut self, ctx: &mut ws::WebsocketContext<Self>) -> bool {
+    let check = self.rate_limiter.message().check(self.ip.to_owned());
+    if !check {
+      debug!("Websocket join with IP: {} has been rate limited.", self.ip);
+      ctx.stop()
+    }
+    check
+  }
 }
index d0b3c04ee4067d7059c949adbf34dea5a6c73b0e..9297334e58ceea1c352e986631467c40cc500df1 100644 (file)
@@ -29,11 +29,11 @@ use lemmy_utils::{
   REQWEST_TIMEOUT,
 };
 use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
+use parking_lot::Mutex;
 use reqwest::Client;
 use reqwest_middleware::ClientBuilder;
 use reqwest_tracing::TracingMiddleware;
 use std::{env, sync::Arc, thread};
-use tokio::sync::Mutex;
 use tracing_actix_web::TracingLogger;
 
 embed_migrations!();
@@ -136,6 +136,7 @@ async fn main() -> Result<(), LemmyError> {
       .wrap(actix_web::middleware::Logger::default())
       .wrap(TracingLogger::<QuieterRootSpanBuilder>::new())
       .app_data(Data::new(context))
+      .app_data(Data::new(rate_limiter.clone()))
       // The routes
       .configure(|cfg| api_routes::config(cfg, &rate_limiter))
       .configure(|cfg| lemmy_apub::http::routes::config(cfg, &settings))