From: Nutomic Date: Fri, 9 Dec 2022 15:31:47 +0000 (+0000) Subject: Rework websocket (#2598) X-Git-Url: http://these/git/readmes/%7B%7D/static/%7B%60https:/%7Burl%7D?a=commitdiff_plain;h=2732a5bf0707cfca38cf6e826d867d69eb6f4888;p=lemmy.git Rework websocket (#2598) * Merge websocket crate into api_common * Add SendActivity trait so that api crates compile in parallel with lemmy_apub * Rework websocket code * fix websocket heartbeat --- diff --git a/Cargo.lock b/Cargo.lock index 5f68a511..72876f85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -51,7 +51,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f728064aca1c318585bf4bb04ffcfac9e75e508ab4e8b1bd9ba5dfe04e2cbed5" dependencies = [ "actix-rt", - "actix_derive", "bitflags", "bytes", "crossbeam-channel", @@ -283,14 +282,16 @@ dependencies = [ ] [[package]] -name = "actix_derive" -version = "0.6.0" +name = "actix-ws" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d44b8fee1ced9671ba043476deddef739dd0959bf77030b26b738cc591737a7" +checksum = "535aec173810be3ca6f25dd5b4d431ae7125d62000aa3cbae1ec739921b02cf3" dependencies = [ - "proc-macro2 1.0.47", - "quote 1.0.21", - "syn 1.0.103", + "actix-codec", + "actix-http", + "actix-web", + "futures-core", + "tokio", ] [[package]] @@ -2060,15 +2061,15 @@ dependencies = [ name = "lemmy_api_common" version = "0.16.5" dependencies = [ - "actix", "actix-rt", "actix-web", - "actix-web-actors", + "actix-ws", "anyhow", "background-jobs", "chrono", "diesel", "encoding", + "futures", "lemmy_db_schema", "lemmy_db_views", "lemmy_db_views_actor", @@ -2118,7 +2119,6 @@ version = "0.16.5" dependencies = [ "activitypub_federation", "activitystreams-kinds", - "actix", "actix-rt", "actix-web", "anyhow", @@ -2248,14 +2248,17 @@ name = "lemmy_server" version = "0.16.5" dependencies = [ "activitypub_federation", - "actix", + "actix-rt", "actix-web", + "actix-web-actors", + "actix-ws", "clokwerk", "console-subscriber", "diesel", "diesel-async", "diesel_migrations", "doku", + "futures", "lemmy_api", "lemmy_api_common", "lemmy_api_crud", @@ -2266,6 +2269,7 @@ dependencies = [ "opentelemetry 0.17.0", "opentelemetry-otlp", "parking_lot", + "rand 0.8.5", "reqwest", "reqwest-middleware", "reqwest-retry", diff --git a/Cargo.toml b/Cargo.toml index e83f0c83..8b62bedc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,7 +64,6 @@ diesel = "2.0.2" diesel_migrations = "2.0.0" diesel-async = "0.1.1" serde = { version = "1.0.147", features = ["derive"] } -actix = "0.13.0" actix-web = { version = "4.2.1", default-features = false, features = ["macros", "rustls"] } tracing = "0.1.36" tracing-actix-web = { version = "0.6.1", default-features = false } @@ -106,6 +105,7 @@ rosetta-i18n = "0.1.2" rand = "0.8.5" opentelemetry = { version = "0.17.0", features = ["rt-tokio"] } tracing-opentelemetry = { version = "0.17.2" } +actix-ws = "0.2.0" [dependencies] lemmy_api = { workspace = true } @@ -120,7 +120,6 @@ diesel = { workspace = true } diesel_migrations = { workspace = true } diesel-async = { workspace = true } serde = { workspace = true } -actix = { workspace = true } actix-web = { workspace = true } tracing = { workspace = true } tracing-actix-web = { workspace = true } @@ -136,7 +135,12 @@ doku = { workspace = true } parking_lot = { workspace = true } reqwest-retry = { workspace = true } serde_json = { workspace = true } +futures = { workspace = true } +actix-ws = { workspace = true } tracing-opentelemetry = { workspace = true, optional = true } opentelemetry = { workspace = true, optional = true } +actix-web-actors = { version = "4.1.0", default-features = false } +actix-rt = "2.6" +rand = { workspace = true } console-subscriber = { version = "0.1.8", optional = true } opentelemetry-otlp = { version = "0.10.0", optional = true } diff --git a/crates/api/src/comment_report/create.rs b/crates/api/src/comment_report/create.rs index bf3fec0a..c026d166 100644 --- a/crates/api/src/comment_report/create.rs +++ b/crates/api/src/comment_report/create.rs @@ -4,7 +4,7 @@ use lemmy_api_common::{ comment::{CommentReportResponse, CreateCommentReport}, context::LemmyContext, utils::{check_community_ban, get_local_user_view_from_jwt}, - websocket::{messages::SendModRoomMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{ source::{ @@ -58,12 +58,15 @@ impl Perform for CreateCommentReport { comment_report_view, }; - context.chat_server().do_send(SendModRoomMessage { - op: UserOperation::CreateCommentReport, - response: res.clone(), - community_id: comment_view.community.id, - websocket_id, - }); + context + .chat_server() + .send_mod_room_message( + UserOperation::CreateCommentReport, + &res, + comment_view.community.id, + websocket_id, + ) + .await?; Ok(res) } diff --git a/crates/api/src/comment_report/resolve.rs b/crates/api/src/comment_report/resolve.rs index 9df11fc2..de4f418d 100644 --- a/crates/api/src/comment_report/resolve.rs +++ b/crates/api/src/comment_report/resolve.rs @@ -4,7 +4,7 @@ use lemmy_api_common::{ comment::{CommentReportResponse, ResolveCommentReport}, context::LemmyContext, utils::{get_local_user_view_from_jwt, is_mod_or_admin}, - websocket::{messages::SendModRoomMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{source::comment_report::CommentReport, traits::Reportable}; use lemmy_db_views::structs::CommentReportView; @@ -49,12 +49,15 @@ impl Perform for ResolveCommentReport { comment_report_view, }; - context.chat_server().do_send(SendModRoomMessage { - op: UserOperation::ResolveCommentReport, - response: res.clone(), - community_id: report.community.id, - websocket_id, - }); + context + .chat_server() + .send_mod_room_message( + UserOperation::ResolveCommentReport, + &res, + report.community.id, + websocket_id, + ) + .await?; Ok(res) } diff --git a/crates/api/src/community/add_mod.rs b/crates/api/src/community/add_mod.rs index ce3082f7..28f56cff 100644 --- a/crates/api/src/community/add_mod.rs +++ b/crates/api/src/community/add_mod.rs @@ -4,7 +4,7 @@ use lemmy_api_common::{ community::{AddModToCommunity, AddModToCommunityResponse}, context::LemmyContext, utils::{get_local_user_view_from_jwt, is_mod_or_admin}, - websocket::{messages::SendCommunityRoomMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{ source::{ @@ -70,12 +70,15 @@ impl Perform for AddModToCommunity { let moderators = CommunityModeratorView::for_community(context.pool(), community_id).await?; let res = AddModToCommunityResponse { moderators }; - context.chat_server().do_send(SendCommunityRoomMessage { - op: UserOperation::AddModToCommunity, - response: res.clone(), - community_id, - websocket_id, - }); + context + .chat_server() + .send_community_room_message( + &UserOperation::AddModToCommunity, + &res, + community_id, + websocket_id, + ) + .await?; Ok(res) } } diff --git a/crates/api/src/community/ban.rs b/crates/api/src/community/ban.rs index fb5e7fcf..e7962d24 100644 --- a/crates/api/src/community/ban.rs +++ b/crates/api/src/community/ban.rs @@ -4,7 +4,7 @@ use lemmy_api_common::{ community::{BanFromCommunity, BanFromCommunityResponse}, context::LemmyContext, utils::{get_local_user_view_from_jwt, is_mod_or_admin, remove_user_data_in_community}, - websocket::{messages::SendCommunityRoomMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{ source::{ @@ -95,12 +95,15 @@ impl Perform for BanFromCommunity { banned: data.ban, }; - context.chat_server().do_send(SendCommunityRoomMessage { - op: UserOperation::BanFromCommunity, - response: res.clone(), - community_id, - websocket_id, - }); + context + .chat_server() + .send_community_room_message( + &UserOperation::BanFromCommunity, + &res, + community_id, + websocket_id, + ) + .await?; Ok(res) } diff --git a/crates/api/src/local_user/add_admin.rs b/crates/api/src/local_user/add_admin.rs index 78357f0c..3b22c447 100644 --- a/crates/api/src/local_user/add_admin.rs +++ b/crates/api/src/local_user/add_admin.rs @@ -4,7 +4,7 @@ use lemmy_api_common::{ context::LemmyContext, person::{AddAdmin, AddAdminResponse}, utils::{get_local_user_view_from_jwt, is_admin}, - websocket::{messages::SendAllMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{ source::{ @@ -56,11 +56,10 @@ impl Perform for AddAdmin { let res = AddAdminResponse { admins }; - context.chat_server().do_send(SendAllMessage { - op: UserOperation::AddAdmin, - response: res.clone(), - websocket_id, - }); + context + .chat_server() + .send_all_message(UserOperation::AddAdmin, &res, websocket_id) + .await?; Ok(res) } diff --git a/crates/api/src/local_user/ban_person.rs b/crates/api/src/local_user/ban_person.rs index 0bb25234..d45528ef 100644 --- a/crates/api/src/local_user/ban_person.rs +++ b/crates/api/src/local_user/ban_person.rs @@ -4,7 +4,7 @@ use lemmy_api_common::{ context::LemmyContext, person::{BanPerson, BanPersonResponse}, utils::{get_local_user_view_from_jwt, is_admin, remove_user_data}, - websocket::{messages::SendAllMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{ source::{ @@ -79,11 +79,10 @@ impl Perform for BanPerson { banned: data.ban, }; - context.chat_server().do_send(SendAllMessage { - op: UserOperation::BanPerson, - response: res.clone(), - websocket_id, - }); + context + .chat_server() + .send_all_message(UserOperation::BanPerson, &res, websocket_id) + .await?; Ok(res) } diff --git a/crates/api/src/local_user/get_captcha.rs b/crates/api/src/local_user/get_captcha.rs index 50a2bdba..3671f79c 100644 --- a/crates/api/src/local_user/get_captcha.rs +++ b/crates/api/src/local_user/get_captcha.rs @@ -5,7 +5,7 @@ use chrono::Duration; use lemmy_api_common::{ context::LemmyContext, person::{CaptchaResponse, GetCaptcha, GetCaptchaResponse}, - websocket::messages::CaptchaItem, + websocket::structs::CaptchaItem, }; use lemmy_db_schema::{source::local_site::LocalSite, utils::naive_now}; use lemmy_utils::{error::LemmyError, ConnectionId}; @@ -47,7 +47,7 @@ impl Perform for GetCaptcha { }; // Stores the captcha item on the queue - context.chat_server().do_send(captcha_item); + context.chat_server().add_captcha(captcha_item)?; Ok(GetCaptchaResponse { ok: Some(CaptchaResponse { png, wav, uuid }), diff --git a/crates/api/src/post_report/create.rs b/crates/api/src/post_report/create.rs index 71be4362..7396d3f8 100644 --- a/crates/api/src/post_report/create.rs +++ b/crates/api/src/post_report/create.rs @@ -4,7 +4,7 @@ use lemmy_api_common::{ context::LemmyContext, post::{CreatePostReport, PostReportResponse}, utils::{check_community_ban, get_local_user_view_from_jwt}, - websocket::{messages::SendModRoomMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{ source::{ @@ -58,12 +58,15 @@ impl Perform for CreatePostReport { let res = PostReportResponse { post_report_view }; - context.chat_server().do_send(SendModRoomMessage { - op: UserOperation::CreatePostReport, - response: res.clone(), - community_id: post_view.community.id, - websocket_id, - }); + context + .chat_server() + .send_mod_room_message( + UserOperation::CreatePostReport, + &res, + post_view.community.id, + websocket_id, + ) + .await?; Ok(res) } diff --git a/crates/api/src/post_report/resolve.rs b/crates/api/src/post_report/resolve.rs index 0e2b9e16..615b7f82 100644 --- a/crates/api/src/post_report/resolve.rs +++ b/crates/api/src/post_report/resolve.rs @@ -4,7 +4,7 @@ use lemmy_api_common::{ context::LemmyContext, post::{PostReportResponse, ResolvePostReport}, utils::{get_local_user_view_from_jwt, is_mod_or_admin}, - websocket::{messages::SendModRoomMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{source::post_report::PostReport, traits::Reportable}; use lemmy_db_views::structs::PostReportView; @@ -46,12 +46,15 @@ impl Perform for ResolvePostReport { let res = PostReportResponse { post_report_view }; - context.chat_server().do_send(SendModRoomMessage { - op: UserOperation::ResolvePostReport, - response: res.clone(), - community_id: report.community.id, - websocket_id, - }); + context + .chat_server() + .send_mod_room_message( + UserOperation::ResolvePostReport, + &res, + report.community.id, + websocket_id, + ) + .await?; Ok(res) } diff --git a/crates/api/src/private_message_report/create.rs b/crates/api/src/private_message_report/create.rs index 66875c61..9267ee7f 100644 --- a/crates/api/src/private_message_report/create.rs +++ b/crates/api/src/private_message_report/create.rs @@ -4,7 +4,7 @@ use lemmy_api_common::{ context::LemmyContext, private_message::{CreatePrivateMessageReport, PrivateMessageReportResponse}, utils::get_local_user_view_from_jwt, - websocket::{messages::SendModRoomMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{ newtypes::CommunityId, @@ -57,12 +57,15 @@ impl Perform for CreatePrivateMessageReport { private_message_report_view, }; - context.chat_server().do_send(SendModRoomMessage { - op: UserOperation::CreatePrivateMessageReport, - response: res.clone(), - community_id: CommunityId(0), - websocket_id, - }); + context + .chat_server() + .send_mod_room_message( + UserOperation::CreatePrivateMessageReport, + &res, + CommunityId(0), + websocket_id, + ) + .await?; // TODO: consider federating this diff --git a/crates/api/src/private_message_report/resolve.rs b/crates/api/src/private_message_report/resolve.rs index 2a3f677a..a48a458d 100644 --- a/crates/api/src/private_message_report/resolve.rs +++ b/crates/api/src/private_message_report/resolve.rs @@ -4,7 +4,7 @@ use lemmy_api_common::{ context::LemmyContext, private_message::{PrivateMessageReportResponse, ResolvePrivateMessageReport}, utils::{get_local_user_view_from_jwt, is_admin}, - websocket::{messages::SendModRoomMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{ newtypes::CommunityId, @@ -48,12 +48,15 @@ impl Perform for ResolvePrivateMessageReport { private_message_report_view, }; - context.chat_server().do_send(SendModRoomMessage { - op: UserOperation::ResolvePrivateMessageReport, - response: res.clone(), - community_id: CommunityId(0), - websocket_id, - }); + context + .chat_server() + .send_mod_room_message( + UserOperation::ResolvePrivateMessageReport, + &res, + CommunityId(0), + websocket_id, + ) + .await?; Ok(res) } diff --git a/crates/api/src/websocket.rs b/crates/api/src/websocket.rs index ca755bde..041cb775 100644 --- a/crates/api/src/websocket.rs +++ b/crates/api/src/websocket.rs @@ -3,18 +3,15 @@ use actix_web::web::Data; use lemmy_api_common::{ context::LemmyContext, utils::get_local_user_view_from_jwt, - websocket::{ - messages::{JoinCommunityRoom, JoinModRoom, JoinPostRoom, JoinUserRoom}, - structs::{ - CommunityJoin, - CommunityJoinResponse, - ModJoin, - ModJoinResponse, - PostJoin, - PostJoinResponse, - UserJoin, - UserJoinResponse, - }, + websocket::structs::{ + CommunityJoin, + CommunityJoinResponse, + ModJoin, + ModJoinResponse, + PostJoin, + PostJoinResponse, + UserJoin, + UserJoinResponse, }, }; use lemmy_utils::{error::LemmyError, ConnectionId}; @@ -34,10 +31,9 @@ impl Perform for UserJoin { get_local_user_view_from_jwt(&data.auth, context.pool(), context.secret()).await?; if let Some(ws_id) = websocket_id { - context.chat_server().do_send(JoinUserRoom { - local_user_id: local_user_view.local_user.id, - id: ws_id, - }); + context + .chat_server() + .join_user_room(local_user_view.local_user.id, ws_id)?; } Ok(UserJoinResponse { joined: true }) @@ -57,10 +53,9 @@ impl Perform for CommunityJoin { let data: &CommunityJoin = self; if let Some(ws_id) = websocket_id { - context.chat_server().do_send(JoinCommunityRoom { - community_id: data.community_id, - id: ws_id, - }); + context + .chat_server() + .join_community_room(data.community_id, ws_id)?; } Ok(CommunityJoinResponse { joined: true }) @@ -80,10 +75,9 @@ impl Perform for ModJoin { let data: &ModJoin = self; if let Some(ws_id) = websocket_id { - context.chat_server().do_send(JoinModRoom { - community_id: data.community_id, - id: ws_id, - }); + context + .chat_server() + .join_mod_room(data.community_id, ws_id)?; } Ok(ModJoinResponse { joined: true }) @@ -103,10 +97,7 @@ impl Perform for PostJoin { let data: &PostJoin = self; if let Some(ws_id) = websocket_id { - context.chat_server().do_send(JoinPostRoom { - post_id: data.post_id, - id: ws_id, - }); + context.chat_server().join_post_room(data.post_id, ws_id)?; } Ok(PostJoinResponse { joined: true }) diff --git a/crates/api_common/Cargo.toml b/crates/api_common/Cargo.toml index a4c80da5..21eed1c6 100644 --- a/crates/api_common/Cargo.toml +++ b/crates/api_common/Cargo.toml @@ -38,14 +38,14 @@ webpage = { version = "1.4.0", default-features = false, features = ["serde"], o encoding = { version = "0.2.33", optional = true } rand = { workspace = true } serde_json = { workspace = true } -actix = { workspace = true } anyhow = { workspace = true } tokio = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } opentelemetry = { workspace = true } tracing-opentelemetry = { workspace = true } -actix-web-actors = { version = "4.1.0", default-features = false } +actix-ws = { workspace = true } +futures = { workspace = true } background-jobs = "0.13.0" [dev-dependencies] diff --git a/crates/api_common/src/context.rs b/crates/api_common/src/context.rs index e552af53..eb53b2af 100644 --- a/crates/api_common/src/context.rs +++ b/crates/api_common/src/context.rs @@ -1,15 +1,15 @@ use crate::websocket::chat_server::ChatServer; -use actix::Addr; use lemmy_db_schema::{source::secret::Secret, utils::DbPool}; use lemmy_utils::{ rate_limit::RateLimitCell, settings::{structs::Settings, SETTINGS}, }; use reqwest_middleware::ClientWithMiddleware; +use std::sync::Arc; pub struct LemmyContext { pool: DbPool, - chat_server: Addr, + chat_server: Arc, client: ClientWithMiddleware, settings: Settings, secret: Secret, @@ -19,7 +19,7 @@ pub struct LemmyContext { impl LemmyContext { pub fn create( pool: DbPool, - chat_server: Addr, + chat_server: Arc, client: ClientWithMiddleware, settings: Settings, secret: Secret, @@ -37,7 +37,7 @@ impl LemmyContext { pub fn pool(&self) -> &DbPool { &self.pool } - pub fn chat_server(&self) -> &Addr { + pub fn chat_server(&self) -> &Arc { &self.chat_server } pub fn client(&self) -> &ClientWithMiddleware { diff --git a/crates/api_common/src/websocket/chat_server.rs b/crates/api_common/src/websocket/chat_server.rs index b669a4fd..bb9e7b6d 100644 --- a/crates/api_common/src/websocket/chat_server.rs +++ b/crates/api_common/src/websocket/chat_server.rs @@ -1,68 +1,30 @@ use crate::{ comment::CommentResponse, - context::LemmyContext, post::PostResponse, - websocket::{ - messages::{CaptchaItem, StandardMessage, WsMessage}, - serialize_websocket_message, - OperationType, - UserOperation, - UserOperationApub, - UserOperationCrud, - }, + websocket::{serialize_websocket_message, structs::CaptchaItem, OperationType}, }; -use actix::prelude::*; +use actix_ws::Session; use anyhow::Context as acontext; -use lemmy_db_schema::{ - newtypes::{CommunityId, LocalUserId, PostId}, - source::secret::Secret, - utils::DbPool, -}; -use lemmy_utils::{ - error::LemmyError, - location_info, - rate_limit::RateLimitCell, - settings::structs::Settings, - ConnectionId, - IpAddr, -}; -use rand::rngs::ThreadRng; -use reqwest_middleware::ClientWithMiddleware; +use futures::future::try_join_all; +use lemmy_db_schema::newtypes::{CommunityId, LocalUserId, PostId}; +use lemmy_utils::{error::LemmyError, location_info, ConnectionId}; +use rand::{rngs::StdRng, SeedableRng}; use serde::Serialize; -use serde_json::Value; use std::{ collections::{HashMap, HashSet}, - future::Future, - str::FromStr, + sync::{Mutex, MutexGuard}, }; -use tokio::macros::support::Pin; - -type MessageHandlerType = fn( - context: LemmyContext, - id: ConnectionId, - op: UserOperation, - data: &str, -) -> Pin> + '_>>; - -type MessageHandlerCrudType = fn( - context: LemmyContext, - id: ConnectionId, - op: UserOperationCrud, - data: &str, -) -> Pin> + '_>>; - -type MessageHandlerApubType = fn( - context: LemmyContext, - id: ConnectionId, - op: UserOperationApub, - data: &str, -) -> Pin> + '_>>; +use tracing::log::warn; /// `ChatServer` manages chat rooms and responsible for coordinating chat /// session. pub struct ChatServer { + inner: Mutex, +} + +pub struct ChatServerInner { /// A map from generated random ID to session addr - pub sessions: HashMap, + pub sessions: HashMap, /// A map from post_id to set of connectionIDs pub post_rooms: HashMap>, @@ -76,91 +38,53 @@ pub struct ChatServer { /// sessions (IE clients) pub(super) user_rooms: HashMap>, - pub(super) rng: ThreadRng, - - /// The DB Pool - pub(super) pool: DbPool, - - /// The Settings - pub(super) settings: Settings, - - /// The Secrets - pub(super) secret: Secret, + pub(super) rng: StdRng, /// A list of the current captchas pub(super) captchas: Vec, - - message_handler: MessageHandlerType, - message_handler_crud: MessageHandlerCrudType, - message_handler_apub: MessageHandlerApubType, - - /// An HTTP Client - client: ClientWithMiddleware, - - rate_limit_cell: RateLimitCell, -} - -pub struct SessionInfo { - pub addr: Recipient, - pub ip: IpAddr, } /// `ChatServer` is an actor. It maintains list of connection client session. /// And manages available rooms. Peers send messages to other peers in same /// room through `ChatServer`. impl ChatServer { - #![allow(clippy::too_many_arguments)] - pub fn startup( - pool: DbPool, - message_handler: MessageHandlerType, - message_handler_crud: MessageHandlerCrudType, - message_handler_apub: MessageHandlerApubType, - client: ClientWithMiddleware, - settings: Settings, - secret: Secret, - rate_limit_cell: RateLimitCell, - ) -> ChatServer { + pub fn startup() -> ChatServer { ChatServer { - sessions: HashMap::new(), - post_rooms: HashMap::new(), - community_rooms: HashMap::new(), - mod_rooms: HashMap::new(), - user_rooms: HashMap::new(), - rng: rand::thread_rng(), - pool, - captchas: Vec::new(), - message_handler, - message_handler_crud, - message_handler_apub, - client, - settings, - secret, - rate_limit_cell, + inner: Mutex::new(ChatServerInner { + sessions: Default::default(), + post_rooms: Default::default(), + community_rooms: Default::default(), + mod_rooms: Default::default(), + user_rooms: Default::default(), + rng: StdRng::from_entropy(), + captchas: vec![], + }), } } pub fn join_community_room( - &mut self, + &self, community_id: CommunityId, id: ConnectionId, ) -> Result<(), LemmyError> { + let mut inner = self.inner()?; // remove session from all rooms - for sessions in self.community_rooms.values_mut() { + for sessions in inner.community_rooms.values_mut() { sessions.remove(&id); } // Also leave all post rooms // This avoids double messages - for sessions in self.post_rooms.values_mut() { + for sessions in inner.post_rooms.values_mut() { sessions.remove(&id); } // If the room doesn't exist yet - if self.community_rooms.get_mut(&community_id).is_none() { - self.community_rooms.insert(community_id, HashSet::new()); + if inner.community_rooms.get_mut(&community_id).is_none() { + inner.community_rooms.insert(community_id, HashSet::new()); } - self + inner .community_rooms .get_mut(&community_id) .context(location_info!())? @@ -169,21 +93,22 @@ impl ChatServer { } pub fn join_mod_room( - &mut self, + &self, community_id: CommunityId, id: ConnectionId, ) -> Result<(), LemmyError> { + let mut inner = self.inner()?; // remove session from all rooms - for sessions in self.mod_rooms.values_mut() { + for sessions in inner.mod_rooms.values_mut() { sessions.remove(&id); } // If the room doesn't exist yet - if self.mod_rooms.get_mut(&community_id).is_none() { - self.mod_rooms.insert(community_id, HashSet::new()); + if inner.mod_rooms.get_mut(&community_id).is_none() { + inner.mod_rooms.insert(community_id, HashSet::new()); } - self + inner .mod_rooms .get_mut(&community_id) .context(location_info!())? @@ -191,9 +116,10 @@ impl ChatServer { Ok(()) } - pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> { + pub fn join_post_room(&self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> { + let mut inner = self.inner()?; // remove session from all rooms - for sessions in self.post_rooms.values_mut() { + for sessions in inner.post_rooms.values_mut() { sessions.remove(&id); } @@ -202,16 +128,16 @@ impl ChatServer { // TODO found a bug, whereby community messages like // delete and remove aren't sent, because // you left the community room - for sessions in self.community_rooms.values_mut() { + for sessions in inner.community_rooms.values_mut() { sessions.remove(&id); } // If the room doesn't exist yet - if self.post_rooms.get_mut(&post_id).is_none() { - self.post_rooms.insert(post_id, HashSet::new()); + if inner.post_rooms.get_mut(&post_id).is_none() { + inner.post_rooms.insert(post_id, HashSet::new()); } - self + inner .post_rooms .get_mut(&post_id) .context(location_info!())? @@ -220,31 +146,27 @@ impl ChatServer { Ok(()) } - pub fn join_user_room( - &mut self, - user_id: LocalUserId, - id: ConnectionId, - ) -> Result<(), LemmyError> { + pub fn join_user_room(&self, user_id: LocalUserId, id: ConnectionId) -> Result<(), LemmyError> { + let mut inner = self.inner()?; // remove session from all rooms - for sessions in self.user_rooms.values_mut() { + for sessions in inner.user_rooms.values_mut() { sessions.remove(&id); } // If the room doesn't exist yet - if self.user_rooms.get_mut(&user_id).is_none() { - self.user_rooms.insert(user_id, HashSet::new()); + if inner.user_rooms.get_mut(&user_id).is_none() { + inner.user_rooms.insert(user_id, HashSet::new()); } - self + inner .user_rooms .get_mut(&user_id) .context(location_info!())? .insert(id); - Ok(()) } - fn send_post_room_message( + async fn send_post_room_message( &self, op: &OP, response: &Response, @@ -255,21 +177,14 @@ impl ChatServer { OP: OperationType + ToString, Response: Serialize, { - let res_str = &serialize_websocket_message(op, response)?; - if let Some(sessions) = self.post_rooms.get(&post_id) { - for id in sessions { - if let Some(my_id) = websocket_id { - if *id == my_id { - continue; - } - } - self.sendit(res_str, *id); - } - } + let msg = serialize_websocket_message(op, response)?; + let room = self.inner()?.post_rooms.get(&post_id).cloned(); + self.send_message_in_room(&msg, room, websocket_id).await?; Ok(()) } - pub fn send_community_room_message( + /// Send message to all users viewing the given community. + pub async fn send_community_room_message( &self, op: &OP, response: &Response, @@ -280,23 +195,16 @@ impl ChatServer { OP: OperationType + ToString, Response: Serialize, { - let res_str = &serialize_websocket_message(op, response)?; - if let Some(sessions) = self.community_rooms.get(&community_id) { - for id in sessions { - if let Some(my_id) = websocket_id { - if *id == my_id { - continue; - } - } - self.sendit(res_str, *id); - } - } + let msg = serialize_websocket_message(op, response)?; + let room = self.inner()?.community_rooms.get(&community_id).cloned(); + self.send_message_in_room(&msg, room, websocket_id).await?; Ok(()) } - pub fn send_mod_room_message( + /// Send message to mods of a given community. Set community_id = 0 to send to site admins. + pub async fn send_mod_room_message( &self, - op: &OP, + op: OP, response: &Response, community_id: CommunityId, websocket_id: Option, @@ -305,43 +213,35 @@ impl ChatServer { OP: OperationType + ToString, Response: Serialize, { - let res_str = &serialize_websocket_message(op, response)?; - if let Some(sessions) = self.mod_rooms.get(&community_id) { - for id in sessions { - if let Some(my_id) = websocket_id { - if *id == my_id { - continue; - } - } - self.sendit(res_str, *id); - } - } + let msg = serialize_websocket_message(&op, response)?; + let room = self.inner()?.mod_rooms.get(&community_id).cloned(); + self.send_message_in_room(&msg, room, websocket_id).await?; Ok(()) } - pub fn send_all_message( + pub async fn send_all_message( &self, - op: &OP, + op: OP, response: &Response, - websocket_id: Option, + exclude_connection: Option, ) -> Result<(), LemmyError> where OP: OperationType + ToString, Response: Serialize, { - let res_str = &serialize_websocket_message(op, response)?; - for id in self.sessions.keys() { - if let Some(my_id) = websocket_id { - if *id == my_id { - continue; - } - } - self.sendit(res_str, *id); - } + let msg = &serialize_websocket_message(&op, response)?; + let sessions = self.inner()?.sessions.clone(); + try_join_all( + sessions + .into_iter() + .filter(|(id, _)| Some(id) != exclude_connection.as_ref()) + .map(|(_, mut s): (_, Session)| async move { s.text(msg).await }), + ) + .await?; Ok(()) } - pub fn send_user_room_message( + pub async fn send_user_room_message( &self, op: &OP, response: &Response, @@ -352,21 +252,13 @@ impl ChatServer { OP: OperationType + ToString, Response: Serialize, { - let res_str = &serialize_websocket_message(op, response)?; - if let Some(sessions) = self.user_rooms.get(&recipient_id) { - for id in sessions { - if let Some(my_id) = websocket_id { - if *id == my_id { - continue; - } - } - self.sendit(res_str, *id); - } - } + let msg = serialize_websocket_message(op, response)?; + let room = self.inner()?.user_rooms.get(&recipient_id).cloned(); + self.send_message_in_room(&msg, room, websocket_id).await?; Ok(()) } - pub fn send_comment( + pub async fn send_comment( &self, user_operation: &OP, comment: &CommentResponse, @@ -384,41 +276,49 @@ impl ChatServer { let mut comment_post_sent = comment_reply_sent.clone(); // Remove the recipients here to separate mentions / user messages from post or community comments comment_post_sent.recipient_ids = Vec::new(); - self.send_post_room_message( - user_operation, - &comment_post_sent, - comment_post_sent.comment_view.post.id, - websocket_id, - )?; + self + .send_post_room_message( + user_operation, + &comment_post_sent, + comment_post_sent.comment_view.post.id, + websocket_id, + ) + .await?; // Send it to the community too - self.send_community_room_message( - user_operation, - &comment_post_sent, - CommunityId(0), - websocket_id, - )?; - self.send_community_room_message( - user_operation, - &comment_post_sent, - comment.comment_view.community.id, - websocket_id, - )?; + self + .send_community_room_message( + user_operation, + &comment_post_sent, + CommunityId(0), + websocket_id, + ) + .await?; + self + .send_community_room_message( + user_operation, + &comment_post_sent, + comment.comment_view.community.id, + websocket_id, + ) + .await?; // Send it to the recipient(s) including the mentioned users for recipient_id in &comment_reply_sent.recipient_ids { - self.send_user_room_message( - user_operation, - &comment_reply_sent, - *recipient_id, - websocket_id, - )?; + self + .send_user_room_message( + user_operation, + &comment_reply_sent, + *recipient_id, + websocket_id, + ) + .await?; } Ok(()) } - pub fn send_post( + pub async fn send_post( &self, user_operation: &OP, post_res: &PostResponse, @@ -434,89 +334,58 @@ impl ChatServer { post_sent.post_view.my_vote = None; // Send it to /c/all and that community - self.send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)?; - self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?; + self + .send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id) + .await?; + self + .send_community_room_message(user_operation, &post_sent, community_id, websocket_id) + .await?; // Send it to the post room - self.send_post_room_message( - user_operation, - &post_sent, - post_res.post_view.post.id, - websocket_id, - )?; + self + .send_post_room_message( + user_operation, + &post_sent, + post_res.post_view.post.id, + websocket_id, + ) + .await?; Ok(()) } - fn sendit(&self, message: &str, id: ConnectionId) { - if let Some(info) = self.sessions.get(&id) { - info.addr.do_send(WsMessage(message.to_owned())); + /// Send websocket message in all sessions which joined a specific room. + /// + /// `message` - The json message body to send + /// `room` - Connection IDs which should receive the message + /// `exclude_connection` - Dont send to user who initiated the api call, as that + /// would result in duplicate notification + async fn send_message_in_room( + &self, + message: &str, + room: Option>, + exclude_connection: Option, + ) -> Result<(), LemmyError> { + let mut session = self.inner()?.sessions.clone(); + if let Some(room) = room { + try_join_all( + room + .into_iter() + .filter(|c| Some(c) != exclude_connection.as_ref()) + .filter_map(|c| session.remove(&c)) + .map(|mut s: Session| async move { s.text(message).await }), + ) + .await?; } + Ok(()) } - pub(super) fn parse_json_message( - &mut self, - msg: StandardMessage, - ctx: &mut Context, - ) -> impl Future> { - let ip: IpAddr = match self.sessions.get(&msg.id) { - Some(info) => info.ip.clone(), - None => IpAddr("blank_ip".to_string()), - }; - - let context = LemmyContext::create( - self.pool.clone(), - ctx.address(), - self.client.clone(), - self.settings.clone(), - self.secret.clone(), - self.rate_limit_cell.clone(), - ); - let message_handler_crud = self.message_handler_crud; - let message_handler = self.message_handler; - let message_handler_apub = self.message_handler_apub; - let rate_limiter = self.rate_limit_cell.clone(); - async move { - let json: Value = serde_json::from_str(&msg.msg)?; - let data = &json["data"].to_string(); - let op = &json["op"] - .as_str() - .ok_or_else(|| LemmyError::from_message("missing op"))?; - - // check if api call passes the rate limit, and generate future for later execution - let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) { - let passed = match user_operation_crud { - UserOperationCrud::Register => rate_limiter.register().check(ip), - UserOperationCrud::CreatePost => rate_limiter.post().check(ip), - UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip), - UserOperationCrud::CreateComment => rate_limiter.comment().check(ip), - _ => rate_limiter.message().check(ip), - }; - let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data); - (passed, fut) - } else if let Ok(user_operation) = UserOperation::from_str(op) { - let passed = match user_operation { - UserOperation::GetCaptcha => rate_limiter.post().check(ip), - _ => rate_limiter.message().check(ip), - }; - let fut = (message_handler)(context, msg.id, user_operation, data); - (passed, fut) - } else { - let user_operation = UserOperationApub::from_str(op)?; - let passed = match user_operation { - UserOperationApub::Search => rate_limiter.search().check(ip), - _ => rate_limiter.message().check(ip), - }; - let fut = (message_handler_apub)(context, msg.id, user_operation, data); - (passed, fut) - }; - - // if rate limit passed, execute api call future - if passed { - fut.await - } else { - // if rate limit was hit, respond with message - Err(LemmyError::from_message("rate_limit_error")) + pub(in crate::websocket) fn inner(&self) -> Result, LemmyError> { + match self.inner.lock() { + Ok(g) => Ok(g), + Err(e) => { + warn!("Failed to lock chatserver mutex: {}", e); + Err(LemmyError::from_message("Failed to lock chatserver mutex")) } } } diff --git a/crates/api_common/src/websocket/handlers.rs b/crates/api_common/src/websocket/handlers.rs index 6f3d164c..afdbfd59 100644 --- a/crates/api_common/src/websocket/handlers.rs +++ b/crates/api_common/src/websocket/handlers.rs @@ -1,303 +1,83 @@ -use crate::websocket::{ - chat_server::{ChatServer, SessionInfo}, - messages::{ - CaptchaItem, - CheckCaptcha, - Connect, - Disconnect, - GetCommunityUsersOnline, - GetPostUsersOnline, - GetUsersOnline, - JoinCommunityRoom, - JoinModRoom, - JoinPostRoom, - JoinUserRoom, - SendAllMessage, - SendComment, - SendCommunityRoomMessage, - SendModRoomMessage, - SendPost, - SendUserRoomMessage, - StandardMessage, - }, - OperationType, +use crate::websocket::{chat_server::ChatServer, structs::CaptchaItem}; +use actix_ws::Session; +use lemmy_db_schema::{ + newtypes::{CommunityId, PostId}, + utils::naive_now, }; -use actix::{Actor, Context, Handler, ResponseFuture}; -use lemmy_db_schema::utils::naive_now; -use lemmy_utils::ConnectionId; -use opentelemetry::trace::TraceContextExt; +use lemmy_utils::{error::LemmyError, ConnectionId}; use rand::Rng; -use serde::Serialize; -use tracing::{error, info}; -use tracing_opentelemetry::OpenTelemetrySpanExt; -/// Make actor from `ChatServer` -impl Actor for ChatServer { - /// We are going to use simple Context, we just need ability to communicate - /// with other actors. - type Context = Context; -} - -/// Handler for Connect message. -/// -/// Register new session and assign unique id to this session -impl Handler for ChatServer { - type Result = ConnectionId; - - fn handle(&mut self, msg: Connect, _ctx: &mut Context) -> Self::Result { +impl ChatServer { + /// Handler for Connect message. + /// + /// Register new session and assign unique id to this session + pub fn handle_connect(&self, session: Session) -> Result { + let mut inner = self.inner()?; // register session with random id - let id = self.rng.gen::(); - info!("{} joined", &msg.ip); + let id = inner.rng.gen::(); - self.sessions.insert( - id, - SessionInfo { - addr: msg.addr, - ip: msg.ip, - }, - ); - - id + inner.sessions.insert(id, session); + Ok(id) } -} -/// Handler for Disconnect message. -impl Handler for ChatServer { - type Result = (); - - fn handle(&mut self, msg: Disconnect, _: &mut Context) { + /// Handler for Disconnect message. + pub fn handle_disconnect(&self, connection_id: &ConnectionId) -> Result<(), LemmyError> { + let mut inner = self.inner()?; // Remove connections from sessions and all 3 scopes - if self.sessions.remove(&msg.id).is_some() { - for sessions in self.user_rooms.values_mut() { - sessions.remove(&msg.id); + if inner.sessions.remove(connection_id).is_some() { + for sessions in inner.user_rooms.values_mut() { + sessions.remove(connection_id); } - for sessions in self.post_rooms.values_mut() { - sessions.remove(&msg.id); + for sessions in inner.post_rooms.values_mut() { + sessions.remove(connection_id); } - for sessions in self.community_rooms.values_mut() { - sessions.remove(&msg.id); + for sessions in inner.community_rooms.values_mut() { + sessions.remove(connection_id); } } + Ok(()) } -} - -fn root_span() -> tracing::Span { - let span = tracing::info_span!( - parent: None, - "Websocket Request", - trace_id = tracing::field::Empty, - ); - { - let trace_id = span.context().span().span_context().trace_id().to_string(); - span.record("trace_id", &tracing::field::display(trace_id)); - } - - span -} - -/// Handler for Message message. -impl Handler for ChatServer { - type Result = ResponseFuture>; - - fn handle(&mut self, msg: StandardMessage, ctx: &mut Context) -> Self::Result { - use tracing::Instrument; - let fut = self.parse_json_message(msg, ctx); - let span = root_span(); - - Box::pin( - async move { - match fut.await { - Ok(m) => { - // info!("Message Sent: {}", m); - Ok(m) - } - Err(e) => { - error!("Error during message handling {}", e); - Ok( - e.to_json() - .unwrap_or_else(|_| String::from(r#"{"error":"failed to serialize json"}"#)), - ) - } - } - } - .instrument(span), - ) - } -} - -impl Handler> for ChatServer -where - OP: OperationType + ToString, - Response: Serialize, -{ - type Result = (); - - fn handle(&mut self, msg: SendAllMessage, _: &mut Context) { - self - .send_all_message(&msg.op, &msg.response, msg.websocket_id) - .ok(); - } -} - -impl Handler> for ChatServer -where - OP: OperationType + ToString, - Response: Serialize, -{ - type Result = (); - - fn handle(&mut self, msg: SendUserRoomMessage, _: &mut Context) { - self - .send_user_room_message( - &msg.op, - &msg.response, - msg.local_recipient_id, - msg.websocket_id, - ) - .ok(); - } -} - -impl Handler> for ChatServer -where - OP: OperationType + ToString, - Response: Serialize, -{ - type Result = (); - - fn handle(&mut self, msg: SendCommunityRoomMessage, _: &mut Context) { - self - .send_community_room_message(&msg.op, &msg.response, msg.community_id, msg.websocket_id) - .ok(); - } -} - -impl Handler> for ChatServer -where - Response: Serialize, -{ - type Result = (); - - fn handle(&mut self, msg: SendModRoomMessage, _: &mut Context) { - self - .send_mod_room_message(&msg.op, &msg.response, msg.community_id, msg.websocket_id) - .ok(); - } -} - -impl Handler> for ChatServer -where - OP: OperationType + ToString, -{ - type Result = (); - - fn handle(&mut self, msg: SendPost, _: &mut Context) { - self.send_post(&msg.op, &msg.post, msg.websocket_id).ok(); - } -} - -impl Handler> for ChatServer -where - OP: OperationType + ToString, -{ - type Result = (); - fn handle(&mut self, msg: SendComment, _: &mut Context) { - self - .send_comment(&msg.op, &msg.comment, msg.websocket_id) - .ok(); + pub fn get_users_online(&self) -> Result { + Ok(self.inner()?.sessions.len()) } -} - -impl Handler for ChatServer { - type Result = (); - fn handle(&mut self, msg: JoinUserRoom, _: &mut Context) { - self.join_user_room(msg.local_user_id, msg.id).ok(); - } -} - -impl Handler for ChatServer { - type Result = (); - - fn handle(&mut self, msg: JoinCommunityRoom, _: &mut Context) { - self.join_community_room(msg.community_id, msg.id).ok(); - } -} - -impl Handler for ChatServer { - type Result = (); - - fn handle(&mut self, msg: JoinModRoom, _: &mut Context) { - self.join_mod_room(msg.community_id, msg.id).ok(); - } -} - -impl Handler for ChatServer { - type Result = (); - - fn handle(&mut self, msg: JoinPostRoom, _: &mut Context) { - self.join_post_room(msg.post_id, msg.id).ok(); - } -} - -impl Handler for ChatServer { - type Result = usize; - - fn handle(&mut self, _msg: GetUsersOnline, _: &mut Context) -> Self::Result { - self.sessions.len() - } -} - -impl Handler for ChatServer { - type Result = usize; - - fn handle(&mut self, msg: GetPostUsersOnline, _: &mut Context) -> Self::Result { - if let Some(users) = self.post_rooms.get(&msg.post_id) { - users.len() + pub fn get_post_users_online(&self, post_id: PostId) -> Result { + if let Some(users) = self.inner()?.post_rooms.get(&post_id) { + Ok(users.len()) } else { - 0 + Ok(0) } } -} - -impl Handler for ChatServer { - type Result = usize; - fn handle(&mut self, msg: GetCommunityUsersOnline, _: &mut Context) -> Self::Result { - if let Some(users) = self.community_rooms.get(&msg.community_id) { - users.len() + pub fn get_community_users_online(&self, community_id: CommunityId) -> Result { + if let Some(users) = self.inner()?.community_rooms.get(&community_id) { + Ok(users.len()) } else { - 0 + Ok(0) } } -} - -impl Handler for ChatServer { - type Result = (); - fn handle(&mut self, msg: CaptchaItem, _: &mut Context) { - self.captchas.push(msg); + pub fn add_captcha(&self, captcha: CaptchaItem) -> Result<(), LemmyError> { + self.inner()?.captchas.push(captcha); + Ok(()) } -} - -impl Handler for ChatServer { - type Result = bool; - fn handle(&mut self, msg: CheckCaptcha, _: &mut Context) -> Self::Result { + pub fn check_captcha(&self, uuid: String, answer: String) -> Result { + let mut inner = self.inner()?; // Remove all the ones that are past the expire time - self.captchas.retain(|x| x.expires.gt(&naive_now())); + inner.captchas.retain(|x| x.expires.gt(&naive_now())); - let check = self + let check = inner .captchas .iter() - .any(|r| r.uuid == msg.uuid && r.answer.to_lowercase() == msg.answer.to_lowercase()); + .any(|r| r.uuid == uuid && r.answer.to_lowercase() == answer.to_lowercase()); // Remove this uuid so it can't be re-checked (Checks only work once) - self.captchas.retain(|x| x.uuid != msg.uuid); + inner.captchas.retain(|x| x.uuid != uuid); - check + Ok(check) } } diff --git a/crates/api_common/src/websocket/messages.rs b/crates/api_common/src/websocket/messages.rs deleted file mode 100644 index f8112416..00000000 --- a/crates/api_common/src/websocket/messages.rs +++ /dev/null @@ -1,150 +0,0 @@ -use crate::{comment::CommentResponse, post::PostResponse, websocket::UserOperation}; -use actix::{prelude::*, Recipient}; -use lemmy_db_schema::newtypes::{CommunityId, LocalUserId, PostId}; -use lemmy_utils::{ConnectionId, IpAddr}; -use serde::{Deserialize, Serialize}; - -/// Chat server sends this messages to session -#[derive(Message)] -#[rtype(result = "()")] -pub struct WsMessage(pub String); - -/// Message for chat server communications - -/// New chat session is created -#[derive(Message)] -#[rtype(usize)] -pub struct Connect { - pub addr: Recipient, - pub ip: IpAddr, -} - -/// Session is disconnected -#[derive(Message)] -#[rtype(result = "()")] -pub struct Disconnect { - pub id: ConnectionId, - pub ip: IpAddr, -} - -/// The messages sent to websocket clients -#[derive(Serialize, Deserialize, Message)] -#[rtype(result = "Result")] -pub struct StandardMessage { - /// Id of the client session - pub id: ConnectionId, - /// Peer message - pub msg: String, -} - -#[derive(Message)] -#[rtype(result = "()")] -pub struct SendAllMessage { - pub op: OP, - pub response: Response, - pub websocket_id: Option, -} - -#[derive(Message)] -#[rtype(result = "()")] -pub struct SendUserRoomMessage { - pub op: OP, - pub response: Response, - pub local_recipient_id: LocalUserId, - pub websocket_id: Option, -} - -/// Send message to all users viewing the given community. -#[derive(Message)] -#[rtype(result = "()")] -pub struct SendCommunityRoomMessage { - pub op: OP, - pub response: Response, - pub community_id: CommunityId, - pub websocket_id: Option, -} - -/// Send message to mods of a given community. Set community_id = 0 to send to site admins. -#[derive(Message)] -#[rtype(result = "()")] -pub struct SendModRoomMessage { - pub op: UserOperation, - pub response: Response, - pub community_id: CommunityId, - pub websocket_id: Option, -} - -#[derive(Message)] -#[rtype(result = "()")] -pub(crate) struct SendPost { - pub op: OP, - pub post: PostResponse, - pub websocket_id: Option, -} - -#[derive(Message)] -#[rtype(result = "()")] -pub(crate) struct SendComment { - pub op: OP, - pub comment: CommentResponse, - pub websocket_id: Option, -} - -#[derive(Message)] -#[rtype(result = "()")] -pub struct JoinUserRoom { - pub local_user_id: LocalUserId, - pub id: ConnectionId, -} - -#[derive(Message)] -#[rtype(result = "()")] -pub struct JoinCommunityRoom { - pub community_id: CommunityId, - pub id: ConnectionId, -} - -#[derive(Message)] -#[rtype(result = "()")] -pub struct JoinModRoom { - pub community_id: CommunityId, - pub id: ConnectionId, -} - -#[derive(Message)] -#[rtype(result = "()")] -pub struct JoinPostRoom { - pub post_id: PostId, - pub id: ConnectionId, -} - -#[derive(Message)] -#[rtype(usize)] -pub struct GetUsersOnline; - -#[derive(Message)] -#[rtype(usize)] -pub struct GetPostUsersOnline { - pub post_id: PostId, -} - -#[derive(Message)] -#[rtype(usize)] -pub struct GetCommunityUsersOnline { - pub community_id: CommunityId, -} - -#[derive(Message, Debug)] -#[rtype(result = "()")] -pub struct CaptchaItem { - pub uuid: String, - pub answer: String, - pub expires: chrono::NaiveDateTime, -} - -#[derive(Message)] -#[rtype(bool)] -pub struct CheckCaptcha { - pub uuid: String, - pub answer: String, -} diff --git a/crates/api_common/src/websocket/mod.rs b/crates/api_common/src/websocket/mod.rs index 430027cf..8686b8e4 100644 --- a/crates/api_common/src/websocket/mod.rs +++ b/crates/api_common/src/websocket/mod.rs @@ -3,8 +3,6 @@ use serde::Serialize; pub mod chat_server; pub mod handlers; -pub mod messages; -pub mod routes; pub mod send; pub mod structs; diff --git a/crates/api_common/src/websocket/routes.rs b/crates/api_common/src/websocket/routes.rs deleted file mode 100644 index 936dc999..00000000 --- a/crates/api_common/src/websocket/routes.rs +++ /dev/null @@ -1,197 +0,0 @@ -use crate::{ - context::LemmyContext, - websocket::{ - chat_server::ChatServer, - messages::{Connect, Disconnect, StandardMessage, WsMessage}, - }, -}; -use actix::prelude::*; -use actix_web::{web, Error, HttpRequest, HttpResponse}; -use actix_web_actors::ws; -use lemmy_utils::{rate_limit::RateLimitCell, utils::get_ip, ConnectionId, IpAddr}; -use std::time::{Duration, Instant}; -use tracing::{debug, error, info}; - -/// How often heartbeat pings are sent -const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); -/// How long before lack of client response causes a timeout -const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); - -/// Entry point for our route -pub async fn chat_route( - req: HttpRequest, - stream: web::Payload, - context: web::Data, - rate_limiter: web::Data, -) -> Result { - ws::start( - WsSession { - cs_addr: context.chat_server().clone(), - id: 0, - hb: Instant::now(), - ip: get_ip(&req.connection_info()), - rate_limiter: rate_limiter.as_ref().clone(), - }, - &req, - stream, - ) -} - -struct WsSession { - cs_addr: Addr, - /// unique session id - id: ConnectionId, - ip: IpAddr, - /// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT), - /// otherwise we drop connection. - hb: Instant, - /// A rate limiter for websocket joins - rate_limiter: RateLimitCell, -} - -impl Actor for WsSession { - type Context = ws::WebsocketContext; - - /// Method is called on actor start. - /// We register ws session with ChatServer - fn started(&mut self, ctx: &mut Self::Context) { - // we'll start heartbeat process on session start. - WsSession::hb(ctx); - - // register self in chat server. `AsyncContext::wait` register - // future within context, but context waits until this future resolves - // before processing any other events. - // across all routes within application - let addr = ctx.address(); - - if !self.rate_limit_check(ctx) { - return; - } - - self - .cs_addr - .send(Connect { - addr: addr.recipient(), - ip: self.ip.clone(), - }) - .into_actor(self) - .then(|res, act, ctx| { - match res { - Ok(res) => act.id = res, - // something is wrong with chat server - _ => ctx.stop(), - } - actix::fut::ready(()) - }) - .wait(ctx); - } - - fn stopping(&mut self, _ctx: &mut Self::Context) -> Running { - // notify chat server - self.cs_addr.do_send(Disconnect { - id: self.id, - ip: self.ip.clone(), - }); - Running::Stop - } -} - -/// Handle messages from chat server, we simply send it to peer websocket -/// These are room messages, IE sent to others in the room -impl Handler for WsSession { - type Result = (); - - fn handle(&mut self, msg: WsMessage, ctx: &mut Self::Context) { - ctx.text(msg.0); - } -} - -/// WebSocket message handler -impl StreamHandler> for WsSession { - fn handle(&mut self, result: Result, ctx: &mut Self::Context) { - if !self.rate_limit_check(ctx) { - return; - } - - let message = match result { - Ok(m) => m, - Err(e) => { - error!("{}", e); - return; - } - }; - match message { - ws::Message::Ping(msg) => { - self.hb = Instant::now(); - ctx.pong(&msg); - } - ws::Message::Pong(_) => { - self.hb = Instant::now(); - } - ws::Message::Text(text) => { - let m = text.trim().to_owned(); - - self - .cs_addr - .send(StandardMessage { - id: self.id, - msg: m, - }) - .into_actor(self) - .then(|res, _, ctx| { - match res { - Ok(Ok(res)) => ctx.text(res), - Ok(Err(_)) => {} - Err(e) => error!("{}", &e), - } - actix::fut::ready(()) - }) - .spawn(ctx); - } - ws::Message::Binary(_bin) => info!("Unexpected binary"), - ws::Message::Close(_) => { - ctx.stop(); - } - _ => {} - } - } -} - -impl WsSession { - /// helper method that sends ping to client every second. - /// - /// also this method checks heartbeats from client - fn hb(ctx: &mut ws::WebsocketContext) { - ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { - // check client heartbeats - if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT { - // heartbeat timed out - debug!("Websocket Client heartbeat failed, disconnecting!"); - - // notify chat server - act.cs_addr.do_send(Disconnect { - id: act.id, - ip: act.ip.clone(), - }); - - // stop actor - ctx.stop(); - - // don't try to send a ping - return; - } - - ctx.ping(b""); - }); - } - - /// Check the rate limit, and stop the ctx if it fails - fn rate_limit_check(&mut self, ctx: &mut ws::WebsocketContext) -> bool { - let check = self.rate_limiter.message().check(self.ip.clone()); - if !check { - debug!("Websocket join with IP: {} has been rate limited.", self.ip); - ctx.stop() - } - check - } -} diff --git a/crates/api_common/src/websocket/send.rs b/crates/api_common/src/websocket/send.rs index cd53955f..4f639452 100644 --- a/crates/api_common/src/websocket/send.rs +++ b/crates/api_common/src/websocket/send.rs @@ -5,10 +5,7 @@ use crate::{ post::PostResponse, private_message::PrivateMessageResponse, utils::{check_person_block, get_interface_language, send_email_to_user}, - websocket::{ - messages::{SendComment, SendCommunityRoomMessage, SendPost, SendUserRoomMessage}, - OperationType, - }, + websocket::OperationType, }; use lemmy_db_schema::{ newtypes::{CommentId, CommunityId, LocalUserId, PersonId, PostId, PrivateMessageId}, @@ -38,11 +35,10 @@ pub async fn send_post_ws_message let res = PostResponse { post_view }; - context.chat_server().do_send(SendPost { - op, - post: res.clone(), - websocket_id, - }); + context + .chat_server() + .send_post(&op, &res, websocket_id) + .await?; Ok(res) } @@ -81,11 +77,10 @@ pub async fn send_comment_ws_message Result { let community_view = CommunityView::read(context.pool(), community_id, person_id).await?; - let res = CommunityResponse { community_view }; + let mut res = CommunityResponse { community_view }; // Strip out the person id and subscribed when sending to others - let mut res_mut = res.clone(); - res_mut.community_view.subscribed = SubscribedType::NotSubscribed; + res.community_view.subscribed = SubscribedType::NotSubscribed; - context.chat_server().do_send(SendCommunityRoomMessage { - op, - response: res_mut, - community_id: res.community_view.community.id, - websocket_id, - }); + context + .chat_server() + .send_community_room_message(&op, &res, res.community_view.community.id, websocket_id) + .await?; Ok(res) } @@ -142,12 +134,11 @@ pub async fn send_pm_ws_message( if res.private_message_view.recipient.local { let recipient_id = res.private_message_view.recipient.id; let local_recipient = LocalUserView::read_person(context.pool(), recipient_id).await?; - context.chat_server().do_send(SendUserRoomMessage { - op, - response: res.clone(), - local_recipient_id: local_recipient.local_user.id, - websocket_id, - }); + + context + .chat_server() + .send_user_room_message(&op, &res, local_recipient.local_user.id, websocket_id) + .await?; } Ok(res) diff --git a/crates/api_common/src/websocket/structs.rs b/crates/api_common/src/websocket/structs.rs index 23f7b2bc..3418d05c 100644 --- a/crates/api_common/src/websocket/structs.rs +++ b/crates/api_common/src/websocket/structs.rs @@ -41,3 +41,10 @@ pub struct PostJoin { pub struct PostJoinResponse { pub joined: bool, } + +#[derive(Debug)] +pub struct CaptchaItem { + pub uuid: String, + pub answer: String, + pub expires: chrono::NaiveDateTime, +} diff --git a/crates/api_crud/src/post/read.rs b/crates/api_crud/src/post/read.rs index c6747c5e..49bb3765 100644 --- a/crates/api_crud/src/post/read.rs +++ b/crates/api_crud/src/post/read.rs @@ -4,7 +4,6 @@ use lemmy_api_common::{ context::LemmyContext, post::{GetPost, GetPostResponse}, utils::{check_private_instance, get_local_user_view_from_jwt_opt, mark_post_as_read}, - websocket::messages::GetPostUsersOnline, }; use lemmy_db_schema::{ aggregates::structs::{PersonPostAggregates, PersonPostAggregatesForm}, @@ -91,11 +90,7 @@ impl PerformCrud for GetPost { let moderators = CommunityModeratorView::for_community(context.pool(), community_id).await?; - let online = context - .chat_server() - .send(GetPostUsersOnline { post_id }) - .await - .unwrap_or(1); + let online = context.chat_server().get_post_users_online(post_id)?; // Return the jwt Ok(GetPostResponse { diff --git a/crates/api_crud/src/site/read.rs b/crates/api_crud/src/site/read.rs index 55dcd8fb..5b133c2c 100644 --- a/crates/api_crud/src/site/read.rs +++ b/crates/api_crud/src/site/read.rs @@ -4,7 +4,6 @@ use lemmy_api_common::{ context::LemmyContext, site::{GetSite, GetSiteResponse, MyUserInfo}, utils::{build_federated_instances, get_local_user_settings_view_from_jwt_opt}, - websocket::messages::GetUsersOnline, }; use lemmy_db_schema::source::{actor_language::SiteLanguage, language::Language, tagline::Tagline}; use lemmy_db_views::structs::{LocalUserDiscussionLanguageView, SiteView}; @@ -33,11 +32,7 @@ impl PerformCrud for GetSite { let admins = PersonViewSafe::admins(context.pool()).await?; - let online = context - .chat_server() - .send(GetUsersOnline) - .await - .unwrap_or(1); + let online = context.chat_server().get_users_online()?; // Build the local user let my_user = if let Some(local_user_view) = get_local_user_settings_view_from_jwt_opt( diff --git a/crates/api_crud/src/site/update.rs b/crates/api_crud/src/site/update.rs index 4a96925e..3909a297 100644 --- a/crates/api_crud/src/site/update.rs +++ b/crates/api_crud/src/site/update.rs @@ -10,7 +10,7 @@ use lemmy_api_common::{ local_site_to_slur_regex, site_description_length_check, }, - websocket::{messages::SendAllMessage, UserOperationCrud}, + websocket::UserOperationCrud, }; use lemmy_db_schema::{ source::{ @@ -189,11 +189,10 @@ impl PerformCrud for EditSite { let res = SiteResponse { site_view }; - context.chat_server().do_send(SendAllMessage { - op: UserOperationCrud::EditSite, - response: res.clone(), - websocket_id, - }); + context + .chat_server() + .send_all_message(UserOperationCrud::EditSite, &res, websocket_id) + .await?; Ok(res) } diff --git a/crates/api_crud/src/user/create.rs b/crates/api_crud/src/user/create.rs index 34b3b69e..f11c45b3 100644 --- a/crates/api_crud/src/user/create.rs +++ b/crates/api_crud/src/user/create.rs @@ -15,7 +15,6 @@ use lemmy_api_common::{ send_verification_email, EndpointType, }, - websocket::messages::CheckCaptcha, }; use lemmy_db_schema::{ aggregates::structs::PersonAggregates, @@ -73,13 +72,10 @@ impl PerformCrud for Register { // If the site is set up, check the captcha if local_site.site_setup && local_site.captcha_enabled { - let check = context - .chat_server() - .send(CheckCaptcha { - uuid: data.captcha_uuid.clone().unwrap_or_default(), - answer: data.captcha_answer.clone().unwrap_or_default(), - }) - .await?; + let check = context.chat_server().check_captcha( + data.captcha_uuid.clone().unwrap_or_default(), + data.captcha_answer.clone().unwrap_or_default(), + )?; if !check { return Err(LemmyError::from_message("captcha_incorrect")); } diff --git a/crates/apub/Cargo.toml b/crates/apub/Cargo.toml index 720fa593..3a4db66d 100644 --- a/crates/apub/Cargo.toml +++ b/crates/apub/Cargo.toml @@ -24,7 +24,6 @@ diesel = { workspace = true } chrono = { workspace = true } serde_json = { workspace = true } serde = { workspace = true } -actix = { workspace = true } actix-web = { workspace = true } actix-rt = { workspace = true } tracing = { workspace = true } diff --git a/crates/apub/src/activities/community/report.rs b/crates/apub/src/activities/community/report.rs index 57b1244d..dfd0a83c 100644 --- a/crates/apub/src/activities/community/report.rs +++ b/crates/apub/src/activities/community/report.rs @@ -18,7 +18,7 @@ use lemmy_api_common::{ context::LemmyContext, post::{CreatePostReport, PostReportResponse}, utils::get_local_user_view_from_jwt, - websocket::{messages::SendModRoomMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{ source::{ @@ -158,12 +158,15 @@ impl ActivityHandler for Report { let post_report_view = PostReportView::read(context.pool(), report.id, actor.id).await?; - context.chat_server().do_send(SendModRoomMessage { - op: UserOperation::CreateCommentReport, - response: PostReportResponse { post_report_view }, - community_id: post.community_id, - websocket_id: None, - }); + context + .chat_server() + .send_mod_room_message( + UserOperation::CreateCommentReport, + &PostReportResponse { post_report_view }, + post.community_id, + None, + ) + .await?; } PostOrComment::Comment(comment) => { let report_form = CommentReportForm { @@ -179,14 +182,17 @@ impl ActivityHandler for Report { CommentReportView::read(context.pool(), report.id, actor.id).await?; let community_id = comment_report_view.community.id; - context.chat_server().do_send(SendModRoomMessage { - op: UserOperation::CreateCommentReport, - response: CommentReportResponse { - comment_report_view, - }, - community_id, - websocket_id: None, - }); + context + .chat_server() + .send_mod_room_message( + UserOperation::CreateCommentReport, + &CommentReportResponse { + comment_report_view, + }, + community_id, + None, + ) + .await?; } }; Ok(()) diff --git a/crates/apub/src/activities/following/accept.rs b/crates/apub/src/activities/following/accept.rs index 6702c8c9..fbde090f 100644 --- a/crates/apub/src/activities/following/accept.rs +++ b/crates/apub/src/activities/following/accept.rs @@ -14,7 +14,7 @@ use activitystreams_kinds::activity::AcceptType; use lemmy_api_common::{ community::CommunityResponse, context::LemmyContext, - websocket::{messages::SendUserRoomMessage, UserOperation}, + websocket::UserOperation, }; use lemmy_db_schema::{source::community::CommunityFollower, traits::Followable}; use lemmy_db_views::structs::LocalUserView; @@ -106,12 +106,15 @@ impl ActivityHandler for AcceptFollow { let response = CommunityResponse { community_view }; - context.chat_server().do_send(SendUserRoomMessage { - op: UserOperation::FollowCommunity, - response, - local_recipient_id, - websocket_id: None, - }); + context + .chat_server() + .send_user_room_message( + &UserOperation::FollowCommunity, + &response, + local_recipient_id, + None, + ) + .await?; Ok(()) } diff --git a/crates/apub/src/api/read_community.rs b/crates/apub/src/api/read_community.rs index b1615d7d..8ffa48ba 100644 --- a/crates/apub/src/api/read_community.rs +++ b/crates/apub/src/api/read_community.rs @@ -8,7 +8,6 @@ use lemmy_api_common::{ community::{GetCommunity, GetCommunityResponse}, context::LemmyContext, utils::{check_private_instance, get_local_user_view_from_jwt_opt}, - websocket::messages::GetCommunityUsersOnline, }; use lemmy_db_schema::{ impls::actor_language::default_post_language, @@ -74,9 +73,7 @@ impl PerformApub for GetCommunity { let online = context .chat_server() - .send(GetCommunityUsersOnline { community_id }) - .await - .unwrap_or(1); + .get_community_users_online(community_id)?; let site_id = Site::instance_actor_id_from_url(community_view.community.actor_id.clone().into()); diff --git a/crates/apub/src/objects/mod.rs b/crates/apub/src/objects/mod.rs index 58e1f23f..29df3d64 100644 --- a/crates/apub/src/objects/mod.rs +++ b/crates/apub/src/objects/mod.rs @@ -54,7 +54,6 @@ pub(crate) fn verify_is_remote_object(id: &Url, settings: &Settings) -> Result<( #[cfg(test)] pub(crate) mod tests { - use actix::Actor; use anyhow::anyhow; use lemmy_api_common::{ context::LemmyContext, @@ -69,6 +68,7 @@ pub(crate) mod tests { }; use reqwest::{Client, Request, Response}; use reqwest_middleware::{ClientBuilder, Middleware, Next}; + use std::sync::Arc; use task_local_extensions::Extensions; struct BlockedMiddleware; @@ -109,17 +109,7 @@ pub(crate) mod tests { let rate_limit_config = RateLimitConfig::builder().build(); let rate_limit_cell = RateLimitCell::new(rate_limit_config).await; - let chat_server = ChatServer::startup( - pool.clone(), - |_, _, _, _| Box::pin(x()), - |_, _, _, _| Box::pin(x()), - |_, _, _, _| Box::pin(x()), - client.clone(), - settings.clone(), - secret.clone(), - rate_limit_cell.clone(), - ) - .start(); + let chat_server = Arc::new(ChatServer::startup()); LemmyContext::create( pool, chat_server, diff --git a/src/api_routes_http.rs b/src/api_routes_http.rs new file mode 100644 index 00000000..34083a48 --- /dev/null +++ b/src/api_routes_http.rs @@ -0,0 +1,463 @@ +use crate::api_routes_websocket::websocket; +use actix_web::{guard, web, Error, HttpResponse, Result}; +use lemmy_api::Perform; +use lemmy_api_common::{ + comment::{ + CreateComment, + CreateCommentLike, + CreateCommentReport, + DeleteComment, + EditComment, + GetComment, + GetComments, + ListCommentReports, + RemoveComment, + ResolveCommentReport, + SaveComment, + }, + community::{ + AddModToCommunity, + BanFromCommunity, + BlockCommunity, + CreateCommunity, + DeleteCommunity, + EditCommunity, + FollowCommunity, + GetCommunity, + HideCommunity, + ListCommunities, + RemoveCommunity, + TransferCommunity, + }, + context::LemmyContext, + person::{ + AddAdmin, + BanPerson, + BlockPerson, + ChangePassword, + DeleteAccount, + GetBannedPersons, + GetCaptcha, + GetPersonDetails, + GetPersonMentions, + GetReplies, + GetReportCount, + GetUnreadCount, + Login, + MarkAllAsRead, + MarkCommentReplyAsRead, + MarkPersonMentionAsRead, + PasswordChangeAfterReset, + PasswordReset, + Register, + SaveUserSettings, + VerifyEmail, + }, + post::{ + CreatePost, + CreatePostLike, + CreatePostReport, + DeletePost, + EditPost, + GetPost, + GetPosts, + GetSiteMetadata, + ListPostReports, + LockPost, + MarkPostAsRead, + RemovePost, + ResolvePostReport, + SavePost, + StickyPost, + }, + private_message::{ + CreatePrivateMessage, + CreatePrivateMessageReport, + DeletePrivateMessage, + EditPrivateMessage, + GetPrivateMessages, + ListPrivateMessageReports, + MarkPrivateMessageAsRead, + ResolvePrivateMessageReport, + }, + site::{ + ApproveRegistrationApplication, + CreateSite, + EditSite, + GetModlog, + GetSite, + GetUnreadRegistrationApplicationCount, + LeaveAdmin, + ListRegistrationApplications, + PurgeComment, + PurgeCommunity, + PurgePerson, + PurgePost, + ResolveObject, + Search, + }, + websocket::structs::{CommunityJoin, ModJoin, PostJoin, UserJoin}, +}; +use lemmy_api_crud::PerformCrud; +use lemmy_apub::{api::PerformApub, SendActivity}; +use lemmy_utils::rate_limit::RateLimitCell; +use serde::Deserialize; + +pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimitCell) { + cfg.service( + web::scope("/api/v3") + // Websocket + .service(web::resource("/ws").to(websocket)) + // Site + .service( + web::scope("/site") + .wrap(rate_limit.message()) + .route("", web::get().to(route_get_crud::)) + // Admin Actions + .route("", web::post().to(route_post_crud::)) + .route("", web::put().to(route_post_crud::)), + ) + .service( + web::resource("/modlog") + .wrap(rate_limit.message()) + .route(web::get().to(route_get::)), + ) + .service( + web::resource("/search") + .wrap(rate_limit.search()) + .route(web::get().to(route_get_apub::)), + ) + .service( + web::resource("/resolve_object") + .wrap(rate_limit.message()) + .route(web::get().to(route_get_apub::)), + ) + // Community + .service( + web::resource("/community") + .guard(guard::Post()) + .wrap(rate_limit.register()) + .route(web::post().to(route_post_crud::)), + ) + .service( + web::scope("/community") + .wrap(rate_limit.message()) + .route("", web::get().to(route_get_apub::)) + .route("", web::put().to(route_post_crud::)) + .route("/hide", web::put().to(route_post::)) + .route("/list", web::get().to(route_get_crud::)) + .route("/follow", web::post().to(route_post::)) + .route("/block", web::post().to(route_post::)) + .route( + "/delete", + web::post().to(route_post_crud::), + ) + // Mod Actions + .route( + "/remove", + web::post().to(route_post_crud::), + ) + .route("/transfer", web::post().to(route_post::)) + .route("/ban_user", web::post().to(route_post::)) + .route("/mod", web::post().to(route_post::)) + .route("/join", web::post().to(route_post::)) + .route("/mod/join", web::post().to(route_post::)), + ) + // Post + .service( + // Handle POST to /post separately to add the post() rate limitter + web::resource("/post") + .guard(guard::Post()) + .wrap(rate_limit.post()) + .route(web::post().to(route_post_crud::)), + ) + .service( + web::scope("/post") + .wrap(rate_limit.message()) + .route("", web::get().to(route_get_crud::)) + .route("", web::put().to(route_post_crud::)) + .route("/delete", web::post().to(route_post_crud::)) + .route("/remove", web::post().to(route_post_crud::)) + .route( + "/mark_as_read", + web::post().to(route_post::), + ) + .route("/lock", web::post().to(route_post::)) + .route("/sticky", web::post().to(route_post::)) + .route("/list", web::get().to(route_get_apub::)) + .route("/like", web::post().to(route_post::)) + .route("/save", web::put().to(route_post::)) + .route("/join", web::post().to(route_post::)) + .route("/report", web::post().to(route_post::)) + .route( + "/report/resolve", + web::put().to(route_post::), + ) + .route("/report/list", web::get().to(route_get::)) + .route( + "/site_metadata", + web::get().to(route_get::), + ), + ) + // Comment + .service( + // Handle POST to /comment separately to add the comment() rate limitter + web::resource("/comment") + .guard(guard::Post()) + .wrap(rate_limit.comment()) + .route(web::post().to(route_post_crud::)), + ) + .service( + web::scope("/comment") + .wrap(rate_limit.message()) + .route("", web::get().to(route_get_crud::)) + .route("", web::put().to(route_post_crud::)) + .route("/delete", web::post().to(route_post_crud::)) + .route("/remove", web::post().to(route_post_crud::)) + .route( + "/mark_as_read", + web::post().to(route_post::), + ) + .route("/like", web::post().to(route_post::)) + .route("/save", web::put().to(route_post::)) + .route("/list", web::get().to(route_get_apub::)) + .route("/report", web::post().to(route_post::)) + .route( + "/report/resolve", + web::put().to(route_post::), + ) + .route( + "/report/list", + web::get().to(route_get::), + ), + ) + // Private Message + .service( + web::scope("/private_message") + .wrap(rate_limit.message()) + .route("/list", web::get().to(route_get_crud::)) + .route("", web::post().to(route_post_crud::)) + .route("", web::put().to(route_post_crud::)) + .route( + "/delete", + web::post().to(route_post_crud::), + ) + .route( + "/mark_as_read", + web::post().to(route_post::), + ) + .route( + "/report", + web::post().to(route_post::), + ) + .route( + "/report/resolve", + web::put().to(route_post::), + ) + .route( + "/report/list", + web::get().to(route_get::), + ), + ) + // User + .service( + // Account action, I don't like that it's in /user maybe /accounts + // Handle /user/register separately to add the register() rate limitter + web::resource("/user/register") + .guard(guard::Post()) + .wrap(rate_limit.register()) + .route(web::post().to(route_post_crud::)), + ) + .service( + // Handle captcha separately + web::resource("/user/get_captcha") + .wrap(rate_limit.post()) + .route(web::get().to(route_get::)), + ) + // User actions + .service( + web::scope("/user") + .wrap(rate_limit.message()) + .route("", web::get().to(route_get_apub::)) + .route("/mention", web::get().to(route_get::)) + .route( + "/mention/mark_as_read", + web::post().to(route_post::), + ) + .route("/replies", web::get().to(route_get::)) + .route("/join", web::post().to(route_post::)) + // Admin action. I don't like that it's in /user + .route("/ban", web::post().to(route_post::)) + .route("/banned", web::get().to(route_get::)) + .route("/block", web::post().to(route_post::)) + // Account actions. I don't like that they're in /user maybe /accounts + .route("/login", web::post().to(route_post::)) + .route( + "/delete_account", + web::post().to(route_post_crud::), + ) + .route( + "/password_reset", + web::post().to(route_post::), + ) + .route( + "/password_change", + web::post().to(route_post::), + ) + // mark_all_as_read feels off being in this section as well + .route( + "/mark_all_as_read", + web::post().to(route_post::), + ) + .route( + "/save_user_settings", + web::put().to(route_post::), + ) + .route( + "/change_password", + web::put().to(route_post::), + ) + .route("/report_count", web::get().to(route_get::)) + .route("/unread_count", web::get().to(route_get::)) + .route("/verify_email", web::post().to(route_post::)) + .route("/leave_admin", web::post().to(route_post::)), + ) + // Admin Actions + .service( + web::scope("/admin") + .wrap(rate_limit.message()) + .route("/add", web::post().to(route_post::)) + .route( + "/registration_application/count", + web::get().to(route_get::), + ) + .route( + "/registration_application/list", + web::get().to(route_get::), + ) + .route( + "/registration_application/approve", + web::put().to(route_post::), + ), + ) + .service( + web::scope("/admin/purge") + .wrap(rate_limit.message()) + .route("/person", web::post().to(route_post::)) + .route("/community", web::post().to(route_post::)) + .route("/post", web::post().to(route_post::)) + .route("/comment", web::post().to(route_post::)), + ), + ); +} + +async fn perform<'a, Data>( + data: Data, + context: web::Data, +) -> Result +where + Data: Perform + + SendActivity::Response> + + Clone + + Deserialize<'a> + + Send + + 'static, +{ + let res = data.perform(&context, None).await?; + SendActivity::send_activity(&data, &res, &context).await?; + Ok(HttpResponse::Ok().json(res)) +} + +async fn route_get<'a, Data>( + data: web::Query, + context: web::Data, +) -> Result +where + Data: Perform + + SendActivity::Response> + + Clone + + Deserialize<'a> + + Send + + 'static, +{ + perform::(data.0, context).await +} + +async fn route_get_apub<'a, Data>( + data: web::Query, + context: web::Data, +) -> Result +where + Data: PerformApub + + SendActivity::Response> + + Clone + + Deserialize<'a> + + Send + + 'static, +{ + let res = data.perform(&context, None).await?; + SendActivity::send_activity(&data.0, &res, &context).await?; + Ok(HttpResponse::Ok().json(res)) +} + +async fn route_post<'a, Data>( + data: web::Json, + context: web::Data, +) -> Result +where + Data: Perform + + SendActivity::Response> + + Clone + + Deserialize<'a> + + Send + + 'static, +{ + perform::(data.0, context).await +} + +async fn perform_crud<'a, Data>( + data: Data, + context: web::Data, +) -> Result +where + Data: PerformCrud + + SendActivity::Response> + + Clone + + Deserialize<'a> + + Send + + 'static, +{ + let res = data.perform(&context, None).await?; + SendActivity::send_activity(&data, &res, &context).await?; + Ok(HttpResponse::Ok().json(res)) +} + +async fn route_get_crud<'a, Data>( + data: web::Query, + context: web::Data, +) -> Result +where + Data: PerformCrud + + SendActivity::Response> + + Clone + + Deserialize<'a> + + Send + + 'static, +{ + perform_crud::(data.0, context).await +} + +async fn route_post_crud<'a, Data>( + data: web::Json, + context: web::Data, +) -> Result +where + Data: PerformCrud + + SendActivity::Response> + + Clone + + Deserialize<'a> + + Send + + 'static, +{ + perform_crud::(data.0, context).await +} diff --git a/src/api_routes.rs b/src/api_routes_websocket.rs similarity index 54% rename from src/api_routes.rs rename to src/api_routes_websocket.rs index 77e5c27d..7a865600 100644 --- a/src/api_routes.rs +++ b/src/api_routes_websocket.rs @@ -1,4 +1,7 @@ -use actix_web::{guard, web, Error, HttpResponse, Result}; +use actix_web::{web, Error, HttpRequest, HttpResponse}; +use actix_web_actors::ws; +use actix_ws::{MessageStream, Session}; +use futures::stream::StreamExt; use lemmy_api::Perform; use lemmy_api_common::{ comment::{ @@ -23,7 +26,6 @@ use lemmy_api_common::{ EditCommunity, FollowCommunity, GetCommunity, - HideCommunity, ListCommunities, RemoveCommunity, TransferCommunity, @@ -96,7 +98,6 @@ use lemmy_api_common::{ Search, }, websocket::{ - routes::chat_route, serialize_websocket_message, structs::{CommunityJoin, ModJoin, PostJoin, UserJoin}, UserOperation, @@ -106,367 +107,187 @@ use lemmy_api_common::{ }; use lemmy_api_crud::PerformCrud; use lemmy_apub::{api::PerformApub, SendActivity}; -use lemmy_utils::{error::LemmyError, rate_limit::RateLimitCell, ConnectionId}; +use lemmy_utils::{error::LemmyError, rate_limit::RateLimitCell, ConnectionId, IpAddr}; use serde::Deserialize; -use std::result; +use serde_json::Value; +use std::{ + result, + str::FromStr, + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; +use tracing::{debug, error, info}; + +/// Entry point for our route +pub async fn websocket( + req: HttpRequest, + body: web::Payload, + context: web::Data, + rate_limiter: web::Data, +) -> Result { + let (response, session, stream) = actix_ws::handle(&req, body)?; -pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimitCell) { - cfg.service( - web::scope("/api/v3") - // Websocket - .service(web::resource("/ws").to(chat_route)) - // Site - .service( - web::scope("/site") - .wrap(rate_limit.message()) - .route("", web::get().to(route_get_crud::)) - // Admin Actions - .route("", web::post().to(route_post_crud::)) - .route("", web::put().to(route_post_crud::)), - ) - .service( - web::resource("/modlog") - .wrap(rate_limit.message()) - .route(web::get().to(route_get::)), - ) - .service( - web::resource("/search") - .wrap(rate_limit.search()) - .route(web::get().to(route_get_apub::)), - ) - .service( - web::resource("/resolve_object") - .wrap(rate_limit.message()) - .route(web::get().to(route_get_apub::)), - ) - // Community - .service( - web::resource("/community") - .guard(guard::Post()) - .wrap(rate_limit.register()) - .route(web::post().to(route_post_crud::)), - ) - .service( - web::scope("/community") - .wrap(rate_limit.message()) - .route("", web::get().to(route_get_apub::)) - .route("", web::put().to(route_post_crud::)) - .route("/hide", web::put().to(route_post::)) - .route("/list", web::get().to(route_get_crud::)) - .route("/follow", web::post().to(route_post::)) - .route("/block", web::post().to(route_post::)) - .route( - "/delete", - web::post().to(route_post_crud::), - ) - // Mod Actions - .route( - "/remove", - web::post().to(route_post_crud::), - ) - .route("/transfer", web::post().to(route_post::)) - .route("/ban_user", web::post().to(route_post::)) - .route("/mod", web::post().to(route_post::)) - .route("/join", web::post().to(route_post::)) - .route("/mod/join", web::post().to(route_post::)), - ) - // Post - .service( - // Handle POST to /post separately to add the post() rate limitter - web::resource("/post") - .guard(guard::Post()) - .wrap(rate_limit.post()) - .route(web::post().to(route_post_crud::)), - ) - .service( - web::scope("/post") - .wrap(rate_limit.message()) - .route("", web::get().to(route_get_crud::)) - .route("", web::put().to(route_post_crud::)) - .route("/delete", web::post().to(route_post_crud::)) - .route("/remove", web::post().to(route_post_crud::)) - .route( - "/mark_as_read", - web::post().to(route_post::), - ) - .route("/lock", web::post().to(route_post::)) - .route("/sticky", web::post().to(route_post::)) - .route("/list", web::get().to(route_get_apub::)) - .route("/like", web::post().to(route_post::)) - .route("/save", web::put().to(route_post::)) - .route("/join", web::post().to(route_post::)) - .route("/report", web::post().to(route_post::)) - .route( - "/report/resolve", - web::put().to(route_post::), - ) - .route("/report/list", web::get().to(route_get::)) - .route( - "/site_metadata", - web::get().to(route_get::), - ), - ) - // Comment - .service( - // Handle POST to /comment separately to add the comment() rate limitter - web::resource("/comment") - .guard(guard::Post()) - .wrap(rate_limit.comment()) - .route(web::post().to(route_post_crud::)), - ) - .service( - web::scope("/comment") - .wrap(rate_limit.message()) - .route("", web::get().to(route_get_crud::)) - .route("", web::put().to(route_post_crud::)) - .route("/delete", web::post().to(route_post_crud::)) - .route("/remove", web::post().to(route_post_crud::)) - .route( - "/mark_as_read", - web::post().to(route_post::), - ) - .route("/like", web::post().to(route_post::)) - .route("/save", web::put().to(route_post::)) - .route("/list", web::get().to(route_get_apub::)) - .route("/report", web::post().to(route_post::)) - .route( - "/report/resolve", - web::put().to(route_post::), - ) - .route( - "/report/list", - web::get().to(route_get::), - ), - ) - // Private Message - .service( - web::scope("/private_message") - .wrap(rate_limit.message()) - .route("/list", web::get().to(route_get_crud::)) - .route("", web::post().to(route_post_crud::)) - .route("", web::put().to(route_post_crud::)) - .route( - "/delete", - web::post().to(route_post_crud::), - ) - .route( - "/mark_as_read", - web::post().to(route_post::), - ) - .route( - "/report", - web::post().to(route_post::), - ) - .route( - "/report/resolve", - web::put().to(route_post::), - ) - .route( - "/report/list", - web::get().to(route_get::), - ), - ) - // User - .service( - // Account action, I don't like that it's in /user maybe /accounts - // Handle /user/register separately to add the register() rate limitter - web::resource("/user/register") - .guard(guard::Post()) - .wrap(rate_limit.register()) - .route(web::post().to(route_post_crud::)), - ) - .service( - // Handle captcha separately - web::resource("/user/get_captcha") - .wrap(rate_limit.post()) - .route(web::get().to(route_get::)), - ) - // User actions - .service( - web::scope("/user") - .wrap(rate_limit.message()) - .route("", web::get().to(route_get_apub::)) - .route("/mention", web::get().to(route_get::)) - .route( - "/mention/mark_as_read", - web::post().to(route_post::), - ) - .route("/replies", web::get().to(route_get::)) - .route("/join", web::post().to(route_post::)) - // Admin action. I don't like that it's in /user - .route("/ban", web::post().to(route_post::)) - .route("/banned", web::get().to(route_get::)) - .route("/block", web::post().to(route_post::)) - // Account actions. I don't like that they're in /user maybe /accounts - .route("/login", web::post().to(route_post::)) - .route( - "/delete_account", - web::post().to(route_post_crud::), - ) - .route( - "/password_reset", - web::post().to(route_post::), - ) - .route( - "/password_change", - web::post().to(route_post::), - ) - // mark_all_as_read feels off being in this section as well - .route( - "/mark_all_as_read", - web::post().to(route_post::), - ) - .route( - "/save_user_settings", - web::put().to(route_post::), - ) - .route( - "/change_password", - web::put().to(route_post::), - ) - .route("/report_count", web::get().to(route_get::)) - .route("/unread_count", web::get().to(route_get::)) - .route("/verify_email", web::post().to(route_post::)) - .route("/leave_admin", web::post().to(route_post::)), - ) - // Admin Actions - .service( - web::scope("/admin") - .wrap(rate_limit.message()) - .route("/add", web::post().to(route_post::)) - .route( - "/registration_application/count", - web::get().to(route_get::), - ) - .route( - "/registration_application/list", - web::get().to(route_get::), - ) - .route( - "/registration_application/approve", - web::put().to(route_post::), - ), - ) - .service( - web::scope("/admin/purge") - .wrap(rate_limit.message()) - .route("/person", web::post().to(route_post::)) - .route("/community", web::post().to(route_post::)) - .route("/post", web::post().to(route_post::)) - .route("/comment", web::post().to(route_post::)), - ), + let client_ip = IpAddr( + req + .connection_info() + .realip_remote_addr() + .unwrap_or("blank_ip") + .to_string(), ); -} -async fn perform<'a, Data>( - data: Data, - context: web::Data, -) -> Result -where - Data: Perform - + SendActivity::Response> - + Clone - + Deserialize<'a> - + Send - + 'static, -{ - let res = data.perform(&context, None).await?; - SendActivity::send_activity(&data, &res, &context).await?; - Ok(HttpResponse::Ok().json(res)) -} + let check = rate_limiter.message().check(client_ip.clone()); + if !check { + debug!( + "Websocket join with IP: {} has been rate limited.", + &client_ip + ); + session.close(None).await.map_err(LemmyError::from)?; + return Ok(response); + } -async fn route_get<'a, Data>( - data: web::Query, - context: web::Data, -) -> Result -where - Data: Perform - + SendActivity::Response> - + Clone - + Deserialize<'a> - + Send - + 'static, -{ - perform::(data.0, context).await -} + let connection_id = context.chat_server().handle_connect(session.clone())?; + info!("{} joined", &client_ip); -async fn route_get_apub<'a, Data>( - data: web::Query, - context: web::Data, -) -> Result -where - Data: PerformApub - + SendActivity::Response> - + Clone - + Deserialize<'a> - + Send - + 'static, -{ - let res = data.perform(&context, None).await?; - SendActivity::send_activity(&data.0, &res, &context).await?; - Ok(HttpResponse::Ok().json(res)) + let alive = Arc::new(Mutex::new(Instant::now())); + heartbeat(session.clone(), alive.clone()); + + actix_rt::spawn(handle_messages( + stream, + client_ip, + session, + connection_id, + alive, + rate_limiter, + context, + )); + + Ok(response) } -async fn route_post<'a, Data>( - data: web::Json, +async fn handle_messages( + mut stream: MessageStream, + client_ip: IpAddr, + mut session: Session, + connection_id: ConnectionId, + alive: Arc>, + rate_limiter: web::Data, context: web::Data, -) -> Result -where - Data: Perform - + SendActivity::Response> - + Clone - + Deserialize<'a> - + Send - + 'static, -{ - perform::(data.0, context).await +) -> Result<(), LemmyError> { + while let Some(Ok(msg)) = stream.next().await { + match msg { + ws::Message::Ping(bytes) => { + if session.pong(&bytes).await.is_err() { + break; + } + } + ws::Message::Pong(_) => { + let mut lock = alive + .lock() + .expect("Failed to acquire websocket heartbeat alive lock"); + *lock = Instant::now(); + } + ws::Message::Text(text) => { + let msg = text.trim().to_string(); + let executed = parse_json_message( + msg, + client_ip.clone(), + connection_id, + rate_limiter.get_ref(), + context.get_ref().clone(), + ) + .await; + + let res = executed.unwrap_or_else(|e| { + error!("Error during message handling {}", e); + e.to_json() + .unwrap_or_else(|_| String::from(r#"{"error":"failed to serialize json"}"#)) + }); + session.text(res).await?; + } + ws::Message::Close(_) => { + session.close(None).await?; + context.chat_server().handle_disconnect(&connection_id)?; + break; + } + ws::Message::Binary(_) => info!("Unexpected binary"), + _ => {} + } + } + Ok(()) } -async fn perform_crud<'a, Data>( - data: Data, - context: web::Data, -) -> Result -where - Data: PerformCrud - + SendActivity::Response> - + Clone - + Deserialize<'a> - + Send - + 'static, -{ - let res = data.perform(&context, None).await?; - SendActivity::send_activity(&data, &res, &context).await?; - Ok(HttpResponse::Ok().json(res)) +fn heartbeat(mut session: Session, alive: Arc>) { + actix_rt::spawn(async move { + let mut interval = actix_rt::time::interval(Duration::from_secs(5)); + loop { + if session.ping(b"").await.is_err() { + break; + } + + let duration_since = { + let alive_lock = alive + .lock() + .expect("Failed to acquire websocket heartbeat alive lock"); + Instant::now().duration_since(*alive_lock) + }; + if duration_since > Duration::from_secs(10) { + let _ = session.close(None).await; + break; + } + interval.tick().await; + } + }); } -async fn route_get_crud<'a, Data>( - data: web::Query, - context: web::Data, -) -> Result -where - Data: PerformCrud - + SendActivity::Response> - + Clone - + Deserialize<'a> - + Send - + 'static, -{ - perform_crud::(data.0, context).await +async fn parse_json_message( + msg: String, + ip: IpAddr, + connection_id: ConnectionId, + rate_limiter: &RateLimitCell, + context: LemmyContext, +) -> Result { + let json: Value = serde_json::from_str(&msg)?; + let data = &json["data"].to_string(); + let op = &json["op"] + .as_str() + .ok_or_else(|| LemmyError::from_message("missing op"))?; + + // check if api call passes the rate limit, and generate future for later execution + if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) { + let passed = match user_operation_crud { + UserOperationCrud::Register => rate_limiter.register().check(ip), + UserOperationCrud::CreatePost => rate_limiter.post().check(ip), + UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip), + UserOperationCrud::CreateComment => rate_limiter.comment().check(ip), + _ => rate_limiter.message().check(ip), + }; + check_rate_limit_passed(passed)?; + match_websocket_operation_crud(context, connection_id, user_operation_crud, data).await + } else if let Ok(user_operation) = UserOperation::from_str(op) { + let passed = match user_operation { + UserOperation::GetCaptcha => rate_limiter.post().check(ip), + _ => rate_limiter.message().check(ip), + }; + check_rate_limit_passed(passed)?; + match_websocket_operation(context, connection_id, user_operation, data).await + } else { + let user_operation = UserOperationApub::from_str(op)?; + let passed = match user_operation { + UserOperationApub::Search => rate_limiter.search().check(ip), + _ => rate_limiter.message().check(ip), + }; + check_rate_limit_passed(passed)?; + match_websocket_operation_apub(context, connection_id, user_operation, data).await + } } -async fn route_post_crud<'a, Data>( - data: web::Json, - context: web::Data, -) -> Result -where - Data: PerformCrud - + SendActivity::Response> - + Clone - + Deserialize<'a> - + Send - + 'static, -{ - perform_crud::(data.0, context).await +fn check_rate_limit_passed(passed: bool) -> Result<(), LemmyError> { + if passed { + Ok(()) + } else { + // if rate limit was hit, respond with message + Err(LemmyError::from_message("rate_limit_error")) + } } pub async fn match_websocket_operation_crud( diff --git a/src/lib.rs b/src/lib.rs index ffac400f..a9e390a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ -pub mod api_routes; +pub mod api_routes_http; +pub mod api_routes_websocket; pub mod code_migrations; pub mod root_span_builder; pub mod scheduled_tasks; diff --git a/src/main.rs b/src/main.rs index c60c1823..70f6d034 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ #[macro_use] extern crate diesel_migrations; -use actix::prelude::*; use actix_web::{middleware, web::Data, App, HttpServer, Result}; use diesel_migrations::EmbeddedMigrations; use doku::json::{AutoComments, CommentsStyle, Formatting, ObjectsStyle}; @@ -21,12 +20,7 @@ use lemmy_db_schema::{ }; use lemmy_routes::{feeds, images, nodeinfo, webfinger}; use lemmy_server::{ - api_routes, - api_routes::{ - match_websocket_operation, - match_websocket_operation_apub, - match_websocket_operation_crud, - }, + api_routes_http, code_migrations::run_advanced_migrations, init_logging, root_span_builder::QuieterRootSpanBuilder, @@ -41,7 +35,7 @@ use reqwest::Client; use reqwest_middleware::ClientBuilder; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use reqwest_tracing::TracingMiddleware; -use std::{env, thread, time::Duration}; +use std::{env, sync::Arc, thread, time::Duration}; use tracing_actix_web::TracingLogger; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); @@ -135,17 +129,7 @@ async fn main() -> Result<(), LemmyError> { .with(TracingMiddleware::default()) .build(); - let chat_server = ChatServer::startup( - pool.clone(), - |c, i, o, d| Box::pin(match_websocket_operation(c, i, o, d)), - |c, i, o, d| Box::pin(match_websocket_operation_crud(c, i, o, d)), - |c, i, o, d| Box::pin(match_websocket_operation_apub(c, i, o, d)), - client.clone(), - settings.clone(), - secret.clone(), - rate_limit_cell.clone(), - ) - .start(); + let chat_server = Arc::new(ChatServer::startup()); // Create Http server with websocket support let settings_bind = settings.clone(); @@ -164,7 +148,7 @@ async fn main() -> Result<(), LemmyError> { .app_data(Data::new(context)) .app_data(Data::new(rate_limit_cell.clone())) // The routes - .configure(|cfg| api_routes::config(cfg, rate_limit_cell)) + .configure(|cfg| api_routes_http::config(cfg, rate_limit_cell)) .configure(|cfg| { if federation_enabled { lemmy_apub::http::routes::config(cfg);