]> Untitled Git - lemmy.git/commitdiff
Rework websocket (#2598)
authorNutomic <me@nutomic.com>
Fri, 9 Dec 2022 15:31:47 +0000 (15:31 +0000)
committerGitHub <noreply@github.com>
Fri, 9 Dec 2022 15:31:47 +0000 (10:31 -0500)
* Merge websocket crate into api_common

* Add SendActivity trait so that api crates compile in parallel with lemmy_apub

* Rework websocket code

* fix websocket heartbeat

36 files changed:
Cargo.lock
Cargo.toml
crates/api/src/comment_report/create.rs
crates/api/src/comment_report/resolve.rs
crates/api/src/community/add_mod.rs
crates/api/src/community/ban.rs
crates/api/src/local_user/add_admin.rs
crates/api/src/local_user/ban_person.rs
crates/api/src/local_user/get_captcha.rs
crates/api/src/post_report/create.rs
crates/api/src/post_report/resolve.rs
crates/api/src/private_message_report/create.rs
crates/api/src/private_message_report/resolve.rs
crates/api/src/websocket.rs
crates/api_common/Cargo.toml
crates/api_common/src/context.rs
crates/api_common/src/websocket/chat_server.rs
crates/api_common/src/websocket/handlers.rs
crates/api_common/src/websocket/messages.rs [deleted file]
crates/api_common/src/websocket/mod.rs
crates/api_common/src/websocket/routes.rs [deleted file]
crates/api_common/src/websocket/send.rs
crates/api_common/src/websocket/structs.rs
crates/api_crud/src/post/read.rs
crates/api_crud/src/site/read.rs
crates/api_crud/src/site/update.rs
crates/api_crud/src/user/create.rs
crates/apub/Cargo.toml
crates/apub/src/activities/community/report.rs
crates/apub/src/activities/following/accept.rs
crates/apub/src/api/read_community.rs
crates/apub/src/objects/mod.rs
src/api_routes_http.rs [new file with mode: 0644]
src/api_routes_websocket.rs [moved from src/api_routes.rs with 54% similarity]
src/lib.rs
src/main.rs

index 5f68a511e5a032b09f7595b332f09c11dc5d46c8..72876f85a51bb9833ca5e8d679502734b1f5f31e 100644 (file)
@@ -51,7 +51,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "f728064aca1c318585bf4bb04ffcfac9e75e508ab4e8b1bd9ba5dfe04e2cbed5"
 dependencies = [
  "actix-rt",
- "actix_derive",
  "bitflags",
  "bytes",
  "crossbeam-channel",
@@ -283,14 +282,16 @@ dependencies = [
 ]
 
 [[package]]
-name = "actix_derive"
-version = "0.6.0"
+name = "actix-ws"
+version = "0.2.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6d44b8fee1ced9671ba043476deddef739dd0959bf77030b26b738cc591737a7"
+checksum = "535aec173810be3ca6f25dd5b4d431ae7125d62000aa3cbae1ec739921b02cf3"
 dependencies = [
- "proc-macro2 1.0.47",
- "quote 1.0.21",
- "syn 1.0.103",
+ "actix-codec",
+ "actix-http",
+ "actix-web",
+ "futures-core",
+ "tokio",
 ]
 
 [[package]]
@@ -2060,15 +2061,15 @@ dependencies = [
 name = "lemmy_api_common"
 version = "0.16.5"
 dependencies = [
- "actix",
  "actix-rt",
  "actix-web",
- "actix-web-actors",
+ "actix-ws",
  "anyhow",
  "background-jobs",
  "chrono",
  "diesel",
  "encoding",
+ "futures",
  "lemmy_db_schema",
  "lemmy_db_views",
  "lemmy_db_views_actor",
@@ -2118,7 +2119,6 @@ version = "0.16.5"
 dependencies = [
  "activitypub_federation",
  "activitystreams-kinds",
- "actix",
  "actix-rt",
  "actix-web",
  "anyhow",
@@ -2248,14 +2248,17 @@ name = "lemmy_server"
 version = "0.16.5"
 dependencies = [
  "activitypub_federation",
- "actix",
+ "actix-rt",
  "actix-web",
+ "actix-web-actors",
+ "actix-ws",
  "clokwerk",
  "console-subscriber",
  "diesel",
  "diesel-async",
  "diesel_migrations",
  "doku",
+ "futures",
  "lemmy_api",
  "lemmy_api_common",
  "lemmy_api_crud",
@@ -2266,6 +2269,7 @@ dependencies = [
  "opentelemetry 0.17.0",
  "opentelemetry-otlp",
  "parking_lot",
+ "rand 0.8.5",
  "reqwest",
  "reqwest-middleware",
  "reqwest-retry",
index e83f0c837cabca1ef52b59d760a555c1b6ee9d90..8b62bedcbfa4c4e62332f2d354c1df14c9b81e5d 100644 (file)
@@ -64,7 +64,6 @@ diesel = "2.0.2"
 diesel_migrations = "2.0.0"
 diesel-async = "0.1.1"
 serde = { version = "1.0.147", features = ["derive"] }
-actix = "0.13.0"
 actix-web = { version = "4.2.1", default-features = false, features = ["macros", "rustls"] }
 tracing = "0.1.36"
 tracing-actix-web = { version = "0.6.1", default-features = false }
@@ -106,6 +105,7 @@ rosetta-i18n = "0.1.2"
 rand = "0.8.5"
 opentelemetry = { version = "0.17.0", features = ["rt-tokio"] }
 tracing-opentelemetry = { version = "0.17.2" }
+actix-ws = "0.2.0"
 
 [dependencies]
 lemmy_api = { workspace = true }
@@ -120,7 +120,6 @@ diesel = { workspace = true }
 diesel_migrations = { workspace = true }
 diesel-async = { workspace = true }
 serde = { workspace = true }
-actix = { workspace = true }
 actix-web = { workspace = true }
 tracing = { workspace = true }
 tracing-actix-web = { workspace = true }
@@ -136,7 +135,12 @@ doku = { workspace = true }
 parking_lot = { workspace = true }
 reqwest-retry = { workspace = true }
 serde_json = { workspace = true }
+futures = { workspace = true }
+actix-ws = { workspace = true }
 tracing-opentelemetry = { workspace = true, optional = true }
 opentelemetry = { workspace = true, optional = true }
+actix-web-actors = { version = "4.1.0", default-features = false }
+actix-rt = "2.6"
+rand = { workspace = true }
 console-subscriber = { version = "0.1.8", optional = true }
 opentelemetry-otlp = { version = "0.10.0", optional = true }
index bf3fec0a144869df1043afff5f714e3e7f007cab..c026d166da61e317efbf5485122316bdc0b4b133 100644 (file)
@@ -4,7 +4,7 @@ use lemmy_api_common::{
   comment::{CommentReportResponse, CreateCommentReport},
   context::LemmyContext,
   utils::{check_community_ban, get_local_user_view_from_jwt},
-  websocket::{messages::SendModRoomMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{
   source::{
@@ -58,12 +58,15 @@ impl Perform for CreateCommentReport {
       comment_report_view,
     };
 
-    context.chat_server().do_send(SendModRoomMessage {
-      op: UserOperation::CreateCommentReport,
-      response: res.clone(),
-      community_id: comment_view.community.id,
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_mod_room_message(
+        UserOperation::CreateCommentReport,
+        &res,
+        comment_view.community.id,
+        websocket_id,
+      )
+      .await?;
 
     Ok(res)
   }
index 9df11fc23876f2bcc0d2c1cd93c40b76826adfdf..de4f418da2353d62c8bd9a1ecf02e695a12727e1 100644 (file)
@@ -4,7 +4,7 @@ use lemmy_api_common::{
   comment::{CommentReportResponse, ResolveCommentReport},
   context::LemmyContext,
   utils::{get_local_user_view_from_jwt, is_mod_or_admin},
-  websocket::{messages::SendModRoomMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{source::comment_report::CommentReport, traits::Reportable};
 use lemmy_db_views::structs::CommentReportView;
@@ -49,12 +49,15 @@ impl Perform for ResolveCommentReport {
       comment_report_view,
     };
 
-    context.chat_server().do_send(SendModRoomMessage {
-      op: UserOperation::ResolveCommentReport,
-      response: res.clone(),
-      community_id: report.community.id,
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_mod_room_message(
+        UserOperation::ResolveCommentReport,
+        &res,
+        report.community.id,
+        websocket_id,
+      )
+      .await?;
 
     Ok(res)
   }
index ce3082f7144e796f6d2b7db907e88b70d3d8c742..28f56cffd2b71df24fef98dbb223dd895890c3d8 100644 (file)
@@ -4,7 +4,7 @@ use lemmy_api_common::{
   community::{AddModToCommunity, AddModToCommunityResponse},
   context::LemmyContext,
   utils::{get_local_user_view_from_jwt, is_mod_or_admin},
-  websocket::{messages::SendCommunityRoomMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{
   source::{
@@ -70,12 +70,15 @@ impl Perform for AddModToCommunity {
     let moderators = CommunityModeratorView::for_community(context.pool(), community_id).await?;
 
     let res = AddModToCommunityResponse { moderators };
-    context.chat_server().do_send(SendCommunityRoomMessage {
-      op: UserOperation::AddModToCommunity,
-      response: res.clone(),
-      community_id,
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_community_room_message(
+        &UserOperation::AddModToCommunity,
+        &res,
+        community_id,
+        websocket_id,
+      )
+      .await?;
     Ok(res)
   }
 }
index fb5e7fcfefce7c5c54bccb78875ddfc6edec316d..e7962d24298de44d727e92c42528868d765b5af8 100644 (file)
@@ -4,7 +4,7 @@ use lemmy_api_common::{
   community::{BanFromCommunity, BanFromCommunityResponse},
   context::LemmyContext,
   utils::{get_local_user_view_from_jwt, is_mod_or_admin, remove_user_data_in_community},
-  websocket::{messages::SendCommunityRoomMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{
   source::{
@@ -95,12 +95,15 @@ impl Perform for BanFromCommunity {
       banned: data.ban,
     };
 
-    context.chat_server().do_send(SendCommunityRoomMessage {
-      op: UserOperation::BanFromCommunity,
-      response: res.clone(),
-      community_id,
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_community_room_message(
+        &UserOperation::BanFromCommunity,
+        &res,
+        community_id,
+        websocket_id,
+      )
+      .await?;
 
     Ok(res)
   }
index 78357f0c58792d65386dc998d1869aee695b5d8b..3b22c447693e5e91a74d69aefe57e3d9d9da36c6 100644 (file)
@@ -4,7 +4,7 @@ use lemmy_api_common::{
   context::LemmyContext,
   person::{AddAdmin, AddAdminResponse},
   utils::{get_local_user_view_from_jwt, is_admin},
-  websocket::{messages::SendAllMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{
   source::{
@@ -56,11 +56,10 @@ impl Perform for AddAdmin {
 
     let res = AddAdminResponse { admins };
 
-    context.chat_server().do_send(SendAllMessage {
-      op: UserOperation::AddAdmin,
-      response: res.clone(),
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_all_message(UserOperation::AddAdmin, &res, websocket_id)
+      .await?;
 
     Ok(res)
   }
index 0bb2523497b4cb70a28dc5c4e124b2d5dbff67b1..d45528efe91020b3576044b13016abdb6ddbf02e 100644 (file)
@@ -4,7 +4,7 @@ use lemmy_api_common::{
   context::LemmyContext,
   person::{BanPerson, BanPersonResponse},
   utils::{get_local_user_view_from_jwt, is_admin, remove_user_data},
-  websocket::{messages::SendAllMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{
   source::{
@@ -79,11 +79,10 @@ impl Perform for BanPerson {
       banned: data.ban,
     };
 
-    context.chat_server().do_send(SendAllMessage {
-      op: UserOperation::BanPerson,
-      response: res.clone(),
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_all_message(UserOperation::BanPerson, &res, websocket_id)
+      .await?;
 
     Ok(res)
   }
index 50a2bdba2d8f79cb06f94fa734c3689fe930c618..3671f79cc23cd80db1b2dd9914a284759dcad314 100644 (file)
@@ -5,7 +5,7 @@ use chrono::Duration;
 use lemmy_api_common::{
   context::LemmyContext,
   person::{CaptchaResponse, GetCaptcha, GetCaptchaResponse},
-  websocket::messages::CaptchaItem,
+  websocket::structs::CaptchaItem,
 };
 use lemmy_db_schema::{source::local_site::LocalSite, utils::naive_now};
 use lemmy_utils::{error::LemmyError, ConnectionId};
@@ -47,7 +47,7 @@ impl Perform for GetCaptcha {
     };
 
     // Stores the captcha item on the queue
-    context.chat_server().do_send(captcha_item);
+    context.chat_server().add_captcha(captcha_item)?;
 
     Ok(GetCaptchaResponse {
       ok: Some(CaptchaResponse { png, wav, uuid }),
index 71be43623ab7912b757f92c165991e177b3d628d..7396d3f8c3512b1f8c5e04c997e41618013879d4 100644 (file)
@@ -4,7 +4,7 @@ use lemmy_api_common::{
   context::LemmyContext,
   post::{CreatePostReport, PostReportResponse},
   utils::{check_community_ban, get_local_user_view_from_jwt},
-  websocket::{messages::SendModRoomMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{
   source::{
@@ -58,12 +58,15 @@ impl Perform for CreatePostReport {
 
     let res = PostReportResponse { post_report_view };
 
-    context.chat_server().do_send(SendModRoomMessage {
-      op: UserOperation::CreatePostReport,
-      response: res.clone(),
-      community_id: post_view.community.id,
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_mod_room_message(
+        UserOperation::CreatePostReport,
+        &res,
+        post_view.community.id,
+        websocket_id,
+      )
+      .await?;
 
     Ok(res)
   }
index 0e2b9e16a0b857d2b56f817276d625f4926eb7bd..615b7f828b1666fd83fc3de62de3a7d1d78121f2 100644 (file)
@@ -4,7 +4,7 @@ use lemmy_api_common::{
   context::LemmyContext,
   post::{PostReportResponse, ResolvePostReport},
   utils::{get_local_user_view_from_jwt, is_mod_or_admin},
-  websocket::{messages::SendModRoomMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{source::post_report::PostReport, traits::Reportable};
 use lemmy_db_views::structs::PostReportView;
@@ -46,12 +46,15 @@ impl Perform for ResolvePostReport {
 
     let res = PostReportResponse { post_report_view };
 
-    context.chat_server().do_send(SendModRoomMessage {
-      op: UserOperation::ResolvePostReport,
-      response: res.clone(),
-      community_id: report.community.id,
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_mod_room_message(
+        UserOperation::ResolvePostReport,
+        &res,
+        report.community.id,
+        websocket_id,
+      )
+      .await?;
 
     Ok(res)
   }
index 66875c614c97febbdc32eacb7d5e545c24f291d5..9267ee7fa3225cb4fcb64bf2120f533528866bcd 100644 (file)
@@ -4,7 +4,7 @@ use lemmy_api_common::{
   context::LemmyContext,
   private_message::{CreatePrivateMessageReport, PrivateMessageReportResponse},
   utils::get_local_user_view_from_jwt,
-  websocket::{messages::SendModRoomMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{
   newtypes::CommunityId,
@@ -57,12 +57,15 @@ impl Perform for CreatePrivateMessageReport {
       private_message_report_view,
     };
 
-    context.chat_server().do_send(SendModRoomMessage {
-      op: UserOperation::CreatePrivateMessageReport,
-      response: res.clone(),
-      community_id: CommunityId(0),
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_mod_room_message(
+        UserOperation::CreatePrivateMessageReport,
+        &res,
+        CommunityId(0),
+        websocket_id,
+      )
+      .await?;
 
     // TODO: consider federating this
 
index 2a3f677ad8eb48dea53d66de0e290cce1607ca40..a48a458d01c41536e5d9de136b30fe30178ac20e 100644 (file)
@@ -4,7 +4,7 @@ use lemmy_api_common::{
   context::LemmyContext,
   private_message::{PrivateMessageReportResponse, ResolvePrivateMessageReport},
   utils::{get_local_user_view_from_jwt, is_admin},
-  websocket::{messages::SendModRoomMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{
   newtypes::CommunityId,
@@ -48,12 +48,15 @@ impl Perform for ResolvePrivateMessageReport {
       private_message_report_view,
     };
 
-    context.chat_server().do_send(SendModRoomMessage {
-      op: UserOperation::ResolvePrivateMessageReport,
-      response: res.clone(),
-      community_id: CommunityId(0),
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_mod_room_message(
+        UserOperation::ResolvePrivateMessageReport,
+        &res,
+        CommunityId(0),
+        websocket_id,
+      )
+      .await?;
 
     Ok(res)
   }
index ca755bde939efd9409a8a272acf4681564e264fd..041cb775e8a3670c31e173e25284e6bdeb0502d2 100644 (file)
@@ -3,18 +3,15 @@ use actix_web::web::Data;
 use lemmy_api_common::{
   context::LemmyContext,
   utils::get_local_user_view_from_jwt,
-  websocket::{
-    messages::{JoinCommunityRoom, JoinModRoom, JoinPostRoom, JoinUserRoom},
-    structs::{
-      CommunityJoin,
-      CommunityJoinResponse,
-      ModJoin,
-      ModJoinResponse,
-      PostJoin,
-      PostJoinResponse,
-      UserJoin,
-      UserJoinResponse,
-    },
+  websocket::structs::{
+    CommunityJoin,
+    CommunityJoinResponse,
+    ModJoin,
+    ModJoinResponse,
+    PostJoin,
+    PostJoinResponse,
+    UserJoin,
+    UserJoinResponse,
   },
 };
 use lemmy_utils::{error::LemmyError, ConnectionId};
@@ -34,10 +31,9 @@ impl Perform for UserJoin {
       get_local_user_view_from_jwt(&data.auth, context.pool(), context.secret()).await?;
 
     if let Some(ws_id) = websocket_id {
-      context.chat_server().do_send(JoinUserRoom {
-        local_user_id: local_user_view.local_user.id,
-        id: ws_id,
-      });
+      context
+        .chat_server()
+        .join_user_room(local_user_view.local_user.id, ws_id)?;
     }
 
     Ok(UserJoinResponse { joined: true })
@@ -57,10 +53,9 @@ impl Perform for CommunityJoin {
     let data: &CommunityJoin = self;
 
     if let Some(ws_id) = websocket_id {
-      context.chat_server().do_send(JoinCommunityRoom {
-        community_id: data.community_id,
-        id: ws_id,
-      });
+      context
+        .chat_server()
+        .join_community_room(data.community_id, ws_id)?;
     }
 
     Ok(CommunityJoinResponse { joined: true })
@@ -80,10 +75,9 @@ impl Perform for ModJoin {
     let data: &ModJoin = self;
 
     if let Some(ws_id) = websocket_id {
-      context.chat_server().do_send(JoinModRoom {
-        community_id: data.community_id,
-        id: ws_id,
-      });
+      context
+        .chat_server()
+        .join_mod_room(data.community_id, ws_id)?;
     }
 
     Ok(ModJoinResponse { joined: true })
@@ -103,10 +97,7 @@ impl Perform for PostJoin {
     let data: &PostJoin = self;
 
     if let Some(ws_id) = websocket_id {
-      context.chat_server().do_send(JoinPostRoom {
-        post_id: data.post_id,
-        id: ws_id,
-      });
+      context.chat_server().join_post_room(data.post_id, ws_id)?;
     }
 
     Ok(PostJoinResponse { joined: true })
index a4c80da523bd2bf55f59965d27cb7054eb976ed3..21eed1c6d4ae9d2f6f0146106c585d53b0dcec4d 100644 (file)
@@ -38,14 +38,14 @@ webpage = { version = "1.4.0", default-features = false, features = ["serde"], o
 encoding = { version = "0.2.33", optional = true }
 rand = { workspace = true }
 serde_json = { workspace = true }
-actix = { workspace = true }
 anyhow = { workspace = true }
 tokio = { workspace = true }
 strum = { workspace = true }
 strum_macros = { workspace = true }
 opentelemetry = { workspace = true }
 tracing-opentelemetry = { workspace = true }
-actix-web-actors = { version = "4.1.0", default-features = false }
+actix-ws = { workspace = true }
+futures = { workspace = true }
 background-jobs = "0.13.0"
 
 [dev-dependencies]
index e552af53040143417573c62f98a6225073849dfe..eb53b2af3109567807851102df99bdd6da25d12c 100644 (file)
@@ -1,15 +1,15 @@
 use crate::websocket::chat_server::ChatServer;
-use actix::Addr;
 use lemmy_db_schema::{source::secret::Secret, utils::DbPool};
 use lemmy_utils::{
   rate_limit::RateLimitCell,
   settings::{structs::Settings, SETTINGS},
 };
 use reqwest_middleware::ClientWithMiddleware;
+use std::sync::Arc;
 
 pub struct LemmyContext {
   pool: DbPool,
-  chat_server: Addr<ChatServer>,
+  chat_server: Arc<ChatServer>,
   client: ClientWithMiddleware,
   settings: Settings,
   secret: Secret,
@@ -19,7 +19,7 @@ pub struct LemmyContext {
 impl LemmyContext {
   pub fn create(
     pool: DbPool,
-    chat_server: Addr<ChatServer>,
+    chat_server: Arc<ChatServer>,
     client: ClientWithMiddleware,
     settings: Settings,
     secret: Secret,
@@ -37,7 +37,7 @@ impl LemmyContext {
   pub fn pool(&self) -> &DbPool {
     &self.pool
   }
-  pub fn chat_server(&self) -> &Addr<ChatServer> {
+  pub fn chat_server(&self) -> &Arc<ChatServer> {
     &self.chat_server
   }
   pub fn client(&self) -> &ClientWithMiddleware {
index b669a4fd30fd7bfe21326362df9c91b54b7abf73..bb9e7b6df3a83d6ef21dfe96e7553937d6cd5874 100644 (file)
@@ -1,68 +1,30 @@
 use crate::{
   comment::CommentResponse,
-  context::LemmyContext,
   post::PostResponse,
-  websocket::{
-    messages::{CaptchaItem, StandardMessage, WsMessage},
-    serialize_websocket_message,
-    OperationType,
-    UserOperation,
-    UserOperationApub,
-    UserOperationCrud,
-  },
+  websocket::{serialize_websocket_message, structs::CaptchaItem, OperationType},
 };
-use actix::prelude::*;
+use actix_ws::Session;
 use anyhow::Context as acontext;
-use lemmy_db_schema::{
-  newtypes::{CommunityId, LocalUserId, PostId},
-  source::secret::Secret,
-  utils::DbPool,
-};
-use lemmy_utils::{
-  error::LemmyError,
-  location_info,
-  rate_limit::RateLimitCell,
-  settings::structs::Settings,
-  ConnectionId,
-  IpAddr,
-};
-use rand::rngs::ThreadRng;
-use reqwest_middleware::ClientWithMiddleware;
+use futures::future::try_join_all;
+use lemmy_db_schema::newtypes::{CommunityId, LocalUserId, PostId};
+use lemmy_utils::{error::LemmyError, location_info, ConnectionId};
+use rand::{rngs::StdRng, SeedableRng};
 use serde::Serialize;
-use serde_json::Value;
 use std::{
   collections::{HashMap, HashSet},
-  future::Future,
-  str::FromStr,
+  sync::{Mutex, MutexGuard},
 };
-use tokio::macros::support::Pin;
-
-type MessageHandlerType = fn(
-  context: LemmyContext,
-  id: ConnectionId,
-  op: UserOperation,
-  data: &str,
-) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
-
-type MessageHandlerCrudType = fn(
-  context: LemmyContext,
-  id: ConnectionId,
-  op: UserOperationCrud,
-  data: &str,
-) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
-
-type MessageHandlerApubType = fn(
-  context: LemmyContext,
-  id: ConnectionId,
-  op: UserOperationApub,
-  data: &str,
-) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
+use tracing::log::warn;
 
 /// `ChatServer` manages chat rooms and responsible for coordinating chat
 /// session.
 pub struct ChatServer {
+  inner: Mutex<ChatServerInner>,
+}
+
+pub struct ChatServerInner {
   /// A map from generated random ID to session addr
-  pub sessions: HashMap<ConnectionId, SessionInfo>,
+  pub sessions: HashMap<ConnectionId, Session>,
 
   /// A map from post_id to set of connectionIDs
   pub post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
@@ -76,91 +38,53 @@ pub struct ChatServer {
   /// sessions (IE clients)
   pub(super) user_rooms: HashMap<LocalUserId, HashSet<ConnectionId>>,
 
-  pub(super) rng: ThreadRng,
-
-  /// The DB Pool
-  pub(super) pool: DbPool,
-
-  /// The Settings
-  pub(super) settings: Settings,
-
-  /// The Secrets
-  pub(super) secret: Secret,
+  pub(super) rng: StdRng,
 
   /// A list of the current captchas
   pub(super) captchas: Vec<CaptchaItem>,
-
-  message_handler: MessageHandlerType,
-  message_handler_crud: MessageHandlerCrudType,
-  message_handler_apub: MessageHandlerApubType,
-
-  /// An HTTP Client
-  client: ClientWithMiddleware,
-
-  rate_limit_cell: RateLimitCell,
-}
-
-pub struct SessionInfo {
-  pub addr: Recipient<WsMessage>,
-  pub ip: IpAddr,
 }
 
 /// `ChatServer` is an actor. It maintains list of connection client session.
 /// And manages available rooms. Peers send messages to other peers in same
 /// room through `ChatServer`.
 impl ChatServer {
-  #![allow(clippy::too_many_arguments)]
-  pub fn startup(
-    pool: DbPool,
-    message_handler: MessageHandlerType,
-    message_handler_crud: MessageHandlerCrudType,
-    message_handler_apub: MessageHandlerApubType,
-    client: ClientWithMiddleware,
-    settings: Settings,
-    secret: Secret,
-    rate_limit_cell: RateLimitCell,
-  ) -> ChatServer {
+  pub fn startup() -> ChatServer {
     ChatServer {
-      sessions: HashMap::new(),
-      post_rooms: HashMap::new(),
-      community_rooms: HashMap::new(),
-      mod_rooms: HashMap::new(),
-      user_rooms: HashMap::new(),
-      rng: rand::thread_rng(),
-      pool,
-      captchas: Vec::new(),
-      message_handler,
-      message_handler_crud,
-      message_handler_apub,
-      client,
-      settings,
-      secret,
-      rate_limit_cell,
+      inner: Mutex::new(ChatServerInner {
+        sessions: Default::default(),
+        post_rooms: Default::default(),
+        community_rooms: Default::default(),
+        mod_rooms: Default::default(),
+        user_rooms: Default::default(),
+        rng: StdRng::from_entropy(),
+        captchas: vec![],
+      }),
     }
   }
 
   pub fn join_community_room(
-    &mut self,
+    &self,
     community_id: CommunityId,
     id: ConnectionId,
   ) -> Result<(), LemmyError> {
+    let mut inner = self.inner()?;
     // remove session from all rooms
-    for sessions in self.community_rooms.values_mut() {
+    for sessions in inner.community_rooms.values_mut() {
       sessions.remove(&id);
     }
 
     // Also leave all post rooms
     // This avoids double messages
-    for sessions in self.post_rooms.values_mut() {
+    for sessions in inner.post_rooms.values_mut() {
       sessions.remove(&id);
     }
 
     // If the room doesn't exist yet
-    if self.community_rooms.get_mut(&community_id).is_none() {
-      self.community_rooms.insert(community_id, HashSet::new());
+    if inner.community_rooms.get_mut(&community_id).is_none() {
+      inner.community_rooms.insert(community_id, HashSet::new());
     }
 
-    self
+    inner
       .community_rooms
       .get_mut(&community_id)
       .context(location_info!())?
@@ -169,21 +93,22 @@ impl ChatServer {
   }
 
   pub fn join_mod_room(
-    &mut self,
+    &self,
     community_id: CommunityId,
     id: ConnectionId,
   ) -> Result<(), LemmyError> {
+    let mut inner = self.inner()?;
     // remove session from all rooms
-    for sessions in self.mod_rooms.values_mut() {
+    for sessions in inner.mod_rooms.values_mut() {
       sessions.remove(&id);
     }
 
     // If the room doesn't exist yet
-    if self.mod_rooms.get_mut(&community_id).is_none() {
-      self.mod_rooms.insert(community_id, HashSet::new());
+    if inner.mod_rooms.get_mut(&community_id).is_none() {
+      inner.mod_rooms.insert(community_id, HashSet::new());
     }
 
-    self
+    inner
       .mod_rooms
       .get_mut(&community_id)
       .context(location_info!())?
@@ -191,9 +116,10 @@ impl ChatServer {
     Ok(())
   }
 
-  pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
+  pub fn join_post_room(&self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
+    let mut inner = self.inner()?;
     // remove session from all rooms
-    for sessions in self.post_rooms.values_mut() {
+    for sessions in inner.post_rooms.values_mut() {
       sessions.remove(&id);
     }
 
@@ -202,16 +128,16 @@ impl ChatServer {
     // TODO found a bug, whereby community messages like
     // delete and remove aren't sent, because
     // you left the community room
-    for sessions in self.community_rooms.values_mut() {
+    for sessions in inner.community_rooms.values_mut() {
       sessions.remove(&id);
     }
 
     // If the room doesn't exist yet
-    if self.post_rooms.get_mut(&post_id).is_none() {
-      self.post_rooms.insert(post_id, HashSet::new());
+    if inner.post_rooms.get_mut(&post_id).is_none() {
+      inner.post_rooms.insert(post_id, HashSet::new());
     }
 
-    self
+    inner
       .post_rooms
       .get_mut(&post_id)
       .context(location_info!())?
@@ -220,31 +146,27 @@ impl ChatServer {
     Ok(())
   }
 
-  pub fn join_user_room(
-    &mut self,
-    user_id: LocalUserId,
-    id: ConnectionId,
-  ) -> Result<(), LemmyError> {
+  pub fn join_user_room(&self, user_id: LocalUserId, id: ConnectionId) -> Result<(), LemmyError> {
+    let mut inner = self.inner()?;
     // remove session from all rooms
-    for sessions in self.user_rooms.values_mut() {
+    for sessions in inner.user_rooms.values_mut() {
       sessions.remove(&id);
     }
 
     // If the room doesn't exist yet
-    if self.user_rooms.get_mut(&user_id).is_none() {
-      self.user_rooms.insert(user_id, HashSet::new());
+    if inner.user_rooms.get_mut(&user_id).is_none() {
+      inner.user_rooms.insert(user_id, HashSet::new());
     }
 
-    self
+    inner
       .user_rooms
       .get_mut(&user_id)
       .context(location_info!())?
       .insert(id);
-
     Ok(())
   }
 
-  fn send_post_room_message<OP, Response>(
+  async fn send_post_room_message<OP, Response>(
     &self,
     op: &OP,
     response: &Response,
@@ -255,21 +177,14 @@ impl ChatServer {
     OP: OperationType + ToString,
     Response: Serialize,
   {
-    let res_str = &serialize_websocket_message(op, response)?;
-    if let Some(sessions) = self.post_rooms.get(&post_id) {
-      for id in sessions {
-        if let Some(my_id) = websocket_id {
-          if *id == my_id {
-            continue;
-          }
-        }
-        self.sendit(res_str, *id);
-      }
-    }
+    let msg = serialize_websocket_message(op, response)?;
+    let room = self.inner()?.post_rooms.get(&post_id).cloned();
+    self.send_message_in_room(&msg, room, websocket_id).await?;
     Ok(())
   }
 
-  pub fn send_community_room_message<OP, Response>(
+  /// Send message to all users viewing the given community.
+  pub async fn send_community_room_message<OP, Response>(
     &self,
     op: &OP,
     response: &Response,
@@ -280,23 +195,16 @@ impl ChatServer {
     OP: OperationType + ToString,
     Response: Serialize,
   {
-    let res_str = &serialize_websocket_message(op, response)?;
-    if let Some(sessions) = self.community_rooms.get(&community_id) {
-      for id in sessions {
-        if let Some(my_id) = websocket_id {
-          if *id == my_id {
-            continue;
-          }
-        }
-        self.sendit(res_str, *id);
-      }
-    }
+    let msg = serialize_websocket_message(op, response)?;
+    let room = self.inner()?.community_rooms.get(&community_id).cloned();
+    self.send_message_in_room(&msg, room, websocket_id).await?;
     Ok(())
   }
 
-  pub fn send_mod_room_message<OP, Response>(
+  /// Send message to mods of a given community. Set community_id = 0 to send to site admins.
+  pub async fn send_mod_room_message<OP, Response>(
     &self,
-    op: &OP,
+    op: OP,
     response: &Response,
     community_id: CommunityId,
     websocket_id: Option<ConnectionId>,
@@ -305,43 +213,35 @@ impl ChatServer {
     OP: OperationType + ToString,
     Response: Serialize,
   {
-    let res_str = &serialize_websocket_message(op, response)?;
-    if let Some(sessions) = self.mod_rooms.get(&community_id) {
-      for id in sessions {
-        if let Some(my_id) = websocket_id {
-          if *id == my_id {
-            continue;
-          }
-        }
-        self.sendit(res_str, *id);
-      }
-    }
+    let msg = serialize_websocket_message(&op, response)?;
+    let room = self.inner()?.mod_rooms.get(&community_id).cloned();
+    self.send_message_in_room(&msg, room, websocket_id).await?;
     Ok(())
   }
 
-  pub fn send_all_message<OP, Response>(
+  pub async fn send_all_message<OP, Response>(
     &self,
-    op: &OP,
+    op: OP,
     response: &Response,
-    websocket_id: Option<ConnectionId>,
+    exclude_connection: Option<ConnectionId>,
   ) -> Result<(), LemmyError>
   where
     OP: OperationType + ToString,
     Response: Serialize,
   {
-    let res_str = &serialize_websocket_message(op, response)?;
-    for id in self.sessions.keys() {
-      if let Some(my_id) = websocket_id {
-        if *id == my_id {
-          continue;
-        }
-      }
-      self.sendit(res_str, *id);
-    }
+    let msg = &serialize_websocket_message(&op, response)?;
+    let sessions = self.inner()?.sessions.clone();
+    try_join_all(
+      sessions
+        .into_iter()
+        .filter(|(id, _)| Some(id) != exclude_connection.as_ref())
+        .map(|(_, mut s): (_, Session)| async move { s.text(msg).await }),
+    )
+    .await?;
     Ok(())
   }
 
-  pub fn send_user_room_message<OP, Response>(
+  pub async fn send_user_room_message<OP, Response>(
     &self,
     op: &OP,
     response: &Response,
@@ -352,21 +252,13 @@ impl ChatServer {
     OP: OperationType + ToString,
     Response: Serialize,
   {
-    let res_str = &serialize_websocket_message(op, response)?;
-    if let Some(sessions) = self.user_rooms.get(&recipient_id) {
-      for id in sessions {
-        if let Some(my_id) = websocket_id {
-          if *id == my_id {
-            continue;
-          }
-        }
-        self.sendit(res_str, *id);
-      }
-    }
+    let msg = serialize_websocket_message(op, response)?;
+    let room = self.inner()?.user_rooms.get(&recipient_id).cloned();
+    self.send_message_in_room(&msg, room, websocket_id).await?;
     Ok(())
   }
 
-  pub fn send_comment<OP>(
+  pub async fn send_comment<OP>(
     &self,
     user_operation: &OP,
     comment: &CommentResponse,
@@ -384,41 +276,49 @@ impl ChatServer {
     let mut comment_post_sent = comment_reply_sent.clone();
     // Remove the recipients here to separate mentions / user messages from post or community comments
     comment_post_sent.recipient_ids = Vec::new();
-    self.send_post_room_message(
-      user_operation,
-      &comment_post_sent,
-      comment_post_sent.comment_view.post.id,
-      websocket_id,
-    )?;
+    self
+      .send_post_room_message(
+        user_operation,
+        &comment_post_sent,
+        comment_post_sent.comment_view.post.id,
+        websocket_id,
+      )
+      .await?;
 
     // Send it to the community too
-    self.send_community_room_message(
-      user_operation,
-      &comment_post_sent,
-      CommunityId(0),
-      websocket_id,
-    )?;
-    self.send_community_room_message(
-      user_operation,
-      &comment_post_sent,
-      comment.comment_view.community.id,
-      websocket_id,
-    )?;
+    self
+      .send_community_room_message(
+        user_operation,
+        &comment_post_sent,
+        CommunityId(0),
+        websocket_id,
+      )
+      .await?;
+    self
+      .send_community_room_message(
+        user_operation,
+        &comment_post_sent,
+        comment.comment_view.community.id,
+        websocket_id,
+      )
+      .await?;
 
     // Send it to the recipient(s) including the mentioned users
     for recipient_id in &comment_reply_sent.recipient_ids {
-      self.send_user_room_message(
-        user_operation,
-        &comment_reply_sent,
-        *recipient_id,
-        websocket_id,
-      )?;
+      self
+        .send_user_room_message(
+          user_operation,
+          &comment_reply_sent,
+          *recipient_id,
+          websocket_id,
+        )
+        .await?;
     }
 
     Ok(())
   }
 
-  pub fn send_post<OP>(
+  pub async fn send_post<OP>(
     &self,
     user_operation: &OP,
     post_res: &PostResponse,
@@ -434,89 +334,58 @@ impl ChatServer {
     post_sent.post_view.my_vote = None;
 
     // Send it to /c/all and that community
-    self.send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)?;
-    self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?;
+    self
+      .send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)
+      .await?;
+    self
+      .send_community_room_message(user_operation, &post_sent, community_id, websocket_id)
+      .await?;
 
     // Send it to the post room
-    self.send_post_room_message(
-      user_operation,
-      &post_sent,
-      post_res.post_view.post.id,
-      websocket_id,
-    )?;
+    self
+      .send_post_room_message(
+        user_operation,
+        &post_sent,
+        post_res.post_view.post.id,
+        websocket_id,
+      )
+      .await?;
 
     Ok(())
   }
 
-  fn sendit(&self, message: &str, id: ConnectionId) {
-    if let Some(info) = self.sessions.get(&id) {
-      info.addr.do_send(WsMessage(message.to_owned()));
+  /// Send websocket message in all sessions which joined a specific room.
+  ///
+  /// `message` - The json message body to send
+  /// `room` - Connection IDs which should receive the message
+  /// `exclude_connection` - Dont send to user who initiated the api call, as that
+  ///                        would result in duplicate notification
+  async fn send_message_in_room(
+    &self,
+    message: &str,
+    room: Option<HashSet<ConnectionId>>,
+    exclude_connection: Option<ConnectionId>,
+  ) -> Result<(), LemmyError> {
+    let mut session = self.inner()?.sessions.clone();
+    if let Some(room) = room {
+      try_join_all(
+        room
+          .into_iter()
+          .filter(|c| Some(c) != exclude_connection.as_ref())
+          .filter_map(|c| session.remove(&c))
+          .map(|mut s: Session| async move { s.text(message).await }),
+      )
+      .await?;
     }
+    Ok(())
   }
 
-  pub(super) fn parse_json_message(
-    &mut self,
-    msg: StandardMessage,
-    ctx: &mut Context<Self>,
-  ) -> impl Future<Output = Result<String, LemmyError>> {
-    let ip: IpAddr = match self.sessions.get(&msg.id) {
-      Some(info) => info.ip.clone(),
-      None => IpAddr("blank_ip".to_string()),
-    };
-
-    let context = LemmyContext::create(
-      self.pool.clone(),
-      ctx.address(),
-      self.client.clone(),
-      self.settings.clone(),
-      self.secret.clone(),
-      self.rate_limit_cell.clone(),
-    );
-    let message_handler_crud = self.message_handler_crud;
-    let message_handler = self.message_handler;
-    let message_handler_apub = self.message_handler_apub;
-    let rate_limiter = self.rate_limit_cell.clone();
-    async move {
-      let json: Value = serde_json::from_str(&msg.msg)?;
-      let data = &json["data"].to_string();
-      let op = &json["op"]
-        .as_str()
-        .ok_or_else(|| LemmyError::from_message("missing op"))?;
-
-      // check if api call passes the rate limit, and generate future for later execution
-      let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
-        let passed = match user_operation_crud {
-          UserOperationCrud::Register => rate_limiter.register().check(ip),
-          UserOperationCrud::CreatePost => rate_limiter.post().check(ip),
-          UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip),
-          UserOperationCrud::CreateComment => rate_limiter.comment().check(ip),
-          _ => rate_limiter.message().check(ip),
-        };
-        let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
-        (passed, fut)
-      } else if let Ok(user_operation) = UserOperation::from_str(op) {
-        let passed = match user_operation {
-          UserOperation::GetCaptcha => rate_limiter.post().check(ip),
-          _ => rate_limiter.message().check(ip),
-        };
-        let fut = (message_handler)(context, msg.id, user_operation, data);
-        (passed, fut)
-      } else {
-        let user_operation = UserOperationApub::from_str(op)?;
-        let passed = match user_operation {
-          UserOperationApub::Search => rate_limiter.search().check(ip),
-          _ => rate_limiter.message().check(ip),
-        };
-        let fut = (message_handler_apub)(context, msg.id, user_operation, data);
-        (passed, fut)
-      };
-
-      // if rate limit passed, execute api call future
-      if passed {
-        fut.await
-      } else {
-        // if rate limit was hit, respond with message
-        Err(LemmyError::from_message("rate_limit_error"))
+  pub(in crate::websocket) fn inner(&self) -> Result<MutexGuard<'_, ChatServerInner>, LemmyError> {
+    match self.inner.lock() {
+      Ok(g) => Ok(g),
+      Err(e) => {
+        warn!("Failed to lock chatserver mutex: {}", e);
+        Err(LemmyError::from_message("Failed to lock chatserver mutex"))
       }
     }
   }
index 6f3d164c920d80381ce6b1619c277ab287741279..afdbfd59716377d86bc912dddd7032bb0f03a067 100644 (file)
-use crate::websocket::{
-  chat_server::{ChatServer, SessionInfo},
-  messages::{
-    CaptchaItem,
-    CheckCaptcha,
-    Connect,
-    Disconnect,
-    GetCommunityUsersOnline,
-    GetPostUsersOnline,
-    GetUsersOnline,
-    JoinCommunityRoom,
-    JoinModRoom,
-    JoinPostRoom,
-    JoinUserRoom,
-    SendAllMessage,
-    SendComment,
-    SendCommunityRoomMessage,
-    SendModRoomMessage,
-    SendPost,
-    SendUserRoomMessage,
-    StandardMessage,
-  },
-  OperationType,
+use crate::websocket::{chat_server::ChatServer, structs::CaptchaItem};
+use actix_ws::Session;
+use lemmy_db_schema::{
+  newtypes::{CommunityId, PostId},
+  utils::naive_now,
 };
-use actix::{Actor, Context, Handler, ResponseFuture};
-use lemmy_db_schema::utils::naive_now;
-use lemmy_utils::ConnectionId;
-use opentelemetry::trace::TraceContextExt;
+use lemmy_utils::{error::LemmyError, ConnectionId};
 use rand::Rng;
-use serde::Serialize;
-use tracing::{error, info};
-use tracing_opentelemetry::OpenTelemetrySpanExt;
 
-/// Make actor from `ChatServer`
-impl Actor for ChatServer {
-  /// We are going to use simple Context, we just need ability to communicate
-  /// with other actors.
-  type Context = Context<Self>;
-}
-
-/// Handler for Connect message.
-///
-/// Register new session and assign unique id to this session
-impl Handler<Connect> for ChatServer {
-  type Result = ConnectionId;
-
-  fn handle(&mut self, msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
+impl ChatServer {
+  /// Handler for Connect message.
+  ///
+  /// Register new session and assign unique id to this session
+  pub fn handle_connect(&self, session: Session) -> Result<ConnectionId, LemmyError> {
+    let mut inner = self.inner()?;
     // register session with random id
-    let id = self.rng.gen::<usize>();
-    info!("{} joined", &msg.ip);
+    let id = inner.rng.gen::<usize>();
 
-    self.sessions.insert(
-      id,
-      SessionInfo {
-        addr: msg.addr,
-        ip: msg.ip,
-      },
-    );
-
-    id
+    inner.sessions.insert(id, session);
+    Ok(id)
   }
-}
 
-/// Handler for Disconnect message.
-impl Handler<Disconnect> for ChatServer {
-  type Result = ();
-
-  fn handle(&mut self, msg: Disconnect, _: &mut Context<Self>) {
+  /// Handler for Disconnect message.
+  pub fn handle_disconnect(&self, connection_id: &ConnectionId) -> Result<(), LemmyError> {
+    let mut inner = self.inner()?;
     // Remove connections from sessions and all 3 scopes
-    if self.sessions.remove(&msg.id).is_some() {
-      for sessions in self.user_rooms.values_mut() {
-        sessions.remove(&msg.id);
+    if inner.sessions.remove(connection_id).is_some() {
+      for sessions in inner.user_rooms.values_mut() {
+        sessions.remove(connection_id);
       }
 
-      for sessions in self.post_rooms.values_mut() {
-        sessions.remove(&msg.id);
+      for sessions in inner.post_rooms.values_mut() {
+        sessions.remove(connection_id);
       }
 
-      for sessions in self.community_rooms.values_mut() {
-        sessions.remove(&msg.id);
+      for sessions in inner.community_rooms.values_mut() {
+        sessions.remove(connection_id);
       }
     }
+    Ok(())
   }
-}
-
-fn root_span() -> tracing::Span {
-  let span = tracing::info_span!(
-    parent: None,
-    "Websocket Request",
-    trace_id = tracing::field::Empty,
-  );
-  {
-    let trace_id = span.context().span().span_context().trace_id().to_string();
-    span.record("trace_id", &tracing::field::display(trace_id));
-  }
-
-  span
-}
-
-/// Handler for Message message.
-impl Handler<StandardMessage> for ChatServer {
-  type Result = ResponseFuture<Result<String, std::convert::Infallible>>;
-
-  fn handle(&mut self, msg: StandardMessage, ctx: &mut Context<Self>) -> Self::Result {
-    use tracing::Instrument;
-    let fut = self.parse_json_message(msg, ctx);
-    let span = root_span();
-
-    Box::pin(
-      async move {
-        match fut.await {
-          Ok(m) => {
-            // info!("Message Sent: {}", m);
-            Ok(m)
-          }
-          Err(e) => {
-            error!("Error during message handling {}", e);
-            Ok(
-              e.to_json()
-                .unwrap_or_else(|_| String::from(r#"{"error":"failed to serialize json"}"#)),
-            )
-          }
-        }
-      }
-      .instrument(span),
-    )
-  }
-}
-
-impl<OP, Response> Handler<SendAllMessage<OP, Response>> for ChatServer
-where
-  OP: OperationType + ToString,
-  Response: Serialize,
-{
-  type Result = ();
-
-  fn handle(&mut self, msg: SendAllMessage<OP, Response>, _: &mut Context<Self>) {
-    self
-      .send_all_message(&msg.op, &msg.response, msg.websocket_id)
-      .ok();
-  }
-}
-
-impl<OP, Response> Handler<SendUserRoomMessage<OP, Response>> for ChatServer
-where
-  OP: OperationType + ToString,
-  Response: Serialize,
-{
-  type Result = ();
-
-  fn handle(&mut self, msg: SendUserRoomMessage<OP, Response>, _: &mut Context<Self>) {
-    self
-      .send_user_room_message(
-        &msg.op,
-        &msg.response,
-        msg.local_recipient_id,
-        msg.websocket_id,
-      )
-      .ok();
-  }
-}
-
-impl<OP, Response> Handler<SendCommunityRoomMessage<OP, Response>> for ChatServer
-where
-  OP: OperationType + ToString,
-  Response: Serialize,
-{
-  type Result = ();
-
-  fn handle(&mut self, msg: SendCommunityRoomMessage<OP, Response>, _: &mut Context<Self>) {
-    self
-      .send_community_room_message(&msg.op, &msg.response, msg.community_id, msg.websocket_id)
-      .ok();
-  }
-}
-
-impl<Response> Handler<SendModRoomMessage<Response>> for ChatServer
-where
-  Response: Serialize,
-{
-  type Result = ();
-
-  fn handle(&mut self, msg: SendModRoomMessage<Response>, _: &mut Context<Self>) {
-    self
-      .send_mod_room_message(&msg.op, &msg.response, msg.community_id, msg.websocket_id)
-      .ok();
-  }
-}
-
-impl<OP> Handler<SendPost<OP>> for ChatServer
-where
-  OP: OperationType + ToString,
-{
-  type Result = ();
-
-  fn handle(&mut self, msg: SendPost<OP>, _: &mut Context<Self>) {
-    self.send_post(&msg.op, &msg.post, msg.websocket_id).ok();
-  }
-}
-
-impl<OP> Handler<SendComment<OP>> for ChatServer
-where
-  OP: OperationType + ToString,
-{
-  type Result = ();
 
-  fn handle(&mut self, msg: SendComment<OP>, _: &mut Context<Self>) {
-    self
-      .send_comment(&msg.op, &msg.comment, msg.websocket_id)
-      .ok();
+  pub fn get_users_online(&self) -> Result<usize, LemmyError> {
+    Ok(self.inner()?.sessions.len())
   }
-}
-
-impl Handler<JoinUserRoom> for ChatServer {
-  type Result = ();
 
-  fn handle(&mut self, msg: JoinUserRoom, _: &mut Context<Self>) {
-    self.join_user_room(msg.local_user_id, msg.id).ok();
-  }
-}
-
-impl Handler<JoinCommunityRoom> for ChatServer {
-  type Result = ();
-
-  fn handle(&mut self, msg: JoinCommunityRoom, _: &mut Context<Self>) {
-    self.join_community_room(msg.community_id, msg.id).ok();
-  }
-}
-
-impl Handler<JoinModRoom> for ChatServer {
-  type Result = ();
-
-  fn handle(&mut self, msg: JoinModRoom, _: &mut Context<Self>) {
-    self.join_mod_room(msg.community_id, msg.id).ok();
-  }
-}
-
-impl Handler<JoinPostRoom> for ChatServer {
-  type Result = ();
-
-  fn handle(&mut self, msg: JoinPostRoom, _: &mut Context<Self>) {
-    self.join_post_room(msg.post_id, msg.id).ok();
-  }
-}
-
-impl Handler<GetUsersOnline> for ChatServer {
-  type Result = usize;
-
-  fn handle(&mut self, _msg: GetUsersOnline, _: &mut Context<Self>) -> Self::Result {
-    self.sessions.len()
-  }
-}
-
-impl Handler<GetPostUsersOnline> for ChatServer {
-  type Result = usize;
-
-  fn handle(&mut self, msg: GetPostUsersOnline, _: &mut Context<Self>) -> Self::Result {
-    if let Some(users) = self.post_rooms.get(&msg.post_id) {
-      users.len()
+  pub fn get_post_users_online(&self, post_id: PostId) -> Result<usize, LemmyError> {
+    if let Some(users) = self.inner()?.post_rooms.get(&post_id) {
+      Ok(users.len())
     } else {
-      0
+      Ok(0)
     }
   }
-}
-
-impl Handler<GetCommunityUsersOnline> for ChatServer {
-  type Result = usize;
 
-  fn handle(&mut self, msg: GetCommunityUsersOnline, _: &mut Context<Self>) -> Self::Result {
-    if let Some(users) = self.community_rooms.get(&msg.community_id) {
-      users.len()
+  pub fn get_community_users_online(&self, community_id: CommunityId) -> Result<usize, LemmyError> {
+    if let Some(users) = self.inner()?.community_rooms.get(&community_id) {
+      Ok(users.len())
     } else {
-      0
+      Ok(0)
     }
   }
-}
-
-impl Handler<CaptchaItem> for ChatServer {
-  type Result = ();
 
-  fn handle(&mut self, msg: CaptchaItem, _: &mut Context<Self>) {
-    self.captchas.push(msg);
+  pub fn add_captcha(&self, captcha: CaptchaItem) -> Result<(), LemmyError> {
+    self.inner()?.captchas.push(captcha);
+    Ok(())
   }
-}
-
-impl Handler<CheckCaptcha> for ChatServer {
-  type Result = bool;
 
-  fn handle(&mut self, msg: CheckCaptcha, _: &mut Context<Self>) -> Self::Result {
+  pub fn check_captcha(&self, uuid: String, answer: String) -> Result<bool, LemmyError> {
+    let mut inner = self.inner()?;
     // Remove all the ones that are past the expire time
-    self.captchas.retain(|x| x.expires.gt(&naive_now()));
+    inner.captchas.retain(|x| x.expires.gt(&naive_now()));
 
-    let check = self
+    let check = inner
       .captchas
       .iter()
-      .any(|r| r.uuid == msg.uuid && r.answer.to_lowercase() == msg.answer.to_lowercase());
+      .any(|r| r.uuid == uuid && r.answer.to_lowercase() == answer.to_lowercase());
 
     // Remove this uuid so it can't be re-checked (Checks only work once)
-    self.captchas.retain(|x| x.uuid != msg.uuid);
+    inner.captchas.retain(|x| x.uuid != uuid);
 
-    check
+    Ok(check)
   }
 }
diff --git a/crates/api_common/src/websocket/messages.rs b/crates/api_common/src/websocket/messages.rs
deleted file mode 100644 (file)
index f811241..0000000
+++ /dev/null
@@ -1,150 +0,0 @@
-use crate::{comment::CommentResponse, post::PostResponse, websocket::UserOperation};
-use actix::{prelude::*, Recipient};
-use lemmy_db_schema::newtypes::{CommunityId, LocalUserId, PostId};
-use lemmy_utils::{ConnectionId, IpAddr};
-use serde::{Deserialize, Serialize};
-
-/// Chat server sends this messages to session
-#[derive(Message)]
-#[rtype(result = "()")]
-pub struct WsMessage(pub String);
-
-/// Message for chat server communications
-
-/// New chat session is created
-#[derive(Message)]
-#[rtype(usize)]
-pub struct Connect {
-  pub addr: Recipient<WsMessage>,
-  pub ip: IpAddr,
-}
-
-/// Session is disconnected
-#[derive(Message)]
-#[rtype(result = "()")]
-pub struct Disconnect {
-  pub id: ConnectionId,
-  pub ip: IpAddr,
-}
-
-/// The messages sent to websocket clients
-#[derive(Serialize, Deserialize, Message)]
-#[rtype(result = "Result<String, std::convert::Infallible>")]
-pub struct StandardMessage {
-  /// Id of the client session
-  pub id: ConnectionId,
-  /// Peer message
-  pub msg: String,
-}
-
-#[derive(Message)]
-#[rtype(result = "()")]
-pub struct SendAllMessage<OP: ToString, Response> {
-  pub op: OP,
-  pub response: Response,
-  pub websocket_id: Option<ConnectionId>,
-}
-
-#[derive(Message)]
-#[rtype(result = "()")]
-pub struct SendUserRoomMessage<OP: ToString, Response> {
-  pub op: OP,
-  pub response: Response,
-  pub local_recipient_id: LocalUserId,
-  pub websocket_id: Option<ConnectionId>,
-}
-
-/// Send message to all users viewing the given community.
-#[derive(Message)]
-#[rtype(result = "()")]
-pub struct SendCommunityRoomMessage<OP: ToString, Response> {
-  pub op: OP,
-  pub response: Response,
-  pub community_id: CommunityId,
-  pub websocket_id: Option<ConnectionId>,
-}
-
-/// Send message to mods of a given community. Set community_id = 0 to send to site admins.
-#[derive(Message)]
-#[rtype(result = "()")]
-pub struct SendModRoomMessage<Response> {
-  pub op: UserOperation,
-  pub response: Response,
-  pub community_id: CommunityId,
-  pub websocket_id: Option<ConnectionId>,
-}
-
-#[derive(Message)]
-#[rtype(result = "()")]
-pub(crate) struct SendPost<OP: ToString> {
-  pub op: OP,
-  pub post: PostResponse,
-  pub websocket_id: Option<ConnectionId>,
-}
-
-#[derive(Message)]
-#[rtype(result = "()")]
-pub(crate) struct SendComment<OP: ToString> {
-  pub op: OP,
-  pub comment: CommentResponse,
-  pub websocket_id: Option<ConnectionId>,
-}
-
-#[derive(Message)]
-#[rtype(result = "()")]
-pub struct JoinUserRoom {
-  pub local_user_id: LocalUserId,
-  pub id: ConnectionId,
-}
-
-#[derive(Message)]
-#[rtype(result = "()")]
-pub struct JoinCommunityRoom {
-  pub community_id: CommunityId,
-  pub id: ConnectionId,
-}
-
-#[derive(Message)]
-#[rtype(result = "()")]
-pub struct JoinModRoom {
-  pub community_id: CommunityId,
-  pub id: ConnectionId,
-}
-
-#[derive(Message)]
-#[rtype(result = "()")]
-pub struct JoinPostRoom {
-  pub post_id: PostId,
-  pub id: ConnectionId,
-}
-
-#[derive(Message)]
-#[rtype(usize)]
-pub struct GetUsersOnline;
-
-#[derive(Message)]
-#[rtype(usize)]
-pub struct GetPostUsersOnline {
-  pub post_id: PostId,
-}
-
-#[derive(Message)]
-#[rtype(usize)]
-pub struct GetCommunityUsersOnline {
-  pub community_id: CommunityId,
-}
-
-#[derive(Message, Debug)]
-#[rtype(result = "()")]
-pub struct CaptchaItem {
-  pub uuid: String,
-  pub answer: String,
-  pub expires: chrono::NaiveDateTime,
-}
-
-#[derive(Message)]
-#[rtype(bool)]
-pub struct CheckCaptcha {
-  pub uuid: String,
-  pub answer: String,
-}
index 430027cfaebabf15055b4f9ed31953a3c2e09aa5..8686b8e4f8289163fdf0814375195202cedaf9d5 100644 (file)
@@ -3,8 +3,6 @@ use serde::Serialize;
 
 pub mod chat_server;
 pub mod handlers;
-pub mod messages;
-pub mod routes;
 pub mod send;
 pub mod structs;
 
diff --git a/crates/api_common/src/websocket/routes.rs b/crates/api_common/src/websocket/routes.rs
deleted file mode 100644 (file)
index 936dc99..0000000
+++ /dev/null
@@ -1,197 +0,0 @@
-use crate::{
-  context::LemmyContext,
-  websocket::{
-    chat_server::ChatServer,
-    messages::{Connect, Disconnect, StandardMessage, WsMessage},
-  },
-};
-use actix::prelude::*;
-use actix_web::{web, Error, HttpRequest, HttpResponse};
-use actix_web_actors::ws;
-use lemmy_utils::{rate_limit::RateLimitCell, utils::get_ip, ConnectionId, IpAddr};
-use std::time::{Duration, Instant};
-use tracing::{debug, error, info};
-
-/// How often heartbeat pings are sent
-const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
-/// How long before lack of client response causes a timeout
-const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
-
-/// Entry point for our route
-pub async fn chat_route(
-  req: HttpRequest,
-  stream: web::Payload,
-  context: web::Data<LemmyContext>,
-  rate_limiter: web::Data<RateLimitCell>,
-) -> Result<HttpResponse, Error> {
-  ws::start(
-    WsSession {
-      cs_addr: context.chat_server().clone(),
-      id: 0,
-      hb: Instant::now(),
-      ip: get_ip(&req.connection_info()),
-      rate_limiter: rate_limiter.as_ref().clone(),
-    },
-    &req,
-    stream,
-  )
-}
-
-struct WsSession {
-  cs_addr: Addr<ChatServer>,
-  /// unique session id
-  id: ConnectionId,
-  ip: IpAddr,
-  /// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
-  /// otherwise we drop connection.
-  hb: Instant,
-  /// A rate limiter for websocket joins
-  rate_limiter: RateLimitCell,
-}
-
-impl Actor for WsSession {
-  type Context = ws::WebsocketContext<Self>;
-
-  /// Method is called on actor start.
-  /// We register ws session with ChatServer
-  fn started(&mut self, ctx: &mut Self::Context) {
-    // we'll start heartbeat process on session start.
-    WsSession::hb(ctx);
-
-    // register self in chat server. `AsyncContext::wait` register
-    // future within context, but context waits until this future resolves
-    // before processing any other events.
-    // across all routes within application
-    let addr = ctx.address();
-
-    if !self.rate_limit_check(ctx) {
-      return;
-    }
-
-    self
-      .cs_addr
-      .send(Connect {
-        addr: addr.recipient(),
-        ip: self.ip.clone(),
-      })
-      .into_actor(self)
-      .then(|res, act, ctx| {
-        match res {
-          Ok(res) => act.id = res,
-          // something is wrong with chat server
-          _ => ctx.stop(),
-        }
-        actix::fut::ready(())
-      })
-      .wait(ctx);
-  }
-
-  fn stopping(&mut self, _ctx: &mut Self::Context) -> Running {
-    // notify chat server
-    self.cs_addr.do_send(Disconnect {
-      id: self.id,
-      ip: self.ip.clone(),
-    });
-    Running::Stop
-  }
-}
-
-/// Handle messages from chat server, we simply send it to peer websocket
-/// These are room messages, IE sent to others in the room
-impl Handler<WsMessage> for WsSession {
-  type Result = ();
-
-  fn handle(&mut self, msg: WsMessage, ctx: &mut Self::Context) {
-    ctx.text(msg.0);
-  }
-}
-
-/// WebSocket message handler
-impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsSession {
-  fn handle(&mut self, result: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
-    if !self.rate_limit_check(ctx) {
-      return;
-    }
-
-    let message = match result {
-      Ok(m) => m,
-      Err(e) => {
-        error!("{}", e);
-        return;
-      }
-    };
-    match message {
-      ws::Message::Ping(msg) => {
-        self.hb = Instant::now();
-        ctx.pong(&msg);
-      }
-      ws::Message::Pong(_) => {
-        self.hb = Instant::now();
-      }
-      ws::Message::Text(text) => {
-        let m = text.trim().to_owned();
-
-        self
-          .cs_addr
-          .send(StandardMessage {
-            id: self.id,
-            msg: m,
-          })
-          .into_actor(self)
-          .then(|res, _, ctx| {
-            match res {
-              Ok(Ok(res)) => ctx.text(res),
-              Ok(Err(_)) => {}
-              Err(e) => error!("{}", &e),
-            }
-            actix::fut::ready(())
-          })
-          .spawn(ctx);
-      }
-      ws::Message::Binary(_bin) => info!("Unexpected binary"),
-      ws::Message::Close(_) => {
-        ctx.stop();
-      }
-      _ => {}
-    }
-  }
-}
-
-impl WsSession {
-  /// helper method that sends ping to client every second.
-  ///
-  /// also this method checks heartbeats from client
-  fn hb(ctx: &mut ws::WebsocketContext<Self>) {
-    ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
-      // check client heartbeats
-      if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
-        // heartbeat timed out
-        debug!("Websocket Client heartbeat failed, disconnecting!");
-
-        // notify chat server
-        act.cs_addr.do_send(Disconnect {
-          id: act.id,
-          ip: act.ip.clone(),
-        });
-
-        // stop actor
-        ctx.stop();
-
-        // don't try to send a ping
-        return;
-      }
-
-      ctx.ping(b"");
-    });
-  }
-
-  /// Check the rate limit, and stop the ctx if it fails
-  fn rate_limit_check(&mut self, ctx: &mut ws::WebsocketContext<Self>) -> bool {
-    let check = self.rate_limiter.message().check(self.ip.clone());
-    if !check {
-      debug!("Websocket join with IP: {} has been rate limited.", self.ip);
-      ctx.stop()
-    }
-    check
-  }
-}
index cd53955f9f42791d602668e871c2883f0135cd79..4f639452e94902b6bb4a8eee925646c01ef211a8 100644 (file)
@@ -5,10 +5,7 @@ use crate::{
   post::PostResponse,
   private_message::PrivateMessageResponse,
   utils::{check_person_block, get_interface_language, send_email_to_user},
-  websocket::{
-    messages::{SendComment, SendCommunityRoomMessage, SendPost, SendUserRoomMessage},
-    OperationType,
-  },
+  websocket::OperationType,
 };
 use lemmy_db_schema::{
   newtypes::{CommentId, CommunityId, LocalUserId, PersonId, PostId, PrivateMessageId},
@@ -38,11 +35,10 @@ pub async fn send_post_ws_message<OP: ToString + Send + OperationType + 'static>
 
   let res = PostResponse { post_view };
 
-  context.chat_server().do_send(SendPost {
-    op,
-    post: res.clone(),
-    websocket_id,
-  });
+  context
+    .chat_server()
+    .send_post(&op, &res, websocket_id)
+    .await?;
 
   Ok(res)
 }
@@ -81,11 +77,10 @@ pub async fn send_comment_ws_message<OP: ToString + Send + OperationType + 'stat
     form_id: None,
   };
 
-  context.chat_server().do_send(SendComment {
-    op,
-    comment: res.clone(),
-    websocket_id,
-  });
+  context
+    .chat_server()
+    .send_comment(&op, &res, websocket_id)
+    .await?;
 
   // The recipient_ids should be empty for returns
   res.recipient_ids = Vec::new();
@@ -104,18 +99,15 @@ pub async fn send_community_ws_message<OP: ToString + Send + OperationType + 'st
 ) -> Result<CommunityResponse, LemmyError> {
   let community_view = CommunityView::read(context.pool(), community_id, person_id).await?;
 
-  let res = CommunityResponse { community_view };
+  let mut res = CommunityResponse { community_view };
 
   // Strip out the person id and subscribed when sending to others
-  let mut res_mut = res.clone();
-  res_mut.community_view.subscribed = SubscribedType::NotSubscribed;
+  res.community_view.subscribed = SubscribedType::NotSubscribed;
 
-  context.chat_server().do_send(SendCommunityRoomMessage {
-    op,
-    response: res_mut,
-    community_id: res.community_view.community.id,
-    websocket_id,
-  });
+  context
+    .chat_server()
+    .send_community_room_message(&op, &res, res.community_view.community.id, websocket_id)
+    .await?;
 
   Ok(res)
 }
@@ -142,12 +134,11 @@ pub async fn send_pm_ws_message<OP: ToString + Send + OperationType + 'static>(
   if res.private_message_view.recipient.local {
     let recipient_id = res.private_message_view.recipient.id;
     let local_recipient = LocalUserView::read_person(context.pool(), recipient_id).await?;
-    context.chat_server().do_send(SendUserRoomMessage {
-      op,
-      response: res.clone(),
-      local_recipient_id: local_recipient.local_user.id,
-      websocket_id,
-    });
+
+    context
+      .chat_server()
+      .send_user_room_message(&op, &res, local_recipient.local_user.id, websocket_id)
+      .await?;
   }
 
   Ok(res)
index 23f7b2bc5305187aa4110efb74e830bb84b4f51b..3418d05c26c6c13727715502b8f1c6e7355f138f 100644 (file)
@@ -41,3 +41,10 @@ pub struct PostJoin {
 pub struct PostJoinResponse {
   pub joined: bool,
 }
+
+#[derive(Debug)]
+pub struct CaptchaItem {
+  pub uuid: String,
+  pub answer: String,
+  pub expires: chrono::NaiveDateTime,
+}
index c6747c5e8361b5512c23b89220152c883ed377b7..49bb3765dfb491506a559943c5f7e4d71c2b77c6 100644 (file)
@@ -4,7 +4,6 @@ use lemmy_api_common::{
   context::LemmyContext,
   post::{GetPost, GetPostResponse},
   utils::{check_private_instance, get_local_user_view_from_jwt_opt, mark_post_as_read},
-  websocket::messages::GetPostUsersOnline,
 };
 use lemmy_db_schema::{
   aggregates::structs::{PersonPostAggregates, PersonPostAggregatesForm},
@@ -91,11 +90,7 @@ impl PerformCrud for GetPost {
 
     let moderators = CommunityModeratorView::for_community(context.pool(), community_id).await?;
 
-    let online = context
-      .chat_server()
-      .send(GetPostUsersOnline { post_id })
-      .await
-      .unwrap_or(1);
+    let online = context.chat_server().get_post_users_online(post_id)?;
 
     // Return the jwt
     Ok(GetPostResponse {
index 55dcd8fb2d915d5d1cd565f110bcb725212732b5..5b133c2c7b6e0a25730cbb1de915db7e03615c37 100644 (file)
@@ -4,7 +4,6 @@ use lemmy_api_common::{
   context::LemmyContext,
   site::{GetSite, GetSiteResponse, MyUserInfo},
   utils::{build_federated_instances, get_local_user_settings_view_from_jwt_opt},
-  websocket::messages::GetUsersOnline,
 };
 use lemmy_db_schema::source::{actor_language::SiteLanguage, language::Language, tagline::Tagline};
 use lemmy_db_views::structs::{LocalUserDiscussionLanguageView, SiteView};
@@ -33,11 +32,7 @@ impl PerformCrud for GetSite {
 
     let admins = PersonViewSafe::admins(context.pool()).await?;
 
-    let online = context
-      .chat_server()
-      .send(GetUsersOnline)
-      .await
-      .unwrap_or(1);
+    let online = context.chat_server().get_users_online()?;
 
     // Build the local user
     let my_user = if let Some(local_user_view) = get_local_user_settings_view_from_jwt_opt(
index 4a96925edbb8b4719302451dbc54a3304d258ac9..3909a297254913f8a251aaadce5275b6b1be4268 100644 (file)
@@ -10,7 +10,7 @@ use lemmy_api_common::{
     local_site_to_slur_regex,
     site_description_length_check,
   },
-  websocket::{messages::SendAllMessage, UserOperationCrud},
+  websocket::UserOperationCrud,
 };
 use lemmy_db_schema::{
   source::{
@@ -189,11 +189,10 @@ impl PerformCrud for EditSite {
 
     let res = SiteResponse { site_view };
 
-    context.chat_server().do_send(SendAllMessage {
-      op: UserOperationCrud::EditSite,
-      response: res.clone(),
-      websocket_id,
-    });
+    context
+      .chat_server()
+      .send_all_message(UserOperationCrud::EditSite, &res, websocket_id)
+      .await?;
 
     Ok(res)
   }
index 34b3b69e47004d154d9018662d570314879a51d4..f11c45b31a9957be9a356adb33a310e18cfad7fc 100644 (file)
@@ -15,7 +15,6 @@ use lemmy_api_common::{
     send_verification_email,
     EndpointType,
   },
-  websocket::messages::CheckCaptcha,
 };
 use lemmy_db_schema::{
   aggregates::structs::PersonAggregates,
@@ -73,13 +72,10 @@ impl PerformCrud for Register {
 
     // If the site is set up, check the captcha
     if local_site.site_setup && local_site.captcha_enabled {
-      let check = context
-        .chat_server()
-        .send(CheckCaptcha {
-          uuid: data.captcha_uuid.clone().unwrap_or_default(),
-          answer: data.captcha_answer.clone().unwrap_or_default(),
-        })
-        .await?;
+      let check = context.chat_server().check_captcha(
+        data.captcha_uuid.clone().unwrap_or_default(),
+        data.captcha_answer.clone().unwrap_or_default(),
+      )?;
       if !check {
         return Err(LemmyError::from_message("captcha_incorrect"));
       }
index 720fa5937ab9cbd4c753a82df17f52affb178d97..3a4db66dc421bd31e4d676941d161669010e62c4 100644 (file)
@@ -24,7 +24,6 @@ diesel = { workspace = true }
 chrono = { workspace = true }
 serde_json = { workspace = true }
 serde = { workspace = true }
-actix = { workspace = true }
 actix-web = { workspace = true }
 actix-rt = { workspace = true }
 tracing = { workspace = true }
index 57b1244d106803500c8e46079c26bda412a2bd51..dfd0a83c8086361ef2c10ee6b23009cf9cb0a0f8 100644 (file)
@@ -18,7 +18,7 @@ use lemmy_api_common::{
   context::LemmyContext,
   post::{CreatePostReport, PostReportResponse},
   utils::get_local_user_view_from_jwt,
-  websocket::{messages::SendModRoomMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{
   source::{
@@ -158,12 +158,15 @@ impl ActivityHandler for Report {
 
         let post_report_view = PostReportView::read(context.pool(), report.id, actor.id).await?;
 
-        context.chat_server().do_send(SendModRoomMessage {
-          op: UserOperation::CreateCommentReport,
-          response: PostReportResponse { post_report_view },
-          community_id: post.community_id,
-          websocket_id: None,
-        });
+        context
+          .chat_server()
+          .send_mod_room_message(
+            UserOperation::CreateCommentReport,
+            &PostReportResponse { post_report_view },
+            post.community_id,
+            None,
+          )
+          .await?;
       }
       PostOrComment::Comment(comment) => {
         let report_form = CommentReportForm {
@@ -179,14 +182,17 @@ impl ActivityHandler for Report {
           CommentReportView::read(context.pool(), report.id, actor.id).await?;
         let community_id = comment_report_view.community.id;
 
-        context.chat_server().do_send(SendModRoomMessage {
-          op: UserOperation::CreateCommentReport,
-          response: CommentReportResponse {
-            comment_report_view,
-          },
-          community_id,
-          websocket_id: None,
-        });
+        context
+          .chat_server()
+          .send_mod_room_message(
+            UserOperation::CreateCommentReport,
+            &CommentReportResponse {
+              comment_report_view,
+            },
+            community_id,
+            None,
+          )
+          .await?;
       }
     };
     Ok(())
index 6702c8c9590c3893d56a08c12c1b5b552b95e860..fbde090f7b434c771e3f4d8ab0c4981e2e7a318c 100644 (file)
@@ -14,7 +14,7 @@ use activitystreams_kinds::activity::AcceptType;
 use lemmy_api_common::{
   community::CommunityResponse,
   context::LemmyContext,
-  websocket::{messages::SendUserRoomMessage, UserOperation},
+  websocket::UserOperation,
 };
 use lemmy_db_schema::{source::community::CommunityFollower, traits::Followable};
 use lemmy_db_views::structs::LocalUserView;
@@ -106,12 +106,15 @@ impl ActivityHandler for AcceptFollow {
 
     let response = CommunityResponse { community_view };
 
-    context.chat_server().do_send(SendUserRoomMessage {
-      op: UserOperation::FollowCommunity,
-      response,
-      local_recipient_id,
-      websocket_id: None,
-    });
+    context
+      .chat_server()
+      .send_user_room_message(
+        &UserOperation::FollowCommunity,
+        &response,
+        local_recipient_id,
+        None,
+      )
+      .await?;
 
     Ok(())
   }
index b1615d7d75aa466ef6f386a3e8cc58f512a1b90a..8ffa48ba79c1adcdf53c7d094feb80c5232f6d0c 100644 (file)
@@ -8,7 +8,6 @@ use lemmy_api_common::{
   community::{GetCommunity, GetCommunityResponse},
   context::LemmyContext,
   utils::{check_private_instance, get_local_user_view_from_jwt_opt},
-  websocket::messages::GetCommunityUsersOnline,
 };
 use lemmy_db_schema::{
   impls::actor_language::default_post_language,
@@ -74,9 +73,7 @@ impl PerformApub for GetCommunity {
 
     let online = context
       .chat_server()
-      .send(GetCommunityUsersOnline { community_id })
-      .await
-      .unwrap_or(1);
+      .get_community_users_online(community_id)?;
 
     let site_id =
       Site::instance_actor_id_from_url(community_view.community.actor_id.clone().into());
index 58e1f23f55e5c0298adccf225e36dbbdf6b4f5cf..29df3d64a3ecb342f4cec6167774c0afb62cfc27 100644 (file)
@@ -54,7 +54,6 @@ pub(crate) fn verify_is_remote_object(id: &Url, settings: &Settings) -> Result<(
 
 #[cfg(test)]
 pub(crate) mod tests {
-  use actix::Actor;
   use anyhow::anyhow;
   use lemmy_api_common::{
     context::LemmyContext,
@@ -69,6 +68,7 @@ pub(crate) mod tests {
   };
   use reqwest::{Client, Request, Response};
   use reqwest_middleware::{ClientBuilder, Middleware, Next};
+  use std::sync::Arc;
   use task_local_extensions::Extensions;
 
   struct BlockedMiddleware;
@@ -109,17 +109,7 @@ pub(crate) mod tests {
     let rate_limit_config = RateLimitConfig::builder().build();
     let rate_limit_cell = RateLimitCell::new(rate_limit_config).await;
 
-    let chat_server = ChatServer::startup(
-      pool.clone(),
-      |_, _, _, _| Box::pin(x()),
-      |_, _, _, _| Box::pin(x()),
-      |_, _, _, _| Box::pin(x()),
-      client.clone(),
-      settings.clone(),
-      secret.clone(),
-      rate_limit_cell.clone(),
-    )
-    .start();
+    let chat_server = Arc::new(ChatServer::startup());
     LemmyContext::create(
       pool,
       chat_server,
diff --git a/src/api_routes_http.rs b/src/api_routes_http.rs
new file mode 100644 (file)
index 0000000..34083a4
--- /dev/null
@@ -0,0 +1,463 @@
+use crate::api_routes_websocket::websocket;
+use actix_web::{guard, web, Error, HttpResponse, Result};
+use lemmy_api::Perform;
+use lemmy_api_common::{
+  comment::{
+    CreateComment,
+    CreateCommentLike,
+    CreateCommentReport,
+    DeleteComment,
+    EditComment,
+    GetComment,
+    GetComments,
+    ListCommentReports,
+    RemoveComment,
+    ResolveCommentReport,
+    SaveComment,
+  },
+  community::{
+    AddModToCommunity,
+    BanFromCommunity,
+    BlockCommunity,
+    CreateCommunity,
+    DeleteCommunity,
+    EditCommunity,
+    FollowCommunity,
+    GetCommunity,
+    HideCommunity,
+    ListCommunities,
+    RemoveCommunity,
+    TransferCommunity,
+  },
+  context::LemmyContext,
+  person::{
+    AddAdmin,
+    BanPerson,
+    BlockPerson,
+    ChangePassword,
+    DeleteAccount,
+    GetBannedPersons,
+    GetCaptcha,
+    GetPersonDetails,
+    GetPersonMentions,
+    GetReplies,
+    GetReportCount,
+    GetUnreadCount,
+    Login,
+    MarkAllAsRead,
+    MarkCommentReplyAsRead,
+    MarkPersonMentionAsRead,
+    PasswordChangeAfterReset,
+    PasswordReset,
+    Register,
+    SaveUserSettings,
+    VerifyEmail,
+  },
+  post::{
+    CreatePost,
+    CreatePostLike,
+    CreatePostReport,
+    DeletePost,
+    EditPost,
+    GetPost,
+    GetPosts,
+    GetSiteMetadata,
+    ListPostReports,
+    LockPost,
+    MarkPostAsRead,
+    RemovePost,
+    ResolvePostReport,
+    SavePost,
+    StickyPost,
+  },
+  private_message::{
+    CreatePrivateMessage,
+    CreatePrivateMessageReport,
+    DeletePrivateMessage,
+    EditPrivateMessage,
+    GetPrivateMessages,
+    ListPrivateMessageReports,
+    MarkPrivateMessageAsRead,
+    ResolvePrivateMessageReport,
+  },
+  site::{
+    ApproveRegistrationApplication,
+    CreateSite,
+    EditSite,
+    GetModlog,
+    GetSite,
+    GetUnreadRegistrationApplicationCount,
+    LeaveAdmin,
+    ListRegistrationApplications,
+    PurgeComment,
+    PurgeCommunity,
+    PurgePerson,
+    PurgePost,
+    ResolveObject,
+    Search,
+  },
+  websocket::structs::{CommunityJoin, ModJoin, PostJoin, UserJoin},
+};
+use lemmy_api_crud::PerformCrud;
+use lemmy_apub::{api::PerformApub, SendActivity};
+use lemmy_utils::rate_limit::RateLimitCell;
+use serde::Deserialize;
+
+pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimitCell) {
+  cfg.service(
+    web::scope("/api/v3")
+      // Websocket
+      .service(web::resource("/ws").to(websocket))
+      // Site
+      .service(
+        web::scope("/site")
+          .wrap(rate_limit.message())
+          .route("", web::get().to(route_get_crud::<GetSite>))
+          // Admin Actions
+          .route("", web::post().to(route_post_crud::<CreateSite>))
+          .route("", web::put().to(route_post_crud::<EditSite>)),
+      )
+      .service(
+        web::resource("/modlog")
+          .wrap(rate_limit.message())
+          .route(web::get().to(route_get::<GetModlog>)),
+      )
+      .service(
+        web::resource("/search")
+          .wrap(rate_limit.search())
+          .route(web::get().to(route_get_apub::<Search>)),
+      )
+      .service(
+        web::resource("/resolve_object")
+          .wrap(rate_limit.message())
+          .route(web::get().to(route_get_apub::<ResolveObject>)),
+      )
+      // Community
+      .service(
+        web::resource("/community")
+          .guard(guard::Post())
+          .wrap(rate_limit.register())
+          .route(web::post().to(route_post_crud::<CreateCommunity>)),
+      )
+      .service(
+        web::scope("/community")
+          .wrap(rate_limit.message())
+          .route("", web::get().to(route_get_apub::<GetCommunity>))
+          .route("", web::put().to(route_post_crud::<EditCommunity>))
+          .route("/hide", web::put().to(route_post::<HideCommunity>))
+          .route("/list", web::get().to(route_get_crud::<ListCommunities>))
+          .route("/follow", web::post().to(route_post::<FollowCommunity>))
+          .route("/block", web::post().to(route_post::<BlockCommunity>))
+          .route(
+            "/delete",
+            web::post().to(route_post_crud::<DeleteCommunity>),
+          )
+          // Mod Actions
+          .route(
+            "/remove",
+            web::post().to(route_post_crud::<RemoveCommunity>),
+          )
+          .route("/transfer", web::post().to(route_post::<TransferCommunity>))
+          .route("/ban_user", web::post().to(route_post::<BanFromCommunity>))
+          .route("/mod", web::post().to(route_post::<AddModToCommunity>))
+          .route("/join", web::post().to(route_post::<CommunityJoin>))
+          .route("/mod/join", web::post().to(route_post::<ModJoin>)),
+      )
+      // Post
+      .service(
+        // Handle POST to /post separately to add the post() rate limitter
+        web::resource("/post")
+          .guard(guard::Post())
+          .wrap(rate_limit.post())
+          .route(web::post().to(route_post_crud::<CreatePost>)),
+      )
+      .service(
+        web::scope("/post")
+          .wrap(rate_limit.message())
+          .route("", web::get().to(route_get_crud::<GetPost>))
+          .route("", web::put().to(route_post_crud::<EditPost>))
+          .route("/delete", web::post().to(route_post_crud::<DeletePost>))
+          .route("/remove", web::post().to(route_post_crud::<RemovePost>))
+          .route(
+            "/mark_as_read",
+            web::post().to(route_post::<MarkPostAsRead>),
+          )
+          .route("/lock", web::post().to(route_post::<LockPost>))
+          .route("/sticky", web::post().to(route_post::<StickyPost>))
+          .route("/list", web::get().to(route_get_apub::<GetPosts>))
+          .route("/like", web::post().to(route_post::<CreatePostLike>))
+          .route("/save", web::put().to(route_post::<SavePost>))
+          .route("/join", web::post().to(route_post::<PostJoin>))
+          .route("/report", web::post().to(route_post::<CreatePostReport>))
+          .route(
+            "/report/resolve",
+            web::put().to(route_post::<ResolvePostReport>),
+          )
+          .route("/report/list", web::get().to(route_get::<ListPostReports>))
+          .route(
+            "/site_metadata",
+            web::get().to(route_get::<GetSiteMetadata>),
+          ),
+      )
+      // Comment
+      .service(
+        // Handle POST to /comment separately to add the comment() rate limitter
+        web::resource("/comment")
+          .guard(guard::Post())
+          .wrap(rate_limit.comment())
+          .route(web::post().to(route_post_crud::<CreateComment>)),
+      )
+      .service(
+        web::scope("/comment")
+          .wrap(rate_limit.message())
+          .route("", web::get().to(route_get_crud::<GetComment>))
+          .route("", web::put().to(route_post_crud::<EditComment>))
+          .route("/delete", web::post().to(route_post_crud::<DeleteComment>))
+          .route("/remove", web::post().to(route_post_crud::<RemoveComment>))
+          .route(
+            "/mark_as_read",
+            web::post().to(route_post::<MarkCommentReplyAsRead>),
+          )
+          .route("/like", web::post().to(route_post::<CreateCommentLike>))
+          .route("/save", web::put().to(route_post::<SaveComment>))
+          .route("/list", web::get().to(route_get_apub::<GetComments>))
+          .route("/report", web::post().to(route_post::<CreateCommentReport>))
+          .route(
+            "/report/resolve",
+            web::put().to(route_post::<ResolveCommentReport>),
+          )
+          .route(
+            "/report/list",
+            web::get().to(route_get::<ListCommentReports>),
+          ),
+      )
+      // Private Message
+      .service(
+        web::scope("/private_message")
+          .wrap(rate_limit.message())
+          .route("/list", web::get().to(route_get_crud::<GetPrivateMessages>))
+          .route("", web::post().to(route_post_crud::<CreatePrivateMessage>))
+          .route("", web::put().to(route_post_crud::<EditPrivateMessage>))
+          .route(
+            "/delete",
+            web::post().to(route_post_crud::<DeletePrivateMessage>),
+          )
+          .route(
+            "/mark_as_read",
+            web::post().to(route_post::<MarkPrivateMessageAsRead>),
+          )
+          .route(
+            "/report",
+            web::post().to(route_post::<CreatePrivateMessageReport>),
+          )
+          .route(
+            "/report/resolve",
+            web::put().to(route_post::<ResolvePrivateMessageReport>),
+          )
+          .route(
+            "/report/list",
+            web::get().to(route_get::<ListPrivateMessageReports>),
+          ),
+      )
+      // User
+      .service(
+        // Account action, I don't like that it's in /user maybe /accounts
+        // Handle /user/register separately to add the register() rate limitter
+        web::resource("/user/register")
+          .guard(guard::Post())
+          .wrap(rate_limit.register())
+          .route(web::post().to(route_post_crud::<Register>)),
+      )
+      .service(
+        // Handle captcha separately
+        web::resource("/user/get_captcha")
+          .wrap(rate_limit.post())
+          .route(web::get().to(route_get::<GetCaptcha>)),
+      )
+      // User actions
+      .service(
+        web::scope("/user")
+          .wrap(rate_limit.message())
+          .route("", web::get().to(route_get_apub::<GetPersonDetails>))
+          .route("/mention", web::get().to(route_get::<GetPersonMentions>))
+          .route(
+            "/mention/mark_as_read",
+            web::post().to(route_post::<MarkPersonMentionAsRead>),
+          )
+          .route("/replies", web::get().to(route_get::<GetReplies>))
+          .route("/join", web::post().to(route_post::<UserJoin>))
+          // Admin action. I don't like that it's in /user
+          .route("/ban", web::post().to(route_post::<BanPerson>))
+          .route("/banned", web::get().to(route_get::<GetBannedPersons>))
+          .route("/block", web::post().to(route_post::<BlockPerson>))
+          // Account actions. I don't like that they're in /user maybe /accounts
+          .route("/login", web::post().to(route_post::<Login>))
+          .route(
+            "/delete_account",
+            web::post().to(route_post_crud::<DeleteAccount>),
+          )
+          .route(
+            "/password_reset",
+            web::post().to(route_post::<PasswordReset>),
+          )
+          .route(
+            "/password_change",
+            web::post().to(route_post::<PasswordChangeAfterReset>),
+          )
+          // mark_all_as_read feels off being in this section as well
+          .route(
+            "/mark_all_as_read",
+            web::post().to(route_post::<MarkAllAsRead>),
+          )
+          .route(
+            "/save_user_settings",
+            web::put().to(route_post::<SaveUserSettings>),
+          )
+          .route(
+            "/change_password",
+            web::put().to(route_post::<ChangePassword>),
+          )
+          .route("/report_count", web::get().to(route_get::<GetReportCount>))
+          .route("/unread_count", web::get().to(route_get::<GetUnreadCount>))
+          .route("/verify_email", web::post().to(route_post::<VerifyEmail>))
+          .route("/leave_admin", web::post().to(route_post::<LeaveAdmin>)),
+      )
+      // Admin Actions
+      .service(
+        web::scope("/admin")
+          .wrap(rate_limit.message())
+          .route("/add", web::post().to(route_post::<AddAdmin>))
+          .route(
+            "/registration_application/count",
+            web::get().to(route_get::<GetUnreadRegistrationApplicationCount>),
+          )
+          .route(
+            "/registration_application/list",
+            web::get().to(route_get::<ListRegistrationApplications>),
+          )
+          .route(
+            "/registration_application/approve",
+            web::put().to(route_post::<ApproveRegistrationApplication>),
+          ),
+      )
+      .service(
+        web::scope("/admin/purge")
+          .wrap(rate_limit.message())
+          .route("/person", web::post().to(route_post::<PurgePerson>))
+          .route("/community", web::post().to(route_post::<PurgeCommunity>))
+          .route("/post", web::post().to(route_post::<PurgePost>))
+          .route("/comment", web::post().to(route_post::<PurgeComment>)),
+      ),
+  );
+}
+
+async fn perform<'a, Data>(
+  data: Data,
+  context: web::Data<LemmyContext>,
+) -> Result<HttpResponse, Error>
+where
+  Data: Perform
+    + SendActivity<Response = <Data as Perform>::Response>
+    + Clone
+    + Deserialize<'a>
+    + Send
+    + 'static,
+{
+  let res = data.perform(&context, None).await?;
+  SendActivity::send_activity(&data, &res, &context).await?;
+  Ok(HttpResponse::Ok().json(res))
+}
+
+async fn route_get<'a, Data>(
+  data: web::Query<Data>,
+  context: web::Data<LemmyContext>,
+) -> Result<HttpResponse, Error>
+where
+  Data: Perform
+    + SendActivity<Response = <Data as Perform>::Response>
+    + Clone
+    + Deserialize<'a>
+    + Send
+    + 'static,
+{
+  perform::<Data>(data.0, context).await
+}
+
+async fn route_get_apub<'a, Data>(
+  data: web::Query<Data>,
+  context: web::Data<LemmyContext>,
+) -> Result<HttpResponse, Error>
+where
+  Data: PerformApub
+    + SendActivity<Response = <Data as PerformApub>::Response>
+    + Clone
+    + Deserialize<'a>
+    + Send
+    + 'static,
+{
+  let res = data.perform(&context, None).await?;
+  SendActivity::send_activity(&data.0, &res, &context).await?;
+  Ok(HttpResponse::Ok().json(res))
+}
+
+async fn route_post<'a, Data>(
+  data: web::Json<Data>,
+  context: web::Data<LemmyContext>,
+) -> Result<HttpResponse, Error>
+where
+  Data: Perform
+    + SendActivity<Response = <Data as Perform>::Response>
+    + Clone
+    + Deserialize<'a>
+    + Send
+    + 'static,
+{
+  perform::<Data>(data.0, context).await
+}
+
+async fn perform_crud<'a, Data>(
+  data: Data,
+  context: web::Data<LemmyContext>,
+) -> Result<HttpResponse, Error>
+where
+  Data: PerformCrud
+    + SendActivity<Response = <Data as PerformCrud>::Response>
+    + Clone
+    + Deserialize<'a>
+    + Send
+    + 'static,
+{
+  let res = data.perform(&context, None).await?;
+  SendActivity::send_activity(&data, &res, &context).await?;
+  Ok(HttpResponse::Ok().json(res))
+}
+
+async fn route_get_crud<'a, Data>(
+  data: web::Query<Data>,
+  context: web::Data<LemmyContext>,
+) -> Result<HttpResponse, Error>
+where
+  Data: PerformCrud
+    + SendActivity<Response = <Data as PerformCrud>::Response>
+    + Clone
+    + Deserialize<'a>
+    + Send
+    + 'static,
+{
+  perform_crud::<Data>(data.0, context).await
+}
+
+async fn route_post_crud<'a, Data>(
+  data: web::Json<Data>,
+  context: web::Data<LemmyContext>,
+) -> Result<HttpResponse, Error>
+where
+  Data: PerformCrud
+    + SendActivity<Response = <Data as PerformCrud>::Response>
+    + Clone
+    + Deserialize<'a>
+    + Send
+    + 'static,
+{
+  perform_crud::<Data>(data.0, context).await
+}
similarity index 54%
rename from src/api_routes.rs
rename to src/api_routes_websocket.rs
index 77e5c27d73c6b96ca83ce119a6b3c4e020657834..7a8656005736b3e7672836d17749854813edd7fd 100644 (file)
@@ -1,4 +1,7 @@
-use actix_web::{guard, web, Error, HttpResponse, Result};
+use actix_web::{web, Error, HttpRequest, HttpResponse};
+use actix_web_actors::ws;
+use actix_ws::{MessageStream, Session};
+use futures::stream::StreamExt;
 use lemmy_api::Perform;
 use lemmy_api_common::{
   comment::{
@@ -23,7 +26,6 @@ use lemmy_api_common::{
     EditCommunity,
     FollowCommunity,
     GetCommunity,
-    HideCommunity,
     ListCommunities,
     RemoveCommunity,
     TransferCommunity,
@@ -96,7 +98,6 @@ use lemmy_api_common::{
     Search,
   },
   websocket::{
-    routes::chat_route,
     serialize_websocket_message,
     structs::{CommunityJoin, ModJoin, PostJoin, UserJoin},
     UserOperation,
@@ -106,367 +107,187 @@ use lemmy_api_common::{
 };
 use lemmy_api_crud::PerformCrud;
 use lemmy_apub::{api::PerformApub, SendActivity};
-use lemmy_utils::{error::LemmyError, rate_limit::RateLimitCell, ConnectionId};
+use lemmy_utils::{error::LemmyError, rate_limit::RateLimitCell, ConnectionId, IpAddr};
 use serde::Deserialize;
-use std::result;
+use serde_json::Value;
+use std::{
+  result,
+  str::FromStr,
+  sync::{Arc, Mutex},
+  time::{Duration, Instant},
+};
+use tracing::{debug, error, info};
+
+/// Entry point for our route
+pub async fn websocket(
+  req: HttpRequest,
+  body: web::Payload,
+  context: web::Data<LemmyContext>,
+  rate_limiter: web::Data<RateLimitCell>,
+) -> Result<HttpResponse, Error> {
+  let (response, session, stream) = actix_ws::handle(&req, body)?;
 
-pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimitCell) {
-  cfg.service(
-    web::scope("/api/v3")
-      // Websocket
-      .service(web::resource("/ws").to(chat_route))
-      // Site
-      .service(
-        web::scope("/site")
-          .wrap(rate_limit.message())
-          .route("", web::get().to(route_get_crud::<GetSite>))
-          // Admin Actions
-          .route("", web::post().to(route_post_crud::<CreateSite>))
-          .route("", web::put().to(route_post_crud::<EditSite>)),
-      )
-      .service(
-        web::resource("/modlog")
-          .wrap(rate_limit.message())
-          .route(web::get().to(route_get::<GetModlog>)),
-      )
-      .service(
-        web::resource("/search")
-          .wrap(rate_limit.search())
-          .route(web::get().to(route_get_apub::<Search>)),
-      )
-      .service(
-        web::resource("/resolve_object")
-          .wrap(rate_limit.message())
-          .route(web::get().to(route_get_apub::<ResolveObject>)),
-      )
-      // Community
-      .service(
-        web::resource("/community")
-          .guard(guard::Post())
-          .wrap(rate_limit.register())
-          .route(web::post().to(route_post_crud::<CreateCommunity>)),
-      )
-      .service(
-        web::scope("/community")
-          .wrap(rate_limit.message())
-          .route("", web::get().to(route_get_apub::<GetCommunity>))
-          .route("", web::put().to(route_post_crud::<EditCommunity>))
-          .route("/hide", web::put().to(route_post::<HideCommunity>))
-          .route("/list", web::get().to(route_get_crud::<ListCommunities>))
-          .route("/follow", web::post().to(route_post::<FollowCommunity>))
-          .route("/block", web::post().to(route_post::<BlockCommunity>))
-          .route(
-            "/delete",
-            web::post().to(route_post_crud::<DeleteCommunity>),
-          )
-          // Mod Actions
-          .route(
-            "/remove",
-            web::post().to(route_post_crud::<RemoveCommunity>),
-          )
-          .route("/transfer", web::post().to(route_post::<TransferCommunity>))
-          .route("/ban_user", web::post().to(route_post::<BanFromCommunity>))
-          .route("/mod", web::post().to(route_post::<AddModToCommunity>))
-          .route("/join", web::post().to(route_post::<CommunityJoin>))
-          .route("/mod/join", web::post().to(route_post::<ModJoin>)),
-      )
-      // Post
-      .service(
-        // Handle POST to /post separately to add the post() rate limitter
-        web::resource("/post")
-          .guard(guard::Post())
-          .wrap(rate_limit.post())
-          .route(web::post().to(route_post_crud::<CreatePost>)),
-      )
-      .service(
-        web::scope("/post")
-          .wrap(rate_limit.message())
-          .route("", web::get().to(route_get_crud::<GetPost>))
-          .route("", web::put().to(route_post_crud::<EditPost>))
-          .route("/delete", web::post().to(route_post_crud::<DeletePost>))
-          .route("/remove", web::post().to(route_post_crud::<RemovePost>))
-          .route(
-            "/mark_as_read",
-            web::post().to(route_post::<MarkPostAsRead>),
-          )
-          .route("/lock", web::post().to(route_post::<LockPost>))
-          .route("/sticky", web::post().to(route_post::<StickyPost>))
-          .route("/list", web::get().to(route_get_apub::<GetPosts>))
-          .route("/like", web::post().to(route_post::<CreatePostLike>))
-          .route("/save", web::put().to(route_post::<SavePost>))
-          .route("/join", web::post().to(route_post::<PostJoin>))
-          .route("/report", web::post().to(route_post::<CreatePostReport>))
-          .route(
-            "/report/resolve",
-            web::put().to(route_post::<ResolvePostReport>),
-          )
-          .route("/report/list", web::get().to(route_get::<ListPostReports>))
-          .route(
-            "/site_metadata",
-            web::get().to(route_get::<GetSiteMetadata>),
-          ),
-      )
-      // Comment
-      .service(
-        // Handle POST to /comment separately to add the comment() rate limitter
-        web::resource("/comment")
-          .guard(guard::Post())
-          .wrap(rate_limit.comment())
-          .route(web::post().to(route_post_crud::<CreateComment>)),
-      )
-      .service(
-        web::scope("/comment")
-          .wrap(rate_limit.message())
-          .route("", web::get().to(route_get_crud::<GetComment>))
-          .route("", web::put().to(route_post_crud::<EditComment>))
-          .route("/delete", web::post().to(route_post_crud::<DeleteComment>))
-          .route("/remove", web::post().to(route_post_crud::<RemoveComment>))
-          .route(
-            "/mark_as_read",
-            web::post().to(route_post::<MarkCommentReplyAsRead>),
-          )
-          .route("/like", web::post().to(route_post::<CreateCommentLike>))
-          .route("/save", web::put().to(route_post::<SaveComment>))
-          .route("/list", web::get().to(route_get_apub::<GetComments>))
-          .route("/report", web::post().to(route_post::<CreateCommentReport>))
-          .route(
-            "/report/resolve",
-            web::put().to(route_post::<ResolveCommentReport>),
-          )
-          .route(
-            "/report/list",
-            web::get().to(route_get::<ListCommentReports>),
-          ),
-      )
-      // Private Message
-      .service(
-        web::scope("/private_message")
-          .wrap(rate_limit.message())
-          .route("/list", web::get().to(route_get_crud::<GetPrivateMessages>))
-          .route("", web::post().to(route_post_crud::<CreatePrivateMessage>))
-          .route("", web::put().to(route_post_crud::<EditPrivateMessage>))
-          .route(
-            "/delete",
-            web::post().to(route_post_crud::<DeletePrivateMessage>),
-          )
-          .route(
-            "/mark_as_read",
-            web::post().to(route_post::<MarkPrivateMessageAsRead>),
-          )
-          .route(
-            "/report",
-            web::post().to(route_post::<CreatePrivateMessageReport>),
-          )
-          .route(
-            "/report/resolve",
-            web::put().to(route_post::<ResolvePrivateMessageReport>),
-          )
-          .route(
-            "/report/list",
-            web::get().to(route_get::<ListPrivateMessageReports>),
-          ),
-      )
-      // User
-      .service(
-        // Account action, I don't like that it's in /user maybe /accounts
-        // Handle /user/register separately to add the register() rate limitter
-        web::resource("/user/register")
-          .guard(guard::Post())
-          .wrap(rate_limit.register())
-          .route(web::post().to(route_post_crud::<Register>)),
-      )
-      .service(
-        // Handle captcha separately
-        web::resource("/user/get_captcha")
-          .wrap(rate_limit.post())
-          .route(web::get().to(route_get::<GetCaptcha>)),
-      )
-      // User actions
-      .service(
-        web::scope("/user")
-          .wrap(rate_limit.message())
-          .route("", web::get().to(route_get_apub::<GetPersonDetails>))
-          .route("/mention", web::get().to(route_get::<GetPersonMentions>))
-          .route(
-            "/mention/mark_as_read",
-            web::post().to(route_post::<MarkPersonMentionAsRead>),
-          )
-          .route("/replies", web::get().to(route_get::<GetReplies>))
-          .route("/join", web::post().to(route_post::<UserJoin>))
-          // Admin action. I don't like that it's in /user
-          .route("/ban", web::post().to(route_post::<BanPerson>))
-          .route("/banned", web::get().to(route_get::<GetBannedPersons>))
-          .route("/block", web::post().to(route_post::<BlockPerson>))
-          // Account actions. I don't like that they're in /user maybe /accounts
-          .route("/login", web::post().to(route_post::<Login>))
-          .route(
-            "/delete_account",
-            web::post().to(route_post_crud::<DeleteAccount>),
-          )
-          .route(
-            "/password_reset",
-            web::post().to(route_post::<PasswordReset>),
-          )
-          .route(
-            "/password_change",
-            web::post().to(route_post::<PasswordChangeAfterReset>),
-          )
-          // mark_all_as_read feels off being in this section as well
-          .route(
-            "/mark_all_as_read",
-            web::post().to(route_post::<MarkAllAsRead>),
-          )
-          .route(
-            "/save_user_settings",
-            web::put().to(route_post::<SaveUserSettings>),
-          )
-          .route(
-            "/change_password",
-            web::put().to(route_post::<ChangePassword>),
-          )
-          .route("/report_count", web::get().to(route_get::<GetReportCount>))
-          .route("/unread_count", web::get().to(route_get::<GetUnreadCount>))
-          .route("/verify_email", web::post().to(route_post::<VerifyEmail>))
-          .route("/leave_admin", web::post().to(route_post::<LeaveAdmin>)),
-      )
-      // Admin Actions
-      .service(
-        web::scope("/admin")
-          .wrap(rate_limit.message())
-          .route("/add", web::post().to(route_post::<AddAdmin>))
-          .route(
-            "/registration_application/count",
-            web::get().to(route_get::<GetUnreadRegistrationApplicationCount>),
-          )
-          .route(
-            "/registration_application/list",
-            web::get().to(route_get::<ListRegistrationApplications>),
-          )
-          .route(
-            "/registration_application/approve",
-            web::put().to(route_post::<ApproveRegistrationApplication>),
-          ),
-      )
-      .service(
-        web::scope("/admin/purge")
-          .wrap(rate_limit.message())
-          .route("/person", web::post().to(route_post::<PurgePerson>))
-          .route("/community", web::post().to(route_post::<PurgeCommunity>))
-          .route("/post", web::post().to(route_post::<PurgePost>))
-          .route("/comment", web::post().to(route_post::<PurgeComment>)),
-      ),
+  let client_ip = IpAddr(
+    req
+      .connection_info()
+      .realip_remote_addr()
+      .unwrap_or("blank_ip")
+      .to_string(),
   );
-}
 
-async fn perform<'a, Data>(
-  data: Data,
-  context: web::Data<LemmyContext>,
-) -> Result<HttpResponse, Error>
-where
-  Data: Perform
-    + SendActivity<Response = <Data as Perform>::Response>
-    + Clone
-    + Deserialize<'a>
-    + Send
-    + 'static,
-{
-  let res = data.perform(&context, None).await?;
-  SendActivity::send_activity(&data, &res, &context).await?;
-  Ok(HttpResponse::Ok().json(res))
-}
+  let check = rate_limiter.message().check(client_ip.clone());
+  if !check {
+    debug!(
+      "Websocket join with IP: {} has been rate limited.",
+      &client_ip
+    );
+    session.close(None).await.map_err(LemmyError::from)?;
+    return Ok(response);
+  }
 
-async fn route_get<'a, Data>(
-  data: web::Query<Data>,
-  context: web::Data<LemmyContext>,
-) -> Result<HttpResponse, Error>
-where
-  Data: Perform
-    + SendActivity<Response = <Data as Perform>::Response>
-    + Clone
-    + Deserialize<'a>
-    + Send
-    + 'static,
-{
-  perform::<Data>(data.0, context).await
-}
+  let connection_id = context.chat_server().handle_connect(session.clone())?;
+  info!("{} joined", &client_ip);
 
-async fn route_get_apub<'a, Data>(
-  data: web::Query<Data>,
-  context: web::Data<LemmyContext>,
-) -> Result<HttpResponse, Error>
-where
-  Data: PerformApub
-    + SendActivity<Response = <Data as PerformApub>::Response>
-    + Clone
-    + Deserialize<'a>
-    + Send
-    + 'static,
-{
-  let res = data.perform(&context, None).await?;
-  SendActivity::send_activity(&data.0, &res, &context).await?;
-  Ok(HttpResponse::Ok().json(res))
+  let alive = Arc::new(Mutex::new(Instant::now()));
+  heartbeat(session.clone(), alive.clone());
+
+  actix_rt::spawn(handle_messages(
+    stream,
+    client_ip,
+    session,
+    connection_id,
+    alive,
+    rate_limiter,
+    context,
+  ));
+
+  Ok(response)
 }
 
-async fn route_post<'a, Data>(
-  data: web::Json<Data>,
+async fn handle_messages(
+  mut stream: MessageStream,
+  client_ip: IpAddr,
+  mut session: Session,
+  connection_id: ConnectionId,
+  alive: Arc<Mutex<Instant>>,
+  rate_limiter: web::Data<RateLimitCell>,
   context: web::Data<LemmyContext>,
-) -> Result<HttpResponse, Error>
-where
-  Data: Perform
-    + SendActivity<Response = <Data as Perform>::Response>
-    + Clone
-    + Deserialize<'a>
-    + Send
-    + 'static,
-{
-  perform::<Data>(data.0, context).await
+) -> Result<(), LemmyError> {
+  while let Some(Ok(msg)) = stream.next().await {
+    match msg {
+      ws::Message::Ping(bytes) => {
+        if session.pong(&bytes).await.is_err() {
+          break;
+        }
+      }
+      ws::Message::Pong(_) => {
+        let mut lock = alive
+          .lock()
+          .expect("Failed to acquire websocket heartbeat alive lock");
+        *lock = Instant::now();
+      }
+      ws::Message::Text(text) => {
+        let msg = text.trim().to_string();
+        let executed = parse_json_message(
+          msg,
+          client_ip.clone(),
+          connection_id,
+          rate_limiter.get_ref(),
+          context.get_ref().clone(),
+        )
+        .await;
+
+        let res = executed.unwrap_or_else(|e| {
+          error!("Error during message handling {}", e);
+          e.to_json()
+            .unwrap_or_else(|_| String::from(r#"{"error":"failed to serialize json"}"#))
+        });
+        session.text(res).await?;
+      }
+      ws::Message::Close(_) => {
+        session.close(None).await?;
+        context.chat_server().handle_disconnect(&connection_id)?;
+        break;
+      }
+      ws::Message::Binary(_) => info!("Unexpected binary"),
+      _ => {}
+    }
+  }
+  Ok(())
 }
 
-async fn perform_crud<'a, Data>(
-  data: Data,
-  context: web::Data<LemmyContext>,
-) -> Result<HttpResponse, Error>
-where
-  Data: PerformCrud
-    + SendActivity<Response = <Data as PerformCrud>::Response>
-    + Clone
-    + Deserialize<'a>
-    + Send
-    + 'static,
-{
-  let res = data.perform(&context, None).await?;
-  SendActivity::send_activity(&data, &res, &context).await?;
-  Ok(HttpResponse::Ok().json(res))
+fn heartbeat(mut session: Session, alive: Arc<Mutex<Instant>>) {
+  actix_rt::spawn(async move {
+    let mut interval = actix_rt::time::interval(Duration::from_secs(5));
+    loop {
+      if session.ping(b"").await.is_err() {
+        break;
+      }
+
+      let duration_since = {
+        let alive_lock = alive
+          .lock()
+          .expect("Failed to acquire websocket heartbeat alive lock");
+        Instant::now().duration_since(*alive_lock)
+      };
+      if duration_since > Duration::from_secs(10) {
+        let _ = session.close(None).await;
+        break;
+      }
+      interval.tick().await;
+    }
+  });
 }
 
-async fn route_get_crud<'a, Data>(
-  data: web::Query<Data>,
-  context: web::Data<LemmyContext>,
-) -> Result<HttpResponse, Error>
-where
-  Data: PerformCrud
-    + SendActivity<Response = <Data as PerformCrud>::Response>
-    + Clone
-    + Deserialize<'a>
-    + Send
-    + 'static,
-{
-  perform_crud::<Data>(data.0, context).await
+async fn parse_json_message(
+  msg: String,
+  ip: IpAddr,
+  connection_id: ConnectionId,
+  rate_limiter: &RateLimitCell,
+  context: LemmyContext,
+) -> Result<String, LemmyError> {
+  let json: Value = serde_json::from_str(&msg)?;
+  let data = &json["data"].to_string();
+  let op = &json["op"]
+    .as_str()
+    .ok_or_else(|| LemmyError::from_message("missing op"))?;
+
+  // check if api call passes the rate limit, and generate future for later execution
+  if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
+    let passed = match user_operation_crud {
+      UserOperationCrud::Register => rate_limiter.register().check(ip),
+      UserOperationCrud::CreatePost => rate_limiter.post().check(ip),
+      UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip),
+      UserOperationCrud::CreateComment => rate_limiter.comment().check(ip),
+      _ => rate_limiter.message().check(ip),
+    };
+    check_rate_limit_passed(passed)?;
+    match_websocket_operation_crud(context, connection_id, user_operation_crud, data).await
+  } else if let Ok(user_operation) = UserOperation::from_str(op) {
+    let passed = match user_operation {
+      UserOperation::GetCaptcha => rate_limiter.post().check(ip),
+      _ => rate_limiter.message().check(ip),
+    };
+    check_rate_limit_passed(passed)?;
+    match_websocket_operation(context, connection_id, user_operation, data).await
+  } else {
+    let user_operation = UserOperationApub::from_str(op)?;
+    let passed = match user_operation {
+      UserOperationApub::Search => rate_limiter.search().check(ip),
+      _ => rate_limiter.message().check(ip),
+    };
+    check_rate_limit_passed(passed)?;
+    match_websocket_operation_apub(context, connection_id, user_operation, data).await
+  }
 }
 
-async fn route_post_crud<'a, Data>(
-  data: web::Json<Data>,
-  context: web::Data<LemmyContext>,
-) -> Result<HttpResponse, Error>
-where
-  Data: PerformCrud
-    + SendActivity<Response = <Data as PerformCrud>::Response>
-    + Clone
-    + Deserialize<'a>
-    + Send
-    + 'static,
-{
-  perform_crud::<Data>(data.0, context).await
+fn check_rate_limit_passed(passed: bool) -> Result<(), LemmyError> {
+  if passed {
+    Ok(())
+  } else {
+    // if rate limit was hit, respond with message
+    Err(LemmyError::from_message("rate_limit_error"))
+  }
 }
 
 pub async fn match_websocket_operation_crud(
index ffac400f9a0dc93899a072a9ef8b77e2dfce28b4..a9e390a21656718ae542755ce7dda55b8d421d2a 100644 (file)
@@ -1,4 +1,5 @@
-pub mod api_routes;
+pub mod api_routes_http;
+pub mod api_routes_websocket;
 pub mod code_migrations;
 pub mod root_span_builder;
 pub mod scheduled_tasks;
index c60c1823a3dbf7066490a52a67318a6fffca17d5..70f6d0348a2ed9bbe5d4ba2ad394b4271d517026 100644 (file)
@@ -1,7 +1,6 @@
 #[macro_use]
 extern crate diesel_migrations;
 
-use actix::prelude::*;
 use actix_web::{middleware, web::Data, App, HttpServer, Result};
 use diesel_migrations::EmbeddedMigrations;
 use doku::json::{AutoComments, CommentsStyle, Formatting, ObjectsStyle};
@@ -21,12 +20,7 @@ use lemmy_db_schema::{
 };
 use lemmy_routes::{feeds, images, nodeinfo, webfinger};
 use lemmy_server::{
-  api_routes,
-  api_routes::{
-    match_websocket_operation,
-    match_websocket_operation_apub,
-    match_websocket_operation_crud,
-  },
+  api_routes_http,
   code_migrations::run_advanced_migrations,
   init_logging,
   root_span_builder::QuieterRootSpanBuilder,
@@ -41,7 +35,7 @@ use reqwest::Client;
 use reqwest_middleware::ClientBuilder;
 use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
 use reqwest_tracing::TracingMiddleware;
-use std::{env, thread, time::Duration};
+use std::{env, sync::Arc, thread, time::Duration};
 use tracing_actix_web::TracingLogger;
 
 pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
@@ -135,17 +129,7 @@ async fn main() -> Result<(), LemmyError> {
     .with(TracingMiddleware::default())
     .build();
 
-  let chat_server = ChatServer::startup(
-    pool.clone(),
-    |c, i, o, d| Box::pin(match_websocket_operation(c, i, o, d)),
-    |c, i, o, d| Box::pin(match_websocket_operation_crud(c, i, o, d)),
-    |c, i, o, d| Box::pin(match_websocket_operation_apub(c, i, o, d)),
-    client.clone(),
-    settings.clone(),
-    secret.clone(),
-    rate_limit_cell.clone(),
-  )
-  .start();
+  let chat_server = Arc::new(ChatServer::startup());
 
   // Create Http server with websocket support
   let settings_bind = settings.clone();
@@ -164,7 +148,7 @@ async fn main() -> Result<(), LemmyError> {
       .app_data(Data::new(context))
       .app_data(Data::new(rate_limit_cell.clone()))
       // The routes
-      .configure(|cfg| api_routes::config(cfg, rate_limit_cell))
+      .configure(|cfg| api_routes_http::config(cfg, rate_limit_cell))
       .configure(|cfg| {
         if federation_enabled {
           lemmy_apub::http::routes::config(cfg);