]> Untitled Git - lemmy.git/blobdiff - crates/websocket/src/chat_server.rs
Fix API and clippy warnings
[lemmy.git] / crates / websocket / src / chat_server.rs
index f1c936d6d4e208721515e085a7b5f592989af8f1..e08aa94a5d4a523b7e369cec037b4b4946c9e6de 100644 (file)
@@ -1,4 +1,11 @@
-use crate::{messages::*, serialize_websocket_message, LemmyContext, UserOperation};
+use crate::{
+  messages::*,
+  serialize_websocket_message,
+  LemmyContext,
+  OperationType,
+  UserOperation,
+  UserOperationCrud,
+};
 use actix::prelude::*;
 use anyhow::Context as acontext;
 use background_jobs::QueueHandle;
@@ -33,6 +40,13 @@ type MessageHandlerType = fn(
   data: &str,
 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
 
+type MessageHandlerCrudType = fn(
+  context: LemmyContext,
+  id: ConnectionId,
+  op: UserOperationCrud,
+  data: &str,
+) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
+
 /// `ChatServer` manages chat rooms and responsible for coordinating chat
 /// session.
 pub struct ChatServer {
@@ -63,6 +77,7 @@ pub struct ChatServer {
   pub(super) captchas: Vec<CaptchaItem>,
 
   message_handler: MessageHandlerType,
+  message_handler_crud: MessageHandlerCrudType,
 
   /// An HTTP Client
   client: Client,
@@ -83,6 +98,7 @@ impl ChatServer {
     pool: Pool<ConnectionManager<PgConnection>>,
     rate_limiter: RateLimit,
     message_handler: MessageHandlerType,
+    message_handler_crud: MessageHandlerCrudType,
     client: Client,
     activity_queue: QueueHandle,
   ) -> ChatServer {
@@ -97,6 +113,7 @@ impl ChatServer {
       rate_limiter,
       captchas: Vec::new(),
       message_handler,
+      message_handler_crud,
       client,
       activity_queue,
     }
@@ -207,14 +224,15 @@ impl ChatServer {
     Ok(())
   }
 
-  fn send_post_room_message<Response>(
+  fn send_post_room_message<OP, Response>(
     &self,
-    op: &UserOperation,
+    op: &OP,
     response: &Response,
     post_id: PostId,
     websocket_id: Option<ConnectionId>,
   ) -> Result<(), LemmyError>
   where
+    OP: OperationType + ToString,
     Response: Serialize,
   {
     let res_str = &serialize_websocket_message(op, response)?;
@@ -231,14 +249,15 @@ impl ChatServer {
     Ok(())
   }
 
-  pub fn send_community_room_message<Response>(
+  pub fn send_community_room_message<OP, Response>(
     &self,
-    op: &UserOperation,
+    op: &OP,
     response: &Response,
     community_id: CommunityId,
     websocket_id: Option<ConnectionId>,
   ) -> Result<(), LemmyError>
   where
+    OP: OperationType + ToString,
     Response: Serialize,
   {
     let res_str = &serialize_websocket_message(op, response)?;
@@ -255,14 +274,15 @@ impl ChatServer {
     Ok(())
   }
 
-  pub fn send_mod_room_message<Response>(
+  pub fn send_mod_room_message<OP, Response>(
     &self,
-    op: &UserOperation,
+    op: &OP,
     response: &Response,
     community_id: CommunityId,
     websocket_id: Option<ConnectionId>,
   ) -> Result<(), LemmyError>
   where
+    OP: OperationType + ToString,
     Response: Serialize,
   {
     let res_str = &serialize_websocket_message(op, response)?;
@@ -279,13 +299,14 @@ impl ChatServer {
     Ok(())
   }
 
-  pub fn send_all_message<Response>(
+  pub fn send_all_message<OP, Response>(
     &self,
-    op: &UserOperation,
+    op: &OP,
     response: &Response,
     websocket_id: Option<ConnectionId>,
   ) -> Result<(), LemmyError>
   where
+    OP: OperationType + ToString,
     Response: Serialize,
   {
     let res_str = &serialize_websocket_message(op, response)?;
@@ -300,14 +321,15 @@ impl ChatServer {
     Ok(())
   }
 
-  pub fn send_user_room_message<Response>(
+  pub fn send_user_room_message<OP, Response>(
     &self,
-    op: &UserOperation,
+    op: &OP,
     response: &Response,
     recipient_id: LocalUserId,
     websocket_id: Option<ConnectionId>,
   ) -> Result<(), LemmyError>
   where
+    OP: OperationType + ToString,
     Response: Serialize,
   {
     let res_str = &serialize_websocket_message(op, response)?;
@@ -324,12 +346,15 @@ impl ChatServer {
     Ok(())
   }
 
-  pub fn send_comment(
+  pub fn send_comment<OP>(
     &self,
-    user_operation: &UserOperation,
+    user_operation: &OP,
     comment: &CommentResponse,
     websocket_id: Option<ConnectionId>,
-  ) -> Result<(), LemmyError> {
+  ) -> Result<(), LemmyError>
+  where
+    OP: OperationType + ToString,
+  {
     let mut comment_reply_sent = comment.clone();
 
     // Strip out my specific user info
@@ -373,12 +398,15 @@ impl ChatServer {
     Ok(())
   }
 
-  pub fn send_post(
+  pub fn send_post<OP>(
     &self,
-    user_operation: &UserOperation,
+    user_operation: &OP,
     post_res: &PostResponse,
     websocket_id: Option<ConnectionId>,
-  ) -> Result<(), LemmyError> {
+  ) -> Result<(), LemmyError>
+  where
+    OP: OperationType + ToString,
+  {
     let community_id = post_res.post_view.community.id;
 
     // Don't send my data with it
@@ -424,6 +452,7 @@ impl ChatServer {
       client: self.client.to_owned(),
       activity_queue: self.activity_queue.to_owned(),
     };
+    let message_handler_crud = self.message_handler_crud;
     let message_handler = self.message_handler;
     async move {
       let json: Value = serde_json::from_str(&msg.msg)?;
@@ -432,13 +461,18 @@ impl ChatServer {
         message: "Unknown op type".to_string(),
       })?;
 
-      let user_operation = UserOperation::from_str(&op)?;
-      let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
-      match user_operation {
-        UserOperation::Register => rate_limiter.register().wrap(ip, fut).await,
-        UserOperation::CreatePost => rate_limiter.post().wrap(ip, fut).await,
-        UserOperation::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
-        _ => rate_limiter.message().wrap(ip, fut).await,
+      if let Ok(user_operation_crud) = UserOperationCrud::from_str(&op) {
+        let fut = (message_handler_crud)(context, msg.id, user_operation_crud.clone(), data);
+        match user_operation_crud {
+          UserOperationCrud::Register => rate_limiter.register().wrap(ip, fut).await,
+          UserOperationCrud::CreatePost => rate_limiter.post().wrap(ip, fut).await,
+          UserOperationCrud::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
+          _ => rate_limiter.message().wrap(ip, fut).await,
+        }
+      } else {
+        let user_operation = UserOperation::from_str(&op)?;
+        let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
+        rate_limiter.message().wrap(ip, fut).await
       }
     }
   }