* Rate limit websocket joins.
* Removing async on mutex lock fn.
* Removing redundant ip
* Return early if check fails.
"lemmy_utils",
"lemmy_websocket",
"once_cell",
+ "parking_lot 0.12.0",
"percent-encoding",
"rand 0.8.4",
"reqwest",
"openssl",
"opentelemetry",
"opentelemetry-otlp",
+ "parking_lot 0.12.0",
"reqwest",
"reqwest-middleware",
"reqwest-tracing",
"lettre",
"once_cell",
"openssl",
+ "parking_lot 0.12.0",
"percent-encoding",
"rand 0.8.4",
"regex",
"lemmy_db_views_actor",
"lemmy_utils",
"opentelemetry",
+ "parking_lot 0.12.0",
"rand 0.8.4",
"reqwest",
"reqwest-middleware",
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
opentelemetry-otlp = "0.9"
tracing-opentelemetry = "0.16"
+parking_lot = "0.12"
reqwest = { version = "0.11.7", features = ["json"] }
html2md = "0.2.13"
once_cell = "1.8.0"
+parking_lot = "0.12"
[dev-dependencies]
serial_test = "0.5.1"
LemmyError,
};
use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
+ use parking_lot::Mutex;
use reqwest::Client;
use reqwest_middleware::ClientBuilder;
use std::sync::Arc;
- use tokio::sync::Mutex;
// TODO: would be nice if we didnt have to use a full context for tests.
// or at least write a helper function so this code is shared with main.rs
encoding = "0.2.33"
html2text = "0.2.1"
rosetta-i18n = "0.1"
+parking_lot = "0.12"
[build-dependencies]
rosetta-build = "0.1"
HttpResponse,
};
use futures::future::{ok, Ready};
+use parking_lot::Mutex;
use rate_limiter::{RateLimitType, RateLimiter};
use std::{
future::Future,
sync::Arc,
task::{Context, Poll},
};
-use tokio::sync::Mutex;
pub mod rate_limiter;
impl RateLimited {
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
- pub async fn check(self, ip_addr: IpAddr) -> bool {
+ pub fn check(self, ip_addr: IpAddr) -> bool {
// Does not need to be blocking because the RwLock in settings never held across await points,
// and the operation here locks only long enough to clone
let rate_limit = self.rate_limit_config;
- let mut limiter = self.rate_limiter.lock().await;
-
let (kind, interval) = match self.type_ {
RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
};
+ let mut limiter = self.rate_limiter.lock();
+
limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
}
}
let service = self.service.clone();
Box::pin(async move {
- if rate_limited.check(ip_addr).await {
+ if rate_limited.check(ip_addr) {
service.call(req).await
} else {
let (http_req, _) = req.into_parts();
actix-web-actors = { version = "4.1.0", default-features = false }
opentelemetry = "0.16"
tracing-opentelemetry = "0.16"
+parking_lot = "0.12"
// check if api call passes the rate limit, and generate future for later execution
let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
let passed = match user_operation_crud {
- UserOperationCrud::Register => rate_limiter.register().check(ip).await,
- UserOperationCrud::CreatePost => rate_limiter.post().check(ip).await,
- UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip).await,
- UserOperationCrud::CreateComment => rate_limiter.comment().check(ip).await,
- _ => rate_limiter.message().check(ip).await,
+ UserOperationCrud::Register => rate_limiter.register().check(ip),
+ UserOperationCrud::CreatePost => rate_limiter.post().check(ip),
+ UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip),
+ UserOperationCrud::CreateComment => rate_limiter.comment().check(ip),
+ _ => rate_limiter.message().check(ip),
};
let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
(passed, fut)
} else {
let user_operation = UserOperation::from_str(op)?;
let passed = match user_operation {
- UserOperation::GetCaptcha => rate_limiter.post().check(ip).await,
- _ => rate_limiter.message().check(ip).await,
+ UserOperation::GetCaptcha => rate_limiter.post().check(ip),
+ _ => rate_limiter.message().check(ip),
};
let fut = (message_handler)(context, msg.id, user_operation, data);
(passed, fut)
use actix::prelude::*;
use actix_web::{web, Error, HttpRequest, HttpResponse};
use actix_web_actors::ws;
-use lemmy_utils::{utils::get_ip, ConnectionId, IpAddr};
+use lemmy_utils::{rate_limit::RateLimit, utils::get_ip, ConnectionId, IpAddr};
use std::time::{Duration, Instant};
use tracing::{debug, error, info};
req: HttpRequest,
stream: web::Payload,
context: web::Data<LemmyContext>,
+ rate_limiter: web::Data<RateLimit>,
) -> Result<HttpResponse, Error> {
ws::start(
WsSession {
id: 0,
hb: Instant::now(),
ip: get_ip(&req.connection_info()),
+ rate_limiter: rate_limiter.as_ref().to_owned(),
},
&req,
stream,
/// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
/// otherwise we drop connection.
hb: Instant,
+ /// A rate limiter for websocket joins
+ rate_limiter: RateLimit,
}
impl Actor for WsSession {
// before processing any other events.
// across all routes within application
let addr = ctx.address();
+
+ if !self.rate_limit_check(ctx) {
+ return;
+ }
+
self
.cs_addr
.send(Connect {
/// WebSocket message handler
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsSession {
fn handle(&mut self, result: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
+ if !self.rate_limit_check(ctx) {
+ return;
+ }
+
let message = match result {
Ok(m) => m,
Err(e) => {
ctx.ping(b"");
});
}
+
+ /// Check the rate limit, and stop the ctx if it fails
+ fn rate_limit_check(&mut self, ctx: &mut ws::WebsocketContext<Self>) -> bool {
+ let check = self.rate_limiter.message().check(self.ip.to_owned());
+ if !check {
+ debug!("Websocket join with IP: {} has been rate limited.", self.ip);
+ ctx.stop()
+ }
+ check
+ }
}
REQWEST_TIMEOUT,
};
use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
+use parking_lot::Mutex;
use reqwest::Client;
use reqwest_middleware::ClientBuilder;
use reqwest_tracing::TracingMiddleware;
use std::{env, sync::Arc, thread};
-use tokio::sync::Mutex;
use tracing_actix_web::TracingLogger;
embed_migrations!();
.wrap(actix_web::middleware::Logger::default())
.wrap(TracingLogger::<QuieterRootSpanBuilder>::new())
.app_data(Data::new(context))
+ .app_data(Data::new(rate_limiter.clone()))
// The routes
.configure(|cfg| api_routes::config(cfg, &rate_limiter))
.configure(|cfg| lemmy_apub::http::routes::config(cfg, &settings))