]> Untitled Git - lemmy.git/blob - crates/websocket/src/chat_server.rs
Check user accepted before sending jwt in password reset (fixes #2591) (#2597)
[lemmy.git] / crates / websocket / src / chat_server.rs
1 use crate::{
2   messages::{CaptchaItem, StandardMessage, WsMessage},
3   serialize_websocket_message,
4   LemmyContext,
5   OperationType,
6   UserOperation,
7   UserOperationCrud,
8 };
9 use actix::prelude::*;
10 use anyhow::Context as acontext;
11 use lemmy_api_common::{comment::CommentResponse, post::PostResponse};
12 use lemmy_db_schema::{
13   newtypes::{CommunityId, LocalUserId, PostId},
14   source::secret::Secret,
15   utils::DbPool,
16 };
17 use lemmy_utils::{
18   error::LemmyError,
19   location_info,
20   rate_limit::RateLimitCell,
21   settings::structs::Settings,
22   ConnectionId,
23   IpAddr,
24 };
25 use rand::rngs::ThreadRng;
26 use reqwest_middleware::ClientWithMiddleware;
27 use serde::Serialize;
28 use serde_json::Value;
29 use std::{
30   collections::{HashMap, HashSet},
31   future::Future,
32   str::FromStr,
33 };
34 use tokio::macros::support::Pin;
35
36 type MessageHandlerType = fn(
37   context: LemmyContext,
38   id: ConnectionId,
39   op: UserOperation,
40   data: &str,
41 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
42
43 type MessageHandlerCrudType = fn(
44   context: LemmyContext,
45   id: ConnectionId,
46   op: UserOperationCrud,
47   data: &str,
48 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
49
50 /// `ChatServer` manages chat rooms and responsible for coordinating chat
51 /// session.
52 pub struct ChatServer {
53   /// A map from generated random ID to session addr
54   pub sessions: HashMap<ConnectionId, SessionInfo>,
55
56   /// A map from post_id to set of connectionIDs
57   pub post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
58
59   /// A map from community to set of connectionIDs
60   pub community_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
61
62   pub mod_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
63
64   /// A map from user id to its connection ID for joined users. Remember a user can have multiple
65   /// sessions (IE clients)
66   pub(super) user_rooms: HashMap<LocalUserId, HashSet<ConnectionId>>,
67
68   pub(super) rng: ThreadRng,
69
70   /// The DB Pool
71   pub(super) pool: DbPool,
72
73   /// The Settings
74   pub(super) settings: Settings,
75
76   /// The Secrets
77   pub(super) secret: Secret,
78
79   /// A list of the current captchas
80   pub(super) captchas: Vec<CaptchaItem>,
81
82   message_handler: MessageHandlerType,
83   message_handler_crud: MessageHandlerCrudType,
84
85   /// An HTTP Client
86   client: ClientWithMiddleware,
87
88   rate_limit_cell: RateLimitCell,
89 }
90
91 pub struct SessionInfo {
92   pub addr: Recipient<WsMessage>,
93   pub ip: IpAddr,
94 }
95
96 /// `ChatServer` is an actor. It maintains list of connection client session.
97 /// And manages available rooms. Peers send messages to other peers in same
98 /// room through `ChatServer`.
99 impl ChatServer {
100   #![allow(clippy::too_many_arguments)]
101   pub fn startup(
102     pool: DbPool,
103     message_handler: MessageHandlerType,
104     message_handler_crud: MessageHandlerCrudType,
105     client: ClientWithMiddleware,
106     settings: Settings,
107     secret: Secret,
108     rate_limit_cell: RateLimitCell,
109   ) -> ChatServer {
110     ChatServer {
111       sessions: HashMap::new(),
112       post_rooms: HashMap::new(),
113       community_rooms: HashMap::new(),
114       mod_rooms: HashMap::new(),
115       user_rooms: HashMap::new(),
116       rng: rand::thread_rng(),
117       pool,
118       captchas: Vec::new(),
119       message_handler,
120       message_handler_crud,
121       client,
122       settings,
123       secret,
124       rate_limit_cell,
125     }
126   }
127
128   pub fn join_community_room(
129     &mut self,
130     community_id: CommunityId,
131     id: ConnectionId,
132   ) -> Result<(), LemmyError> {
133     // remove session from all rooms
134     for sessions in self.community_rooms.values_mut() {
135       sessions.remove(&id);
136     }
137
138     // Also leave all post rooms
139     // This avoids double messages
140     for sessions in self.post_rooms.values_mut() {
141       sessions.remove(&id);
142     }
143
144     // If the room doesn't exist yet
145     if self.community_rooms.get_mut(&community_id).is_none() {
146       self.community_rooms.insert(community_id, HashSet::new());
147     }
148
149     self
150       .community_rooms
151       .get_mut(&community_id)
152       .context(location_info!())?
153       .insert(id);
154     Ok(())
155   }
156
157   pub fn join_mod_room(
158     &mut self,
159     community_id: CommunityId,
160     id: ConnectionId,
161   ) -> Result<(), LemmyError> {
162     // remove session from all rooms
163     for sessions in self.mod_rooms.values_mut() {
164       sessions.remove(&id);
165     }
166
167     // If the room doesn't exist yet
168     if self.mod_rooms.get_mut(&community_id).is_none() {
169       self.mod_rooms.insert(community_id, HashSet::new());
170     }
171
172     self
173       .mod_rooms
174       .get_mut(&community_id)
175       .context(location_info!())?
176       .insert(id);
177     Ok(())
178   }
179
180   pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
181     // remove session from all rooms
182     for sessions in self.post_rooms.values_mut() {
183       sessions.remove(&id);
184     }
185
186     // Also leave all communities
187     // This avoids double messages
188     // TODO found a bug, whereby community messages like
189     // delete and remove aren't sent, because
190     // you left the community room
191     for sessions in self.community_rooms.values_mut() {
192       sessions.remove(&id);
193     }
194
195     // If the room doesn't exist yet
196     if self.post_rooms.get_mut(&post_id).is_none() {
197       self.post_rooms.insert(post_id, HashSet::new());
198     }
199
200     self
201       .post_rooms
202       .get_mut(&post_id)
203       .context(location_info!())?
204       .insert(id);
205
206     Ok(())
207   }
208
209   pub fn join_user_room(
210     &mut self,
211     user_id: LocalUserId,
212     id: ConnectionId,
213   ) -> Result<(), LemmyError> {
214     // remove session from all rooms
215     for sessions in self.user_rooms.values_mut() {
216       sessions.remove(&id);
217     }
218
219     // If the room doesn't exist yet
220     if self.user_rooms.get_mut(&user_id).is_none() {
221       self.user_rooms.insert(user_id, HashSet::new());
222     }
223
224     self
225       .user_rooms
226       .get_mut(&user_id)
227       .context(location_info!())?
228       .insert(id);
229
230     Ok(())
231   }
232
233   fn send_post_room_message<OP, Response>(
234     &self,
235     op: &OP,
236     response: &Response,
237     post_id: PostId,
238     websocket_id: Option<ConnectionId>,
239   ) -> Result<(), LemmyError>
240   where
241     OP: OperationType + ToString,
242     Response: Serialize,
243   {
244     let res_str = &serialize_websocket_message(op, response)?;
245     if let Some(sessions) = self.post_rooms.get(&post_id) {
246       for id in sessions {
247         if let Some(my_id) = websocket_id {
248           if *id == my_id {
249             continue;
250           }
251         }
252         self.sendit(res_str, *id);
253       }
254     }
255     Ok(())
256   }
257
258   pub fn send_community_room_message<OP, Response>(
259     &self,
260     op: &OP,
261     response: &Response,
262     community_id: CommunityId,
263     websocket_id: Option<ConnectionId>,
264   ) -> Result<(), LemmyError>
265   where
266     OP: OperationType + ToString,
267     Response: Serialize,
268   {
269     let res_str = &serialize_websocket_message(op, response)?;
270     if let Some(sessions) = self.community_rooms.get(&community_id) {
271       for id in sessions {
272         if let Some(my_id) = websocket_id {
273           if *id == my_id {
274             continue;
275           }
276         }
277         self.sendit(res_str, *id);
278       }
279     }
280     Ok(())
281   }
282
283   pub fn send_mod_room_message<OP, Response>(
284     &self,
285     op: &OP,
286     response: &Response,
287     community_id: CommunityId,
288     websocket_id: Option<ConnectionId>,
289   ) -> Result<(), LemmyError>
290   where
291     OP: OperationType + ToString,
292     Response: Serialize,
293   {
294     let res_str = &serialize_websocket_message(op, response)?;
295     if let Some(sessions) = self.mod_rooms.get(&community_id) {
296       for id in sessions {
297         if let Some(my_id) = websocket_id {
298           if *id == my_id {
299             continue;
300           }
301         }
302         self.sendit(res_str, *id);
303       }
304     }
305     Ok(())
306   }
307
308   pub fn send_all_message<OP, Response>(
309     &self,
310     op: &OP,
311     response: &Response,
312     websocket_id: Option<ConnectionId>,
313   ) -> Result<(), LemmyError>
314   where
315     OP: OperationType + ToString,
316     Response: Serialize,
317   {
318     let res_str = &serialize_websocket_message(op, response)?;
319     for id in self.sessions.keys() {
320       if let Some(my_id) = websocket_id {
321         if *id == my_id {
322           continue;
323         }
324       }
325       self.sendit(res_str, *id);
326     }
327     Ok(())
328   }
329
330   pub fn send_user_room_message<OP, Response>(
331     &self,
332     op: &OP,
333     response: &Response,
334     recipient_id: LocalUserId,
335     websocket_id: Option<ConnectionId>,
336   ) -> Result<(), LemmyError>
337   where
338     OP: OperationType + ToString,
339     Response: Serialize,
340   {
341     let res_str = &serialize_websocket_message(op, response)?;
342     if let Some(sessions) = self.user_rooms.get(&recipient_id) {
343       for id in sessions {
344         if let Some(my_id) = websocket_id {
345           if *id == my_id {
346             continue;
347           }
348         }
349         self.sendit(res_str, *id);
350       }
351     }
352     Ok(())
353   }
354
355   pub fn send_comment<OP>(
356     &self,
357     user_operation: &OP,
358     comment: &CommentResponse,
359     websocket_id: Option<ConnectionId>,
360   ) -> Result<(), LemmyError>
361   where
362     OP: OperationType + ToString,
363   {
364     let mut comment_reply_sent = comment.clone();
365
366     // Strip out my specific user info
367     comment_reply_sent.comment_view.my_vote = None;
368
369     // Send it to the post room
370     let mut comment_post_sent = comment_reply_sent.clone();
371     // Remove the recipients here to separate mentions / user messages from post or community comments
372     comment_post_sent.recipient_ids = Vec::new();
373     self.send_post_room_message(
374       user_operation,
375       &comment_post_sent,
376       comment_post_sent.comment_view.post.id,
377       websocket_id,
378     )?;
379
380     // Send it to the community too
381     self.send_community_room_message(
382       user_operation,
383       &comment_post_sent,
384       CommunityId(0),
385       websocket_id,
386     )?;
387     self.send_community_room_message(
388       user_operation,
389       &comment_post_sent,
390       comment.comment_view.community.id,
391       websocket_id,
392     )?;
393
394     // Send it to the recipient(s) including the mentioned users
395     for recipient_id in &comment_reply_sent.recipient_ids {
396       self.send_user_room_message(
397         user_operation,
398         &comment_reply_sent,
399         *recipient_id,
400         websocket_id,
401       )?;
402     }
403
404     Ok(())
405   }
406
407   pub fn send_post<OP>(
408     &self,
409     user_operation: &OP,
410     post_res: &PostResponse,
411     websocket_id: Option<ConnectionId>,
412   ) -> Result<(), LemmyError>
413   where
414     OP: OperationType + ToString,
415   {
416     let community_id = post_res.post_view.community.id;
417
418     // Don't send my data with it
419     let mut post_sent = post_res.clone();
420     post_sent.post_view.my_vote = None;
421
422     // Send it to /c/all and that community
423     self.send_community_room_message(user_operation, &post_sent, CommunityId(0), websocket_id)?;
424     self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?;
425
426     // Send it to the post room
427     self.send_post_room_message(
428       user_operation,
429       &post_sent,
430       post_res.post_view.post.id,
431       websocket_id,
432     )?;
433
434     Ok(())
435   }
436
437   fn sendit(&self, message: &str, id: ConnectionId) {
438     if let Some(info) = self.sessions.get(&id) {
439       info.addr.do_send(WsMessage(message.to_owned()));
440     }
441   }
442
443   pub(super) fn parse_json_message(
444     &mut self,
445     msg: StandardMessage,
446     ctx: &mut Context<Self>,
447   ) -> impl Future<Output = Result<String, LemmyError>> {
448     let ip: IpAddr = match self.sessions.get(&msg.id) {
449       Some(info) => info.ip.clone(),
450       None => IpAddr("blank_ip".to_string()),
451     };
452
453     let context = LemmyContext {
454       pool: self.pool.clone(),
455       chat_server: ctx.address(),
456       client: self.client.clone(),
457       settings: self.settings.clone(),
458       secret: self.secret.clone(),
459       rate_limit_cell: self.rate_limit_cell.clone(),
460     };
461     let message_handler_crud = self.message_handler_crud;
462     let message_handler = self.message_handler;
463     let rate_limiter = self.rate_limit_cell.clone();
464     async move {
465       let json: Value = serde_json::from_str(&msg.msg)?;
466       let data = &json["data"].to_string();
467       let op = &json["op"]
468         .as_str()
469         .ok_or_else(|| LemmyError::from_message("missing op"))?;
470
471       // check if api call passes the rate limit, and generate future for later execution
472       let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
473         let passed = match user_operation_crud {
474           UserOperationCrud::Register => rate_limiter.register().check(ip),
475           UserOperationCrud::CreatePost => rate_limiter.post().check(ip),
476           UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip),
477           UserOperationCrud::CreateComment => rate_limiter.comment().check(ip),
478           _ => rate_limiter.message().check(ip),
479         };
480         let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
481         (passed, fut)
482       } else {
483         let user_operation = UserOperation::from_str(op)?;
484         let passed = match user_operation {
485           UserOperation::GetCaptcha => rate_limiter.post().check(ip),
486           UserOperation::Search => rate_limiter.search().check(ip),
487           _ => rate_limiter.message().check(ip),
488         };
489         let fut = (message_handler)(context, msg.id, user_operation, data);
490         (passed, fut)
491       };
492
493       // if rate limit passed, execute api call future
494       if passed {
495         fut.await
496       } else {
497         // if rate limit was hit, respond with message
498         Err(LemmyError::from_message("rate_limit_error"))
499       }
500     }
501   }
502 }