]> Untitled Git - lemmy.git/blob - crates/websocket/src/chat_server.rs
8dcc81d1c1acbe5bc63357bb211bed073f1772e2
[lemmy.git] / crates / websocket / src / chat_server.rs
1 use crate::{
2   messages::*,
3   serialize_websocket_message,
4   LemmyContext,
5   OperationType,
6   UserOperation,
7   UserOperationCrud,
8 };
9 use actix::prelude::*;
10 use anyhow::Context as acontext;
11 use diesel::{
12   r2d2::{ConnectionManager, Pool},
13   PgConnection,
14 };
15 use lemmy_api_common::{comment::*, post::*};
16 use lemmy_db_schema::{
17   newtypes::{CommunityId, LocalUserId, PostId},
18   source::secret::Secret,
19 };
20 use lemmy_utils::{
21   error::LemmyError,
22   location_info,
23   rate_limit::RateLimit,
24   settings::structs::Settings,
25   ConnectionId,
26   IpAddr,
27 };
28 use rand::rngs::ThreadRng;
29 use reqwest_middleware::ClientWithMiddleware;
30 use serde::Serialize;
31 use serde_json::Value;
32 use std::{
33   collections::{HashMap, HashSet},
34   future::Future,
35   str::FromStr,
36 };
37 use tokio::macros::support::Pin;
38
39 type MessageHandlerType = fn(
40   context: LemmyContext,
41   id: ConnectionId,
42   op: UserOperation,
43   data: &str,
44 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
45
46 type MessageHandlerCrudType = fn(
47   context: LemmyContext,
48   id: ConnectionId,
49   op: UserOperationCrud,
50   data: &str,
51 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
52
53 /// `ChatServer` manages chat rooms and responsible for coordinating chat
54 /// session.
55 pub struct ChatServer {
56   /// A map from generated random ID to session addr
57   pub sessions: HashMap<ConnectionId, SessionInfo>,
58
59   /// A map from post_id to set of connectionIDs
60   pub post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
61
62   /// A map from community to set of connectionIDs
63   pub community_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
64
65   pub mod_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
66
67   /// A map from user id to its connection ID for joined users. Remember a user can have multiple
68   /// sessions (IE clients)
69   pub(super) user_rooms: HashMap<LocalUserId, HashSet<ConnectionId>>,
70
71   pub(super) rng: ThreadRng,
72
73   /// The DB Pool
74   pub(super) pool: Pool<ConnectionManager<PgConnection>>,
75
76   /// The Settings
77   pub(super) settings: Settings,
78
79   /// The Secrets
80   pub(super) secret: Secret,
81
82   /// Rate limiting based on rate type and IP addr
83   pub(super) rate_limiter: RateLimit,
84
85   /// A list of the current captchas
86   pub(super) captchas: Vec<CaptchaItem>,
87
88   message_handler: MessageHandlerType,
89   message_handler_crud: MessageHandlerCrudType,
90
91   /// An HTTP Client
92   client: ClientWithMiddleware,
93 }
94
95 pub struct SessionInfo {
96   pub addr: Recipient<WsMessage>,
97   pub ip: IpAddr,
98 }
99
100 /// `ChatServer` is an actor. It maintains list of connection client session.
101 /// And manages available rooms. Peers send messages to other peers in same
102 /// room through `ChatServer`.
103 impl ChatServer {
104   #![allow(clippy::too_many_arguments)]
105   pub fn startup(
106     pool: Pool<ConnectionManager<PgConnection>>,
107     rate_limiter: RateLimit,
108     message_handler: MessageHandlerType,
109     message_handler_crud: MessageHandlerCrudType,
110     client: ClientWithMiddleware,
111     settings: Settings,
112     secret: Secret,
113   ) -> ChatServer {
114     ChatServer {
115       sessions: HashMap::new(),
116       post_rooms: HashMap::new(),
117       community_rooms: HashMap::new(),
118       mod_rooms: HashMap::new(),
119       user_rooms: HashMap::new(),
120       rng: rand::thread_rng(),
121       pool,
122       rate_limiter,
123       captchas: Vec::new(),
124       message_handler,
125       message_handler_crud,
126       client,
127       settings,
128       secret,
129     }
130   }
131
132   pub fn join_community_room(
133     &mut self,
134     community_id: CommunityId,
135     id: ConnectionId,
136   ) -> Result<(), LemmyError> {
137     // remove session from all rooms
138     for sessions in self.community_rooms.values_mut() {
139       sessions.remove(&id);
140     }
141
142     // Also leave all post rooms
143     // This avoids double messages
144     for sessions in self.post_rooms.values_mut() {
145       sessions.remove(&id);
146     }
147
148     // If the room doesn't exist yet
149     if self.community_rooms.get_mut(&community_id).is_none() {
150       self.community_rooms.insert(community_id, HashSet::new());
151     }
152
153     self
154       .community_rooms
155       .get_mut(&community_id)
156       .context(location_info!())?
157       .insert(id);
158     Ok(())
159   }
160
161   pub fn join_mod_room(
162     &mut self,
163     community_id: CommunityId,
164     id: ConnectionId,
165   ) -> Result<(), LemmyError> {
166     // remove session from all rooms
167     for sessions in self.mod_rooms.values_mut() {
168       sessions.remove(&id);
169     }
170
171     // If the room doesn't exist yet
172     if self.mod_rooms.get_mut(&community_id).is_none() {
173       self.mod_rooms.insert(community_id, HashSet::new());
174     }
175
176     self
177       .mod_rooms
178       .get_mut(&community_id)
179       .context(location_info!())?
180       .insert(id);
181     Ok(())
182   }
183
184   pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
185     // remove session from all rooms
186     for sessions in self.post_rooms.values_mut() {
187       sessions.remove(&id);
188     }
189
190     // Also leave all communities
191     // This avoids double messages
192     // TODO found a bug, whereby community messages like
193     // delete and remove aren't sent, because
194     // you left the community room
195     for sessions in self.community_rooms.values_mut() {
196       sessions.remove(&id);
197     }
198
199     // If the room doesn't exist yet
200     if self.post_rooms.get_mut(&post_id).is_none() {
201       self.post_rooms.insert(post_id, HashSet::new());
202     }
203
204     self
205       .post_rooms
206       .get_mut(&post_id)
207       .context(location_info!())?
208       .insert(id);
209
210     Ok(())
211   }
212
213   pub fn join_user_room(
214     &mut self,
215     user_id: LocalUserId,
216     id: ConnectionId,
217   ) -> Result<(), LemmyError> {
218     // remove session from all rooms
219     for sessions in self.user_rooms.values_mut() {
220       sessions.remove(&id);
221     }
222
223     // If the room doesn't exist yet
224     if self.user_rooms.get_mut(&user_id).is_none() {
225       self.user_rooms.insert(user_id, HashSet::new());
226     }
227
228     self
229       .user_rooms
230       .get_mut(&user_id)
231       .context(location_info!())?
232       .insert(id);
233
234     Ok(())
235   }
236
237   fn send_post_room_message<OP, Response>(
238     &self,
239     op: &OP,
240     response: &Response,
241     post_id: PostId,
242     websocket_id: Option<ConnectionId>,
243   ) -> Result<(), LemmyError>
244   where
245     OP: OperationType + ToString,
246     Response: Serialize,
247   {
248     let res_str = &serialize_websocket_message(op, response)?;
249     if let Some(sessions) = self.post_rooms.get(&post_id) {
250       for id in sessions {
251         if let Some(my_id) = websocket_id {
252           if *id == my_id {
253             continue;
254           }
255         }
256         self.sendit(res_str, *id);
257       }
258     }
259     Ok(())
260   }
261
262   pub fn send_community_room_message<OP, Response>(
263     &self,
264     op: &OP,
265     response: &Response,
266     community_id: CommunityId,
267     websocket_id: Option<ConnectionId>,
268   ) -> Result<(), LemmyError>
269   where
270     OP: OperationType + ToString,
271     Response: Serialize,
272   {
273     let res_str = &serialize_websocket_message(op, response)?;
274     if let Some(sessions) = self.community_rooms.get(&community_id) {
275       for id in sessions {
276         if let Some(my_id) = websocket_id {
277           if *id == my_id {
278             continue;
279           }
280         }
281         self.sendit(res_str, *id);
282       }
283     }
284     Ok(())
285   }
286
287   pub fn send_mod_room_message<OP, Response>(
288     &self,
289     op: &OP,
290     response: &Response,
291     community_id: CommunityId,
292     websocket_id: Option<ConnectionId>,
293   ) -> Result<(), LemmyError>
294   where
295     OP: OperationType + ToString,
296     Response: Serialize,
297   {
298     let res_str = &serialize_websocket_message(op, response)?;
299     if let Some(sessions) = self.mod_rooms.get(&community_id) {
300       for id in sessions {
301         if let Some(my_id) = websocket_id {
302           if *id == my_id {
303             continue;
304           }
305         }
306         self.sendit(res_str, *id);
307       }
308     }
309     Ok(())
310   }
311
312   pub fn send_all_message<OP, Response>(
313     &self,
314     op: &OP,
315     response: &Response,
316     websocket_id: Option<ConnectionId>,
317   ) -> Result<(), LemmyError>
318   where
319     OP: OperationType + ToString,
320     Response: Serialize,
321   {
322     let res_str = &serialize_websocket_message(op, response)?;
323     for id in self.sessions.keys() {
324       if let Some(my_id) = websocket_id {
325         if *id == my_id {
326           continue;
327         }
328       }
329       self.sendit(res_str, *id);
330     }
331     Ok(())
332   }
333
334   pub fn send_user_room_message<OP, Response>(
335     &self,
336     op: &OP,
337     response: &Response,
338     recipient_id: LocalUserId,
339     websocket_id: Option<ConnectionId>,
340   ) -> Result<(), LemmyError>
341   where
342     OP: OperationType + ToString,
343     Response: Serialize,
344   {
345     let res_str = &serialize_websocket_message(op, response)?;
346     if let Some(sessions) = self.user_rooms.get(&recipient_id) {
347       for id in sessions {
348         if let Some(my_id) = websocket_id {
349           if *id == my_id {
350             continue;
351           }
352         }
353         self.sendit(res_str, *id);
354       }
355     }
356     Ok(())
357   }
358
359   pub fn send_comment<OP>(
360     &self,
361     user_operation: &OP,
362     comment: &CommentResponse,
363     websocket_id: Option<ConnectionId>,
364   ) -> Result<(), LemmyError>
365   where
366     OP: OperationType + ToString,
367   {
368     let mut comment_reply_sent = comment.clone();
369
370     // Strip out my specific user info
371     comment_reply_sent.comment_view.my_vote = None;
372
373     // Send it to the post room
374     let mut comment_post_sent = comment_reply_sent.clone();
375     // Remove the recipients here to separate mentions / user messages from post or community comments
376     comment_post_sent.recipient_ids = Vec::new();
377     self.send_post_room_message(
378       user_operation,
379       &comment_post_sent,
380       comment_post_sent.comment_view.post.id,
381       websocket_id,
382     )?;
383
384     // Send it to the community too
385     self.send_community_room_message(
386       user_operation,
387       &comment_post_sent,
388       CommunityId(0),
389       websocket_id,
390     )?;
391     self.send_community_room_message(
392       user_operation,
393       &comment_post_sent,
394       comment.comment_view.community.id,
395       websocket_id,
396     )?;
397
398     // Send it to the recipient(s) including the mentioned users
399     for recipient_id in &comment_reply_sent.recipient_ids {
400       self.send_user_room_message(
401         user_operation,
402         &comment_reply_sent,
403         *recipient_id,
404         websocket_id,
405       )?;
406     }
407
408     Ok(())
409   }
410
411   pub fn send_post<OP>(
412     &self,
413     user_operation: &OP,
414     post_res: &PostResponse,
415     websocket_id: Option<ConnectionId>,
416   ) -> Result<(), LemmyError>
417   where
418     OP: OperationType + ToString,
419   {
420     let community_id = post_res.post_view.community.id;
421
422     // Don't send my data with it
423     let mut post_sent = post_res.clone();
424     post_sent.post_view.my_vote = None;
425
426     // Send it to /c/all and that community
427     self.send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)?;
428     self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?;
429
430     // Send it to the post room
431     self.send_post_room_message(
432       user_operation,
433       &post_sent,
434       post_res.post_view.post.id,
435       websocket_id,
436     )?;
437
438     Ok(())
439   }
440
441   fn sendit(&self, message: &str, id: ConnectionId) {
442     if let Some(info) = self.sessions.get(&id) {
443       info.addr.do_send(WsMessage(message.to_owned()));
444     }
445   }
446
447   pub(super) fn parse_json_message(
448     &mut self,
449     msg: StandardMessage,
450     ctx: &mut Context<Self>,
451   ) -> impl Future<Output = Result<String, LemmyError>> {
452     let rate_limiter = self.rate_limiter.clone();
453
454     let ip: IpAddr = match self.sessions.get(&msg.id) {
455       Some(info) => info.ip.to_owned(),
456       None => IpAddr("blank_ip".to_string()),
457     };
458
459     let context = LemmyContext {
460       pool: self.pool.clone(),
461       chat_server: ctx.address(),
462       client: self.client.to_owned(),
463       settings: self.settings.to_owned(),
464       secret: self.secret.to_owned(),
465     };
466     let message_handler_crud = self.message_handler_crud;
467     let message_handler = self.message_handler;
468     async move {
469       let json: Value = serde_json::from_str(&msg.msg)?;
470       let data = &json["data"].to_string();
471       let op = &json["op"]
472         .as_str()
473         .ok_or_else(|| LemmyError::from_message("missing op"))?;
474
475       // check if api call passes the rate limit, and generate future for later execution
476       let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
477         let passed = match user_operation_crud {
478           UserOperationCrud::Register => rate_limiter.register().check(ip),
479           UserOperationCrud::CreatePost => rate_limiter.post().check(ip),
480           UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip),
481           UserOperationCrud::CreateComment => rate_limiter.comment().check(ip),
482           _ => rate_limiter.message().check(ip),
483         };
484         let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
485         (passed, fut)
486       } else {
487         let user_operation = UserOperation::from_str(op)?;
488         let passed = match user_operation {
489           UserOperation::GetCaptcha => rate_limiter.post().check(ip),
490           UserOperation::Search => rate_limiter.search().check(ip),
491           _ => rate_limiter.message().check(ip),
492         };
493         let fut = (message_handler)(context, msg.id, user_operation, data);
494         (passed, fut)
495       };
496
497       // if rate limit passed, execute api call future
498       if passed {
499         fut.await
500       } else {
501         // if rate limit was hit, respond with empty message
502         Ok("".to_string())
503       }
504     }
505   }
506 }