]> Untitled Git - lemmy.git/blob - crates/websocket/src/chat_server.rs
Moving settings and secrets to context.
[lemmy.git] / crates / websocket / src / chat_server.rs
1 use crate::{
2   messages::*,
3   serialize_websocket_message,
4   LemmyContext,
5   OperationType,
6   UserOperation,
7   UserOperationCrud,
8 };
9 use actix::prelude::*;
10 use anyhow::Context as acontext;
11 use background_jobs::QueueHandle;
12 use diesel::{
13   r2d2::{ConnectionManager, Pool},
14   PgConnection,
15 };
16 use lemmy_api_common::{comment::*, post::*};
17 use lemmy_db_schema::{source::secret::Secret, CommunityId, LocalUserId, PostId};
18 use lemmy_utils::{
19   location_info,
20   rate_limit::RateLimit,
21   settings::structs::Settings,
22   ApiError,
23   ConnectionId,
24   IpAddr,
25   LemmyError,
26 };
27 use rand::rngs::ThreadRng;
28 use reqwest::Client;
29 use serde::Serialize;
30 use serde_json::Value;
31 use std::{
32   collections::{HashMap, HashSet},
33   future::Future,
34   str::FromStr,
35 };
36 use tokio::macros::support::Pin;
37
38 type MessageHandlerType = fn(
39   context: LemmyContext,
40   id: ConnectionId,
41   op: UserOperation,
42   data: &str,
43 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
44
45 type MessageHandlerCrudType = fn(
46   context: LemmyContext,
47   id: ConnectionId,
48   op: UserOperationCrud,
49   data: &str,
50 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
51
52 /// `ChatServer` manages chat rooms and responsible for coordinating chat
53 /// session.
54 pub struct ChatServer {
55   /// A map from generated random ID to session addr
56   pub sessions: HashMap<ConnectionId, SessionInfo>,
57
58   /// A map from post_id to set of connectionIDs
59   pub post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
60
61   /// A map from community to set of connectionIDs
62   pub community_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
63
64   pub mod_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
65
66   /// A map from user id to its connection ID for joined users. Remember a user can have multiple
67   /// sessions (IE clients)
68   pub(super) user_rooms: HashMap<LocalUserId, HashSet<ConnectionId>>,
69
70   pub(super) rng: ThreadRng,
71
72   /// The DB Pool
73   pub(super) pool: Pool<ConnectionManager<PgConnection>>,
74
75   /// The Settings
76   pub(super) settings: Settings,
77
78   /// The Secrets
79   pub(super) secret: Secret,
80
81   /// Rate limiting based on rate type and IP addr
82   pub(super) rate_limiter: RateLimit,
83
84   /// A list of the current captchas
85   pub(super) captchas: Vec<CaptchaItem>,
86
87   message_handler: MessageHandlerType,
88   message_handler_crud: MessageHandlerCrudType,
89
90   /// An HTTP Client
91   client: Client,
92
93   activity_queue: QueueHandle,
94 }
95
96 pub struct SessionInfo {
97   pub addr: Recipient<WsMessage>,
98   pub ip: IpAddr,
99 }
100
101 /// `ChatServer` is an actor. It maintains list of connection client session.
102 /// And manages available rooms. Peers send messages to other peers in same
103 /// room through `ChatServer`.
104 impl ChatServer {
105   #![allow(clippy::too_many_arguments)]
106   pub fn startup(
107     pool: Pool<ConnectionManager<PgConnection>>,
108     rate_limiter: RateLimit,
109     message_handler: MessageHandlerType,
110     message_handler_crud: MessageHandlerCrudType,
111     client: Client,
112     activity_queue: QueueHandle,
113     settings: Settings,
114     secret: Secret,
115   ) -> ChatServer {
116     ChatServer {
117       sessions: HashMap::new(),
118       post_rooms: HashMap::new(),
119       community_rooms: HashMap::new(),
120       mod_rooms: HashMap::new(),
121       user_rooms: HashMap::new(),
122       rng: rand::thread_rng(),
123       pool,
124       rate_limiter,
125       captchas: Vec::new(),
126       message_handler,
127       message_handler_crud,
128       client,
129       activity_queue,
130       settings,
131       secret,
132     }
133   }
134
135   pub fn join_community_room(
136     &mut self,
137     community_id: CommunityId,
138     id: ConnectionId,
139   ) -> Result<(), LemmyError> {
140     // remove session from all rooms
141     for sessions in self.community_rooms.values_mut() {
142       sessions.remove(&id);
143     }
144
145     // Also leave all post rooms
146     // This avoids double messages
147     for sessions in self.post_rooms.values_mut() {
148       sessions.remove(&id);
149     }
150
151     // If the room doesn't exist yet
152     if self.community_rooms.get_mut(&community_id).is_none() {
153       self.community_rooms.insert(community_id, HashSet::new());
154     }
155
156     self
157       .community_rooms
158       .get_mut(&community_id)
159       .context(location_info!())?
160       .insert(id);
161     Ok(())
162   }
163
164   pub fn join_mod_room(
165     &mut self,
166     community_id: CommunityId,
167     id: ConnectionId,
168   ) -> Result<(), LemmyError> {
169     // remove session from all rooms
170     for sessions in self.mod_rooms.values_mut() {
171       sessions.remove(&id);
172     }
173
174     // If the room doesn't exist yet
175     if self.mod_rooms.get_mut(&community_id).is_none() {
176       self.mod_rooms.insert(community_id, HashSet::new());
177     }
178
179     self
180       .mod_rooms
181       .get_mut(&community_id)
182       .context(location_info!())?
183       .insert(id);
184     Ok(())
185   }
186
187   pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
188     // remove session from all rooms
189     for sessions in self.post_rooms.values_mut() {
190       sessions.remove(&id);
191     }
192
193     // Also leave all communities
194     // This avoids double messages
195     // TODO found a bug, whereby community messages like
196     // delete and remove aren't sent, because
197     // you left the community room
198     for sessions in self.community_rooms.values_mut() {
199       sessions.remove(&id);
200     }
201
202     // If the room doesn't exist yet
203     if self.post_rooms.get_mut(&post_id).is_none() {
204       self.post_rooms.insert(post_id, HashSet::new());
205     }
206
207     self
208       .post_rooms
209       .get_mut(&post_id)
210       .context(location_info!())?
211       .insert(id);
212
213     Ok(())
214   }
215
216   pub fn join_user_room(
217     &mut self,
218     user_id: LocalUserId,
219     id: ConnectionId,
220   ) -> Result<(), LemmyError> {
221     // remove session from all rooms
222     for sessions in self.user_rooms.values_mut() {
223       sessions.remove(&id);
224     }
225
226     // If the room doesn't exist yet
227     if self.user_rooms.get_mut(&user_id).is_none() {
228       self.user_rooms.insert(user_id, HashSet::new());
229     }
230
231     self
232       .user_rooms
233       .get_mut(&user_id)
234       .context(location_info!())?
235       .insert(id);
236
237     Ok(())
238   }
239
240   fn send_post_room_message<OP, Response>(
241     &self,
242     op: &OP,
243     response: &Response,
244     post_id: PostId,
245     websocket_id: Option<ConnectionId>,
246   ) -> Result<(), LemmyError>
247   where
248     OP: OperationType + ToString,
249     Response: Serialize,
250   {
251     let res_str = &serialize_websocket_message(op, response)?;
252     if let Some(sessions) = self.post_rooms.get(&post_id) {
253       for id in sessions {
254         if let Some(my_id) = websocket_id {
255           if *id == my_id {
256             continue;
257           }
258         }
259         self.sendit(res_str, *id);
260       }
261     }
262     Ok(())
263   }
264
265   pub fn send_community_room_message<OP, Response>(
266     &self,
267     op: &OP,
268     response: &Response,
269     community_id: CommunityId,
270     websocket_id: Option<ConnectionId>,
271   ) -> Result<(), LemmyError>
272   where
273     OP: OperationType + ToString,
274     Response: Serialize,
275   {
276     let res_str = &serialize_websocket_message(op, response)?;
277     if let Some(sessions) = self.community_rooms.get(&community_id) {
278       for id in sessions {
279         if let Some(my_id) = websocket_id {
280           if *id == my_id {
281             continue;
282           }
283         }
284         self.sendit(res_str, *id);
285       }
286     }
287     Ok(())
288   }
289
290   pub fn send_mod_room_message<OP, Response>(
291     &self,
292     op: &OP,
293     response: &Response,
294     community_id: CommunityId,
295     websocket_id: Option<ConnectionId>,
296   ) -> Result<(), LemmyError>
297   where
298     OP: OperationType + ToString,
299     Response: Serialize,
300   {
301     let res_str = &serialize_websocket_message(op, response)?;
302     if let Some(sessions) = self.mod_rooms.get(&community_id) {
303       for id in sessions {
304         if let Some(my_id) = websocket_id {
305           if *id == my_id {
306             continue;
307           }
308         }
309         self.sendit(res_str, *id);
310       }
311     }
312     Ok(())
313   }
314
315   pub fn send_all_message<OP, Response>(
316     &self,
317     op: &OP,
318     response: &Response,
319     websocket_id: Option<ConnectionId>,
320   ) -> Result<(), LemmyError>
321   where
322     OP: OperationType + ToString,
323     Response: Serialize,
324   {
325     let res_str = &serialize_websocket_message(op, response)?;
326     for id in self.sessions.keys() {
327       if let Some(my_id) = websocket_id {
328         if *id == my_id {
329           continue;
330         }
331       }
332       self.sendit(res_str, *id);
333     }
334     Ok(())
335   }
336
337   pub fn send_user_room_message<OP, Response>(
338     &self,
339     op: &OP,
340     response: &Response,
341     recipient_id: LocalUserId,
342     websocket_id: Option<ConnectionId>,
343   ) -> Result<(), LemmyError>
344   where
345     OP: OperationType + ToString,
346     Response: Serialize,
347   {
348     let res_str = &serialize_websocket_message(op, response)?;
349     if let Some(sessions) = self.user_rooms.get(&recipient_id) {
350       for id in sessions {
351         if let Some(my_id) = websocket_id {
352           if *id == my_id {
353             continue;
354           }
355         }
356         self.sendit(res_str, *id);
357       }
358     }
359     Ok(())
360   }
361
362   pub fn send_comment<OP>(
363     &self,
364     user_operation: &OP,
365     comment: &CommentResponse,
366     websocket_id: Option<ConnectionId>,
367   ) -> Result<(), LemmyError>
368   where
369     OP: OperationType + ToString,
370   {
371     let mut comment_reply_sent = comment.clone();
372
373     // Strip out my specific user info
374     comment_reply_sent.comment_view.my_vote = None;
375
376     // Send it to the post room
377     let mut comment_post_sent = comment_reply_sent.clone();
378     // Remove the recipients here to separate mentions / user messages from post or community comments
379     comment_post_sent.recipient_ids = Vec::new();
380     self.send_post_room_message(
381       user_operation,
382       &comment_post_sent,
383       comment_post_sent.comment_view.post.id,
384       websocket_id,
385     )?;
386
387     // Send it to the community too
388     self.send_community_room_message(
389       user_operation,
390       &comment_post_sent,
391       CommunityId(0),
392       websocket_id,
393     )?;
394     self.send_community_room_message(
395       user_operation,
396       &comment_post_sent,
397       comment.comment_view.community.id,
398       websocket_id,
399     )?;
400
401     // Send it to the recipient(s) including the mentioned users
402     for recipient_id in &comment_reply_sent.recipient_ids {
403       self.send_user_room_message(
404         user_operation,
405         &comment_reply_sent,
406         *recipient_id,
407         websocket_id,
408       )?;
409     }
410
411     Ok(())
412   }
413
414   pub fn send_post<OP>(
415     &self,
416     user_operation: &OP,
417     post_res: &PostResponse,
418     websocket_id: Option<ConnectionId>,
419   ) -> Result<(), LemmyError>
420   where
421     OP: OperationType + ToString,
422   {
423     let community_id = post_res.post_view.community.id;
424
425     // Don't send my data with it
426     let mut post_sent = post_res.clone();
427     post_sent.post_view.my_vote = None;
428
429     // Send it to /c/all and that community
430     self.send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)?;
431     self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?;
432
433     // Send it to the post room
434     self.send_post_room_message(
435       user_operation,
436       &post_sent,
437       post_res.post_view.post.id,
438       websocket_id,
439     )?;
440
441     Ok(())
442   }
443
444   fn sendit(&self, message: &str, id: ConnectionId) {
445     if let Some(info) = self.sessions.get(&id) {
446       let _ = info.addr.do_send(WsMessage(message.to_owned()));
447     }
448   }
449
450   pub(super) fn parse_json_message(
451     &mut self,
452     msg: StandardMessage,
453     ctx: &mut Context<Self>,
454   ) -> impl Future<Output = Result<String, LemmyError>> {
455     let rate_limiter = self.rate_limiter.clone();
456
457     let ip: IpAddr = match self.sessions.get(&msg.id) {
458       Some(info) => info.ip.to_owned(),
459       None => IpAddr("blank_ip".to_string()),
460     };
461
462     let context = LemmyContext {
463       pool: self.pool.clone(),
464       chat_server: ctx.address(),
465       client: self.client.to_owned(),
466       activity_queue: self.activity_queue.to_owned(),
467       settings: self.settings.to_owned(),
468       secret: self.secret.to_owned(),
469     };
470     let message_handler_crud = self.message_handler_crud;
471     let message_handler = self.message_handler;
472     async move {
473       let json: Value = serde_json::from_str(&msg.msg)?;
474       let data = &json["data"].to_string();
475       let op = &json["op"].as_str().ok_or(ApiError {
476         message: "Unknown op type".to_string(),
477       })?;
478
479       if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
480         let fut = (message_handler_crud)(context, msg.id, user_operation_crud.clone(), data);
481         match user_operation_crud {
482           UserOperationCrud::Register => rate_limiter.register().wrap(ip, fut).await,
483           UserOperationCrud::CreatePost => rate_limiter.post().wrap(ip, fut).await,
484           UserOperationCrud::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
485           _ => rate_limiter.message().wrap(ip, fut).await,
486         }
487       } else {
488         let user_operation = UserOperation::from_str(op)?;
489         let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
490         rate_limiter.message().wrap(ip, fut).await
491       }
492     }
493   }
494 }