]> Untitled Git - lemmy.git/blob - server/src/websocket/server.rs
Version v0.6.11
[lemmy.git] / server / src / websocket / server.rs
1 //! `ChatServer` is an actor. It maintains list of connection client session.
2 //! And manages available rooms. Peers send messages to other peers in same
3 //! room through `ChatServer`.
4
5 use actix::prelude::*;
6 use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
7 use diesel::PgConnection;
8 use failure::Error;
9 use rand::{rngs::ThreadRng, Rng};
10 use serde::{Deserialize, Serialize};
11 use serde_json::Value;
12 use std::collections::{HashMap, HashSet};
13 use std::str::FromStr;
14 use std::time::SystemTime;
15
16 use crate::api::comment::*;
17 use crate::api::community::*;
18 use crate::api::post::*;
19 use crate::api::site::*;
20 use crate::api::user::*;
21 use crate::api::*;
22 use crate::websocket::UserOperation;
23 use crate::Settings;
24
25 type ConnectionId = usize;
26 type PostId = i32;
27 type CommunityId = i32;
28 type UserId = i32;
29 type IPAddr = String;
30
31 /// Chat server sends this messages to session
32 #[derive(Message)]
33 #[rtype(result = "()")]
34 pub struct WSMessage(pub String);
35
36 /// Message for chat server communications
37
38 /// New chat session is created
39 #[derive(Message)]
40 #[rtype(usize)]
41 pub struct Connect {
42   pub addr: Recipient<WSMessage>,
43   pub ip: IPAddr,
44 }
45
46 /// Session is disconnected
47 #[derive(Message)]
48 #[rtype(result = "()")]
49 pub struct Disconnect {
50   pub id: ConnectionId,
51   pub ip: IPAddr,
52 }
53
54 #[derive(Serialize, Deserialize, Message)]
55 #[rtype(String)]
56 pub struct StandardMessage {
57   /// Id of the client session
58   pub id: ConnectionId,
59   /// Peer message
60   pub msg: String,
61 }
62
63 #[derive(Debug)]
64 pub struct RateLimitBucket {
65   last_checked: SystemTime,
66   allowance: f64,
67 }
68
69 pub struct SessionInfo {
70   pub addr: Recipient<WSMessage>,
71   pub ip: IPAddr,
72 }
73
74 /// `ChatServer` manages chat rooms and responsible for coordinating chat
75 /// session.
76 pub struct ChatServer {
77   /// A map from generated random ID to session addr
78   sessions: HashMap<ConnectionId, SessionInfo>,
79
80   /// A map from post_id to set of connectionIDs
81   post_rooms: HashMap<PostId, HashSet<ConnectionId>>,
82
83   /// A map from community to set of connectionIDs
84   community_rooms: HashMap<CommunityId, HashSet<ConnectionId>>,
85
86   /// A map from user id to its connection ID for joined users. Remember a user can have multiple
87   /// sessions (IE clients)
88   user_rooms: HashMap<UserId, HashSet<ConnectionId>>,
89
90   /// Rate limiting based on IP addr
91   rate_limits: HashMap<IPAddr, RateLimitBucket>,
92
93   rng: ThreadRng,
94   db: Pool<ConnectionManager<PgConnection>>,
95 }
96
97 impl ChatServer {
98   pub fn startup(db: Pool<ConnectionManager<PgConnection>>) -> ChatServer {
99     ChatServer {
100       sessions: HashMap::new(),
101       rate_limits: HashMap::new(),
102       post_rooms: HashMap::new(),
103       community_rooms: HashMap::new(),
104       user_rooms: HashMap::new(),
105       rng: rand::thread_rng(),
106       db,
107     }
108   }
109
110   fn join_community_room(&mut self, community_id: CommunityId, id: ConnectionId) {
111     // remove session from all rooms
112     for sessions in self.community_rooms.values_mut() {
113       sessions.remove(&id);
114     }
115
116     // If the room doesn't exist yet
117     if self.community_rooms.get_mut(&community_id).is_none() {
118       self.community_rooms.insert(community_id, HashSet::new());
119     }
120
121     self
122       .community_rooms
123       .get_mut(&community_id)
124       .unwrap()
125       .insert(id);
126   }
127
128   fn join_post_room(&mut self, post_id: PostId, id: ConnectionId) {
129     // remove session from all rooms
130     for sessions in self.post_rooms.values_mut() {
131       sessions.remove(&id);
132     }
133
134     // If the room doesn't exist yet
135     if self.post_rooms.get_mut(&post_id).is_none() {
136       self.post_rooms.insert(post_id, HashSet::new());
137     }
138
139     self.post_rooms.get_mut(&post_id).unwrap().insert(id);
140   }
141
142   fn join_user_room(&mut self, user_id: UserId, id: ConnectionId) {
143     // remove session from all rooms
144     for sessions in self.user_rooms.values_mut() {
145       sessions.remove(&id);
146     }
147
148     // If the room doesn't exist yet
149     if self.user_rooms.get_mut(&user_id).is_none() {
150       self.user_rooms.insert(user_id, HashSet::new());
151     }
152
153     self.user_rooms.get_mut(&user_id).unwrap().insert(id);
154   }
155
156   fn send_post_room_message(&self, post_id: PostId, message: &str, skip_id: ConnectionId) {
157     if let Some(sessions) = self.post_rooms.get(&post_id) {
158       for id in sessions {
159         if *id != skip_id {
160           if let Some(info) = self.sessions.get(id) {
161             let _ = info.addr.do_send(WSMessage(message.to_owned()));
162           }
163         }
164       }
165     }
166   }
167
168   fn send_community_room_message(
169     &self,
170     community_id: CommunityId,
171     message: &str,
172     skip_id: ConnectionId,
173   ) {
174     if let Some(sessions) = self.community_rooms.get(&community_id) {
175       for id in sessions {
176         if *id != skip_id {
177           if let Some(info) = self.sessions.get(id) {
178             let _ = info.addr.do_send(WSMessage(message.to_owned()));
179           }
180         }
181       }
182     }
183   }
184
185   fn send_user_room_message(&self, user_id: UserId, message: &str, skip_id: ConnectionId) {
186     if let Some(sessions) = self.user_rooms.get(&user_id) {
187       for id in sessions {
188         if *id != skip_id {
189           if let Some(info) = self.sessions.get(id) {
190             let _ = info.addr.do_send(WSMessage(message.to_owned()));
191           }
192         }
193       }
194     }
195   }
196
197   fn send_all_message(&self, message: &str, skip_id: ConnectionId) {
198     for id in self.sessions.keys() {
199       if *id != skip_id {
200         if let Some(info) = self.sessions.get(id) {
201           let _ = info.addr.do_send(WSMessage(message.to_owned()));
202         }
203       }
204     }
205   }
206
207   fn comment_sends(
208     &self,
209     user_operation: UserOperation,
210     comment: CommentResponse,
211     id: ConnectionId,
212   ) -> Result<String, Error> {
213     let mut comment_reply_sent = comment.clone();
214     comment_reply_sent.comment.my_vote = None;
215     comment_reply_sent.comment.user_id = None;
216
217     // For the post room ones, and the directs back to the user
218     // strip out the recipient_ids, so that
219     // users don't get double notifs
220     let mut comment_user_sent = comment.clone();
221     comment_user_sent.recipient_ids = Vec::new();
222
223     let mut comment_post_sent = comment_reply_sent.clone();
224     comment_post_sent.recipient_ids = Vec::new();
225
226     let comment_reply_sent_str = to_json_string(&user_operation, &comment_reply_sent)?;
227     let comment_post_sent_str = to_json_string(&user_operation, &comment_post_sent)?;
228     let comment_user_sent_str = to_json_string(&user_operation, &comment_user_sent)?;
229
230     // Send it to the post room
231     self.send_post_room_message(comment.comment.post_id, &comment_post_sent_str, id);
232
233     // Send it to the recipient(s) including the mentioned users
234     for recipient_id in comment_reply_sent.recipient_ids {
235       self.send_user_room_message(recipient_id, &comment_reply_sent_str, id);
236     }
237
238     Ok(comment_user_sent_str)
239   }
240
241   fn post_sends(
242     &self,
243     user_operation: UserOperation,
244     post: PostResponse,
245     id: ConnectionId,
246   ) -> Result<String, Error> {
247     let community_id = post.post.community_id;
248
249     // Don't send my data with it
250     let mut post_sent = post.clone();
251     post_sent.post.my_vote = None;
252     post_sent.post.user_id = None;
253     let post_sent_str = to_json_string(&user_operation, &post_sent)?;
254
255     // Send it to /c/all and that community
256     self.send_community_room_message(0, &post_sent_str, id);
257     self.send_community_room_message(community_id, &post_sent_str, id);
258
259     to_json_string(&user_operation, post)
260   }
261
262   fn check_rate_limit_register(&mut self, id: usize) -> Result<(), Error> {
263     self.check_rate_limit_full(
264       id,
265       Settings::get().rate_limit.register,
266       Settings::get().rate_limit.register_per_second,
267     )
268   }
269
270   fn check_rate_limit_post(&mut self, id: usize) -> Result<(), Error> {
271     self.check_rate_limit_full(
272       id,
273       Settings::get().rate_limit.post,
274       Settings::get().rate_limit.post_per_second,
275     )
276   }
277
278   fn check_rate_limit_message(&mut self, id: usize) -> Result<(), Error> {
279     self.check_rate_limit_full(
280       id,
281       Settings::get().rate_limit.message,
282       Settings::get().rate_limit.message_per_second,
283     )
284   }
285
286   #[allow(clippy::float_cmp)]
287   fn check_rate_limit_full(&mut self, id: usize, rate: i32, per: i32) -> Result<(), Error> {
288     if let Some(info) = self.sessions.get(&id) {
289       if let Some(rate_limit) = self.rate_limits.get_mut(&info.ip) {
290         // The initial value
291         if rate_limit.allowance == -2f64 {
292           rate_limit.allowance = rate as f64;
293         };
294
295         let current = SystemTime::now();
296         let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
297         rate_limit.last_checked = current;
298         rate_limit.allowance += time_passed * (rate as f64 / per as f64);
299         if rate_limit.allowance > rate as f64 {
300           rate_limit.allowance = rate as f64;
301         }
302
303         if rate_limit.allowance < 1.0 {
304           println!(
305             "Rate limited IP: {}, time_passed: {}, allowance: {}",
306             &info.ip, time_passed, rate_limit.allowance
307           );
308           Err(
309             APIError {
310               message: format!("Too many requests. {} per {} seconds", rate, per),
311             }
312             .into(),
313           )
314         } else {
315           rate_limit.allowance -= 1.0;
316           Ok(())
317         }
318       } else {
319         Ok(())
320       }
321     } else {
322       Ok(())
323     }
324   }
325 }
326
327 /// Make actor from `ChatServer`
328 impl Actor for ChatServer {
329   /// We are going to use simple Context, we just need ability to communicate
330   /// with other actors.
331   type Context = Context<Self>;
332 }
333
334 /// Handler for Connect message.
335 ///
336 /// Register new session and assign unique id to this session
337 impl Handler<Connect> for ChatServer {
338   type Result = usize;
339
340   fn handle(&mut self, msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
341     // register session with random id
342     let id = self.rng.gen::<usize>();
343     println!("{} joined", &msg.ip);
344
345     self.sessions.insert(
346       id,
347       SessionInfo {
348         addr: msg.addr,
349         ip: msg.ip.to_owned(),
350       },
351     );
352
353     if self.rate_limits.get(&msg.ip).is_none() {
354       self.rate_limits.insert(
355         msg.ip,
356         RateLimitBucket {
357           last_checked: SystemTime::now(),
358           allowance: -2f64,
359         },
360       );
361     }
362
363     id
364   }
365 }
366
367 /// Handler for Disconnect message.
368 impl Handler<Disconnect> for ChatServer {
369   type Result = ();
370
371   fn handle(&mut self, msg: Disconnect, _: &mut Context<Self>) {
372     // Remove connections from sessions and all 3 scopes
373     if self.sessions.remove(&msg.id).is_some() {
374       for sessions in self.user_rooms.values_mut() {
375         sessions.remove(&msg.id);
376       }
377
378       for sessions in self.post_rooms.values_mut() {
379         sessions.remove(&msg.id);
380       }
381
382       for sessions in self.community_rooms.values_mut() {
383         sessions.remove(&msg.id);
384       }
385     }
386   }
387 }
388
389 /// Handler for Message message.
390 impl Handler<StandardMessage> for ChatServer {
391   type Result = MessageResult<StandardMessage>;
392
393   fn handle(&mut self, msg: StandardMessage, _: &mut Context<Self>) -> Self::Result {
394     let msg_out = match parse_json_message(self, msg) {
395       Ok(m) => m,
396       Err(e) => e.to_string(),
397     };
398
399     println!("Message Sent: {}", msg_out);
400     MessageResult(msg_out)
401   }
402 }
403
404 #[derive(Serialize)]
405 struct WebsocketResponse<T> {
406   op: String,
407   data: T,
408 }
409
410 fn to_json_string<T>(op: &UserOperation, data: T) -> Result<String, Error>
411 where
412   T: Serialize,
413 {
414   let response = WebsocketResponse {
415     op: op.to_string(),
416     data,
417   };
418   Ok(serde_json::to_string(&response)?)
419 }
420
421 fn do_user_operation<'a, Data, Response>(
422   op: UserOperation,
423   data: &str,
424   conn: &PooledConnection<ConnectionManager<PgConnection>>,
425 ) -> Result<String, Error>
426 where
427   for<'de> Data: Deserialize<'de> + 'a,
428   Response: Serialize,
429   Oper<Data>: Perform<Response>,
430 {
431   let parsed_data: Data = serde_json::from_str(data)?;
432   let res = Oper::new(parsed_data).perform(&conn)?;
433   to_json_string(&op, &res)
434 }
435
436 fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<String, Error> {
437   let json: Value = serde_json::from_str(&msg.msg)?;
438   let data = &json["data"].to_string();
439   let op = &json["op"].as_str().ok_or(APIError {
440     message: "Unknown op type".to_string(),
441   })?;
442
443   let conn = chat.db.get()?;
444
445   let user_operation: UserOperation = UserOperation::from_str(&op)?;
446
447   // TODO: none of the chat messages are going to work if stuff is submitted via http api,
448   //       need to move that handling elsewhere
449   match user_operation {
450     UserOperation::Login => do_user_operation::<Login, LoginResponse>(user_operation, data, &conn),
451     UserOperation::Register => {
452       chat.check_rate_limit_register(msg.id)?;
453       do_user_operation::<Register, LoginResponse>(user_operation, data, &conn)
454     }
455     UserOperation::GetUserDetails => {
456       do_user_operation::<GetUserDetails, GetUserDetailsResponse>(user_operation, data, &conn)
457     }
458     UserOperation::SaveUserSettings => {
459       do_user_operation::<SaveUserSettings, LoginResponse>(user_operation, data, &conn)
460     }
461     UserOperation::AddAdmin => {
462       let add_admin: AddAdmin = serde_json::from_str(data)?;
463       let res = Oper::new(add_admin).perform(&conn)?;
464       let res_str = to_json_string(&user_operation, &res)?;
465       chat.send_all_message(&res_str, msg.id);
466       Ok(res_str)
467     }
468     UserOperation::BanUser => {
469       let ban_user: BanUser = serde_json::from_str(data)?;
470       let res = Oper::new(ban_user).perform(&conn)?;
471       let res_str = to_json_string(&user_operation, &res)?;
472       chat.send_all_message(&res_str, msg.id);
473       Ok(res_str)
474     }
475     UserOperation::GetReplies => {
476       do_user_operation::<GetReplies, GetRepliesResponse>(user_operation, data, &conn)
477     }
478     UserOperation::GetUserMentions => {
479       do_user_operation::<GetUserMentions, GetUserMentionsResponse>(user_operation, data, &conn)
480     }
481     UserOperation::EditUserMention => {
482       do_user_operation::<EditUserMention, UserMentionResponse>(user_operation, data, &conn)
483     }
484     UserOperation::MarkAllAsRead => {
485       do_user_operation::<MarkAllAsRead, GetRepliesResponse>(user_operation, data, &conn)
486     }
487     UserOperation::GetCommunity => {
488       let get_community: GetCommunity = serde_json::from_str(data)?;
489       let mut res = Oper::new(get_community).perform(&conn)?;
490       let community_id = res.community.id;
491
492       chat.join_community_room(community_id, msg.id);
493
494       res.online = if let Some(community_users) = chat.community_rooms.get(&community_id) {
495         community_users.len()
496       } else {
497         0
498       };
499
500       to_json_string(&user_operation, &res)
501     }
502     UserOperation::ListCommunities => {
503       do_user_operation::<ListCommunities, ListCommunitiesResponse>(user_operation, data, &conn)
504     }
505     UserOperation::CreateCommunity => {
506       chat.check_rate_limit_register(msg.id)?;
507       do_user_operation::<CreateCommunity, CommunityResponse>(user_operation, data, &conn)
508     }
509     UserOperation::EditCommunity => {
510       let edit_community: EditCommunity = serde_json::from_str(data)?;
511       let res = Oper::new(edit_community).perform(&conn)?;
512       let mut community_sent: CommunityResponse = res.clone();
513       community_sent.community.user_id = None;
514       community_sent.community.subscribed = None;
515       let community_sent_str = to_json_string(&user_operation, &community_sent)?;
516       chat.send_community_room_message(community_sent.community.id, &community_sent_str, msg.id);
517       to_json_string(&user_operation, &res)
518     }
519     UserOperation::FollowCommunity => {
520       do_user_operation::<FollowCommunity, CommunityResponse>(user_operation, data, &conn)
521     }
522     UserOperation::GetFollowedCommunities => do_user_operation::<
523       GetFollowedCommunities,
524       GetFollowedCommunitiesResponse,
525     >(user_operation, data, &conn),
526     UserOperation::BanFromCommunity => {
527       let ban_from_community: BanFromCommunity = serde_json::from_str(data)?;
528       let community_id = ban_from_community.community_id;
529       let res = Oper::new(ban_from_community).perform(&conn)?;
530       let res_str = to_json_string(&user_operation, &res)?;
531       chat.send_community_room_message(community_id, &res_str, msg.id);
532       Ok(res_str)
533     }
534     UserOperation::AddModToCommunity => {
535       let mod_add_to_community: AddModToCommunity = serde_json::from_str(data)?;
536       let community_id = mod_add_to_community.community_id;
537       let res = Oper::new(mod_add_to_community).perform(&conn)?;
538       let res_str = to_json_string(&user_operation, &res)?;
539       chat.send_community_room_message(community_id, &res_str, msg.id);
540       Ok(res_str)
541     }
542     UserOperation::ListCategories => {
543       do_user_operation::<ListCategories, ListCategoriesResponse>(user_operation, data, &conn)
544     }
545     UserOperation::GetPost => {
546       let get_post: GetPost = serde_json::from_str(data)?;
547       let post_id = get_post.id;
548       chat.join_post_room(post_id, msg.id);
549       let mut res = Oper::new(get_post).perform(&conn)?;
550
551       res.online = if let Some(post_users) = chat.post_rooms.get(&post_id) {
552         post_users.len()
553       } else {
554         0
555       };
556
557       to_json_string(&user_operation, &res)
558     }
559     UserOperation::GetPosts => {
560       let get_posts: GetPosts = serde_json::from_str(data)?;
561       if get_posts.community_id.is_none() {
562         // 0 is the "all" community
563         chat.join_community_room(0, msg.id);
564       }
565       let res = Oper::new(get_posts).perform(&conn)?;
566       to_json_string(&user_operation, &res)
567     }
568     UserOperation::CreatePost => {
569       chat.check_rate_limit_post(msg.id)?;
570       let create_post: CreatePost = serde_json::from_str(data)?;
571       let res = Oper::new(create_post).perform(&conn)?;
572
573       chat.post_sends(UserOperation::CreatePost, res, msg.id)
574     }
575     UserOperation::CreatePostLike => {
576       chat.check_rate_limit_message(msg.id)?;
577       let create_post_like: CreatePostLike = serde_json::from_str(data)?;
578       let res = Oper::new(create_post_like).perform(&conn)?;
579
580       chat.post_sends(UserOperation::CreatePostLike, res, msg.id)
581     }
582     UserOperation::EditPost => {
583       let edit_post: EditPost = serde_json::from_str(data)?;
584       let res = Oper::new(edit_post).perform(&conn)?;
585
586       chat.post_sends(UserOperation::EditPost, res, msg.id)
587     }
588     UserOperation::SavePost => {
589       do_user_operation::<SavePost, PostResponse>(user_operation, data, &conn)
590     }
591     UserOperation::CreateComment => {
592       chat.check_rate_limit_message(msg.id)?;
593       let create_comment: CreateComment = serde_json::from_str(data)?;
594       let res = Oper::new(create_comment).perform(&conn)?;
595
596       chat.comment_sends(UserOperation::CreateComment, res, msg.id)
597     }
598     UserOperation::EditComment => {
599       let edit_comment: EditComment = serde_json::from_str(data)?;
600       let res = Oper::new(edit_comment).perform(&conn)?;
601
602       chat.comment_sends(UserOperation::EditComment, res, msg.id)
603     }
604     UserOperation::SaveComment => {
605       do_user_operation::<SaveComment, CommentResponse>(user_operation, data, &conn)
606     }
607     UserOperation::CreateCommentLike => {
608       chat.check_rate_limit_message(msg.id)?;
609       let create_comment_like: CreateCommentLike = serde_json::from_str(data)?;
610       let res = Oper::new(create_comment_like).perform(&conn)?;
611
612       chat.comment_sends(UserOperation::CreateCommentLike, res, msg.id)
613     }
614     UserOperation::GetModlog => {
615       do_user_operation::<GetModlog, GetModlogResponse>(user_operation, data, &conn)
616     }
617     UserOperation::CreateSite => {
618       do_user_operation::<CreateSite, SiteResponse>(user_operation, data, &conn)
619     }
620     UserOperation::EditSite => {
621       let edit_site: EditSite = serde_json::from_str(data)?;
622       let res = Oper::new(edit_site).perform(&conn)?;
623       let res_str = to_json_string(&user_operation, &res)?;
624       chat.send_all_message(&res_str, msg.id);
625       Ok(res_str)
626     }
627     UserOperation::GetSite => {
628       let get_site: GetSite = serde_json::from_str(data)?;
629       let mut res = Oper::new(get_site).perform(&conn)?;
630       res.online = chat.sessions.len();
631       to_json_string(&user_operation, &res)
632     }
633     UserOperation::Search => {
634       do_user_operation::<Search, SearchResponse>(user_operation, data, &conn)
635     }
636     UserOperation::TransferCommunity => {
637       do_user_operation::<TransferCommunity, GetCommunityResponse>(user_operation, data, &conn)
638     }
639     UserOperation::TransferSite => {
640       do_user_operation::<TransferSite, GetSiteResponse>(user_operation, data, &conn)
641     }
642     UserOperation::DeleteAccount => {
643       do_user_operation::<DeleteAccount, LoginResponse>(user_operation, data, &conn)
644     }
645     UserOperation::PasswordReset => {
646       do_user_operation::<PasswordReset, PasswordResetResponse>(user_operation, data, &conn)
647     }
648     UserOperation::PasswordChange => {
649       do_user_operation::<PasswordChange, LoginResponse>(user_operation, data, &conn)
650     }
651     UserOperation::CreatePrivateMessage => {
652       chat.check_rate_limit_message(msg.id)?;
653       let create_private_message: CreatePrivateMessage = serde_json::from_str(data)?;
654       let recipient_id = create_private_message.recipient_id;
655       let res = Oper::new(create_private_message).perform(&conn)?;
656       let res_str = to_json_string(&user_operation, &res)?;
657
658       chat.send_user_room_message(recipient_id, &res_str, msg.id);
659       Ok(res_str)
660     }
661     UserOperation::EditPrivateMessage => {
662       do_user_operation::<EditPrivateMessage, PrivateMessageResponse>(user_operation, data, &conn)
663     }
664     UserOperation::GetPrivateMessages => {
665       do_user_operation::<GetPrivateMessages, PrivateMessagesResponse>(user_operation, data, &conn)
666     }
667     UserOperation::UserJoin => {
668       let user_join: UserJoin = serde_json::from_str(data)?;
669       let res = Oper::new(user_join).perform(&conn)?;
670       chat.join_user_room(res.user_id, msg.id);
671       to_json_string(&user_operation, &res)
672     }
673   }
674 }