]> Untitled Git - lemmy.git/blobdiff - src/api_routes_websocket.rs
Dont return error in case optional auth is invalid (#2879)
[lemmy.git] / src / api_routes_websocket.rs
index d030980686e8f560a0e852d8d352f912e8d1abc5..23b9a5b5e14e9240797f3a3035169e1d23a06378 100644 (file)
@@ -1,7 +1,18 @@
+use activitypub_federation::config::Data as ContextData;
+use actix::{
+  fut,
+  Actor,
+  ActorContext,
+  ActorFutureExt,
+  AsyncContext,
+  ContextFutureSpawner,
+  Handler,
+  Running,
+  StreamHandler,
+  WrapFuture,
+};
 use actix_web::{web, Error, HttpRequest, HttpResponse};
 use actix_web_actors::ws;
-use actix_ws::{MessageStream, Session};
-use futures::stream::StreamExt;
 use lemmy_api::Perform;
 use lemmy_api_common::{
   comment::{
@@ -32,6 +43,7 @@ use lemmy_api_common::{
     TransferCommunity,
   },
   context::LemmyContext,
+  custom_emoji::{CreateCustomEmoji, DeleteCustomEmoji, EditCustomEmoji},
   person::{
     AddAdmin,
     BanPerson,
@@ -86,6 +98,7 @@ use lemmy_api_common::{
     ApproveRegistrationApplication,
     CreateSite,
     EditSite,
+    GetFederatedInstances,
     GetModlog,
     GetSite,
     GetUnreadRegistrationApplicationCount,
@@ -99,6 +112,10 @@ use lemmy_api_common::{
     Search,
   },
   websocket::{
+    handlers::{
+      connect::{Connect, Disconnect},
+      WsMessage,
+    },
     serialize_websocket_message,
     structs::{CommunityJoin, ModJoin, PostJoin, UserJoin},
     UserOperation,
@@ -112,22 +129,39 @@ use lemmy_utils::{error::LemmyError, rate_limit::RateLimitCell, ConnectionId, Ip
 use serde::Deserialize;
 use serde_json::Value;
 use std::{
+  ops::Deref,
   result,
   str::FromStr,
-  sync::{Arc, Mutex},
   time::{Duration, Instant},
 };
-use tracing::{debug, error, info};
+use tracing::{debug, error};
+
+/// How often heartbeat pings are sent
+const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(25);
+
+/// How long before lack of client response causes a timeout
+const CLIENT_TIMEOUT: Duration = Duration::from_secs(60);
+
+pub struct WsChatSession {
+  /// unique session id
+  pub id: ConnectionId,
+
+  pub ip: IpAddr,
+
+  /// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
+  /// otherwise we drop connection.
+  pub hb: Instant,
+
+  /// The context data
+  apub_data: ContextData<LemmyContext>,
+}
 
-/// Entry point for our route
 pub async fn websocket(
   req: HttpRequest,
   body: web::Payload,
-  context: web::Data<LemmyContext>,
   rate_limiter: web::Data<RateLimitCell>,
+  apub_data: ContextData<LemmyContext>,
 ) -> Result<HttpResponse, Error> {
-  let (response, session, stream) = actix_ws::handle(&req, body)?;
-
   let client_ip = IpAddr(
     req
       .connection_info()
@@ -142,120 +176,171 @@ pub async fn websocket(
       "Websocket join with IP: {} has been rate limited.",
       &client_ip
     );
-    session.close(None).await.map_err(LemmyError::from)?;
-    return Ok(response);
+    return Ok(HttpResponse::TooManyRequests().finish());
   }
 
-  let connection_id = context.chat_server().handle_connect(session.clone())?;
-  info!("{} joined", &client_ip);
+  ws::start(
+    WsChatSession {
+      id: 0,
+      ip: client_ip,
+      hb: Instant::now(),
+      apub_data,
+    },
+    &req,
+    body,
+  )
+}
+
+/// helper method that sends ping to client every few seconds (HEARTBEAT_INTERVAL).
+///
+/// also this method checks heartbeats from client
+fn hb(ctx: &mut ws::WebsocketContext<WsChatSession>) {
+  ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
+    // check client heartbeats
+    if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
+      // heartbeat timed out
+
+      // notify chat server
+      act
+        .apub_data
+        .chat_server()
+        .do_send(Disconnect { id: act.id });
 
-  let alive = Arc::new(Mutex::new(Instant::now()));
-  heartbeat(session.clone(), alive.clone());
+      // stop actor
+      ctx.stop();
 
-  actix_rt::spawn(handle_messages(
-    stream,
-    client_ip,
-    session,
-    connection_id,
-    alive,
-    rate_limiter,
-    context,
-  ));
+      // don't try to send a ping
+      return;
+    }
 
-  Ok(response)
+    ctx.ping(b"");
+  });
 }
 
-async fn handle_messages(
-  mut stream: MessageStream,
-  client_ip: IpAddr,
-  mut session: Session,
-  connection_id: ConnectionId,
-  alive: Arc<Mutex<Instant>>,
-  rate_limiter: web::Data<RateLimitCell>,
-  context: web::Data<LemmyContext>,
-) -> Result<(), LemmyError> {
-  while let Some(Ok(msg)) = stream.next().await {
-    match msg {
-      ws::Message::Ping(bytes) => {
-        if session.pong(&bytes).await.is_err() {
-          break;
+impl Actor for WsChatSession {
+  type Context = ws::WebsocketContext<Self>;
+
+  /// Method is called on actor start.
+  /// We register ws session with ChatServer
+  fn started(&mut self, ctx: &mut Self::Context) {
+    // we'll start heartbeat process on session start.
+    hb(ctx);
+
+    // register self in chat server. `AsyncContext::wait` register
+    // future within context, but context waits until this future resolves
+    // before processing any other events.
+    // HttpContext::state() is instance of WsChatSessionState, state is shared
+    // across all routes within application
+    let addr = ctx.address();
+    self
+      .apub_data
+      .chat_server()
+      .send(Connect {
+        addr: addr.recipient(),
+      })
+      .into_actor(self)
+      .then(|res, act, ctx| {
+        match res {
+          Ok(res) => act.id = res,
+          // something is wrong with chat server
+          _ => ctx.stop(),
         }
+        fut::ready(())
+      })
+      .wait(ctx);
+  }
+  fn stopping(&mut self, _: &mut Self::Context) -> Running {
+    // notify chat server
+    self
+      .apub_data
+      .chat_server()
+      .do_send(Disconnect { id: self.id });
+    Running::Stop
+  }
+}
+
+/// Handle messages from chat server, we simply send it to peer websocket
+impl Handler<WsMessage> for WsChatSession {
+  type Result = ();
+
+  fn handle(&mut self, msg: WsMessage, ctx: &mut Self::Context) {
+    ctx.text(msg.0);
+  }
+}
+
+/// WebSocket message handler
+impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsChatSession {
+  fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
+    let msg = match msg {
+      Err(_) => {
+        ctx.stop();
+        return;
+      }
+      Ok(msg) => msg,
+    };
+
+    match msg {
+      ws::Message::Ping(msg) => {
+        self.hb = Instant::now();
+        ctx.pong(&msg);
       }
       ws::Message::Pong(_) => {
-        let mut lock = alive
-          .lock()
-          .expect("Failed to acquire websocket heartbeat alive lock");
-        *lock = Instant::now();
+        self.hb = Instant::now();
       }
       ws::Message::Text(text) => {
-        let msg = text.trim().to_string();
-        let executed = parse_json_message(
-          msg,
-          client_ip.clone(),
-          connection_id,
-          rate_limiter.get_ref(),
-          context.get_ref().clone(),
-        )
-        .await;
-
-        let res = executed.unwrap_or_else(|e| {
-          error!("Error during message handling {}", e);
-          e.to_json()
-            .unwrap_or_else(|_| String::from(r#"{"error":"failed to serialize json"}"#))
+        let ip_clone = self.ip.clone();
+        let id_clone = self.id.to_owned();
+        let context_clone = self.apub_data.reset_request_count();
+
+        let fut = Box::pin(async move {
+          let msg = text.trim().to_string();
+          parse_json_message(msg, ip_clone, id_clone, context_clone).await
         });
-        session.text(res).await?;
-      }
-      ws::Message::Close(_) => {
-        session.close(None).await?;
-        context.chat_server().handle_disconnect(&connection_id)?;
-        break;
+        fut
+          .into_actor(self)
+          .then(|res, _, ctx| {
+            match res {
+              Ok(res) => ctx.text(res),
+              Err(e) => error!("{}", &e),
+            }
+            actix::fut::ready(())
+          })
+          .spawn(ctx);
       }
-      ws::Message::Binary(_) => info!("Unexpected binary"),
-      _ => {}
-    }
-  }
-  Ok(())
-}
-
-fn heartbeat(mut session: Session, alive: Arc<Mutex<Instant>>) {
-  actix_rt::spawn(async move {
-    let mut interval = actix_rt::time::interval(Duration::from_secs(5));
-    loop {
-      if session.ping(b"").await.is_err() {
-        break;
+      ws::Message::Binary(_) => println!("Unexpected binary"),
+      ws::Message::Close(reason) => {
+        ctx.close(reason);
+        ctx.stop();
       }
-
-      let duration_since = {
-        let alive_lock = alive
-          .lock()
-          .expect("Failed to acquire websocket heartbeat alive lock");
-        Instant::now().duration_since(*alive_lock)
-      };
-      if duration_since > Duration::from_secs(10) {
-        let _ = session.close(None).await;
-        break;
+      ws::Message::Continuation(_) => {
+        ctx.stop();
       }
-      interval.tick().await;
+      ws::Message::Nop => (),
     }
-  });
+  }
 }
 
+/// Entry point for our websocket route
 async fn parse_json_message(
   msg: String,
   ip: IpAddr,
   connection_id: ConnectionId,
-  rate_limiter: &RateLimitCell,
-  context: LemmyContext,
+  context: ContextData<LemmyContext>,
 ) -> Result<String, LemmyError> {
+  let rate_limiter = context.settings_updated_channel();
   let json: Value = serde_json::from_str(&msg)?;
-  let data = &json
+  let data = json
     .get("data")
-    .ok_or_else(|| LemmyError::from_message("missing data"))?
-    .to_string();
-  let op = &json
+    .cloned()
+    .ok_or_else(|| LemmyError::from_message("missing data"))?;
+
+  let missing_op_err = || LemmyError::from_message("missing op");
+
+  let op = json
     .get("op")
-    .ok_or_else(|| LemmyError::from_message("missing op"))?
-    .to_string();
+    .ok_or_else(missing_op_err)?
+    .as_str()
+    .ok_or_else(missing_op_err)?;
 
   // check if api call passes the rate limit, and generate future for later execution
   if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
@@ -296,10 +381,10 @@ fn check_rate_limit_passed(passed: bool) -> Result<(), LemmyError> {
 }
 
 pub async fn match_websocket_operation_crud(
-  context: LemmyContext,
+  context: ContextData<LemmyContext>,
   id: ConnectionId,
   op: UserOperationCrud,
-  data: &str,
+  data: Value,
 ) -> result::Result<String, LemmyError> {
   match op {
     // User ops
@@ -385,32 +470,42 @@ pub async fn match_websocket_operation_crud(
     UserOperationCrud::GetComment => {
       do_websocket_operation_crud::<GetComment>(context, id, op, data).await
     }
+    // Emojis
+    UserOperationCrud::CreateCustomEmoji => {
+      do_websocket_operation_crud::<CreateCustomEmoji>(context, id, op, data).await
+    }
+    UserOperationCrud::EditCustomEmoji => {
+      do_websocket_operation_crud::<EditCustomEmoji>(context, id, op, data).await
+    }
+    UserOperationCrud::DeleteCustomEmoji => {
+      do_websocket_operation_crud::<DeleteCustomEmoji>(context, id, op, data).await
+    }
   }
 }
 
 async fn do_websocket_operation_crud<'a, 'b, Data>(
-  context: LemmyContext,
+  context: ContextData<LemmyContext>,
   id: ConnectionId,
   op: UserOperationCrud,
-  data: &str,
+  data: Value,
 ) -> result::Result<String, LemmyError>
 where
-  Data: PerformCrud + SendActivity<Response = <Data as PerformCrud>::Response>,
+  Data: PerformCrud + SendActivity<Response = <Data as PerformCrud>::Response> + Send,
   for<'de> Data: Deserialize<'de>,
 {
-  let parsed_data: Data = serde_json::from_str(data)?;
+  let parsed_data: Data = serde_json::from_value(data)?;
   let res = parsed_data
-    .perform(&web::Data::new(context.clone()), Some(id))
+    .perform(&web::Data::new(context.deref().clone()), Some(id))
     .await?;
   SendActivity::send_activity(&parsed_data, &res, &context).await?;
   serialize_websocket_message(&op, &res)
 }
 
 pub async fn match_websocket_operation_apub(
-  context: LemmyContext,
+  context: ContextData<LemmyContext>,
   id: ConnectionId,
   op: UserOperationApub,
-  data: &str,
+  data: Value,
 ) -> result::Result<String, LemmyError> {
   match op {
     UserOperationApub::GetPersonDetails => {
@@ -433,28 +528,26 @@ pub async fn match_websocket_operation_apub(
 }
 
 async fn do_websocket_operation_apub<'a, 'b, Data>(
-  context: LemmyContext,
+  context: ContextData<LemmyContext>,
   id: ConnectionId,
   op: UserOperationApub,
-  data: &str,
+  data: Value,
 ) -> result::Result<String, LemmyError>
 where
-  Data: PerformApub + SendActivity<Response = <Data as PerformApub>::Response>,
+  Data: PerformApub + SendActivity<Response = <Data as PerformApub>::Response> + Send,
   for<'de> Data: Deserialize<'de>,
 {
-  let parsed_data: Data = serde_json::from_str(data)?;
-  let res = parsed_data
-    .perform(&web::Data::new(context.clone()), Some(id))
-    .await?;
+  let parsed_data: Data = serde_json::from_value(data)?;
+  let res = parsed_data.perform(&context, Some(id)).await?;
   SendActivity::send_activity(&parsed_data, &res, &context).await?;
   serialize_websocket_message(&op, &res)
 }
 
 pub async fn match_websocket_operation(
-  context: LemmyContext,
+  context: ContextData<LemmyContext>,
   id: ConnectionId,
   op: UserOperation,
-  data: &str,
+  data: Value,
 ) -> result::Result<String, LemmyError> {
   match op {
     // User ops
@@ -548,6 +641,9 @@ pub async fn match_websocket_operation(
       do_websocket_operation::<TransferCommunity>(context, id, op, data).await
     }
     UserOperation::LeaveAdmin => do_websocket_operation::<LeaveAdmin>(context, id, op, data).await,
+    UserOperation::GetFederatedInstances => {
+      do_websocket_operation::<GetFederatedInstances>(context, id, op, data).await
+    }
 
     // Community ops
     UserOperation::FollowCommunity => {
@@ -611,18 +707,18 @@ pub async fn match_websocket_operation(
 }
 
 async fn do_websocket_operation<'a, 'b, Data>(
-  context: LemmyContext,
+  context: ContextData<LemmyContext>,
   id: ConnectionId,
   op: UserOperation,
-  data: &str,
+  data: Value,
 ) -> result::Result<String, LemmyError>
 where
-  Data: Perform + SendActivity<Response = <Data as Perform>::Response>,
+  Data: Perform + SendActivity<Response = <Data as Perform>::Response> + Send,
   for<'de> Data: Deserialize<'de>,
 {
-  let parsed_data: Data = serde_json::from_str(data)?;
+  let parsed_data: Data = serde_json::from_value(data)?;
   let res = parsed_data
-    .perform(&web::Data::new(context.clone()), Some(id))
+    .perform(&web::Data::new(context.deref().clone()), Some(id))
     .await?;
   SendActivity::send_activity(&parsed_data, &res, &context).await?;
   serialize_websocket_message(&op, &res)