]> Untitled Git - lemmy.git/blob - crates/websocket/src/chat_server.rs
2b58b2c1effd03818ca54b1d5e6bc23a3eb24fd1
[lemmy.git] / crates / websocket / src / chat_server.rs
1 use crate::{
2   messages::*,
3   serialize_websocket_message,
4   LemmyContext,
5   OperationType,
6   UserOperation,
7   UserOperationCrud,
8 };
9 use actix::prelude::*;
10 use anyhow::Context as acontext;
11 use background_jobs::QueueHandle;
12 use diesel::{
13   r2d2::{ConnectionManager, Pool},
14   PgConnection,
15 };
16 use lemmy_api_common::{comment::*, post::*};
17 use lemmy_db_schema::{
18   newtypes::{CommunityId, LocalUserId, PostId},
19   source::secret::Secret,
20 };
21 use lemmy_utils::{
22   location_info,
23   rate_limit::RateLimit,
24   settings::structs::Settings,
25   ApiError,
26   ConnectionId,
27   IpAddr,
28   LemmyError,
29 };
30 use rand::rngs::ThreadRng;
31 use reqwest::Client;
32 use serde::Serialize;
33 use serde_json::Value;
34 use std::{
35   collections::{HashMap, HashSet},
36   future::Future,
37   str::FromStr,
38 };
39 use tokio::macros::support::Pin;
40
41 type MessageHandlerType = fn(
42   context: LemmyContext,
43   id: ConnectionId,
44   op: UserOperation,
45   data: &str,
46 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
47
48 type MessageHandlerCrudType = fn(
49   context: LemmyContext,
50   id: ConnectionId,
51   op: UserOperationCrud,
52   data: &str,
53 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
54
55 /// `ChatServer` manages chat rooms and responsible for coordinating chat
56 /// session.
57 pub struct ChatServer {
58   /// A map from generated random ID to session addr
59   pub sessions: HashMap<ConnectionId, SessionInfo>,
60
61   /// A map from post_id to set of connectionIDs
62   pub post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
63
64   /// A map from community to set of connectionIDs
65   pub community_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
66
67   pub mod_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
68
69   /// A map from user id to its connection ID for joined users. Remember a user can have multiple
70   /// sessions (IE clients)
71   pub(super) user_rooms: HashMap<LocalUserId, HashSet<ConnectionId>>,
72
73   pub(super) rng: ThreadRng,
74
75   /// The DB Pool
76   pub(super) pool: Pool<ConnectionManager<PgConnection>>,
77
78   /// The Settings
79   pub(super) settings: Settings,
80
81   /// The Secrets
82   pub(super) secret: Secret,
83
84   /// Rate limiting based on rate type and IP addr
85   pub(super) rate_limiter: RateLimit,
86
87   /// A list of the current captchas
88   pub(super) captchas: Vec<CaptchaItem>,
89
90   message_handler: MessageHandlerType,
91   message_handler_crud: MessageHandlerCrudType,
92
93   /// An HTTP Client
94   client: Client,
95
96   activity_queue: QueueHandle,
97 }
98
99 pub struct SessionInfo {
100   pub addr: Recipient<WsMessage>,
101   pub ip: IpAddr,
102 }
103
104 /// `ChatServer` is an actor. It maintains list of connection client session.
105 /// And manages available rooms. Peers send messages to other peers in same
106 /// room through `ChatServer`.
107 impl ChatServer {
108   #![allow(clippy::too_many_arguments)]
109   pub fn startup(
110     pool: Pool<ConnectionManager<PgConnection>>,
111     rate_limiter: RateLimit,
112     message_handler: MessageHandlerType,
113     message_handler_crud: MessageHandlerCrudType,
114     client: Client,
115     activity_queue: QueueHandle,
116     settings: Settings,
117     secret: Secret,
118   ) -> ChatServer {
119     ChatServer {
120       sessions: HashMap::new(),
121       post_rooms: HashMap::new(),
122       community_rooms: HashMap::new(),
123       mod_rooms: HashMap::new(),
124       user_rooms: HashMap::new(),
125       rng: rand::thread_rng(),
126       pool,
127       rate_limiter,
128       captchas: Vec::new(),
129       message_handler,
130       message_handler_crud,
131       client,
132       activity_queue,
133       settings,
134       secret,
135     }
136   }
137
138   pub fn join_community_room(
139     &mut self,
140     community_id: CommunityId,
141     id: ConnectionId,
142   ) -> Result<(), LemmyError> {
143     // remove session from all rooms
144     for sessions in self.community_rooms.values_mut() {
145       sessions.remove(&id);
146     }
147
148     // Also leave all post rooms
149     // This avoids double messages
150     for sessions in self.post_rooms.values_mut() {
151       sessions.remove(&id);
152     }
153
154     // If the room doesn't exist yet
155     if self.community_rooms.get_mut(&community_id).is_none() {
156       self.community_rooms.insert(community_id, HashSet::new());
157     }
158
159     self
160       .community_rooms
161       .get_mut(&community_id)
162       .context(location_info!())?
163       .insert(id);
164     Ok(())
165   }
166
167   pub fn join_mod_room(
168     &mut self,
169     community_id: CommunityId,
170     id: ConnectionId,
171   ) -> Result<(), LemmyError> {
172     // remove session from all rooms
173     for sessions in self.mod_rooms.values_mut() {
174       sessions.remove(&id);
175     }
176
177     // If the room doesn't exist yet
178     if self.mod_rooms.get_mut(&community_id).is_none() {
179       self.mod_rooms.insert(community_id, HashSet::new());
180     }
181
182     self
183       .mod_rooms
184       .get_mut(&community_id)
185       .context(location_info!())?
186       .insert(id);
187     Ok(())
188   }
189
190   pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
191     // remove session from all rooms
192     for sessions in self.post_rooms.values_mut() {
193       sessions.remove(&id);
194     }
195
196     // Also leave all communities
197     // This avoids double messages
198     // TODO found a bug, whereby community messages like
199     // delete and remove aren't sent, because
200     // you left the community room
201     for sessions in self.community_rooms.values_mut() {
202       sessions.remove(&id);
203     }
204
205     // If the room doesn't exist yet
206     if self.post_rooms.get_mut(&post_id).is_none() {
207       self.post_rooms.insert(post_id, HashSet::new());
208     }
209
210     self
211       .post_rooms
212       .get_mut(&post_id)
213       .context(location_info!())?
214       .insert(id);
215
216     Ok(())
217   }
218
219   pub fn join_user_room(
220     &mut self,
221     user_id: LocalUserId,
222     id: ConnectionId,
223   ) -> Result<(), LemmyError> {
224     // remove session from all rooms
225     for sessions in self.user_rooms.values_mut() {
226       sessions.remove(&id);
227     }
228
229     // If the room doesn't exist yet
230     if self.user_rooms.get_mut(&user_id).is_none() {
231       self.user_rooms.insert(user_id, HashSet::new());
232     }
233
234     self
235       .user_rooms
236       .get_mut(&user_id)
237       .context(location_info!())?
238       .insert(id);
239
240     Ok(())
241   }
242
243   fn send_post_room_message<OP, Response>(
244     &self,
245     op: &OP,
246     response: &Response,
247     post_id: PostId,
248     websocket_id: Option<ConnectionId>,
249   ) -> Result<(), LemmyError>
250   where
251     OP: OperationType + ToString,
252     Response: Serialize,
253   {
254     let res_str = &serialize_websocket_message(op, response)?;
255     if let Some(sessions) = self.post_rooms.get(&post_id) {
256       for id in sessions {
257         if let Some(my_id) = websocket_id {
258           if *id == my_id {
259             continue;
260           }
261         }
262         self.sendit(res_str, *id);
263       }
264     }
265     Ok(())
266   }
267
268   pub fn send_community_room_message<OP, Response>(
269     &self,
270     op: &OP,
271     response: &Response,
272     community_id: CommunityId,
273     websocket_id: Option<ConnectionId>,
274   ) -> Result<(), LemmyError>
275   where
276     OP: OperationType + ToString,
277     Response: Serialize,
278   {
279     let res_str = &serialize_websocket_message(op, response)?;
280     if let Some(sessions) = self.community_rooms.get(&community_id) {
281       for id in sessions {
282         if let Some(my_id) = websocket_id {
283           if *id == my_id {
284             continue;
285           }
286         }
287         self.sendit(res_str, *id);
288       }
289     }
290     Ok(())
291   }
292
293   pub fn send_mod_room_message<OP, Response>(
294     &self,
295     op: &OP,
296     response: &Response,
297     community_id: CommunityId,
298     websocket_id: Option<ConnectionId>,
299   ) -> Result<(), LemmyError>
300   where
301     OP: OperationType + ToString,
302     Response: Serialize,
303   {
304     let res_str = &serialize_websocket_message(op, response)?;
305     if let Some(sessions) = self.mod_rooms.get(&community_id) {
306       for id in sessions {
307         if let Some(my_id) = websocket_id {
308           if *id == my_id {
309             continue;
310           }
311         }
312         self.sendit(res_str, *id);
313       }
314     }
315     Ok(())
316   }
317
318   pub fn send_all_message<OP, Response>(
319     &self,
320     op: &OP,
321     response: &Response,
322     websocket_id: Option<ConnectionId>,
323   ) -> Result<(), LemmyError>
324   where
325     OP: OperationType + ToString,
326     Response: Serialize,
327   {
328     let res_str = &serialize_websocket_message(op, response)?;
329     for id in self.sessions.keys() {
330       if let Some(my_id) = websocket_id {
331         if *id == my_id {
332           continue;
333         }
334       }
335       self.sendit(res_str, *id);
336     }
337     Ok(())
338   }
339
340   pub fn send_user_room_message<OP, Response>(
341     &self,
342     op: &OP,
343     response: &Response,
344     recipient_id: LocalUserId,
345     websocket_id: Option<ConnectionId>,
346   ) -> Result<(), LemmyError>
347   where
348     OP: OperationType + ToString,
349     Response: Serialize,
350   {
351     let res_str = &serialize_websocket_message(op, response)?;
352     if let Some(sessions) = self.user_rooms.get(&recipient_id) {
353       for id in sessions {
354         if let Some(my_id) = websocket_id {
355           if *id == my_id {
356             continue;
357           }
358         }
359         self.sendit(res_str, *id);
360       }
361     }
362     Ok(())
363   }
364
365   pub fn send_comment<OP>(
366     &self,
367     user_operation: &OP,
368     comment: &CommentResponse,
369     websocket_id: Option<ConnectionId>,
370   ) -> Result<(), LemmyError>
371   where
372     OP: OperationType + ToString,
373   {
374     let mut comment_reply_sent = comment.clone();
375
376     // Strip out my specific user info
377     comment_reply_sent.comment_view.my_vote = None;
378
379     // Send it to the post room
380     let mut comment_post_sent = comment_reply_sent.clone();
381     // Remove the recipients here to separate mentions / user messages from post or community comments
382     comment_post_sent.recipient_ids = Vec::new();
383     self.send_post_room_message(
384       user_operation,
385       &comment_post_sent,
386       comment_post_sent.comment_view.post.id,
387       websocket_id,
388     )?;
389
390     // Send it to the community too
391     self.send_community_room_message(
392       user_operation,
393       &comment_post_sent,
394       CommunityId(0),
395       websocket_id,
396     )?;
397     self.send_community_room_message(
398       user_operation,
399       &comment_post_sent,
400       comment.comment_view.community.id,
401       websocket_id,
402     )?;
403
404     // Send it to the recipient(s) including the mentioned users
405     for recipient_id in &comment_reply_sent.recipient_ids {
406       self.send_user_room_message(
407         user_operation,
408         &comment_reply_sent,
409         *recipient_id,
410         websocket_id,
411       )?;
412     }
413
414     Ok(())
415   }
416
417   pub fn send_post<OP>(
418     &self,
419     user_operation: &OP,
420     post_res: &PostResponse,
421     websocket_id: Option<ConnectionId>,
422   ) -> Result<(), LemmyError>
423   where
424     OP: OperationType + ToString,
425   {
426     let community_id = post_res.post_view.community.id;
427
428     // Don't send my data with it
429     let mut post_sent = post_res.clone();
430     post_sent.post_view.my_vote = None;
431
432     // Send it to /c/all and that community
433     self.send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)?;
434     self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?;
435
436     // Send it to the post room
437     self.send_post_room_message(
438       user_operation,
439       &post_sent,
440       post_res.post_view.post.id,
441       websocket_id,
442     )?;
443
444     Ok(())
445   }
446
447   fn sendit(&self, message: &str, id: ConnectionId) {
448     if let Some(info) = self.sessions.get(&id) {
449       let _ = info.addr.do_send(WsMessage(message.to_owned()));
450     }
451   }
452
453   pub(super) fn parse_json_message(
454     &mut self,
455     msg: StandardMessage,
456     ctx: &mut Context<Self>,
457   ) -> impl Future<Output = Result<String, LemmyError>> {
458     let rate_limiter = self.rate_limiter.clone();
459
460     let ip: IpAddr = match self.sessions.get(&msg.id) {
461       Some(info) => info.ip.to_owned(),
462       None => IpAddr("blank_ip".to_string()),
463     };
464
465     let context = LemmyContext {
466       pool: self.pool.clone(),
467       chat_server: ctx.address(),
468       client: self.client.to_owned(),
469       activity_queue: self.activity_queue.to_owned(),
470       settings: self.settings.to_owned(),
471       secret: self.secret.to_owned(),
472     };
473     let message_handler_crud = self.message_handler_crud;
474     let message_handler = self.message_handler;
475     async move {
476       let json: Value = serde_json::from_str(&msg.msg)?;
477       let data = &json["data"].to_string();
478       let op = &json["op"]
479         .as_str()
480         .ok_or_else(|| ApiError::err_plain("missing op"))?;
481
482       if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
483         let fut = (message_handler_crud)(context, msg.id, user_operation_crud.clone(), data);
484         match user_operation_crud {
485           UserOperationCrud::Register => rate_limiter.register().wrap(ip, fut).await,
486           UserOperationCrud::CreatePost => rate_limiter.post().wrap(ip, fut).await,
487           UserOperationCrud::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
488           UserOperationCrud::CreateComment => rate_limiter.comment().wrap(ip, fut).await,
489           _ => rate_limiter.message().wrap(ip, fut).await,
490         }
491       } else {
492         let user_operation = UserOperation::from_str(op)?;
493         let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
494         match user_operation {
495           UserOperation::GetCaptcha => rate_limiter.post().wrap(ip, fut).await,
496           _ => rate_limiter.message().wrap(ip, fut).await,
497         }
498       }
499     }
500   }
501 }