]> Untitled Git - lemmy.git/blob - crates/websocket/src/chat_server.rs
Fix API and clippy warnings
[lemmy.git] / crates / websocket / src / chat_server.rs
1 use crate::{
2   messages::*,
3   serialize_websocket_message,
4   LemmyContext,
5   OperationType,
6   UserOperation,
7   UserOperationCrud,
8 };
9 use actix::prelude::*;
10 use anyhow::Context as acontext;
11 use background_jobs::QueueHandle;
12 use diesel::{
13   r2d2::{ConnectionManager, Pool},
14   PgConnection,
15 };
16 use lemmy_api_common::{comment::*, post::*};
17 use lemmy_db_schema::{CommunityId, LocalUserId, PostId};
18 use lemmy_utils::{
19   location_info,
20   rate_limit::RateLimit,
21   ApiError,
22   ConnectionId,
23   IpAddr,
24   LemmyError,
25 };
26 use rand::rngs::ThreadRng;
27 use reqwest::Client;
28 use serde::Serialize;
29 use serde_json::Value;
30 use std::{
31   collections::{HashMap, HashSet},
32   str::FromStr,
33 };
34 use tokio::macros::support::Pin;
35
36 type MessageHandlerType = fn(
37   context: LemmyContext,
38   id: ConnectionId,
39   op: UserOperation,
40   data: &str,
41 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
42
43 type MessageHandlerCrudType = fn(
44   context: LemmyContext,
45   id: ConnectionId,
46   op: UserOperationCrud,
47   data: &str,
48 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
49
50 /// `ChatServer` manages chat rooms and responsible for coordinating chat
51 /// session.
52 pub struct ChatServer {
53   /// A map from generated random ID to session addr
54   pub sessions: HashMap<ConnectionId, SessionInfo>,
55
56   /// A map from post_id to set of connectionIDs
57   pub post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
58
59   /// A map from community to set of connectionIDs
60   pub community_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
61
62   pub mod_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
63
64   /// A map from user id to its connection ID for joined users. Remember a user can have multiple
65   /// sessions (IE clients)
66   pub(super) user_rooms: HashMap<LocalUserId, HashSet<ConnectionId>>,
67
68   pub(super) rng: ThreadRng,
69
70   /// The DB Pool
71   pub(super) pool: Pool<ConnectionManager<PgConnection>>,
72
73   /// Rate limiting based on rate type and IP addr
74   pub(super) rate_limiter: RateLimit,
75
76   /// A list of the current captchas
77   pub(super) captchas: Vec<CaptchaItem>,
78
79   message_handler: MessageHandlerType,
80   message_handler_crud: MessageHandlerCrudType,
81
82   /// An HTTP Client
83   client: Client,
84
85   activity_queue: QueueHandle,
86 }
87
88 pub struct SessionInfo {
89   pub addr: Recipient<WsMessage>,
90   pub ip: IpAddr,
91 }
92
93 /// `ChatServer` is an actor. It maintains list of connection client session.
94 /// And manages available rooms. Peers send messages to other peers in same
95 /// room through `ChatServer`.
96 impl ChatServer {
97   pub fn startup(
98     pool: Pool<ConnectionManager<PgConnection>>,
99     rate_limiter: RateLimit,
100     message_handler: MessageHandlerType,
101     message_handler_crud: MessageHandlerCrudType,
102     client: Client,
103     activity_queue: QueueHandle,
104   ) -> ChatServer {
105     ChatServer {
106       sessions: HashMap::new(),
107       post_rooms: HashMap::new(),
108       community_rooms: HashMap::new(),
109       mod_rooms: HashMap::new(),
110       user_rooms: HashMap::new(),
111       rng: rand::thread_rng(),
112       pool,
113       rate_limiter,
114       captchas: Vec::new(),
115       message_handler,
116       message_handler_crud,
117       client,
118       activity_queue,
119     }
120   }
121
122   pub fn join_community_room(
123     &mut self,
124     community_id: CommunityId,
125     id: ConnectionId,
126   ) -> Result<(), LemmyError> {
127     // remove session from all rooms
128     for sessions in self.community_rooms.values_mut() {
129       sessions.remove(&id);
130     }
131
132     // Also leave all post rooms
133     // This avoids double messages
134     for sessions in self.post_rooms.values_mut() {
135       sessions.remove(&id);
136     }
137
138     // If the room doesn't exist yet
139     if self.community_rooms.get_mut(&community_id).is_none() {
140       self.community_rooms.insert(community_id, HashSet::new());
141     }
142
143     self
144       .community_rooms
145       .get_mut(&community_id)
146       .context(location_info!())?
147       .insert(id);
148     Ok(())
149   }
150
151   pub fn join_mod_room(
152     &mut self,
153     community_id: CommunityId,
154     id: ConnectionId,
155   ) -> Result<(), LemmyError> {
156     // remove session from all rooms
157     for sessions in self.mod_rooms.values_mut() {
158       sessions.remove(&id);
159     }
160
161     // If the room doesn't exist yet
162     if self.mod_rooms.get_mut(&community_id).is_none() {
163       self.mod_rooms.insert(community_id, HashSet::new());
164     }
165
166     self
167       .mod_rooms
168       .get_mut(&community_id)
169       .context(location_info!())?
170       .insert(id);
171     Ok(())
172   }
173
174   pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
175     // remove session from all rooms
176     for sessions in self.post_rooms.values_mut() {
177       sessions.remove(&id);
178     }
179
180     // Also leave all communities
181     // This avoids double messages
182     // TODO found a bug, whereby community messages like
183     // delete and remove aren't sent, because
184     // you left the community room
185     for sessions in self.community_rooms.values_mut() {
186       sessions.remove(&id);
187     }
188
189     // If the room doesn't exist yet
190     if self.post_rooms.get_mut(&post_id).is_none() {
191       self.post_rooms.insert(post_id, HashSet::new());
192     }
193
194     self
195       .post_rooms
196       .get_mut(&post_id)
197       .context(location_info!())?
198       .insert(id);
199
200     Ok(())
201   }
202
203   pub fn join_user_room(
204     &mut self,
205     user_id: LocalUserId,
206     id: ConnectionId,
207   ) -> Result<(), LemmyError> {
208     // remove session from all rooms
209     for sessions in self.user_rooms.values_mut() {
210       sessions.remove(&id);
211     }
212
213     // If the room doesn't exist yet
214     if self.user_rooms.get_mut(&user_id).is_none() {
215       self.user_rooms.insert(user_id, HashSet::new());
216     }
217
218     self
219       .user_rooms
220       .get_mut(&user_id)
221       .context(location_info!())?
222       .insert(id);
223
224     Ok(())
225   }
226
227   fn send_post_room_message<OP, Response>(
228     &self,
229     op: &OP,
230     response: &Response,
231     post_id: PostId,
232     websocket_id: Option<ConnectionId>,
233   ) -> Result<(), LemmyError>
234   where
235     OP: OperationType + ToString,
236     Response: Serialize,
237   {
238     let res_str = &serialize_websocket_message(op, response)?;
239     if let Some(sessions) = self.post_rooms.get(&post_id) {
240       for id in sessions {
241         if let Some(my_id) = websocket_id {
242           if *id == my_id {
243             continue;
244           }
245         }
246         self.sendit(res_str, *id);
247       }
248     }
249     Ok(())
250   }
251
252   pub fn send_community_room_message<OP, Response>(
253     &self,
254     op: &OP,
255     response: &Response,
256     community_id: CommunityId,
257     websocket_id: Option<ConnectionId>,
258   ) -> Result<(), LemmyError>
259   where
260     OP: OperationType + ToString,
261     Response: Serialize,
262   {
263     let res_str = &serialize_websocket_message(op, response)?;
264     if let Some(sessions) = self.community_rooms.get(&community_id) {
265       for id in sessions {
266         if let Some(my_id) = websocket_id {
267           if *id == my_id {
268             continue;
269           }
270         }
271         self.sendit(res_str, *id);
272       }
273     }
274     Ok(())
275   }
276
277   pub fn send_mod_room_message<OP, Response>(
278     &self,
279     op: &OP,
280     response: &Response,
281     community_id: CommunityId,
282     websocket_id: Option<ConnectionId>,
283   ) -> Result<(), LemmyError>
284   where
285     OP: OperationType + ToString,
286     Response: Serialize,
287   {
288     let res_str = &serialize_websocket_message(op, response)?;
289     if let Some(sessions) = self.mod_rooms.get(&community_id) {
290       for id in sessions {
291         if let Some(my_id) = websocket_id {
292           if *id == my_id {
293             continue;
294           }
295         }
296         self.sendit(res_str, *id);
297       }
298     }
299     Ok(())
300   }
301
302   pub fn send_all_message<OP, Response>(
303     &self,
304     op: &OP,
305     response: &Response,
306     websocket_id: Option<ConnectionId>,
307   ) -> Result<(), LemmyError>
308   where
309     OP: OperationType + ToString,
310     Response: Serialize,
311   {
312     let res_str = &serialize_websocket_message(op, response)?;
313     for id in self.sessions.keys() {
314       if let Some(my_id) = websocket_id {
315         if *id == my_id {
316           continue;
317         }
318       }
319       self.sendit(res_str, *id);
320     }
321     Ok(())
322   }
323
324   pub fn send_user_room_message<OP, Response>(
325     &self,
326     op: &OP,
327     response: &Response,
328     recipient_id: LocalUserId,
329     websocket_id: Option<ConnectionId>,
330   ) -> Result<(), LemmyError>
331   where
332     OP: OperationType + ToString,
333     Response: Serialize,
334   {
335     let res_str = &serialize_websocket_message(op, response)?;
336     if let Some(sessions) = self.user_rooms.get(&recipient_id) {
337       for id in sessions {
338         if let Some(my_id) = websocket_id {
339           if *id == my_id {
340             continue;
341           }
342         }
343         self.sendit(res_str, *id);
344       }
345     }
346     Ok(())
347   }
348
349   pub fn send_comment<OP>(
350     &self,
351     user_operation: &OP,
352     comment: &CommentResponse,
353     websocket_id: Option<ConnectionId>,
354   ) -> Result<(), LemmyError>
355   where
356     OP: OperationType + ToString,
357   {
358     let mut comment_reply_sent = comment.clone();
359
360     // Strip out my specific user info
361     comment_reply_sent.comment_view.my_vote = None;
362
363     // Send it to the post room
364     let mut comment_post_sent = comment_reply_sent.clone();
365     // Remove the recipients here to separate mentions / user messages from post or community comments
366     comment_post_sent.recipient_ids = Vec::new();
367     self.send_post_room_message(
368       user_operation,
369       &comment_post_sent,
370       comment_post_sent.comment_view.post.id,
371       websocket_id,
372     )?;
373
374     // Send it to the community too
375     self.send_community_room_message(
376       user_operation,
377       &comment_post_sent,
378       CommunityId(0),
379       websocket_id,
380     )?;
381     self.send_community_room_message(
382       user_operation,
383       &comment_post_sent,
384       comment.comment_view.community.id,
385       websocket_id,
386     )?;
387
388     // Send it to the recipient(s) including the mentioned users
389     for recipient_id in &comment_reply_sent.recipient_ids {
390       self.send_user_room_message(
391         user_operation,
392         &comment_reply_sent,
393         *recipient_id,
394         websocket_id,
395       )?;
396     }
397
398     Ok(())
399   }
400
401   pub fn send_post<OP>(
402     &self,
403     user_operation: &OP,
404     post_res: &PostResponse,
405     websocket_id: Option<ConnectionId>,
406   ) -> Result<(), LemmyError>
407   where
408     OP: OperationType + ToString,
409   {
410     let community_id = post_res.post_view.community.id;
411
412     // Don't send my data with it
413     let mut post_sent = post_res.clone();
414     post_sent.post_view.my_vote = None;
415
416     // Send it to /c/all and that community
417     self.send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)?;
418     self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?;
419
420     // Send it to the post room
421     self.send_post_room_message(
422       user_operation,
423       &post_sent,
424       post_res.post_view.post.id,
425       websocket_id,
426     )?;
427
428     Ok(())
429   }
430
431   fn sendit(&self, message: &str, id: ConnectionId) {
432     if let Some(info) = self.sessions.get(&id) {
433       let _ = info.addr.do_send(WsMessage(message.to_owned()));
434     }
435   }
436
437   pub(super) fn parse_json_message(
438     &mut self,
439     msg: StandardMessage,
440     ctx: &mut Context<Self>,
441   ) -> impl Future<Output = Result<String, LemmyError>> {
442     let rate_limiter = self.rate_limiter.clone();
443
444     let ip: IpAddr = match self.sessions.get(&msg.id) {
445       Some(info) => info.ip.to_owned(),
446       None => IpAddr("blank_ip".to_string()),
447     };
448
449     let context = LemmyContext {
450       pool: self.pool.clone(),
451       chat_server: ctx.address(),
452       client: self.client.to_owned(),
453       activity_queue: self.activity_queue.to_owned(),
454     };
455     let message_handler_crud = self.message_handler_crud;
456     let message_handler = self.message_handler;
457     async move {
458       let json: Value = serde_json::from_str(&msg.msg)?;
459       let data = &json["data"].to_string();
460       let op = &json["op"].as_str().ok_or(ApiError {
461         message: "Unknown op type".to_string(),
462       })?;
463
464       if let Ok(user_operation_crud) = UserOperationCrud::from_str(&op) {
465         let fut = (message_handler_crud)(context, msg.id, user_operation_crud.clone(), data);
466         match user_operation_crud {
467           UserOperationCrud::Register => rate_limiter.register().wrap(ip, fut).await,
468           UserOperationCrud::CreatePost => rate_limiter.post().wrap(ip, fut).await,
469           UserOperationCrud::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
470           _ => rate_limiter.message().wrap(ip, fut).await,
471         }
472       } else {
473         let user_operation = UserOperation::from_str(&op)?;
474         let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
475         rate_limiter.message().wrap(ip, fut).await
476       }
477     }
478   }
479 }