]> Untitled Git - lemmy.git/blob - crates/websocket/src/chat_server.rs
Upgrading deps, running clippy fix on nightly 1.55.0 (#1638)
[lemmy.git] / crates / websocket / src / chat_server.rs
1 use crate::{
2   messages::*,
3   serialize_websocket_message,
4   LemmyContext,
5   OperationType,
6   UserOperation,
7   UserOperationCrud,
8 };
9 use actix::prelude::*;
10 use anyhow::Context as acontext;
11 use background_jobs::QueueHandle;
12 use diesel::{
13   r2d2::{ConnectionManager, Pool},
14   PgConnection,
15 };
16 use lemmy_api_common::{comment::*, post::*};
17 use lemmy_db_schema::{CommunityId, LocalUserId, PostId};
18 use lemmy_utils::{
19   location_info,
20   rate_limit::RateLimit,
21   ApiError,
22   ConnectionId,
23   IpAddr,
24   LemmyError,
25 };
26 use rand::rngs::ThreadRng;
27 use reqwest::Client;
28 use serde::Serialize;
29 use serde_json::Value;
30 use std::{
31   collections::{HashMap, HashSet},
32   future::Future,
33   str::FromStr,
34 };
35 use tokio::macros::support::Pin;
36
37 type MessageHandlerType = fn(
38   context: LemmyContext,
39   id: ConnectionId,
40   op: UserOperation,
41   data: &str,
42 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
43
44 type MessageHandlerCrudType = fn(
45   context: LemmyContext,
46   id: ConnectionId,
47   op: UserOperationCrud,
48   data: &str,
49 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
50
51 /// `ChatServer` manages chat rooms and responsible for coordinating chat
52 /// session.
53 pub struct ChatServer {
54   /// A map from generated random ID to session addr
55   pub sessions: HashMap<ConnectionId, SessionInfo>,
56
57   /// A map from post_id to set of connectionIDs
58   pub post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
59
60   /// A map from community to set of connectionIDs
61   pub community_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
62
63   pub mod_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
64
65   /// A map from user id to its connection ID for joined users. Remember a user can have multiple
66   /// sessions (IE clients)
67   pub(super) user_rooms: HashMap<LocalUserId, HashSet<ConnectionId>>,
68
69   pub(super) rng: ThreadRng,
70
71   /// The DB Pool
72   pub(super) pool: Pool<ConnectionManager<PgConnection>>,
73
74   /// Rate limiting based on rate type and IP addr
75   pub(super) rate_limiter: RateLimit,
76
77   /// A list of the current captchas
78   pub(super) captchas: Vec<CaptchaItem>,
79
80   message_handler: MessageHandlerType,
81   message_handler_crud: MessageHandlerCrudType,
82
83   /// An HTTP Client
84   client: Client,
85
86   activity_queue: QueueHandle,
87 }
88
89 pub struct SessionInfo {
90   pub addr: Recipient<WsMessage>,
91   pub ip: IpAddr,
92 }
93
94 /// `ChatServer` is an actor. It maintains list of connection client session.
95 /// And manages available rooms. Peers send messages to other peers in same
96 /// room through `ChatServer`.
97 impl ChatServer {
98   pub fn startup(
99     pool: Pool<ConnectionManager<PgConnection>>,
100     rate_limiter: RateLimit,
101     message_handler: MessageHandlerType,
102     message_handler_crud: MessageHandlerCrudType,
103     client: Client,
104     activity_queue: QueueHandle,
105   ) -> ChatServer {
106     ChatServer {
107       sessions: HashMap::new(),
108       post_rooms: HashMap::new(),
109       community_rooms: HashMap::new(),
110       mod_rooms: HashMap::new(),
111       user_rooms: HashMap::new(),
112       rng: rand::thread_rng(),
113       pool,
114       rate_limiter,
115       captchas: Vec::new(),
116       message_handler,
117       message_handler_crud,
118       client,
119       activity_queue,
120     }
121   }
122
123   pub fn join_community_room(
124     &mut self,
125     community_id: CommunityId,
126     id: ConnectionId,
127   ) -> Result<(), LemmyError> {
128     // remove session from all rooms
129     for sessions in self.community_rooms.values_mut() {
130       sessions.remove(&id);
131     }
132
133     // Also leave all post rooms
134     // This avoids double messages
135     for sessions in self.post_rooms.values_mut() {
136       sessions.remove(&id);
137     }
138
139     // If the room doesn't exist yet
140     if self.community_rooms.get_mut(&community_id).is_none() {
141       self.community_rooms.insert(community_id, HashSet::new());
142     }
143
144     self
145       .community_rooms
146       .get_mut(&community_id)
147       .context(location_info!())?
148       .insert(id);
149     Ok(())
150   }
151
152   pub fn join_mod_room(
153     &mut self,
154     community_id: CommunityId,
155     id: ConnectionId,
156   ) -> Result<(), LemmyError> {
157     // remove session from all rooms
158     for sessions in self.mod_rooms.values_mut() {
159       sessions.remove(&id);
160     }
161
162     // If the room doesn't exist yet
163     if self.mod_rooms.get_mut(&community_id).is_none() {
164       self.mod_rooms.insert(community_id, HashSet::new());
165     }
166
167     self
168       .mod_rooms
169       .get_mut(&community_id)
170       .context(location_info!())?
171       .insert(id);
172     Ok(())
173   }
174
175   pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
176     // remove session from all rooms
177     for sessions in self.post_rooms.values_mut() {
178       sessions.remove(&id);
179     }
180
181     // Also leave all communities
182     // This avoids double messages
183     // TODO found a bug, whereby community messages like
184     // delete and remove aren't sent, because
185     // you left the community room
186     for sessions in self.community_rooms.values_mut() {
187       sessions.remove(&id);
188     }
189
190     // If the room doesn't exist yet
191     if self.post_rooms.get_mut(&post_id).is_none() {
192       self.post_rooms.insert(post_id, HashSet::new());
193     }
194
195     self
196       .post_rooms
197       .get_mut(&post_id)
198       .context(location_info!())?
199       .insert(id);
200
201     Ok(())
202   }
203
204   pub fn join_user_room(
205     &mut self,
206     user_id: LocalUserId,
207     id: ConnectionId,
208   ) -> Result<(), LemmyError> {
209     // remove session from all rooms
210     for sessions in self.user_rooms.values_mut() {
211       sessions.remove(&id);
212     }
213
214     // If the room doesn't exist yet
215     if self.user_rooms.get_mut(&user_id).is_none() {
216       self.user_rooms.insert(user_id, HashSet::new());
217     }
218
219     self
220       .user_rooms
221       .get_mut(&user_id)
222       .context(location_info!())?
223       .insert(id);
224
225     Ok(())
226   }
227
228   fn send_post_room_message<OP, Response>(
229     &self,
230     op: &OP,
231     response: &Response,
232     post_id: PostId,
233     websocket_id: Option<ConnectionId>,
234   ) -> Result<(), LemmyError>
235   where
236     OP: OperationType + ToString,
237     Response: Serialize,
238   {
239     let res_str = &serialize_websocket_message(op, response)?;
240     if let Some(sessions) = self.post_rooms.get(&post_id) {
241       for id in sessions {
242         if let Some(my_id) = websocket_id {
243           if *id == my_id {
244             continue;
245           }
246         }
247         self.sendit(res_str, *id);
248       }
249     }
250     Ok(())
251   }
252
253   pub fn send_community_room_message<OP, Response>(
254     &self,
255     op: &OP,
256     response: &Response,
257     community_id: CommunityId,
258     websocket_id: Option<ConnectionId>,
259   ) -> Result<(), LemmyError>
260   where
261     OP: OperationType + ToString,
262     Response: Serialize,
263   {
264     let res_str = &serialize_websocket_message(op, response)?;
265     if let Some(sessions) = self.community_rooms.get(&community_id) {
266       for id in sessions {
267         if let Some(my_id) = websocket_id {
268           if *id == my_id {
269             continue;
270           }
271         }
272         self.sendit(res_str, *id);
273       }
274     }
275     Ok(())
276   }
277
278   pub fn send_mod_room_message<OP, Response>(
279     &self,
280     op: &OP,
281     response: &Response,
282     community_id: CommunityId,
283     websocket_id: Option<ConnectionId>,
284   ) -> Result<(), LemmyError>
285   where
286     OP: OperationType + ToString,
287     Response: Serialize,
288   {
289     let res_str = &serialize_websocket_message(op, response)?;
290     if let Some(sessions) = self.mod_rooms.get(&community_id) {
291       for id in sessions {
292         if let Some(my_id) = websocket_id {
293           if *id == my_id {
294             continue;
295           }
296         }
297         self.sendit(res_str, *id);
298       }
299     }
300     Ok(())
301   }
302
303   pub fn send_all_message<OP, Response>(
304     &self,
305     op: &OP,
306     response: &Response,
307     websocket_id: Option<ConnectionId>,
308   ) -> Result<(), LemmyError>
309   where
310     OP: OperationType + ToString,
311     Response: Serialize,
312   {
313     let res_str = &serialize_websocket_message(op, response)?;
314     for id in self.sessions.keys() {
315       if let Some(my_id) = websocket_id {
316         if *id == my_id {
317           continue;
318         }
319       }
320       self.sendit(res_str, *id);
321     }
322     Ok(())
323   }
324
325   pub fn send_user_room_message<OP, Response>(
326     &self,
327     op: &OP,
328     response: &Response,
329     recipient_id: LocalUserId,
330     websocket_id: Option<ConnectionId>,
331   ) -> Result<(), LemmyError>
332   where
333     OP: OperationType + ToString,
334     Response: Serialize,
335   {
336     let res_str = &serialize_websocket_message(op, response)?;
337     if let Some(sessions) = self.user_rooms.get(&recipient_id) {
338       for id in sessions {
339         if let Some(my_id) = websocket_id {
340           if *id == my_id {
341             continue;
342           }
343         }
344         self.sendit(res_str, *id);
345       }
346     }
347     Ok(())
348   }
349
350   pub fn send_comment<OP>(
351     &self,
352     user_operation: &OP,
353     comment: &CommentResponse,
354     websocket_id: Option<ConnectionId>,
355   ) -> Result<(), LemmyError>
356   where
357     OP: OperationType + ToString,
358   {
359     let mut comment_reply_sent = comment.clone();
360
361     // Strip out my specific user info
362     comment_reply_sent.comment_view.my_vote = None;
363
364     // Send it to the post room
365     let mut comment_post_sent = comment_reply_sent.clone();
366     // Remove the recipients here to separate mentions / user messages from post or community comments
367     comment_post_sent.recipient_ids = Vec::new();
368     self.send_post_room_message(
369       user_operation,
370       &comment_post_sent,
371       comment_post_sent.comment_view.post.id,
372       websocket_id,
373     )?;
374
375     // Send it to the community too
376     self.send_community_room_message(
377       user_operation,
378       &comment_post_sent,
379       CommunityId(0),
380       websocket_id,
381     )?;
382     self.send_community_room_message(
383       user_operation,
384       &comment_post_sent,
385       comment.comment_view.community.id,
386       websocket_id,
387     )?;
388
389     // Send it to the recipient(s) including the mentioned users
390     for recipient_id in &comment_reply_sent.recipient_ids {
391       self.send_user_room_message(
392         user_operation,
393         &comment_reply_sent,
394         *recipient_id,
395         websocket_id,
396       )?;
397     }
398
399     Ok(())
400   }
401
402   pub fn send_post<OP>(
403     &self,
404     user_operation: &OP,
405     post_res: &PostResponse,
406     websocket_id: Option<ConnectionId>,
407   ) -> Result<(), LemmyError>
408   where
409     OP: OperationType + ToString,
410   {
411     let community_id = post_res.post_view.community.id;
412
413     // Don't send my data with it
414     let mut post_sent = post_res.clone();
415     post_sent.post_view.my_vote = None;
416
417     // Send it to /c/all and that community
418     self.send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)?;
419     self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?;
420
421     // Send it to the post room
422     self.send_post_room_message(
423       user_operation,
424       &post_sent,
425       post_res.post_view.post.id,
426       websocket_id,
427     )?;
428
429     Ok(())
430   }
431
432   fn sendit(&self, message: &str, id: ConnectionId) {
433     if let Some(info) = self.sessions.get(&id) {
434       let _ = info.addr.do_send(WsMessage(message.to_owned()));
435     }
436   }
437
438   pub(super) fn parse_json_message(
439     &mut self,
440     msg: StandardMessage,
441     ctx: &mut Context<Self>,
442   ) -> impl Future<Output = Result<String, LemmyError>> {
443     let rate_limiter = self.rate_limiter.clone();
444
445     let ip: IpAddr = match self.sessions.get(&msg.id) {
446       Some(info) => info.ip.to_owned(),
447       None => IpAddr("blank_ip".to_string()),
448     };
449
450     let context = LemmyContext {
451       pool: self.pool.clone(),
452       chat_server: ctx.address(),
453       client: self.client.to_owned(),
454       activity_queue: self.activity_queue.to_owned(),
455     };
456     let message_handler_crud = self.message_handler_crud;
457     let message_handler = self.message_handler;
458     async move {
459       let json: Value = serde_json::from_str(&msg.msg)?;
460       let data = &json["data"].to_string();
461       let op = &json["op"].as_str().ok_or(ApiError {
462         message: "Unknown op type".to_string(),
463       })?;
464
465       if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
466         let fut = (message_handler_crud)(context, msg.id, user_operation_crud.clone(), data);
467         match user_operation_crud {
468           UserOperationCrud::Register => rate_limiter.register().wrap(ip, fut).await,
469           UserOperationCrud::CreatePost => rate_limiter.post().wrap(ip, fut).await,
470           UserOperationCrud::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
471           _ => rate_limiter.message().wrap(ip, fut).await,
472         }
473       } else {
474         let user_operation = UserOperation::from_str(op)?;
475         let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
476         rate_limiter.message().wrap(ip, fut).await
477       }
478     }
479   }
480 }