]> Untitled Git - lemmy.git/blob - crates/websocket/src/chat_server.rs
Consolidate reqwest clients, use reqwest-middleware for tracing
[lemmy.git] / crates / websocket / src / chat_server.rs
1 use crate::{
2   messages::*,
3   serialize_websocket_message,
4   LemmyContext,
5   OperationType,
6   UserOperation,
7   UserOperationCrud,
8 };
9 use actix::prelude::*;
10 use anyhow::Context as acontext;
11 use background_jobs::QueueHandle;
12 use diesel::{
13   r2d2::{ConnectionManager, Pool},
14   PgConnection,
15 };
16 use lemmy_api_common::{comment::*, post::*};
17 use lemmy_db_schema::{
18   newtypes::{CommunityId, LocalUserId, PostId},
19   source::secret::Secret,
20 };
21 use lemmy_utils::{
22   location_info,
23   rate_limit::RateLimit,
24   settings::structs::Settings,
25   ConnectionId,
26   IpAddr,
27   LemmyError,
28 };
29 use rand::rngs::ThreadRng;
30 use reqwest_middleware::ClientWithMiddleware;
31 use serde::Serialize;
32 use serde_json::Value;
33 use std::{
34   collections::{HashMap, HashSet},
35   future::Future,
36   str::FromStr,
37 };
38 use tokio::macros::support::Pin;
39
40 type MessageHandlerType = fn(
41   context: LemmyContext,
42   id: ConnectionId,
43   op: UserOperation,
44   data: &str,
45 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
46
47 type MessageHandlerCrudType = fn(
48   context: LemmyContext,
49   id: ConnectionId,
50   op: UserOperationCrud,
51   data: &str,
52 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
53
54 /// `ChatServer` manages chat rooms and responsible for coordinating chat
55 /// session.
56 pub struct ChatServer {
57   /// A map from generated random ID to session addr
58   pub sessions: HashMap<ConnectionId, SessionInfo>,
59
60   /// A map from post_id to set of connectionIDs
61   pub post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
62
63   /// A map from community to set of connectionIDs
64   pub community_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
65
66   pub mod_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
67
68   /// A map from user id to its connection ID for joined users. Remember a user can have multiple
69   /// sessions (IE clients)
70   pub(super) user_rooms: HashMap<LocalUserId, HashSet<ConnectionId>>,
71
72   pub(super) rng: ThreadRng,
73
74   /// The DB Pool
75   pub(super) pool: Pool<ConnectionManager<PgConnection>>,
76
77   /// The Settings
78   pub(super) settings: Settings,
79
80   /// The Secrets
81   pub(super) secret: Secret,
82
83   /// Rate limiting based on rate type and IP addr
84   pub(super) rate_limiter: RateLimit,
85
86   /// A list of the current captchas
87   pub(super) captchas: Vec<CaptchaItem>,
88
89   message_handler: MessageHandlerType,
90   message_handler_crud: MessageHandlerCrudType,
91
92   /// An HTTP Client
93   client: ClientWithMiddleware,
94
95   activity_queue: QueueHandle,
96 }
97
98 pub struct SessionInfo {
99   pub addr: Recipient<WsMessage>,
100   pub ip: IpAddr,
101 }
102
103 /// `ChatServer` is an actor. It maintains list of connection client session.
104 /// And manages available rooms. Peers send messages to other peers in same
105 /// room through `ChatServer`.
106 impl ChatServer {
107   #![allow(clippy::too_many_arguments)]
108   pub fn startup(
109     pool: Pool<ConnectionManager<PgConnection>>,
110     rate_limiter: RateLimit,
111     message_handler: MessageHandlerType,
112     message_handler_crud: MessageHandlerCrudType,
113     client: ClientWithMiddleware,
114     activity_queue: QueueHandle,
115     settings: Settings,
116     secret: Secret,
117   ) -> ChatServer {
118     ChatServer {
119       sessions: HashMap::new(),
120       post_rooms: HashMap::new(),
121       community_rooms: HashMap::new(),
122       mod_rooms: HashMap::new(),
123       user_rooms: HashMap::new(),
124       rng: rand::thread_rng(),
125       pool,
126       rate_limiter,
127       captchas: Vec::new(),
128       message_handler,
129       message_handler_crud,
130       client,
131       activity_queue,
132       settings,
133       secret,
134     }
135   }
136
137   pub fn join_community_room(
138     &mut self,
139     community_id: CommunityId,
140     id: ConnectionId,
141   ) -> Result<(), LemmyError> {
142     // remove session from all rooms
143     for sessions in self.community_rooms.values_mut() {
144       sessions.remove(&id);
145     }
146
147     // Also leave all post rooms
148     // This avoids double messages
149     for sessions in self.post_rooms.values_mut() {
150       sessions.remove(&id);
151     }
152
153     // If the room doesn't exist yet
154     if self.community_rooms.get_mut(&community_id).is_none() {
155       self.community_rooms.insert(community_id, HashSet::new());
156     }
157
158     self
159       .community_rooms
160       .get_mut(&community_id)
161       .context(location_info!())?
162       .insert(id);
163     Ok(())
164   }
165
166   pub fn join_mod_room(
167     &mut self,
168     community_id: CommunityId,
169     id: ConnectionId,
170   ) -> Result<(), LemmyError> {
171     // remove session from all rooms
172     for sessions in self.mod_rooms.values_mut() {
173       sessions.remove(&id);
174     }
175
176     // If the room doesn't exist yet
177     if self.mod_rooms.get_mut(&community_id).is_none() {
178       self.mod_rooms.insert(community_id, HashSet::new());
179     }
180
181     self
182       .mod_rooms
183       .get_mut(&community_id)
184       .context(location_info!())?
185       .insert(id);
186     Ok(())
187   }
188
189   pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
190     // remove session from all rooms
191     for sessions in self.post_rooms.values_mut() {
192       sessions.remove(&id);
193     }
194
195     // Also leave all communities
196     // This avoids double messages
197     // TODO found a bug, whereby community messages like
198     // delete and remove aren't sent, because
199     // you left the community room
200     for sessions in self.community_rooms.values_mut() {
201       sessions.remove(&id);
202     }
203
204     // If the room doesn't exist yet
205     if self.post_rooms.get_mut(&post_id).is_none() {
206       self.post_rooms.insert(post_id, HashSet::new());
207     }
208
209     self
210       .post_rooms
211       .get_mut(&post_id)
212       .context(location_info!())?
213       .insert(id);
214
215     Ok(())
216   }
217
218   pub fn join_user_room(
219     &mut self,
220     user_id: LocalUserId,
221     id: ConnectionId,
222   ) -> Result<(), LemmyError> {
223     // remove session from all rooms
224     for sessions in self.user_rooms.values_mut() {
225       sessions.remove(&id);
226     }
227
228     // If the room doesn't exist yet
229     if self.user_rooms.get_mut(&user_id).is_none() {
230       self.user_rooms.insert(user_id, HashSet::new());
231     }
232
233     self
234       .user_rooms
235       .get_mut(&user_id)
236       .context(location_info!())?
237       .insert(id);
238
239     Ok(())
240   }
241
242   fn send_post_room_message<OP, Response>(
243     &self,
244     op: &OP,
245     response: &Response,
246     post_id: PostId,
247     websocket_id: Option<ConnectionId>,
248   ) -> Result<(), LemmyError>
249   where
250     OP: OperationType + ToString,
251     Response: Serialize,
252   {
253     let res_str = &serialize_websocket_message(op, response)?;
254     if let Some(sessions) = self.post_rooms.get(&post_id) {
255       for id in sessions {
256         if let Some(my_id) = websocket_id {
257           if *id == my_id {
258             continue;
259           }
260         }
261         self.sendit(res_str, *id);
262       }
263     }
264     Ok(())
265   }
266
267   pub fn send_community_room_message<OP, Response>(
268     &self,
269     op: &OP,
270     response: &Response,
271     community_id: CommunityId,
272     websocket_id: Option<ConnectionId>,
273   ) -> Result<(), LemmyError>
274   where
275     OP: OperationType + ToString,
276     Response: Serialize,
277   {
278     let res_str = &serialize_websocket_message(op, response)?;
279     if let Some(sessions) = self.community_rooms.get(&community_id) {
280       for id in sessions {
281         if let Some(my_id) = websocket_id {
282           if *id == my_id {
283             continue;
284           }
285         }
286         self.sendit(res_str, *id);
287       }
288     }
289     Ok(())
290   }
291
292   pub fn send_mod_room_message<OP, Response>(
293     &self,
294     op: &OP,
295     response: &Response,
296     community_id: CommunityId,
297     websocket_id: Option<ConnectionId>,
298   ) -> Result<(), LemmyError>
299   where
300     OP: OperationType + ToString,
301     Response: Serialize,
302   {
303     let res_str = &serialize_websocket_message(op, response)?;
304     if let Some(sessions) = self.mod_rooms.get(&community_id) {
305       for id in sessions {
306         if let Some(my_id) = websocket_id {
307           if *id == my_id {
308             continue;
309           }
310         }
311         self.sendit(res_str, *id);
312       }
313     }
314     Ok(())
315   }
316
317   pub fn send_all_message<OP, Response>(
318     &self,
319     op: &OP,
320     response: &Response,
321     websocket_id: Option<ConnectionId>,
322   ) -> Result<(), LemmyError>
323   where
324     OP: OperationType + ToString,
325     Response: Serialize,
326   {
327     let res_str = &serialize_websocket_message(op, response)?;
328     for id in self.sessions.keys() {
329       if let Some(my_id) = websocket_id {
330         if *id == my_id {
331           continue;
332         }
333       }
334       self.sendit(res_str, *id);
335     }
336     Ok(())
337   }
338
339   pub fn send_user_room_message<OP, Response>(
340     &self,
341     op: &OP,
342     response: &Response,
343     recipient_id: LocalUserId,
344     websocket_id: Option<ConnectionId>,
345   ) -> Result<(), LemmyError>
346   where
347     OP: OperationType + ToString,
348     Response: Serialize,
349   {
350     let res_str = &serialize_websocket_message(op, response)?;
351     if let Some(sessions) = self.user_rooms.get(&recipient_id) {
352       for id in sessions {
353         if let Some(my_id) = websocket_id {
354           if *id == my_id {
355             continue;
356           }
357         }
358         self.sendit(res_str, *id);
359       }
360     }
361     Ok(())
362   }
363
364   pub fn send_comment<OP>(
365     &self,
366     user_operation: &OP,
367     comment: &CommentResponse,
368     websocket_id: Option<ConnectionId>,
369   ) -> Result<(), LemmyError>
370   where
371     OP: OperationType + ToString,
372   {
373     let mut comment_reply_sent = comment.clone();
374
375     // Strip out my specific user info
376     comment_reply_sent.comment_view.my_vote = None;
377
378     // Send it to the post room
379     let mut comment_post_sent = comment_reply_sent.clone();
380     // Remove the recipients here to separate mentions / user messages from post or community comments
381     comment_post_sent.recipient_ids = Vec::new();
382     self.send_post_room_message(
383       user_operation,
384       &comment_post_sent,
385       comment_post_sent.comment_view.post.id,
386       websocket_id,
387     )?;
388
389     // Send it to the community too
390     self.send_community_room_message(
391       user_operation,
392       &comment_post_sent,
393       CommunityId(0),
394       websocket_id,
395     )?;
396     self.send_community_room_message(
397       user_operation,
398       &comment_post_sent,
399       comment.comment_view.community.id,
400       websocket_id,
401     )?;
402
403     // Send it to the recipient(s) including the mentioned users
404     for recipient_id in &comment_reply_sent.recipient_ids {
405       self.send_user_room_message(
406         user_operation,
407         &comment_reply_sent,
408         *recipient_id,
409         websocket_id,
410       )?;
411     }
412
413     Ok(())
414   }
415
416   pub fn send_post<OP>(
417     &self,
418     user_operation: &OP,
419     post_res: &PostResponse,
420     websocket_id: Option<ConnectionId>,
421   ) -> Result<(), LemmyError>
422   where
423     OP: OperationType + ToString,
424   {
425     let community_id = post_res.post_view.community.id;
426
427     // Don't send my data with it
428     let mut post_sent = post_res.clone();
429     post_sent.post_view.my_vote = None;
430
431     // Send it to /c/all and that community
432     self.send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)?;
433     self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?;
434
435     // Send it to the post room
436     self.send_post_room_message(
437       user_operation,
438       &post_sent,
439       post_res.post_view.post.id,
440       websocket_id,
441     )?;
442
443     Ok(())
444   }
445
446   fn sendit(&self, message: &str, id: ConnectionId) {
447     if let Some(info) = self.sessions.get(&id) {
448       let _ = info.addr.do_send(WsMessage(message.to_owned()));
449     }
450   }
451
452   pub(super) fn parse_json_message(
453     &mut self,
454     msg: StandardMessage,
455     ctx: &mut Context<Self>,
456   ) -> impl Future<Output = Result<String, LemmyError>> {
457     let rate_limiter = self.rate_limiter.clone();
458
459     let ip: IpAddr = match self.sessions.get(&msg.id) {
460       Some(info) => info.ip.to_owned(),
461       None => IpAddr("blank_ip".to_string()),
462     };
463
464     let context = LemmyContext {
465       pool: self.pool.clone(),
466       chat_server: ctx.address(),
467       client: self.client.to_owned(),
468       activity_queue: self.activity_queue.to_owned(),
469       settings: self.settings.to_owned(),
470       secret: self.secret.to_owned(),
471     };
472     let message_handler_crud = self.message_handler_crud;
473     let message_handler = self.message_handler;
474     async move {
475       let json: Value = serde_json::from_str(&msg.msg)?;
476       let data = &json["data"].to_string();
477       let op = &json["op"]
478         .as_str()
479         .ok_or_else(|| LemmyError::from_message("missing op"))?;
480
481       if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
482         let fut = (message_handler_crud)(context, msg.id, user_operation_crud.clone(), data);
483         match user_operation_crud {
484           UserOperationCrud::Register => rate_limiter.register().wrap(ip, fut).await,
485           UserOperationCrud::CreatePost => rate_limiter.post().wrap(ip, fut).await,
486           UserOperationCrud::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
487           UserOperationCrud::CreateComment => rate_limiter.comment().wrap(ip, fut).await,
488           _ => rate_limiter.message().wrap(ip, fut).await,
489         }
490       } else {
491         let user_operation = UserOperation::from_str(op)?;
492         let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
493         match user_operation {
494           UserOperation::GetCaptcha => rate_limiter.post().wrap(ip, fut).await,
495           _ => rate_limiter.message().wrap(ip, fut).await,
496         }
497       }
498     }
499   }
500 }