]> Untitled Git - lemmy.git/blob - crates/api_common/src/websocket/chat_server.rs
Merge pull request #2593 from LemmyNet/refactor-notifications
[lemmy.git] / crates / api_common / src / websocket / chat_server.rs
1 use crate::{
2   comment::CommentResponse,
3   context::LemmyContext,
4   post::PostResponse,
5   websocket::{
6     messages::{CaptchaItem, StandardMessage, WsMessage},
7     serialize_websocket_message,
8     OperationType,
9     UserOperation,
10     UserOperationApub,
11     UserOperationCrud,
12   },
13 };
14 use actix::prelude::*;
15 use anyhow::Context as acontext;
16 use lemmy_db_schema::{
17   newtypes::{CommunityId, LocalUserId, PostId},
18   source::secret::Secret,
19   utils::DbPool,
20 };
21 use lemmy_utils::{
22   error::LemmyError,
23   location_info,
24   rate_limit::RateLimitCell,
25   settings::structs::Settings,
26   ConnectionId,
27   IpAddr,
28 };
29 use rand::rngs::ThreadRng;
30 use reqwest_middleware::ClientWithMiddleware;
31 use serde::Serialize;
32 use serde_json::Value;
33 use std::{
34   collections::{HashMap, HashSet},
35   future::Future,
36   str::FromStr,
37 };
38 use tokio::macros::support::Pin;
39
40 type MessageHandlerType = fn(
41   context: LemmyContext,
42   id: ConnectionId,
43   op: UserOperation,
44   data: &str,
45 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
46
47 type MessageHandlerCrudType = fn(
48   context: LemmyContext,
49   id: ConnectionId,
50   op: UserOperationCrud,
51   data: &str,
52 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
53
54 type MessageHandlerApubType = fn(
55   context: LemmyContext,
56   id: ConnectionId,
57   op: UserOperationApub,
58   data: &str,
59 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
60
61 /// `ChatServer` manages chat rooms and responsible for coordinating chat
62 /// session.
63 pub struct ChatServer {
64   /// A map from generated random ID to session addr
65   pub sessions: HashMap<ConnectionId, SessionInfo>,
66
67   /// A map from post_id to set of connectionIDs
68   pub post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
69
70   /// A map from community to set of connectionIDs
71   pub community_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
72
73   pub mod_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
74
75   /// A map from user id to its connection ID for joined users. Remember a user can have multiple
76   /// sessions (IE clients)
77   pub(super) user_rooms: HashMap<LocalUserId, HashSet<ConnectionId>>,
78
79   pub(super) rng: ThreadRng,
80
81   /// The DB Pool
82   pub(super) pool: DbPool,
83
84   /// The Settings
85   pub(super) settings: Settings,
86
87   /// The Secrets
88   pub(super) secret: Secret,
89
90   /// A list of the current captchas
91   pub(super) captchas: Vec<CaptchaItem>,
92
93   message_handler: MessageHandlerType,
94   message_handler_crud: MessageHandlerCrudType,
95   message_handler_apub: MessageHandlerApubType,
96
97   /// An HTTP Client
98   client: ClientWithMiddleware,
99
100   rate_limit_cell: RateLimitCell,
101 }
102
103 pub struct SessionInfo {
104   pub addr: Recipient<WsMessage>,
105   pub ip: IpAddr,
106 }
107
108 /// `ChatServer` is an actor. It maintains list of connection client session.
109 /// And manages available rooms. Peers send messages to other peers in same
110 /// room through `ChatServer`.
111 impl ChatServer {
112   #![allow(clippy::too_many_arguments)]
113   pub fn startup(
114     pool: DbPool,
115     message_handler: MessageHandlerType,
116     message_handler_crud: MessageHandlerCrudType,
117     message_handler_apub: MessageHandlerApubType,
118     client: ClientWithMiddleware,
119     settings: Settings,
120     secret: Secret,
121     rate_limit_cell: RateLimitCell,
122   ) -> ChatServer {
123     ChatServer {
124       sessions: HashMap::new(),
125       post_rooms: HashMap::new(),
126       community_rooms: HashMap::new(),
127       mod_rooms: HashMap::new(),
128       user_rooms: HashMap::new(),
129       rng: rand::thread_rng(),
130       pool,
131       captchas: Vec::new(),
132       message_handler,
133       message_handler_crud,
134       message_handler_apub,
135       client,
136       settings,
137       secret,
138       rate_limit_cell,
139     }
140   }
141
142   pub fn join_community_room(
143     &mut self,
144     community_id: CommunityId,
145     id: ConnectionId,
146   ) -> Result<(), LemmyError> {
147     // remove session from all rooms
148     for sessions in self.community_rooms.values_mut() {
149       sessions.remove(&id);
150     }
151
152     // Also leave all post rooms
153     // This avoids double messages
154     for sessions in self.post_rooms.values_mut() {
155       sessions.remove(&id);
156     }
157
158     // If the room doesn't exist yet
159     if self.community_rooms.get_mut(&community_id).is_none() {
160       self.community_rooms.insert(community_id, HashSet::new());
161     }
162
163     self
164       .community_rooms
165       .get_mut(&community_id)
166       .context(location_info!())?
167       .insert(id);
168     Ok(())
169   }
170
171   pub fn join_mod_room(
172     &mut self,
173     community_id: CommunityId,
174     id: ConnectionId,
175   ) -> Result<(), LemmyError> {
176     // remove session from all rooms
177     for sessions in self.mod_rooms.values_mut() {
178       sessions.remove(&id);
179     }
180
181     // If the room doesn't exist yet
182     if self.mod_rooms.get_mut(&community_id).is_none() {
183       self.mod_rooms.insert(community_id, HashSet::new());
184     }
185
186     self
187       .mod_rooms
188       .get_mut(&community_id)
189       .context(location_info!())?
190       .insert(id);
191     Ok(())
192   }
193
194   pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
195     // remove session from all rooms
196     for sessions in self.post_rooms.values_mut() {
197       sessions.remove(&id);
198     }
199
200     // Also leave all communities
201     // This avoids double messages
202     // TODO found a bug, whereby community messages like
203     // delete and remove aren't sent, because
204     // you left the community room
205     for sessions in self.community_rooms.values_mut() {
206       sessions.remove(&id);
207     }
208
209     // If the room doesn't exist yet
210     if self.post_rooms.get_mut(&post_id).is_none() {
211       self.post_rooms.insert(post_id, HashSet::new());
212     }
213
214     self
215       .post_rooms
216       .get_mut(&post_id)
217       .context(location_info!())?
218       .insert(id);
219
220     Ok(())
221   }
222
223   pub fn join_user_room(
224     &mut self,
225     user_id: LocalUserId,
226     id: ConnectionId,
227   ) -> Result<(), LemmyError> {
228     // remove session from all rooms
229     for sessions in self.user_rooms.values_mut() {
230       sessions.remove(&id);
231     }
232
233     // If the room doesn't exist yet
234     if self.user_rooms.get_mut(&user_id).is_none() {
235       self.user_rooms.insert(user_id, HashSet::new());
236     }
237
238     self
239       .user_rooms
240       .get_mut(&user_id)
241       .context(location_info!())?
242       .insert(id);
243
244     Ok(())
245   }
246
247   fn send_post_room_message<OP, Response>(
248     &self,
249     op: &OP,
250     response: &Response,
251     post_id: PostId,
252     websocket_id: Option<ConnectionId>,
253   ) -> Result<(), LemmyError>
254   where
255     OP: OperationType + ToString,
256     Response: Serialize,
257   {
258     let res_str = &serialize_websocket_message(op, response)?;
259     if let Some(sessions) = self.post_rooms.get(&post_id) {
260       for id in sessions {
261         if let Some(my_id) = websocket_id {
262           if *id == my_id {
263             continue;
264           }
265         }
266         self.sendit(res_str, *id);
267       }
268     }
269     Ok(())
270   }
271
272   pub fn send_community_room_message<OP, Response>(
273     &self,
274     op: &OP,
275     response: &Response,
276     community_id: CommunityId,
277     websocket_id: Option<ConnectionId>,
278   ) -> Result<(), LemmyError>
279   where
280     OP: OperationType + ToString,
281     Response: Serialize,
282   {
283     let res_str = &serialize_websocket_message(op, response)?;
284     if let Some(sessions) = self.community_rooms.get(&community_id) {
285       for id in sessions {
286         if let Some(my_id) = websocket_id {
287           if *id == my_id {
288             continue;
289           }
290         }
291         self.sendit(res_str, *id);
292       }
293     }
294     Ok(())
295   }
296
297   pub fn send_mod_room_message<OP, Response>(
298     &self,
299     op: &OP,
300     response: &Response,
301     community_id: CommunityId,
302     websocket_id: Option<ConnectionId>,
303   ) -> Result<(), LemmyError>
304   where
305     OP: OperationType + ToString,
306     Response: Serialize,
307   {
308     let res_str = &serialize_websocket_message(op, response)?;
309     if let Some(sessions) = self.mod_rooms.get(&community_id) {
310       for id in sessions {
311         if let Some(my_id) = websocket_id {
312           if *id == my_id {
313             continue;
314           }
315         }
316         self.sendit(res_str, *id);
317       }
318     }
319     Ok(())
320   }
321
322   pub fn send_all_message<OP, Response>(
323     &self,
324     op: &OP,
325     response: &Response,
326     websocket_id: Option<ConnectionId>,
327   ) -> Result<(), LemmyError>
328   where
329     OP: OperationType + ToString,
330     Response: Serialize,
331   {
332     let res_str = &serialize_websocket_message(op, response)?;
333     for id in self.sessions.keys() {
334       if let Some(my_id) = websocket_id {
335         if *id == my_id {
336           continue;
337         }
338       }
339       self.sendit(res_str, *id);
340     }
341     Ok(())
342   }
343
344   pub fn send_user_room_message<OP, Response>(
345     &self,
346     op: &OP,
347     response: &Response,
348     recipient_id: LocalUserId,
349     websocket_id: Option<ConnectionId>,
350   ) -> Result<(), LemmyError>
351   where
352     OP: OperationType + ToString,
353     Response: Serialize,
354   {
355     let res_str = &serialize_websocket_message(op, response)?;
356     if let Some(sessions) = self.user_rooms.get(&recipient_id) {
357       for id in sessions {
358         if let Some(my_id) = websocket_id {
359           if *id == my_id {
360             continue;
361           }
362         }
363         self.sendit(res_str, *id);
364       }
365     }
366     Ok(())
367   }
368
369   pub fn send_comment<OP>(
370     &self,
371     user_operation: &OP,
372     comment: &CommentResponse,
373     websocket_id: Option<ConnectionId>,
374   ) -> Result<(), LemmyError>
375   where
376     OP: OperationType + ToString,
377   {
378     let mut comment_reply_sent = comment.clone();
379
380     // Strip out my specific user info
381     comment_reply_sent.comment_view.my_vote = None;
382
383     // Send it to the post room
384     let mut comment_post_sent = comment_reply_sent.clone();
385     // Remove the recipients here to separate mentions / user messages from post or community comments
386     comment_post_sent.recipient_ids = Vec::new();
387     self.send_post_room_message(
388       user_operation,
389       &comment_post_sent,
390       comment_post_sent.comment_view.post.id,
391       websocket_id,
392     )?;
393
394     // Send it to the community too
395     self.send_community_room_message(
396       user_operation,
397       &comment_post_sent,
398       CommunityId(0),
399       websocket_id,
400     )?;
401     self.send_community_room_message(
402       user_operation,
403       &comment_post_sent,
404       comment.comment_view.community.id,
405       websocket_id,
406     )?;
407
408     // Send it to the recipient(s) including the mentioned users
409     for recipient_id in &comment_reply_sent.recipient_ids {
410       self.send_user_room_message(
411         user_operation,
412         &comment_reply_sent,
413         *recipient_id,
414         websocket_id,
415       )?;
416     }
417
418     Ok(())
419   }
420
421   pub fn send_post<OP>(
422     &self,
423     user_operation: &OP,
424     post_res: &PostResponse,
425     websocket_id: Option<ConnectionId>,
426   ) -> Result<(), LemmyError>
427   where
428     OP: OperationType + ToString,
429   {
430     let community_id = post_res.post_view.community.id;
431
432     // Don't send my data with it
433     let mut post_sent = post_res.clone();
434     post_sent.post_view.my_vote = None;
435
436     // Send it to /c/all and that community
437     self.send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)?;
438     self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?;
439
440     // Send it to the post room
441     self.send_post_room_message(
442       user_operation,
443       &post_sent,
444       post_res.post_view.post.id,
445       websocket_id,
446     )?;
447
448     Ok(())
449   }
450
451   fn sendit(&self, message: &str, id: ConnectionId) {
452     if let Some(info) = self.sessions.get(&id) {
453       info.addr.do_send(WsMessage(message.to_owned()));
454     }
455   }
456
457   pub(super) fn parse_json_message(
458     &mut self,
459     msg: StandardMessage,
460     ctx: &mut Context<Self>,
461   ) -> impl Future<Output = Result<String, LemmyError>> {
462     let ip: IpAddr = match self.sessions.get(&msg.id) {
463       Some(info) => info.ip.clone(),
464       None => IpAddr("blank_ip".to_string()),
465     };
466
467     let context = LemmyContext::create(
468       self.pool.clone(),
469       ctx.address(),
470       self.client.clone(),
471       self.settings.clone(),
472       self.secret.clone(),
473       self.rate_limit_cell.clone(),
474     );
475     let message_handler_crud = self.message_handler_crud;
476     let message_handler = self.message_handler;
477     let message_handler_apub = self.message_handler_apub;
478     let rate_limiter = self.rate_limit_cell.clone();
479     async move {
480       let json: Value = serde_json::from_str(&msg.msg)?;
481       let data = &json["data"].to_string();
482       let op = &json["op"]
483         .as_str()
484         .ok_or_else(|| LemmyError::from_message("missing op"))?;
485
486       // check if api call passes the rate limit, and generate future for later execution
487       let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
488         let passed = match user_operation_crud {
489           UserOperationCrud::Register => rate_limiter.register().check(ip),
490           UserOperationCrud::CreatePost => rate_limiter.post().check(ip),
491           UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip),
492           UserOperationCrud::CreateComment => rate_limiter.comment().check(ip),
493           _ => rate_limiter.message().check(ip),
494         };
495         let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
496         (passed, fut)
497       } else if let Ok(user_operation) = UserOperation::from_str(op) {
498         let passed = match user_operation {
499           UserOperation::GetCaptcha => rate_limiter.post().check(ip),
500           _ => rate_limiter.message().check(ip),
501         };
502         let fut = (message_handler)(context, msg.id, user_operation, data);
503         (passed, fut)
504       } else {
505         let user_operation = UserOperationApub::from_str(op)?;
506         let passed = match user_operation {
507           UserOperationApub::Search => rate_limiter.search().check(ip),
508           _ => rate_limiter.message().check(ip),
509         };
510         let fut = (message_handler_apub)(context, msg.id, user_operation, data);
511         (passed, fut)
512       };
513
514       // if rate limit passed, execute api call future
515       if passed {
516         fut.await
517       } else {
518         // if rate limit was hit, respond with message
519         Err(LemmyError::from_message("rate_limit_error"))
520       }
521     }
522   }
523 }