]> Untitled Git - lemmy.git/blob - lemmy_websocket/src/chat_server.rs
Move websocket code into workspace (#107)
[lemmy.git] / lemmy_websocket / src / chat_server.rs
1 use crate::{messages::*, serialize_websocket_message, LemmyContext, UserOperation};
2 use actix::prelude::*;
3 use anyhow::Context as acontext;
4 use background_jobs::QueueHandle;
5 use diesel::{
6   r2d2::{ConnectionManager, Pool},
7   PgConnection,
8 };
9 use lemmy_rate_limit::RateLimit;
10 use lemmy_structs::{comment::*, post::*};
11 use lemmy_utils::{
12   location_info,
13   APIError,
14   CommunityId,
15   ConnectionId,
16   IPAddr,
17   LemmyError,
18   PostId,
19   UserId,
20 };
21 use rand::rngs::ThreadRng;
22 use reqwest::Client;
23 use serde::Serialize;
24 use serde_json::Value;
25 use std::{
26   collections::{HashMap, HashSet},
27   str::FromStr,
28 };
29 use tokio::macros::support::Pin;
30
31 type MessageHandlerType = fn(
32   context: LemmyContext,
33   id: ConnectionId,
34   op: UserOperation,
35   data: &str,
36 ) -> Pin<Box<dyn Future<Output = Result<String, LemmyError>> + '_>>;
37
38 /// `ChatServer` manages chat rooms and responsible for coordinating chat
39 /// session.
40 pub struct ChatServer {
41   /// A map from generated random ID to session addr
42   pub sessions: HashMap<ConnectionId, SessionInfo>,
43
44   /// A map from post_id to set of connectionIDs
45   pub post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
46
47   /// A map from community to set of connectionIDs
48   pub community_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
49
50   /// A map from user id to its connection ID for joined users. Remember a user can have multiple
51   /// sessions (IE clients)
52   pub(super) user_rooms: HashMap<UserId, HashSet<ConnectionId>>,
53
54   pub(super) rng: ThreadRng,
55
56   /// The DB Pool
57   pub(super) pool: Pool<ConnectionManager<PgConnection>>,
58
59   /// Rate limiting based on rate type and IP addr
60   pub(super) rate_limiter: RateLimit,
61
62   /// A list of the current captchas
63   pub(super) captchas: Vec<CaptchaItem>,
64
65   message_handler: MessageHandlerType,
66
67   /// An HTTP Client
68   client: Client,
69
70   activity_queue: QueueHandle,
71 }
72
73 pub struct SessionInfo {
74   pub addr: Recipient<WSMessage>,
75   pub ip: IPAddr,
76 }
77
78 /// `ChatServer` is an actor. It maintains list of connection client session.
79 /// And manages available rooms. Peers send messages to other peers in same
80 /// room through `ChatServer`.
81 impl ChatServer {
82   pub fn startup(
83     pool: Pool<ConnectionManager<PgConnection>>,
84     rate_limiter: RateLimit,
85     message_handler: MessageHandlerType,
86     client: Client,
87     activity_queue: QueueHandle,
88   ) -> ChatServer {
89     ChatServer {
90       sessions: HashMap::new(),
91       post_rooms: HashMap::new(),
92       community_rooms: HashMap::new(),
93       user_rooms: HashMap::new(),
94       rng: rand::thread_rng(),
95       pool,
96       rate_limiter,
97       captchas: Vec::new(),
98       message_handler,
99       client,
100       activity_queue,
101     }
102   }
103
104   pub fn join_community_room(
105     &mut self,
106     community_id: CommunityId,
107     id: ConnectionId,
108   ) -> Result<(), LemmyError> {
109     // remove session from all rooms
110     for sessions in self.community_rooms.values_mut() {
111       sessions.remove(&id);
112     }
113
114     // Also leave all post rooms
115     // This avoids double messages
116     for sessions in self.post_rooms.values_mut() {
117       sessions.remove(&id);
118     }
119
120     // If the room doesn't exist yet
121     if self.community_rooms.get_mut(&community_id).is_none() {
122       self.community_rooms.insert(community_id, HashSet::new());
123     }
124
125     self
126       .community_rooms
127       .get_mut(&community_id)
128       .context(location_info!())?
129       .insert(id);
130     Ok(())
131   }
132
133   pub fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) -> Result<(), LemmyError> {
134     // remove session from all rooms
135     for sessions in self.post_rooms.values_mut() {
136       sessions.remove(&id);
137     }
138
139     // Also leave all communities
140     // This avoids double messages
141     // TODO found a bug, whereby community messages like
142     // delete and remove aren't sent, because
143     // you left the community room
144     for sessions in self.community_rooms.values_mut() {
145       sessions.remove(&id);
146     }
147
148     // If the room doesn't exist yet
149     if self.post_rooms.get_mut(&post_id).is_none() {
150       self.post_rooms.insert(post_id, HashSet::new());
151     }
152
153     self
154       .post_rooms
155       .get_mut(&post_id)
156       .context(location_info!())?
157       .insert(id);
158
159     Ok(())
160   }
161
162   pub fn join_user_room(&mut self, user_id: UserId, id: ConnectionId) -> Result<(), LemmyError> {
163     // remove session from all rooms
164     for sessions in self.user_rooms.values_mut() {
165       sessions.remove(&id);
166     }
167
168     // If the room doesn't exist yet
169     if self.user_rooms.get_mut(&user_id).is_none() {
170       self.user_rooms.insert(user_id, HashSet::new());
171     }
172
173     self
174       .user_rooms
175       .get_mut(&user_id)
176       .context(location_info!())?
177       .insert(id);
178
179     Ok(())
180   }
181
182   fn send_post_room_message<Response>(
183     &self,
184     op: &UserOperation,
185     response: &Response,
186     post_id: PostId,
187     websocket_id: Option<ConnectionId>,
188   ) -> Result<(), LemmyError>
189   where
190     Response: Serialize,
191   {
192     let res_str = &serialize_websocket_message(op, response)?;
193     if let Some(sessions) = self.post_rooms.get(&post_id) {
194       for id in sessions {
195         if let Some(my_id) = websocket_id {
196           if *id == my_id {
197             continue;
198           }
199         }
200         self.sendit(res_str, *id);
201       }
202     }
203     Ok(())
204   }
205
206   pub fn send_community_room_message<Response>(
207     &self,
208     op: &UserOperation,
209     response: &Response,
210     community_id: CommunityId,
211     websocket_id: Option<ConnectionId>,
212   ) -> Result<(), LemmyError>
213   where
214     Response: Serialize,
215   {
216     let res_str = &serialize_websocket_message(op, response)?;
217     if let Some(sessions) = self.community_rooms.get(&community_id) {
218       for id in sessions {
219         if let Some(my_id) = websocket_id {
220           if *id == my_id {
221             continue;
222           }
223         }
224         self.sendit(res_str, *id);
225       }
226     }
227     Ok(())
228   }
229
230   pub fn send_all_message<Response>(
231     &self,
232     op: &UserOperation,
233     response: &Response,
234     websocket_id: Option<ConnectionId>,
235   ) -> Result<(), LemmyError>
236   where
237     Response: Serialize,
238   {
239     let res_str = &serialize_websocket_message(op, response)?;
240     for id in self.sessions.keys() {
241       if let Some(my_id) = websocket_id {
242         if *id == my_id {
243           continue;
244         }
245       }
246       self.sendit(res_str, *id);
247     }
248     Ok(())
249   }
250
251   pub fn send_user_room_message<Response>(
252     &self,
253     op: &UserOperation,
254     response: &Response,
255     recipient_id: UserId,
256     websocket_id: Option<ConnectionId>,
257   ) -> Result<(), LemmyError>
258   where
259     Response: Serialize,
260   {
261     let res_str = &serialize_websocket_message(op, response)?;
262     if let Some(sessions) = self.user_rooms.get(&recipient_id) {
263       for id in sessions {
264         if let Some(my_id) = websocket_id {
265           if *id == my_id {
266             continue;
267           }
268         }
269         self.sendit(res_str, *id);
270       }
271     }
272     Ok(())
273   }
274
275   pub fn send_comment(
276     &self,
277     user_operation: &UserOperation,
278     comment: &CommentResponse,
279     websocket_id: Option<ConnectionId>,
280   ) -> Result<(), LemmyError> {
281     let mut comment_reply_sent = comment.clone();
282     comment_reply_sent.comment.my_vote = None;
283     comment_reply_sent.comment.user_id = None;
284
285     let mut comment_post_sent = comment_reply_sent.clone();
286     comment_post_sent.recipient_ids = Vec::new();
287
288     // Send it to the post room
289     self.send_post_room_message(
290       user_operation,
291       &comment_post_sent,
292       comment_post_sent.comment.post_id,
293       websocket_id,
294     )?;
295
296     // Send it to the recipient(s) including the mentioned users
297     for recipient_id in &comment_reply_sent.recipient_ids {
298       self.send_user_room_message(
299         user_operation,
300         &comment_reply_sent,
301         *recipient_id,
302         websocket_id,
303       )?;
304     }
305
306     // Send it to the community too
307     self.send_community_room_message(user_operation, &comment_post_sent, 0, websocket_id)?;
308     self.send_community_room_message(
309       user_operation,
310       &comment_post_sent,
311       comment.comment.community_id,
312       websocket_id,
313     )?;
314
315     Ok(())
316   }
317
318   pub fn send_post(
319     &self,
320     user_operation: &UserOperation,
321     post: &PostResponse,
322     websocket_id: Option<ConnectionId>,
323   ) -> Result<(), LemmyError> {
324     let community_id = post.post.community_id;
325
326     // Don't send my data with it
327     let mut post_sent = post.clone();
328     post_sent.post.my_vote = None;
329     post_sent.post.user_id = None;
330
331     // Send it to /c/all and that community
332     self.send_community_room_message(user_operation, &post_sent, 0, websocket_id)?;
333     self.send_community_room_message(user_operation, &post_sent, community_id, websocket_id)?;
334
335     // Send it to the post room
336     self.send_post_room_message(user_operation, &post_sent, post.post.id, websocket_id)?;
337
338     Ok(())
339   }
340
341   fn sendit(&self, message: &str, id: ConnectionId) {
342     if let Some(info) = self.sessions.get(&id) {
343       let _ = info.addr.do_send(WSMessage(message.to_owned()));
344     }
345   }
346
347   pub(super) fn parse_json_message(
348     &mut self,
349     msg: StandardMessage,
350     ctx: &mut Context<Self>,
351   ) -> impl Future<Output = Result<String, LemmyError>> {
352     let rate_limiter = self.rate_limiter.clone();
353
354     let ip: IPAddr = match self.sessions.get(&msg.id) {
355       Some(info) => info.ip.to_owned(),
356       None => "blank_ip".to_string(),
357     };
358
359     let context = LemmyContext {
360       pool: self.pool.clone(),
361       chat_server: ctx.address(),
362       client: self.client.to_owned(),
363       activity_queue: self.activity_queue.to_owned(),
364     };
365     let message_handler = self.message_handler;
366     async move {
367       let json: Value = serde_json::from_str(&msg.msg)?;
368       let data = &json["data"].to_string();
369       let op = &json["op"].as_str().ok_or(APIError {
370         message: "Unknown op type".to_string(),
371       })?;
372
373       let user_operation = UserOperation::from_str(&op)?;
374       let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
375       match user_operation {
376         UserOperation::Register => rate_limiter.register().wrap(ip, fut).await,
377         UserOperation::CreatePost => rate_limiter.post().wrap(ip, fut).await,
378         UserOperation::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
379         _ => rate_limiter.message().wrap(ip, fut).await,
380       }
381     }
382   }
383 }