]> Untitled Git - lemmy.git/blob - crates/api_common/src/websocket/handlers.rs
Check user accepted before sending jwt in password reset (fixes #2591) (#2597)
[lemmy.git] / crates / api_common / src / websocket / handlers.rs
1 use crate::websocket::{
2   chat_server::{ChatServer, SessionInfo},
3   messages::{
4     CaptchaItem,
5     CheckCaptcha,
6     Connect,
7     Disconnect,
8     GetCommunityUsersOnline,
9     GetPostUsersOnline,
10     GetUsersOnline,
11     JoinCommunityRoom,
12     JoinModRoom,
13     JoinPostRoom,
14     JoinUserRoom,
15     SendAllMessage,
16     SendComment,
17     SendCommunityRoomMessage,
18     SendModRoomMessage,
19     SendPost,
20     SendUserRoomMessage,
21     StandardMessage,
22   },
23   OperationType,
24 };
25 use actix::{Actor, Context, Handler, ResponseFuture};
26 use lemmy_db_schema::utils::naive_now;
27 use lemmy_utils::ConnectionId;
28 use opentelemetry::trace::TraceContextExt;
29 use rand::Rng;
30 use serde::Serialize;
31 use tracing::{error, info};
32 use tracing_opentelemetry::OpenTelemetrySpanExt;
33
34 /// Make actor from `ChatServer`
35 impl Actor for ChatServer {
36   /// We are going to use simple Context, we just need ability to communicate
37   /// with other actors.
38   type Context = Context<Self>;
39 }
40
41 /// Handler for Connect message.
42 ///
43 /// Register new session and assign unique id to this session
44 impl Handler<Connect> for ChatServer {
45   type Result = ConnectionId;
46
47   fn handle(&mut self, msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
48     // register session with random id
49     let id = self.rng.gen::<usize>();
50     info!("{} joined", &msg.ip);
51
52     self.sessions.insert(
53       id,
54       SessionInfo {
55         addr: msg.addr,
56         ip: msg.ip,
57       },
58     );
59
60     id
61   }
62 }
63
64 /// Handler for Disconnect message.
65 impl Handler<Disconnect> for ChatServer {
66   type Result = ();
67
68   fn handle(&mut self, msg: Disconnect, _: &mut Context<Self>) {
69     // Remove connections from sessions and all 3 scopes
70     if self.sessions.remove(&msg.id).is_some() {
71       for sessions in self.user_rooms.values_mut() {
72         sessions.remove(&msg.id);
73       }
74
75       for sessions in self.post_rooms.values_mut() {
76         sessions.remove(&msg.id);
77       }
78
79       for sessions in self.community_rooms.values_mut() {
80         sessions.remove(&msg.id);
81       }
82     }
83   }
84 }
85
86 fn root_span() -> tracing::Span {
87   let span = tracing::info_span!(
88     parent: None,
89     "Websocket Request",
90     trace_id = tracing::field::Empty,
91   );
92   {
93     let trace_id = span.context().span().span_context().trace_id().to_string();
94     span.record("trace_id", &tracing::field::display(trace_id));
95   }
96
97   span
98 }
99
100 /// Handler for Message message.
101 impl Handler<StandardMessage> for ChatServer {
102   type Result = ResponseFuture<Result<String, std::convert::Infallible>>;
103
104   fn handle(&mut self, msg: StandardMessage, ctx: &mut Context<Self>) -> Self::Result {
105     use tracing::Instrument;
106     let fut = self.parse_json_message(msg, ctx);
107     let span = root_span();
108
109     Box::pin(
110       async move {
111         match fut.await {
112           Ok(m) => {
113             // info!("Message Sent: {}", m);
114             Ok(m)
115           }
116           Err(e) => {
117             error!("Error during message handling {}", e);
118             Ok(
119               e.to_json()
120                 .unwrap_or_else(|_| String::from(r#"{"error":"failed to serialize json"}"#)),
121             )
122           }
123         }
124       }
125       .instrument(span),
126     )
127   }
128 }
129
130 impl<OP, Response> Handler<SendAllMessage<OP, Response>> for ChatServer
131 where
132   OP: OperationType + ToString,
133   Response: Serialize,
134 {
135   type Result = ();
136
137   fn handle(&mut self, msg: SendAllMessage<OP, Response>, _: &mut Context<Self>) {
138     self
139       .send_all_message(&msg.op, &msg.response, msg.websocket_id)
140       .ok();
141   }
142 }
143
144 impl<OP, Response> Handler<SendUserRoomMessage<OP, Response>> for ChatServer
145 where
146   OP: OperationType + ToString,
147   Response: Serialize,
148 {
149   type Result = ();
150
151   fn handle(&mut self, msg: SendUserRoomMessage<OP, Response>, _: &mut Context<Self>) {
152     self
153       .send_user_room_message(
154         &msg.op,
155         &msg.response,
156         msg.local_recipient_id,
157         msg.websocket_id,
158       )
159       .ok();
160   }
161 }
162
163 impl<OP, Response> Handler<SendCommunityRoomMessage<OP, Response>> for ChatServer
164 where
165   OP: OperationType + ToString,
166   Response: Serialize,
167 {
168   type Result = ();
169
170   fn handle(&mut self, msg: SendCommunityRoomMessage<OP, Response>, _: &mut Context<Self>) {
171     self
172       .send_community_room_message(&msg.op, &msg.response, msg.community_id, msg.websocket_id)
173       .ok();
174   }
175 }
176
177 impl<Response> Handler<SendModRoomMessage<Response>> for ChatServer
178 where
179   Response: Serialize,
180 {
181   type Result = ();
182
183   fn handle(&mut self, msg: SendModRoomMessage<Response>, _: &mut Context<Self>) {
184     self
185       .send_mod_room_message(&msg.op, &msg.response, msg.community_id, msg.websocket_id)
186       .ok();
187   }
188 }
189
190 impl<OP> Handler<SendPost<OP>> for ChatServer
191 where
192   OP: OperationType + ToString,
193 {
194   type Result = ();
195
196   fn handle(&mut self, msg: SendPost<OP>, _: &mut Context<Self>) {
197     self.send_post(&msg.op, &msg.post, msg.websocket_id).ok();
198   }
199 }
200
201 impl<OP> Handler<SendComment<OP>> for ChatServer
202 where
203   OP: OperationType + ToString,
204 {
205   type Result = ();
206
207   fn handle(&mut self, msg: SendComment<OP>, _: &mut Context<Self>) {
208     self
209       .send_comment(&msg.op, &msg.comment, msg.websocket_id)
210       .ok();
211   }
212 }
213
214 impl Handler<JoinUserRoom> for ChatServer {
215   type Result = ();
216
217   fn handle(&mut self, msg: JoinUserRoom, _: &mut Context<Self>) {
218     self.join_user_room(msg.local_user_id, msg.id).ok();
219   }
220 }
221
222 impl Handler<JoinCommunityRoom> for ChatServer {
223   type Result = ();
224
225   fn handle(&mut self, msg: JoinCommunityRoom, _: &mut Context<Self>) {
226     self.join_community_room(msg.community_id, msg.id).ok();
227   }
228 }
229
230 impl Handler<JoinModRoom> for ChatServer {
231   type Result = ();
232
233   fn handle(&mut self, msg: JoinModRoom, _: &mut Context<Self>) {
234     self.join_mod_room(msg.community_id, msg.id).ok();
235   }
236 }
237
238 impl Handler<JoinPostRoom> for ChatServer {
239   type Result = ();
240
241   fn handle(&mut self, msg: JoinPostRoom, _: &mut Context<Self>) {
242     self.join_post_room(msg.post_id, msg.id).ok();
243   }
244 }
245
246 impl Handler<GetUsersOnline> for ChatServer {
247   type Result = usize;
248
249   fn handle(&mut self, _msg: GetUsersOnline, _: &mut Context<Self>) -> Self::Result {
250     self.sessions.len()
251   }
252 }
253
254 impl Handler<GetPostUsersOnline> for ChatServer {
255   type Result = usize;
256
257   fn handle(&mut self, msg: GetPostUsersOnline, _: &mut Context<Self>) -> Self::Result {
258     if let Some(users) = self.post_rooms.get(&msg.post_id) {
259       users.len()
260     } else {
261       0
262     }
263   }
264 }
265
266 impl Handler<GetCommunityUsersOnline> for ChatServer {
267   type Result = usize;
268
269   fn handle(&mut self, msg: GetCommunityUsersOnline, _: &mut Context<Self>) -> Self::Result {
270     if let Some(users) = self.community_rooms.get(&msg.community_id) {
271       users.len()
272     } else {
273       0
274     }
275   }
276 }
277
278 impl Handler<CaptchaItem> for ChatServer {
279   type Result = ();
280
281   fn handle(&mut self, msg: CaptchaItem, _: &mut Context<Self>) {
282     self.captchas.push(msg);
283   }
284 }
285
286 impl Handler<CheckCaptcha> for ChatServer {
287   type Result = bool;
288
289   fn handle(&mut self, msg: CheckCaptcha, _: &mut Context<Self>) -> Self::Result {
290     // Remove all the ones that are past the expire time
291     self.captchas.retain(|x| x.expires.gt(&naive_now()));
292
293     let check = self
294       .captchas
295       .iter()
296       .any(|r| r.uuid == msg.uuid && r.answer.to_lowercase() == msg.answer.to_lowercase());
297
298     // Remove this uuid so it can't be re-checked (Checks only work once)
299     self.captchas.retain(|x| x.uuid != msg.uuid);
300
301     check
302   }
303 }