-use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr, LemmyError};
-use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
+use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr};
+use actix_web::{
+ dev::{Service, ServiceRequest, ServiceResponse, Transform},
+ HttpResponse,
+};
use futures::future::{ok, Ready};
use rate_limiter::{RateLimitType, RateLimiter};
use std::{
future::Future,
pin::Pin,
+ rc::Rc,
sync::Arc,
task::{Context, Poll},
};
pub struct RateLimitedMiddleware<S> {
rate_limited: RateLimited,
- service: S,
+ service: Rc<S>,
}
impl RateLimit {
}
impl RateLimited {
- pub async fn wrap<T, E>(
- self,
- ip_addr: IpAddr,
- fut: impl Future<Output = Result<T, E>>,
- ) -> Result<T, E>
- where
- E: From<LemmyError>,
- {
+ /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
+ pub async fn check(self, ip_addr: IpAddr) -> bool {
// Does not need to be blocking because the RwLock in settings never held across await points,
// and the operation here locks only long enough to clone
let rate_limit = self.rate_limit_config;
- // before
- {
- let mut limiter = self.rate_limiter.lock().await;
-
- match self.type_ {
- RateLimitType::Message => {
- limiter.check_rate_limit_full(
- self.type_,
- &ip_addr,
- rate_limit.message,
- rate_limit.message_per_second,
- )?;
-
- drop(limiter);
- return fut.await;
- }
- RateLimitType::Post => {
- limiter.check_rate_limit_full(
- self.type_,
- &ip_addr,
- rate_limit.post,
- rate_limit.post_per_second,
- )?;
- }
- RateLimitType::Register => {
- limiter.check_rate_limit_full(
- self.type_,
- &ip_addr,
- rate_limit.register,
- rate_limit.register_per_second,
- )?;
- }
- RateLimitType::Image => {
- limiter.check_rate_limit_full(
- self.type_,
- &ip_addr,
- rate_limit.image,
- rate_limit.image_per_second,
- )?;
- }
- RateLimitType::Comment => {
- limiter.check_rate_limit_full(
- self.type_,
- &ip_addr,
- rate_limit.comment,
- rate_limit.comment_per_second,
- )?;
- }
- };
- }
-
- let res = fut.await;
+ let mut limiter = self.rate_limiter.lock().await;
- res
+ let (kind, interval) = match self.type_ {
+ RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
+ RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
+ RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second),
+ RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
+ RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
+ };
+ limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
}
}
impl<S> Transform<S, ServiceRequest> for RateLimited
where
- S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
+ S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
S::Future: 'static,
{
type Response = S::Response;
fn new_transform(&self, service: S) -> Self::Future {
ok(RateLimitedMiddleware {
rate_limited: self.clone(),
- service,
+ service: Rc::new(service),
})
}
}
impl<S> Service<ServiceRequest> for RateLimitedMiddleware<S>
where
- S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
+ S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
S::Future: 'static,
{
type Response = S::Response;
fn call(&self, req: ServiceRequest) -> Self::Future {
let ip_addr = get_ip(&req.connection_info());
- let fut = self
- .rate_limited
- .clone()
- .wrap(ip_addr, self.service.call(req));
-
- Box::pin(async move { fut.await.map_err(actix_web::Error::from) })
+ let rate_limited = self.rate_limited.clone();
+ let service = self.service.clone();
+
+ Box::pin(async move {
+ if rate_limited.check(ip_addr).await {
+ service.call(req).await
+ } else {
+ let (http_req, _) = req.into_parts();
+ // if rate limit was hit, respond with http 400
+ Ok(ServiceResponse::new(
+ http_req,
+ HttpResponse::BadRequest().finish(),
+ ))
+ }
+ })
}
}
-use crate::{IpAddr, LemmyError};
-use std::{collections::HashMap, time::SystemTime};
+use crate::IpAddr;
+use std::{collections::HashMap, time::Instant};
use strum::IntoEnumIterator;
use tracing::debug;
#[derive(Debug, Clone)]
struct RateLimitBucket {
- last_checked: SystemTime,
+ last_checked: Instant,
allowance: f64,
}
bucket.insert(
ip.clone(),
RateLimitBucket {
- last_checked: SystemTime::now(),
+ last_checked: Instant::now(),
allowance: -2f64,
},
);
}
/// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
+ ///
+ /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
#[allow(clippy::float_cmp)]
pub(super) fn check_rate_limit_full(
&mut self,
ip: &IpAddr,
rate: i32,
per: i32,
- ) -> Result<(), LemmyError> {
+ ) -> bool {
self.insert_ip(ip);
if let Some(bucket) = self.buckets.get_mut(&type_) {
if let Some(rate_limit) = bucket.get_mut(ip) {
- let current = SystemTime::now();
- let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
+ let current = Instant::now();
+ let time_passed = current.duration_since(rate_limit.last_checked).as_secs() as f64;
// The initial value
if rate_limit.allowance == -2f64 {
time_passed,
rate_limit.allowance
);
- Err(LemmyError::from_error_message(
- anyhow::anyhow!(
- "Too many requests. type: {}, IP: {}, {} per {} seconds",
- type_.as_ref(),
- ip,
- rate,
- per
- ),
- "too_many_requests",
- ))
+ false
} else {
rate_limit.allowance -= 1.0;
- Ok(())
+ true
}
} else {
- Ok(())
+ true
}
} else {
- Ok(())
+ true
}
}
}
.as_str()
.ok_or_else(|| LemmyError::from_message("missing op"))?;
- if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
- let fut = (message_handler_crud)(context, msg.id, user_operation_crud.clone(), data);
- match user_operation_crud {
- UserOperationCrud::Register => rate_limiter.register().wrap(ip, fut).await,
- UserOperationCrud::CreatePost => rate_limiter.post().wrap(ip, fut).await,
- UserOperationCrud::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
- UserOperationCrud::CreateComment => rate_limiter.comment().wrap(ip, fut).await,
- _ => rate_limiter.message().wrap(ip, fut).await,
- }
+ // check if api call passes the rate limit, and generate future for later execution
+ let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
+ let passed = match user_operation_crud {
+ UserOperationCrud::Register => rate_limiter.register().check(ip).await,
+ UserOperationCrud::CreatePost => rate_limiter.post().check(ip).await,
+ UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip).await,
+ UserOperationCrud::CreateComment => rate_limiter.comment().check(ip).await,
+ _ => rate_limiter.message().check(ip).await,
+ };
+ let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
+ (passed, fut)
} else {
let user_operation = UserOperation::from_str(op)?;
- let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
- match user_operation {
- UserOperation::GetCaptcha => rate_limiter.post().wrap(ip, fut).await,
- _ => rate_limiter.message().wrap(ip, fut).await,
- }
+ let passed = match user_operation {
+ UserOperation::GetCaptcha => rate_limiter.post().check(ip).await,
+ _ => rate_limiter.message().check(ip).await,
+ };
+ let fut = (message_handler)(context, msg.id, user_operation, data);
+ (passed, fut)
+ };
+
+ // if rate limit passed, execute api call future
+ if passed {
+ fut.await
+ } else {
+ // if rate limit was hit, respond with empty message
+ Ok("".to_string())
}
}
}