use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr}; use actix_web::{ dev::{Service, ServiceRequest, ServiceResponse, Transform}, HttpResponse, }; use futures::future::{ok, Ready}; use rate_limiter::{RateLimitType, RateLimiter}; use std::{ future::Future, pin::Pin, rc::Rc, sync::Arc, task::{Context, Poll}, }; use tokio::sync::Mutex; pub mod rate_limiter; #[derive(Debug, Clone)] pub struct RateLimit { // it might be reasonable to use a std::sync::Mutex here, since we don't need to lock this // across await points pub rate_limiter: Arc>, pub rate_limit_config: RateLimitConfig, } #[derive(Debug, Clone)] pub struct RateLimited { rate_limiter: Arc>, rate_limit_config: RateLimitConfig, type_: RateLimitType, } pub struct RateLimitedMiddleware { rate_limited: RateLimited, service: Rc, } impl RateLimit { pub fn message(&self) -> RateLimited { self.kind(RateLimitType::Message) } pub fn post(&self) -> RateLimited { self.kind(RateLimitType::Post) } pub fn register(&self) -> RateLimited { self.kind(RateLimitType::Register) } pub fn image(&self) -> RateLimited { self.kind(RateLimitType::Image) } pub fn comment(&self) -> RateLimited { self.kind(RateLimitType::Comment) } fn kind(&self, type_: RateLimitType) -> RateLimited { RateLimited { rate_limiter: self.rate_limiter.clone(), rate_limit_config: self.rate_limit_config.clone(), type_, } } } impl RateLimited { /// Returns true if the request passed the rate limit, false if it failed and should be rejected. pub async fn check(self, ip_addr: IpAddr) -> bool { // Does not need to be blocking because the RwLock in settings never held across await points, // and the operation here locks only long enough to clone let rate_limit = self.rate_limit_config; let mut limiter = self.rate_limiter.lock().await; let (kind, interval) = match self.type_ { RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second), RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second), RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second), RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second), RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second), }; limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) } } impl Transform for RateLimited where S: Service + 'static, S::Future: 'static, { type Response = S::Response; type Error = actix_web::Error; type InitError = (); type Transform = RateLimitedMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(RateLimitedMiddleware { rate_limited: self.clone(), service: Rc::new(service), }) } } type FutResult = dyn Future>; impl Service for RateLimitedMiddleware where S: Service + 'static, S::Future: 'static, { type Response = S::Response; type Error = actix_web::Error; type Future = Pin>>; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } fn call(&self, req: ServiceRequest) -> Self::Future { let ip_addr = get_ip(&req.connection_info()); let rate_limited = self.rate_limited.clone(); let service = self.service.clone(); Box::pin(async move { if rate_limited.check(ip_addr).await { service.call(req).await } else { let (http_req, _) = req.into_parts(); // if rate limit was hit, respond with http 400 Ok(ServiceResponse::new( http_req, HttpResponse::BadRequest().finish(), )) } }) } }