]> Untitled Git - lemmy.git/blob - crates/api_common/src/websocket/routes.rs
Merge pull request #2593 from LemmyNet/refactor-notifications
[lemmy.git] / crates / api_common / src / websocket / routes.rs
1 use crate::{
2   context::LemmyContext,
3   websocket::{
4     chat_server::ChatServer,
5     messages::{Connect, Disconnect, StandardMessage, WsMessage},
6   },
7 };
8 use actix::prelude::*;
9 use actix_web::{web, Error, HttpRequest, HttpResponse};
10 use actix_web_actors::ws;
11 use lemmy_utils::{rate_limit::RateLimitCell, utils::get_ip, ConnectionId, IpAddr};
12 use std::time::{Duration, Instant};
13 use tracing::{debug, error, info};
14
15 /// How often heartbeat pings are sent
16 const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
17 /// How long before lack of client response causes a timeout
18 const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
19
20 /// Entry point for our route
21 pub async fn chat_route(
22   req: HttpRequest,
23   stream: web::Payload,
24   context: web::Data<LemmyContext>,
25   rate_limiter: web::Data<RateLimitCell>,
26 ) -> Result<HttpResponse, Error> {
27   ws::start(
28     WsSession {
29       cs_addr: context.chat_server().clone(),
30       id: 0,
31       hb: Instant::now(),
32       ip: get_ip(&req.connection_info()),
33       rate_limiter: rate_limiter.as_ref().clone(),
34     },
35     &req,
36     stream,
37   )
38 }
39
40 struct WsSession {
41   cs_addr: Addr<ChatServer>,
42   /// unique session id
43   id: ConnectionId,
44   ip: IpAddr,
45   /// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
46   /// otherwise we drop connection.
47   hb: Instant,
48   /// A rate limiter for websocket joins
49   rate_limiter: RateLimitCell,
50 }
51
52 impl Actor for WsSession {
53   type Context = ws::WebsocketContext<Self>;
54
55   /// Method is called on actor start.
56   /// We register ws session with ChatServer
57   fn started(&mut self, ctx: &mut Self::Context) {
58     // we'll start heartbeat process on session start.
59     WsSession::hb(ctx);
60
61     // register self in chat server. `AsyncContext::wait` register
62     // future within context, but context waits until this future resolves
63     // before processing any other events.
64     // across all routes within application
65     let addr = ctx.address();
66
67     if !self.rate_limit_check(ctx) {
68       return;
69     }
70
71     self
72       .cs_addr
73       .send(Connect {
74         addr: addr.recipient(),
75         ip: self.ip.clone(),
76       })
77       .into_actor(self)
78       .then(|res, act, ctx| {
79         match res {
80           Ok(res) => act.id = res,
81           // something is wrong with chat server
82           _ => ctx.stop(),
83         }
84         actix::fut::ready(())
85       })
86       .wait(ctx);
87   }
88
89   fn stopping(&mut self, _ctx: &mut Self::Context) -> Running {
90     // notify chat server
91     self.cs_addr.do_send(Disconnect {
92       id: self.id,
93       ip: self.ip.clone(),
94     });
95     Running::Stop
96   }
97 }
98
99 /// Handle messages from chat server, we simply send it to peer websocket
100 /// These are room messages, IE sent to others in the room
101 impl Handler<WsMessage> for WsSession {
102   type Result = ();
103
104   fn handle(&mut self, msg: WsMessage, ctx: &mut Self::Context) {
105     ctx.text(msg.0);
106   }
107 }
108
109 /// WebSocket message handler
110 impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsSession {
111   fn handle(&mut self, result: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
112     if !self.rate_limit_check(ctx) {
113       return;
114     }
115
116     let message = match result {
117       Ok(m) => m,
118       Err(e) => {
119         error!("{}", e);
120         return;
121       }
122     };
123     match message {
124       ws::Message::Ping(msg) => {
125         self.hb = Instant::now();
126         ctx.pong(&msg);
127       }
128       ws::Message::Pong(_) => {
129         self.hb = Instant::now();
130       }
131       ws::Message::Text(text) => {
132         let m = text.trim().to_owned();
133
134         self
135           .cs_addr
136           .send(StandardMessage {
137             id: self.id,
138             msg: m,
139           })
140           .into_actor(self)
141           .then(|res, _, ctx| {
142             match res {
143               Ok(Ok(res)) => ctx.text(res),
144               Ok(Err(_)) => {}
145               Err(e) => error!("{}", &e),
146             }
147             actix::fut::ready(())
148           })
149           .spawn(ctx);
150       }
151       ws::Message::Binary(_bin) => info!("Unexpected binary"),
152       ws::Message::Close(_) => {
153         ctx.stop();
154       }
155       _ => {}
156     }
157   }
158 }
159
160 impl WsSession {
161   /// helper method that sends ping to client every second.
162   ///
163   /// also this method checks heartbeats from client
164   fn hb(ctx: &mut ws::WebsocketContext<Self>) {
165     ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
166       // check client heartbeats
167       if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
168         // heartbeat timed out
169         debug!("Websocket Client heartbeat failed, disconnecting!");
170
171         // notify chat server
172         act.cs_addr.do_send(Disconnect {
173           id: act.id,
174           ip: act.ip.clone(),
175         });
176
177         // stop actor
178         ctx.stop();
179
180         // don't try to send a ping
181         return;
182       }
183
184       ctx.ping(b"");
185     });
186   }
187
188   /// Check the rate limit, and stop the ctx if it fails
189   fn rate_limit_check(&mut self, ctx: &mut ws::WebsocketContext<Self>) -> bool {
190     let check = self.rate_limiter.message().check(self.ip.clone());
191     if !check {
192       debug!("Websocket join with IP: {} has been rate limited.", self.ip);
193       ctx.stop()
194     }
195     check
196   }
197 }