]> Untitled Git - lemmy.git/commitdiff
Trying to add r2d2 connection pooling to websockets.
authorDessalines <tyhou13@gmx.com>
Sun, 12 Jan 2020 15:31:51 +0000 (10:31 -0500)
committerDessalines <tyhou13@gmx.com>
Sun, 12 Jan 2020 15:31:51 +0000 (10:31 -0500)
25 files changed:
server/src/api/comment.rs
server/src/api/community.rs
server/src/api/mod.rs
server/src/api/post.rs
server/src/api/site.rs
server/src/api/user.rs
server/src/apub/community.rs
server/src/apub/user.rs
server/src/db/category.rs
server/src/db/comment.rs
server/src/db/comment_view.rs
server/src/db/community.rs
server/src/db/mod.rs
server/src/db/moderator.rs
server/src/db/password_reset_request.rs
server/src/db/post.rs
server/src/db/post_view.rs
server/src/db/user.rs
server/src/db/user_mention.rs
server/src/main.rs
server/src/routes/feeds.rs
server/src/routes/nodeinfo.rs
server/src/routes/webfinger.rs
server/src/routes/websocket.rs
server/src/websocket/server.rs

index 62759578bc4c5b7d91ba9d912686fe4f730b9119..61cc95063344c5af6618b95a3ae157cf6835bb21 100644 (file)
@@ -1,6 +1,7 @@
 use super::*;
 use crate::send_email;
 use crate::settings::Settings;
+use diesel::PgConnection;
 
 #[derive(Serialize, Deserialize)]
 pub struct CreateComment {
@@ -47,9 +48,8 @@ pub struct CreateCommentLike {
 }
 
 impl Perform<CommentResponse> for Oper<CreateComment> {
-  fn perform(&self) -> Result<CommentResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<CommentResponse, Error> {
     let data: &CreateComment = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -206,9 +206,8 @@ impl Perform<CommentResponse> for Oper<CreateComment> {
 }
 
 impl Perform<CommentResponse> for Oper<EditComment> {
-  fn perform(&self) -> Result<CommentResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<CommentResponse, Error> {
     let data: &EditComment = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -318,9 +317,8 @@ impl Perform<CommentResponse> for Oper<EditComment> {
 }
 
 impl Perform<CommentResponse> for Oper<SaveComment> {
-  fn perform(&self) -> Result<CommentResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<CommentResponse, Error> {
     let data: &SaveComment = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -356,9 +354,8 @@ impl Perform<CommentResponse> for Oper<SaveComment> {
 }
 
 impl Perform<CommentResponse> for Oper<CreateCommentLike> {
-  fn perform(&self) -> Result<CommentResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<CommentResponse, Error> {
     let data: &CreateCommentLike = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
index a1109c03c8137c0c788b599403bdf7043700690c..0bf846c3d9dd4488a1bf62811657b339114fdce9 100644 (file)
@@ -1,4 +1,5 @@
 use super::*;
+use diesel::PgConnection;
 use std::str::FromStr;
 
 #[derive(Serialize, Deserialize)]
@@ -118,9 +119,8 @@ pub struct TransferCommunity {
 }
 
 impl Perform<GetCommunityResponse> for Oper<GetCommunity> {
-  fn perform(&self) -> Result<GetCommunityResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetCommunityResponse, Error> {
     let data: &GetCommunity = &self.data;
-    let conn = establish_connection();
 
     let user_id: Option<i32> = match &data.auth {
       Some(auth) => match Claims::decode(&auth) {
@@ -173,9 +173,8 @@ impl Perform<GetCommunityResponse> for Oper<GetCommunity> {
 }
 
 impl Perform<CommunityResponse> for Oper<CreateCommunity> {
-  fn perform(&self) -> Result<CommunityResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<CommunityResponse, Error> {
     let data: &CreateCommunity = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -248,15 +247,13 @@ impl Perform<CommunityResponse> for Oper<CreateCommunity> {
 }
 
 impl Perform<CommunityResponse> for Oper<EditCommunity> {
-  fn perform(&self) -> Result<CommunityResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<CommunityResponse, Error> {
     let data: &EditCommunity = &self.data;
 
     if has_slurs(&data.name) || has_slurs(&data.title) {
       return Err(APIError::err(&self.op, "no_slurs").into());
     }
 
-    let conn = establish_connection();
-
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
       Err(_e) => return Err(APIError::err(&self.op, "not_logged_in").into()),
@@ -325,9 +322,8 @@ impl Perform<CommunityResponse> for Oper<EditCommunity> {
 }
 
 impl Perform<ListCommunitiesResponse> for Oper<ListCommunities> {
-  fn perform(&self) -> Result<ListCommunitiesResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<ListCommunitiesResponse, Error> {
     let data: &ListCommunities = &self.data;
-    let conn = establish_connection();
 
     let user_claims: Option<Claims> = match &data.auth {
       Some(auth) => match Claims::decode(&auth) {
@@ -366,9 +362,8 @@ impl Perform<ListCommunitiesResponse> for Oper<ListCommunities> {
 }
 
 impl Perform<CommunityResponse> for Oper<FollowCommunity> {
-  fn perform(&self) -> Result<CommunityResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<CommunityResponse, Error> {
     let data: &FollowCommunity = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -404,9 +399,8 @@ impl Perform<CommunityResponse> for Oper<FollowCommunity> {
 }
 
 impl Perform<GetFollowedCommunitiesResponse> for Oper<GetFollowedCommunities> {
-  fn perform(&self) -> Result<GetFollowedCommunitiesResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetFollowedCommunitiesResponse, Error> {
     let data: &GetFollowedCommunities = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -430,9 +424,8 @@ impl Perform<GetFollowedCommunitiesResponse> for Oper<GetFollowedCommunities> {
 }
 
 impl Perform<BanFromCommunityResponse> for Oper<BanFromCommunity> {
-  fn perform(&self) -> Result<BanFromCommunityResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<BanFromCommunityResponse, Error> {
     let data: &BanFromCommunity = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -485,9 +478,8 @@ impl Perform<BanFromCommunityResponse> for Oper<BanFromCommunity> {
 }
 
 impl Perform<AddModToCommunityResponse> for Oper<AddModToCommunity> {
-  fn perform(&self) -> Result<AddModToCommunityResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<AddModToCommunityResponse, Error> {
     let data: &AddModToCommunity = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -536,9 +528,8 @@ impl Perform<AddModToCommunityResponse> for Oper<AddModToCommunity> {
 }
 
 impl Perform<GetCommunityResponse> for Oper<TransferCommunity> {
-  fn perform(&self) -> Result<GetCommunityResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetCommunityResponse, Error> {
     let data: &TransferCommunity = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
index 07712e874800f8426ae568f9f16071e8ad26de7d..e35804476850de62cc00bb2648a8afdaf0bf155d 100644 (file)
@@ -16,6 +16,7 @@ use crate::db::user_mention_view::*;
 use crate::db::user_view::*;
 use crate::db::*;
 use crate::{extract_usernames, has_slurs, naive_from_unix, naive_now, remove_slurs};
+use diesel::PgConnection;
 use failure::Error;
 use serde::{Deserialize, Serialize};
 
@@ -96,7 +97,7 @@ impl<T> Oper<T> {
 }
 
 pub trait Perform<T> {
-  fn perform(&self) -> Result<T, Error>
+  fn perform(&self, conn: &PgConnection) -> Result<T, Error>
   where
     T: Sized;
 }
index 5bc31defe1223647dae28e4b88a0bb47afecfeec..b0fcdd0c1b9f24023b106d1395232462d79ecd58 100644 (file)
@@ -1,4 +1,5 @@
 use super::*;
+use diesel::PgConnection;
 use std::str::FromStr;
 
 #[derive(Serialize, Deserialize)]
@@ -87,9 +88,8 @@ pub struct SavePost {
 }
 
 impl Perform<PostResponse> for Oper<CreatePost> {
-  fn perform(&self) -> Result<PostResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<PostResponse, Error> {
     let data: &CreatePost = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -158,9 +158,8 @@ impl Perform<PostResponse> for Oper<CreatePost> {
 }
 
 impl Perform<GetPostResponse> for Oper<GetPost> {
-  fn perform(&self) -> Result<GetPostResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetPostResponse, Error> {
     let data: &GetPost = &self.data;
-    let conn = establish_connection();
 
     let user_id: Option<i32> = match &data.auth {
       Some(auth) => match Claims::decode(&auth) {
@@ -207,9 +206,8 @@ impl Perform<GetPostResponse> for Oper<GetPost> {
 }
 
 impl Perform<GetPostsResponse> for Oper<GetPosts> {
-  fn perform(&self) -> Result<GetPostsResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetPostsResponse, Error> {
     let data: &GetPosts = &self.data;
-    let conn = establish_connection();
 
     let user_claims: Option<Claims> = match &data.auth {
       Some(auth) => match Claims::decode(&auth) {
@@ -254,9 +252,8 @@ impl Perform<GetPostsResponse> for Oper<GetPosts> {
 }
 
 impl Perform<CreatePostLikeResponse> for Oper<CreatePostLike> {
-  fn perform(&self) -> Result<CreatePostLikeResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<CreatePostLikeResponse, Error> {
     let data: &CreatePostLike = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -316,14 +313,12 @@ impl Perform<CreatePostLikeResponse> for Oper<CreatePostLike> {
 }
 
 impl Perform<PostResponse> for Oper<EditPost> {
-  fn perform(&self) -> Result<PostResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<PostResponse, Error> {
     let data: &EditPost = &self.data;
     if has_slurs(&data.name) || (data.body.is_some() && has_slurs(&data.body.to_owned().unwrap())) {
       return Err(APIError::err(&self.op, "no_slurs").into());
     }
 
-    let conn = establish_connection();
-
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
       Err(_e) => return Err(APIError::err(&self.op, "not_logged_in").into()),
@@ -412,9 +407,8 @@ impl Perform<PostResponse> for Oper<EditPost> {
 }
 
 impl Perform<PostResponse> for Oper<SavePost> {
-  fn perform(&self) -> Result<PostResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<PostResponse, Error> {
     let data: &SavePost = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
index 58c34e8fa85d93fb77c2cf8a392a1bf06f56cf9f..a189a0308866dfc48897c111446cc46e18a20491 100644 (file)
@@ -1,4 +1,5 @@
 use super::*;
+use diesel::PgConnection;
 use std::str::FromStr;
 
 #[derive(Serialize, Deserialize)]
@@ -97,9 +98,8 @@ pub struct TransferSite {
 }
 
 impl Perform<ListCategoriesResponse> for Oper<ListCategories> {
-  fn perform(&self) -> Result<ListCategoriesResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<ListCategoriesResponse, Error> {
     let _data: &ListCategories = &self.data;
-    let conn = establish_connection();
 
     let categories: Vec<Category> = Category::list_all(&conn)?;
 
@@ -112,9 +112,8 @@ impl Perform<ListCategoriesResponse> for Oper<ListCategories> {
 }
 
 impl Perform<GetModlogResponse> for Oper<GetModlog> {
-  fn perform(&self) -> Result<GetModlogResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetModlogResponse, Error> {
     let data: &GetModlog = &self.data;
-    let conn = establish_connection();
 
     let removed_posts = ModRemovePostView::list(
       &conn,
@@ -187,9 +186,8 @@ impl Perform<GetModlogResponse> for Oper<GetModlog> {
 }
 
 impl Perform<SiteResponse> for Oper<CreateSite> {
-  fn perform(&self) -> Result<SiteResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<SiteResponse, Error> {
     let data: &CreateSite = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -234,9 +232,8 @@ impl Perform<SiteResponse> for Oper<CreateSite> {
 }
 
 impl Perform<SiteResponse> for Oper<EditSite> {
-  fn perform(&self) -> Result<SiteResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<SiteResponse, Error> {
     let data: &EditSite = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -283,9 +280,8 @@ impl Perform<SiteResponse> for Oper<EditSite> {
 }
 
 impl Perform<GetSiteResponse> for Oper<GetSite> {
-  fn perform(&self) -> Result<GetSiteResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetSiteResponse, Error> {
     let _data: &GetSite = &self.data;
-    let conn = establish_connection();
 
     // It can return a null site in order to redirect
     let site_view = match Site::read(&conn, 1) {
@@ -314,9 +310,8 @@ impl Perform<GetSiteResponse> for Oper<GetSite> {
 }
 
 impl Perform<SearchResponse> for Oper<Search> {
-  fn perform(&self) -> Result<SearchResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<SearchResponse, Error> {
     let data: &Search = &self.data;
-    let conn = establish_connection();
 
     let sort = SortType::from_str(&data.sort)?;
     let type_ = SearchType::from_str(&data.type_)?;
@@ -419,9 +414,8 @@ impl Perform<SearchResponse> for Oper<Search> {
 }
 
 impl Perform<GetSiteResponse> for Oper<TransferSite> {
-  fn perform(&self) -> Result<GetSiteResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetSiteResponse, Error> {
     let data: &TransferSite = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
index 41729eb756196c212b0c6a82549d940cd08eda1e..20eb09c19a9ee6498c20e222213668a9f5567897 100644 (file)
@@ -2,6 +2,7 @@ use super::*;
 use crate::settings::Settings;
 use crate::{generate_random_string, send_email};
 use bcrypt::verify;
+use diesel::PgConnection;
 use std::str::FromStr;
 
 #[derive(Serialize, Deserialize, Debug)]
@@ -167,9 +168,8 @@ pub struct PasswordChange {
 }
 
 impl Perform<LoginResponse> for Oper<Login> {
-  fn perform(&self) -> Result<LoginResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> {
     let data: &Login = &self.data;
-    let conn = establish_connection();
 
     // Fetch that username / email
     let user: User_ = match User_::find_by_email_or_username(&conn, &data.username_or_email) {
@@ -192,9 +192,8 @@ impl Perform<LoginResponse> for Oper<Login> {
 }
 
 impl Perform<LoginResponse> for Oper<Register> {
-  fn perform(&self) -> Result<LoginResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> {
     let data: &Register = &self.data;
-    let conn = establish_connection();
 
     // Make sure site has open registration
     if let Ok(site) = SiteView::read(&conn) {
@@ -299,9 +298,8 @@ impl Perform<LoginResponse> for Oper<Register> {
 }
 
 impl Perform<LoginResponse> for Oper<SaveUserSettings> {
-  fn perform(&self) -> Result<LoginResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> {
     let data: &SaveUserSettings = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -378,9 +376,8 @@ impl Perform<LoginResponse> for Oper<SaveUserSettings> {
 }
 
 impl Perform<GetUserDetailsResponse> for Oper<GetUserDetails> {
-  fn perform(&self) -> Result<GetUserDetailsResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetUserDetailsResponse, Error> {
     let data: &GetUserDetails = &self.data;
-    let conn = establish_connection();
 
     let user_claims: Option<Claims> = match &data.auth {
       Some(auth) => match Claims::decode(&auth) {
@@ -470,9 +467,8 @@ impl Perform<GetUserDetailsResponse> for Oper<GetUserDetails> {
 }
 
 impl Perform<AddAdminResponse> for Oper<AddAdmin> {
-  fn perform(&self) -> Result<AddAdminResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<AddAdminResponse, Error> {
     let data: &AddAdmin = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -535,9 +531,8 @@ impl Perform<AddAdminResponse> for Oper<AddAdmin> {
 }
 
 impl Perform<BanUserResponse> for Oper<BanUser> {
-  fn perform(&self) -> Result<BanUserResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<BanUserResponse, Error> {
     let data: &BanUser = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -604,9 +599,8 @@ impl Perform<BanUserResponse> for Oper<BanUser> {
 }
 
 impl Perform<GetRepliesResponse> for Oper<GetReplies> {
-  fn perform(&self) -> Result<GetRepliesResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetRepliesResponse, Error> {
     let data: &GetReplies = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -632,9 +626,8 @@ impl Perform<GetRepliesResponse> for Oper<GetReplies> {
 }
 
 impl Perform<GetUserMentionsResponse> for Oper<GetUserMentions> {
-  fn perform(&self) -> Result<GetUserMentionsResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetUserMentionsResponse, Error> {
     let data: &GetUserMentions = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -660,9 +653,8 @@ impl Perform<GetUserMentionsResponse> for Oper<GetUserMentions> {
 }
 
 impl Perform<UserMentionResponse> for Oper<EditUserMention> {
-  fn perform(&self) -> Result<UserMentionResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<UserMentionResponse, Error> {
     let data: &EditUserMention = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -695,9 +687,8 @@ impl Perform<UserMentionResponse> for Oper<EditUserMention> {
 }
 
 impl Perform<GetRepliesResponse> for Oper<MarkAllAsRead> {
-  fn perform(&self) -> Result<GetRepliesResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<GetRepliesResponse, Error> {
     let data: &MarkAllAsRead = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -759,9 +750,8 @@ impl Perform<GetRepliesResponse> for Oper<MarkAllAsRead> {
 }
 
 impl Perform<LoginResponse> for Oper<DeleteAccount> {
-  fn perform(&self) -> Result<LoginResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> {
     let data: &DeleteAccount = &self.data;
-    let conn = establish_connection();
 
     let claims = match Claims::decode(&data.auth) {
       Ok(claims) => claims.claims,
@@ -838,9 +828,8 @@ impl Perform<LoginResponse> for Oper<DeleteAccount> {
 }
 
 impl Perform<PasswordResetResponse> for Oper<PasswordReset> {
-  fn perform(&self) -> Result<PasswordResetResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<PasswordResetResponse, Error> {
     let data: &PasswordReset = &self.data;
-    let conn = establish_connection();
 
     // Fetch that email
     let user: User_ = match User_::find_by_email(&conn, &data.email) {
@@ -872,9 +861,8 @@ impl Perform<PasswordResetResponse> for Oper<PasswordReset> {
 }
 
 impl Perform<LoginResponse> for Oper<PasswordChange> {
-  fn perform(&self) -> Result<LoginResponse, Error> {
+  fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> {
     let data: &PasswordChange = &self.data;
-    let conn = establish_connection();
 
     // Fetch the user_id from the token
     let user_id = PasswordResetRequest::read_from_token(&conn, &data.token)?.user_id;
index fac6088e46418eabdf88627e05a34467decc9413..32f14eeb28f7309b6832c9b51555f4aa7027505c 100644 (file)
@@ -1,7 +1,7 @@
 use crate::apub::make_apub_endpoint;
 use crate::db::community::Community;
 use crate::db::community_view::CommunityFollowerView;
-use crate::db::establish_connection;
+use crate::db::establish_unpooled_connection;
 use crate::to_datetime_utc;
 use activitypub::{actor::Group, collection::UnorderedCollection, context};
 use actix_web::body::Body;
@@ -62,7 +62,7 @@ impl Community {
     collection.object_props.set_context_object(context()).ok();
     collection.object_props.set_id_string(base_url).ok();
 
-    let connection = establish_connection();
+    let connection = establish_unpooled_connection();
     //As we are an object, we validated that the community id was valid
     let community_followers = CommunityFollowerView::for_community(&connection, self.id).unwrap();
 
@@ -85,7 +85,7 @@ pub struct CommunityQuery {
 }
 
 pub async fn get_apub_community(info: Path<CommunityQuery>) -> HttpResponse<Body> {
-  let connection = establish_connection();
+  let connection = establish_unpooled_connection();
 
   if let Ok(community) = Community::read_from_name(&connection, info.community_name.to_owned()) {
     HttpResponse::Ok()
@@ -97,7 +97,7 @@ pub async fn get_apub_community(info: Path<CommunityQuery>) -> HttpResponse<Body
 }
 
 pub async fn get_apub_community_followers(info: Path<CommunityQuery>) -> HttpResponse<Body> {
-  let connection = establish_connection();
+  let connection = establish_unpooled_connection();
 
   if let Ok(community) = Community::read_from_name(&connection, info.community_name.to_owned()) {
     HttpResponse::Ok()
index cf9a9797cf8277787d53f2fd7e63678e371b43d0..5f2421f11cdb57c9ae6b7997a6f993859b3c8111 100644 (file)
@@ -1,5 +1,5 @@
 use crate::apub::make_apub_endpoint;
-use crate::db::establish_connection;
+use crate::db::establish_unpooled_connection;
 use crate::db::user::User_;
 use crate::to_datetime_utc;
 use activitypub::{actor::Person, context};
@@ -62,7 +62,7 @@ pub struct UserQuery {
 }
 
 pub async fn get_apub_user(info: Path<UserQuery>) -> HttpResponse<Body> {
-  let connection = establish_connection();
+  let connection = establish_unpooled_connection();
 
   if let Ok(user) = User_::find_by_email_or_username(&connection, &info.user_name) {
     HttpResponse::Ok()
index 6e483ce0878b5551290cd5e8107ecca0d2553e63..408c8231bfb215e43319cb4aac62dd86baeb48ea 100644 (file)
@@ -52,7 +52,7 @@ mod tests {
   use super::*;
   #[test]
   fn test_crud() {
-    let conn = establish_connection();
+    let conn = establish_unpooled_connection();
 
     let categories = Category::list_all(&conn).unwrap();
     let expected_first_category = Category {
index b319805225f3920749d161384cabc9ecf95ea90a..a9c7d81ddc04ad58af9401e1c678b41f0c289900 100644 (file)
@@ -166,7 +166,7 @@ mod tests {
   use super::*;
   #[test]
   fn test_crud() {
-    let conn = establish_connection();
+    let conn = establish_unpooled_connection();
 
     let new_user = UserForm {
       name: "terry".into(),
index 2942bbe746ac104723f26986ccc936d9a7a84d2e..ba085af64aaa4222483f57c0140a9e06ad969ec4 100644 (file)
@@ -364,7 +364,7 @@ mod tests {
   use super::*;
   #[test]
   fn test_crud() {
-    let conn = establish_connection();
+    let conn = establish_unpooled_connection();
 
     let new_user = UserForm {
       name: "timmy".into(),
index 09c3ddc4ff167d408dd8aead3ad4e0d97b9d3d40..b482ca4a91d5f78df33086688dc32c1cdc9bdf13 100644 (file)
@@ -212,7 +212,7 @@ mod tests {
   use super::*;
   #[test]
   fn test_crud() {
-    let conn = establish_connection();
+    let conn = establish_unpooled_connection();
 
     let new_user = UserForm {
       name: "bobbee".into(),
index fe6cb3ce4c845125d6b9120609f5f1ec4a4a6c19..c25a64560a7226452755c06050f74cfd6a01eeaa 100644 (file)
@@ -1,7 +1,6 @@
 extern crate lazy_static;
 use crate::settings::Settings;
 use diesel::dsl::*;
-use diesel::r2d2::*;
 use diesel::result::Error;
 use diesel::*;
 use serde::{Deserialize, Serialize};
@@ -111,19 +110,9 @@ impl<T> MaybeOptional<T> for Option<T> {
   }
 }
 
-lazy_static! {
-  static ref PG_POOL: Pool<ConnectionManager<PgConnection>> = {
-    let db_url = Settings::get().get_database_url();
-    let manager = ConnectionManager::<PgConnection>::new(&db_url);
-    Pool::builder()
-      .max_size(Settings::get().database.pool_size)
-      .build(manager)
-      .unwrap_or_else(|_| panic!("Error connecting to {}", db_url))
-  };
-}
-
-pub fn establish_connection() -> PooledConnection<ConnectionManager<PgConnection>> {
-  PG_POOL.get().unwrap()
+pub fn establish_unpooled_connection() -> PgConnection {
+  let db_url = Settings::get().get_database_url();
+  PgConnection::establish(&db_url).expect(&format!("Error connecting to {}", db_url))
 }
 
 #[derive(EnumString, ToString, Debug, Serialize, Deserialize)]
index dc018bd93957ccbb753c87c92acdc18457383368..3c6233cb99d43a27c79f82491feb5df127ffb356 100644 (file)
@@ -434,7 +434,7 @@ mod tests {
   // use Crud;
   #[test]
   fn test_crud() {
-    let conn = establish_connection();
+    let conn = establish_unpooled_connection();
 
     let new_mod = UserForm {
       name: "the mod".into(),
index 1664516b36bd896b2c4e44b5c921bbb2ba84c06c..fa060a591bfc6feb27f04787ae850b9ed75b2f60 100644 (file)
@@ -84,7 +84,7 @@ mod tests {
 
   #[test]
   fn test_crud() {
-    let conn = establish_connection();
+    let conn = establish_unpooled_connection();
 
     let new_user = UserForm {
       name: "thommy prw".into(),
index 084edc9bf3867013f127fa75529f4c01fea7127a..d3fba4dad0ff7e009c523f0478d16dbce194baa8 100644 (file)
@@ -179,7 +179,7 @@ mod tests {
   use super::*;
   #[test]
   fn test_crud() {
-    let conn = establish_connection();
+    let conn = establish_unpooled_connection();
 
     let new_user = UserForm {
       name: "jim".into(),
index 0efcab43e59fbc5e753a4104e5bdcaad5e5cc73b..5217e73a7a5607af4e76c9bc27a83b2626f97555 100644 (file)
@@ -290,7 +290,7 @@ mod tests {
   use super::*;
   #[test]
   fn test_crud() {
-    let conn = establish_connection();
+    let conn = establish_unpooled_connection();
 
     let user_name = "tegan".to_string();
     let community_name = "test_community_3".to_string();
index b04012974128a427b6b8912986d032f96af9453d..71b63d742c8a62c0a51bb9df44fcf90b71497821 100644 (file)
@@ -176,7 +176,7 @@ mod tests {
 
   #[test]
   fn test_crud() {
-    let conn = establish_connection();
+    let conn = establish_unpooled_connection();
 
     let new_user = UserForm {
       name: "thommy".into(),
index 67286779769206de8cda5242567d67d89a9aedb1..21dd1675d3591bba1addb7c21d15cac86dcdd442 100644 (file)
@@ -60,7 +60,7 @@ mod tests {
   use super::*;
   #[test]
   fn test_crud() {
-    let conn = establish_connection();
+    let conn = establish_unpooled_connection();
 
     let new_user = UserForm {
       name: "terrylake".into(),
index 763f540fbca5e67d19d7ace1c924761138ffe47f..ea49cd13340d3d282e8742914cd55634bf039e9d 100644 (file)
@@ -3,7 +3,8 @@ extern crate lemmy_server;
 extern crate diesel_migrations;
 
 use actix_web::*;
-use lemmy_server::db::establish_connection;
+use diesel::r2d2::{ConnectionManager, Pool};
+use diesel::PgConnection;
 use lemmy_server::routes::{federation, feeds, index, nodeinfo, webfinger, websocket};
 use lemmy_server::settings::Settings;
 use std::io;
@@ -13,13 +14,19 @@ embed_migrations!();
 #[actix_rt::main]
 async fn main() -> io::Result<()> {
   env_logger::init();
+  let settings = Settings::get();
+
+  // Set up the r2d2 connection pool
+  let manager = ConnectionManager::<PgConnection>::new(&settings.get_database_url());
+  let pool = Pool::builder()
+    .max_size(settings.database.pool_size)
+    .build(manager)
+    .unwrap_or_else(|_| panic!("Error connecting to {}", settings.get_database_url()));
 
   // Run the migrations from code
-  let conn = establish_connection();
+  let conn = pool.get().unwrap();
   embedded_migrations::run(&conn).unwrap();
 
-  let settings = Settings::get();
-
   println!(
     "Starting http server at {}:{}",
     settings.bind, settings.port
@@ -28,12 +35,18 @@ async fn main() -> io::Result<()> {
   // Create Http server with websocket support
   HttpServer::new(move || {
     App::new()
+      .wrap(middleware::Logger::default())
+      .data(pool.clone())
+      // The routes
       .configure(federation::config)
       .configure(feeds::config)
       .configure(index::config)
       .configure(nodeinfo::config)
       .configure(webfinger::config)
       .configure(websocket::config)
+      // .configure(websocket.config(pool))
+      // .configure(websocket.
+      // static files
       .service(actix_files::Files::new(
         "/static",
         settings.front_end_dir.to_owned(),
index ae1631e2d7780dabfc0f23adad8375f1dfc923da..ad0f28d5a14c8f9fbdec8f09cb8f89f9199aa444 100644 (file)
@@ -7,11 +7,12 @@ use crate::db::post_view::{PostQueryBuilder, PostView};
 use crate::db::site_view::SiteView;
 use crate::db::user::{Claims, User_};
 use crate::db::user_mention_view::{UserMentionQueryBuilder, UserMentionView};
-use crate::db::{establish_connection, ListingType, SortType};
+use crate::db::{ListingType, SortType};
 use crate::Settings;
-use actix_web::body::Body;
 use actix_web::{web, HttpResponse, Result};
 use chrono::{DateTime, Utc};
+use diesel::r2d2::{ConnectionManager, Pool};
+use diesel::PgConnection;
 use failure::Error;
 use rss::{CategoryBuilder, ChannelBuilder, GuidBuilder, Item, ItemBuilder};
 use serde::Deserialize;
@@ -37,54 +38,61 @@ pub fn config(cfg: &mut web::ServiceConfig) {
     .route("/feeds/all.xml", web::get().to(feeds::get_all_feed));
 }
 
-async fn get_all_feed(info: web::Query<Params>) -> HttpResponse<Body> {
-  let sort_type = match get_sort_type(info) {
-    Ok(sort_type) => sort_type,
-    Err(_) => return HttpResponse::BadRequest().finish(),
-  };
-
-  let feed_result = get_feed_all_data(&sort_type);
-
-  match feed_result {
-    Ok(rss) => HttpResponse::Ok()
+async fn get_all_feed(
+  info: web::Query<Params>,
+  db: web::Data<Pool<ConnectionManager<PgConnection>>>,
+) -> Result<HttpResponse, actix_web::Error> {
+  let res = web::block(move || {
+    let conn = db.get()?;
+
+    let sort_type = get_sort_type(info)?;
+    get_feed_all_data(&conn, &sort_type)
+  })
+  .await
+  .map(|rss| {
+    HttpResponse::Ok()
       .content_type("application/rss+xml")
-      .body(rss),
-    Err(_) => HttpResponse::NotFound().finish(),
-  }
+      .body(rss)
+  })
+  .map_err(|_| HttpResponse::InternalServerError())?;
+  Ok(res)
 }
 
 async fn get_feed(
   path: web::Path<(String, String)>,
   info: web::Query<Params>,
-) -> HttpResponse<Body> {
-  let sort_type = match get_sort_type(info) {
-    Ok(sort_type) => sort_type,
-    Err(_) => return HttpResponse::BadRequest().finish(),
-  };
-
-  let request_type = match path.0.as_ref() {
-    "u" => RequestType::User,
-    "c" => RequestType::Community,
-    "front" => RequestType::Front,
-    "inbox" => RequestType::Inbox,
-    _ => return HttpResponse::NotFound().finish(),
-  };
-
-  let param = path.1.to_owned();
-
-  let feed_result = match request_type {
-    RequestType::User => get_feed_user(&sort_type, param),
-    RequestType::Community => get_feed_community(&sort_type, param),
-    RequestType::Front => get_feed_front(&sort_type, param),
-    RequestType::Inbox => get_feed_inbox(param),
-  };
-
-  match feed_result {
-    Ok(rss) => HttpResponse::Ok()
+  db: web::Data<Pool<ConnectionManager<PgConnection>>>,
+) -> Result<HttpResponse, actix_web::Error> {
+  let res = web::block(move || {
+    let conn = db.get()?;
+
+    let sort_type = get_sort_type(info)?;
+
+    let request_type = match path.0.as_ref() {
+      "u" => RequestType::User,
+      "c" => RequestType::Community,
+      "front" => RequestType::Front,
+      "inbox" => RequestType::Inbox,
+      _ => return Err(format_err!("wrong_type")),
+    };
+
+    let param = path.1.to_owned();
+
+    match request_type {
+      RequestType::User => get_feed_user(&conn, &sort_type, param),
+      RequestType::Community => get_feed_community(&conn, &sort_type, param),
+      RequestType::Front => get_feed_front(&conn, &sort_type, param),
+      RequestType::Inbox => get_feed_inbox(&conn, param),
+    }
+  })
+  .await
+  .map(|rss| {
+    HttpResponse::Ok()
       .content_type("application/rss+xml")
-      .body(rss),
-    Err(_) => HttpResponse::NotFound().finish(),
-  }
+      .body(rss)
+  })
+  .map_err(|_| HttpResponse::InternalServerError())?;
+  Ok(res)
 }
 
 fn get_sort_type(info: web::Query<Params>) -> Result<SortType, ParseError> {
@@ -95,9 +103,7 @@ fn get_sort_type(info: web::Query<Params>) -> Result<SortType, ParseError> {
   SortType::from_str(&sort_query)
 }
 
-fn get_feed_all_data(sort_type: &SortType) -> Result<String, Error> {
-  let conn = establish_connection();
-
+fn get_feed_all_data(conn: &PgConnection, sort_type: &SortType) -> Result<String, failure::Error> {
   let site_view = SiteView::read(&conn)?;
 
   let posts = PostQueryBuilder::create(&conn)
@@ -120,9 +126,11 @@ fn get_feed_all_data(sort_type: &SortType) -> Result<String, Error> {
   Ok(channel_builder.build().unwrap().to_string())
 }
 
-fn get_feed_user(sort_type: &SortType, user_name: String) -> Result<String, Error> {
-  let conn = establish_connection();
-
+fn get_feed_user(
+  conn: &PgConnection,
+  sort_type: &SortType,
+  user_name: String,
+) -> Result<String, Error> {
   let site_view = SiteView::read(&conn)?;
   let user = User_::find_by_username(&conn, &user_name)?;
   let user_url = user.get_profile_url();
@@ -144,9 +152,11 @@ fn get_feed_user(sort_type: &SortType, user_name: String) -> Result<String, Erro
   Ok(channel_builder.build().unwrap().to_string())
 }
 
-fn get_feed_community(sort_type: &SortType, community_name: String) -> Result<String, Error> {
-  let conn = establish_connection();
-
+fn get_feed_community(
+  conn: &PgConnection,
+  sort_type: &SortType,
+  community_name: String,
+) -> Result<String, Error> {
   let site_view = SiteView::read(&conn)?;
   let community = Community::read_from_name(&conn, community_name)?;
   let community_url = community.get_url();
@@ -172,9 +182,7 @@ fn get_feed_community(sort_type: &SortType, community_name: String) -> Result<St
   Ok(channel_builder.build().unwrap().to_string())
 }
 
-fn get_feed_front(sort_type: &SortType, jwt: String) -> Result<String, Error> {
-  let conn = establish_connection();
-
+fn get_feed_front(conn: &PgConnection, sort_type: &SortType, jwt: String) -> Result<String, Error> {
   let site_view = SiteView::read(&conn)?;
   let user_id = Claims::decode(&jwt)?.claims.id;
 
@@ -199,9 +207,7 @@ fn get_feed_front(sort_type: &SortType, jwt: String) -> Result<String, Error> {
   Ok(channel_builder.build().unwrap().to_string())
 }
 
-fn get_feed_inbox(jwt: String) -> Result<String, Error> {
-  let conn = establish_connection();
-
+fn get_feed_inbox(conn: &PgConnection, jwt: String) -> Result<String, Error> {
   let site_view = SiteView::read(&conn)?;
   let user_id = Claims::decode(&jwt)?.claims.id;
 
index 2b7135fba462b787b20e1b664bc683fc1c3b55d9..6ab540f9e459e16b4c147949a3e278bbc394cb35 100644 (file)
@@ -1,10 +1,11 @@
-use crate::db::establish_connection;
 use crate::db::site_view::SiteView;
 use crate::version;
 use crate::Settings;
 use actix_web::body::Body;
 use actix_web::web;
 use actix_web::HttpResponse;
+use diesel::r2d2::{ConnectionManager, Pool};
+use diesel::PgConnection;
 use serde_json::json;
 
 pub fn config(cfg: &mut web::ServiceConfig) {
@@ -26,34 +27,39 @@ async fn node_info_well_known() -> HttpResponse<Body> {
     .body(json.to_string())
 }
 
-async fn node_info() -> HttpResponse<Body> {
-  let conn = establish_connection();
-  let site_view = match SiteView::read(&conn) {
-    Ok(site_view) => site_view,
-    Err(_e) => return HttpResponse::InternalServerError().finish(),
-  };
-  let protocols = if Settings::get().federation_enabled {
-    vec!["activitypub"]
-  } else {
-    vec![]
-  };
-  let json = json!({
-    "version": "2.0",
-    "software": {
-      "name": "lemmy",
-      "version": version::VERSION,
-    },
-    "protocols": protocols,
-    "usage": {
-      "users": {
-        "total": site_view.number_of_users
+async fn node_info(
+  db: web::Data<Pool<ConnectionManager<PgConnection>>>,
+) -> Result<HttpResponse, actix_web::Error> {
+  let res = web::block(move || {
+    let conn = db.get()?;
+    let site_view = match SiteView::read(&conn) {
+      Ok(site_view) => site_view,
+      Err(_) => return Err(format_err!("not_found")),
+    };
+    let protocols = if Settings::get().federation_enabled {
+      vec!["activitypub"]
+    } else {
+      vec![]
+    };
+    Ok(json!({
+      "version": "2.0",
+      "software": {
+        "name": "lemmy",
+        "version": version::VERSION,
       },
-      "localPosts": site_view.number_of_posts,
-      "localComments": site_view.number_of_comments,
-      "openRegistrations": site_view.open_registration,
+      "protocols": protocols,
+      "usage": {
+        "users": {
+          "total": site_view.number_of_users
+        },
+        "localPosts": site_view.number_of_posts,
+        "localComments": site_view.number_of_comments,
+        "openRegistrations": site_view.open_registration,
       }
-  });
-  HttpResponse::Ok()
-    .content_type("application/json")
-    .body(json.to_string())
+    }))
+  })
+  .await
+  .map(|json| HttpResponse::Ok().json(json))
+  .map_err(|_| HttpResponse::InternalServerError())?;
+  Ok(res)
 }
index c538f5b1d19329740fad9b70df3ea96e3904bd49..20f76a9ad0c898598ae0a975cf5cf42bacf2d015 100644 (file)
@@ -1,10 +1,10 @@
 use crate::db::community::Community;
-use crate::db::establish_connection;
 use crate::Settings;
-use actix_web::body::Body;
 use actix_web::web;
 use actix_web::web::Query;
 use actix_web::HttpResponse;
+use diesel::r2d2::{ConnectionManager, Pool};
+use diesel::PgConnection;
 use regex::Regex;
 use serde::Deserialize;
 use serde_json::json;
@@ -37,54 +37,61 @@ lazy_static! {
 ///
 /// You can also view the webfinger response that Mastodon sends:
 /// https://radical.town/.well-known/webfinger?resource=acct:felix@radical.town
-async fn get_webfinger_response(info: Query<Params>) -> HttpResponse<Body> {
-  let regex_parsed = WEBFINGER_COMMUNITY_REGEX
-    .captures(&info.resource)
-    .map(|c| c.get(1));
-  // TODO: replace this with .flatten() once we are running rust 1.40
-  let regex_parsed_flattened = match regex_parsed {
-    Some(s) => s,
-    None => None,
-  };
-  let community_name = match regex_parsed_flattened {
-    Some(c) => c.as_str(),
-    None => return HttpResponse::NotFound().finish(),
-  };
+async fn get_webfinger_response(
+  info: Query<Params>,
+  db: web::Data<Pool<ConnectionManager<PgConnection>>>,
+) -> Result<HttpResponse, actix_web::Error> {
+  let res = web::block(move || {
+    let conn = db.get()?;
 
-  // Make sure the requested community exists.
-  let conn = establish_connection();
-  let community = match Community::read_from_name(&conn, community_name.to_string()) {
-    Ok(o) => o,
-    Err(_) => return HttpResponse::NotFound().finish(),
-  };
+    let regex_parsed = WEBFINGER_COMMUNITY_REGEX
+      .captures(&info.resource)
+      .map(|c| c.get(1));
+    // TODO: replace this with .flatten() once we are running rust 1.40
+    let regex_parsed_flattened = match regex_parsed {
+      Some(s) => s,
+      None => None,
+    };
+    let community_name = match regex_parsed_flattened {
+      Some(c) => c.as_str(),
+      None => return Err(format_err!("not_found")),
+    };
 
-  let community_url = community.get_url();
+    // Make sure the requested community exists.
+    let community = match Community::read_from_name(&conn, community_name.to_string()) {
+      Ok(o) => o,
+      Err(_) => return Err(format_err!("not_found")),
+    };
 
-  let json = json!({
+    let community_url = community.get_url();
+
+    Ok(json!({
     "subject": info.resource,
     "aliases": [
       community_url,
     ],
     "links": [
-      {
-        "rel": "http://webfinger.net/rel/profile-page",
-        "type": "text/html",
-        "href": community_url
-      },
-      {
-        "rel": "self",
-        "type": "application/activity+json",
-        // Yes this is correct, this link doesn't include the `.json` extension
-        "href": community_url
-      }
-      // TODO: this also needs to return the subscribe link once that's implemented
-      //{
-      //  "rel": "http://ostatus.org/schema/1.0/subscribe",
-      //  "template": "https://my_instance.com/authorize_interaction?uri={uri}"
-      //}
+    {
+      "rel": "http://webfinger.net/rel/profile-page",
+      "type": "text/html",
+      "href": community_url
+    },
+    {
+      "rel": "self",
+      "type": "application/activity+json",
+      // Yes this is correct, this link doesn't include the `.json` extension
+      "href": community_url
+    }
+    // TODO: this also needs to return the subscribe link once that's implemented
+    //{
+    //  "rel": "http://ostatus.org/schema/1.0/subscribe",
+    //  "template": "https://my_instance.com/authorize_interaction?uri={uri}"
+    //}
     ]
-  });
-  HttpResponse::Ok()
-    .content_type("application/activity+json")
-    .body(json.to_string())
+    }))
+  })
+  .await
+  .map(|json| HttpResponse::Ok().json(json))
+  .map_err(|_| HttpResponse::InternalServerError())?;
+  Ok(res)
 }
index 8113a613eb583cc48bad08006234e20455349f83..0d953d24313ae2115e0172aa29137a1841fbf97b 100644 (file)
@@ -1,13 +1,24 @@
 use crate::websocket::server::*;
+use crate::Settings;
 use actix::prelude::*;
 use actix_web::web;
 use actix_web::*;
 use actix_web_actors::ws;
+use diesel::r2d2::{ConnectionManager, Pool};
+use diesel::PgConnection;
 use std::time::{Duration, Instant};
 
 pub fn config(cfg: &mut web::ServiceConfig) {
+  // TODO couldn't figure out how to get this method to recieve the other pool
+  let settings = Settings::get();
+  let manager = ConnectionManager::<PgConnection>::new(&settings.get_database_url());
+  let pool = Pool::builder()
+    .max_size(settings.database.pool_size)
+    .build(manager)
+    .unwrap_or_else(|_| panic!("Error connecting to {}", settings.get_database_url()));
+
   // Start chat server actor in separate thread
-  let server = ChatServer::default().start();
+  let server = ChatServer::startup(pool).start();
   cfg
     .data(server)
     .service(web::resource("/api/v1/ws").to(chat_route));
@@ -24,9 +35,11 @@ async fn chat_route(
   stream: web::Payload,
   chat_server: web::Data<Addr<ChatServer>>,
 ) -> Result<HttpResponse, Error> {
+  // TODO not sure if the blocking should be here or not
   ws::start(
     WSSession {
-      cs_addr: chat_server.get_ref().to_owned(),
+      // db: db.get_ref().clone(),
+      cs_addr: chat_server.get_ref().clone(),
       id: 0,
       hb: Instant::now(),
       ip: req
@@ -51,6 +64,7 @@ struct WSSession {
   /// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
   /// otherwise we drop connection.
   hb: Instant,
+  // db: Pool<ConnectionManager<PgConnection>>,
 }
 
 impl Actor for WSSession {
@@ -127,7 +141,7 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WSSession {
       }
       ws::Message::Text(text) => {
         let m = text.trim().to_owned();
-        println!("WEBSOCKET MESSAGE: {:?} from id: {}", &m, self.id);
+        // println!("WEBSOCKET MESSAGE: {:?} from id: {}", &m, self.id);
 
         self
           .cs_addr
index b9dad9a4b246879da405ed0c1c2d85784b0e9e2a..006169dfb9e47ad6c0ea4f926a7a78058cc9c58d 100644 (file)
@@ -3,6 +3,8 @@
 //! room through `ChatServer`.
 
 use actix::prelude::*;
+use diesel::r2d2::{ConnectionManager, Pool};
+use diesel::PgConnection;
 use failure::Error;
 use rand::{rngs::ThreadRng, Rng};
 use serde::{Deserialize, Serialize};
@@ -42,6 +44,7 @@ pub struct Disconnect {
   pub ip: String,
 }
 
+// TODO this is unused rn
 /// Send message to specific room
 #[derive(Message)]
 #[rtype(result = "()")]
@@ -81,10 +84,26 @@ pub struct ChatServer {
   rate_limits: HashMap<String, RateLimitBucket>,
   rooms: HashMap<i32, HashSet<usize>>, // A map from room / post name to set of connectionIDs
   rng: ThreadRng,
+  db: Pool<ConnectionManager<PgConnection>>,
 }
 
-impl Default for ChatServer {
-  fn default() -> ChatServer {
+// impl Default for ChatServer {
+//   fn default(nah: String) -> ChatServer {
+//     // default room
+//     let rooms = HashMap::new();
+
+//     ChatServer {
+//       sessions: HashMap::new(),
+//       rate_limits: HashMap::new(),
+//       rooms,
+//       rng: rand::thread_rng(),
+//       nah: nah,
+//     }
+//   }
+// }
+
+impl ChatServer {
+  pub fn startup(db: Pool<ConnectionManager<PgConnection>>) -> ChatServer {
     // default room
     let rooms = HashMap::new();
 
@@ -93,11 +112,10 @@ impl Default for ChatServer {
       rate_limits: HashMap::new(),
       rooms,
       rng: rand::thread_rng(),
+      db,
     }
   }
-}
 
-impl ChatServer {
   /// Send message to all users in the room
   fn send_room_message(&self, room: i32, message: &str, skip_id: usize) {
     if let Some(sessions) = self.rooms.get(&room) {
@@ -133,7 +151,8 @@ impl ChatServer {
   ) -> Result<(), Error> {
     use crate::db::post_view::*;
     use crate::db::*;
-    let conn = establish_connection();
+
+    let conn = self.db.get()?;
 
     let posts = PostQueryBuilder::create(&conn)
       .listing_type(ListingType::Community)
@@ -299,17 +318,19 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
     message: "Unknown op type".to_string(),
   })?;
 
+  let conn = chat.db.get()?;
+
   let user_operation: UserOperation = UserOperation::from_str(&op)?;
 
   match user_operation {
     UserOperation::Login => {
       let login: Login = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, login).perform()?;
+      let res = Oper::new(user_operation, login).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::Register => {
       let register: Register = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, register).perform();
+      let res = Oper::new(user_operation, register).perform(&conn);
       if res.is_ok() {
         chat.check_rate_limit_register(msg.id)?;
       }
@@ -317,63 +338,63 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
     }
     UserOperation::GetUserDetails => {
       let get_user_details: GetUserDetails = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, get_user_details).perform()?;
+      let res = Oper::new(user_operation, get_user_details).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::SaveUserSettings => {
       let save_user_settings: SaveUserSettings = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, save_user_settings).perform()?;
+      let res = Oper::new(user_operation, save_user_settings).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::AddAdmin => {
       let add_admin: AddAdmin = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, add_admin).perform()?;
+      let res = Oper::new(user_operation, add_admin).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::BanUser => {
       let ban_user: BanUser = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, ban_user).perform()?;
+      let res = Oper::new(user_operation, ban_user).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::GetReplies => {
       let get_replies: GetReplies = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, get_replies).perform()?;
+      let res = Oper::new(user_operation, get_replies).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::GetUserMentions => {
       let get_user_mentions: GetUserMentions = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, get_user_mentions).perform()?;
+      let res = Oper::new(user_operation, get_user_mentions).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::EditUserMention => {
       let edit_user_mention: EditUserMention = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, edit_user_mention).perform()?;
+      let res = Oper::new(user_operation, edit_user_mention).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::MarkAllAsRead => {
       let mark_all_as_read: MarkAllAsRead = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, mark_all_as_read).perform()?;
+      let res = Oper::new(user_operation, mark_all_as_read).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::GetCommunity => {
       let get_community: GetCommunity = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, get_community).perform()?;
+      let res = Oper::new(user_operation, get_community).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::ListCommunities => {
       let list_communities: ListCommunities = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, list_communities).perform()?;
+      let res = Oper::new(user_operation, list_communities).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::CreateCommunity => {
       chat.check_rate_limit_register(msg.id)?;
       let create_community: CreateCommunity = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, create_community).perform()?;
+      let res = Oper::new(user_operation, create_community).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::EditCommunity => {
       let edit_community: EditCommunity = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, edit_community).perform()?;
+      let res = Oper::new(user_operation, edit_community).perform(&conn)?;
       let mut community_sent: CommunityResponse = res.clone();
       community_sent.community.user_id = None;
       community_sent.community.subscribed = None;
@@ -383,18 +404,18 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
     }
     UserOperation::FollowCommunity => {
       let follow_community: FollowCommunity = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, follow_community).perform()?;
+      let res = Oper::new(user_operation, follow_community).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::GetFollowedCommunities => {
       let followed_communities: GetFollowedCommunities = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, followed_communities).perform()?;
+      let res = Oper::new(user_operation, followed_communities).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::BanFromCommunity => {
       let ban_from_community: BanFromCommunity = serde_json::from_str(data)?;
       let community_id = ban_from_community.community_id;
-      let res = Oper::new(user_operation, ban_from_community).perform()?;
+      let res = Oper::new(user_operation, ban_from_community).perform(&conn)?;
       let res_str = serde_json::to_string(&res)?;
       chat.send_community_message(community_id, &res_str, msg.id)?;
       Ok(res_str)
@@ -402,42 +423,42 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
     UserOperation::AddModToCommunity => {
       let mod_add_to_community: AddModToCommunity = serde_json::from_str(data)?;
       let community_id = mod_add_to_community.community_id;
-      let res = Oper::new(user_operation, mod_add_to_community).perform()?;
+      let res = Oper::new(user_operation, mod_add_to_community).perform(&conn)?;
       let res_str = serde_json::to_string(&res)?;
       chat.send_community_message(community_id, &res_str, msg.id)?;
       Ok(res_str)
     }
     UserOperation::ListCategories => {
       let list_categories: ListCategories = ListCategories;
-      let res = Oper::new(user_operation, list_categories).perform()?;
+      let res = Oper::new(user_operation, list_categories).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::CreatePost => {
       chat.check_rate_limit_post(msg.id)?;
       let create_post: CreatePost = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, create_post).perform()?;
+      let res = Oper::new(user_operation, create_post).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::GetPost => {
       let get_post: GetPost = serde_json::from_str(data)?;
       chat.join_room(get_post.id, msg.id);
-      let res = Oper::new(user_operation, get_post).perform()?;
+      let res = Oper::new(user_operation, get_post).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::GetPosts => {
       let get_posts: GetPosts = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, get_posts).perform()?;
+      let res = Oper::new(user_operation, get_posts).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::CreatePostLike => {
       chat.check_rate_limit_message(msg.id)?;
       let create_post_like: CreatePostLike = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, create_post_like).perform()?;
+      let res = Oper::new(user_operation, create_post_like).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::EditPost => {
       let edit_post: EditPost = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, edit_post).perform()?;
+      let res = Oper::new(user_operation, edit_post).perform(&conn)?;
       let mut post_sent = res.clone();
       post_sent.post.my_vote = None;
       let post_sent_str = serde_json::to_string(&post_sent)?;
@@ -446,14 +467,14 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
     }
     UserOperation::SavePost => {
       let save_post: SavePost = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, save_post).perform()?;
+      let res = Oper::new(user_operation, save_post).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::CreateComment => {
       chat.check_rate_limit_message(msg.id)?;
       let create_comment: CreateComment = serde_json::from_str(data)?;
       let post_id = create_comment.post_id;
-      let res = Oper::new(user_operation, create_comment).perform()?;
+      let res = Oper::new(user_operation, create_comment).perform(&conn)?;
       let mut comment_sent = res.clone();
       comment_sent.comment.my_vote = None;
       comment_sent.comment.user_id = None;
@@ -464,7 +485,7 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
     UserOperation::EditComment => {
       let edit_comment: EditComment = serde_json::from_str(data)?;
       let post_id = edit_comment.post_id;
-      let res = Oper::new(user_operation, edit_comment).perform()?;
+      let res = Oper::new(user_operation, edit_comment).perform(&conn)?;
       let mut comment_sent = res.clone();
       comment_sent.comment.my_vote = None;
       comment_sent.comment.user_id = None;
@@ -474,14 +495,14 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
     }
     UserOperation::SaveComment => {
       let save_comment: SaveComment = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, save_comment).perform()?;
+      let res = Oper::new(user_operation, save_comment).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::CreateCommentLike => {
       chat.check_rate_limit_message(msg.id)?;
       let create_comment_like: CreateCommentLike = serde_json::from_str(data)?;
       let post_id = create_comment_like.post_id;
-      let res = Oper::new(user_operation, create_comment_like).perform()?;
+      let res = Oper::new(user_operation, create_comment_like).perform(&conn)?;
       let mut comment_sent = res.clone();
       comment_sent.comment.my_vote = None;
       comment_sent.comment.user_id = None;
@@ -491,54 +512,54 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
     }
     UserOperation::GetModlog => {
       let get_modlog: GetModlog = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, get_modlog).perform()?;
+      let res = Oper::new(user_operation, get_modlog).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::CreateSite => {
       let create_site: CreateSite = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, create_site).perform()?;
+      let res = Oper::new(user_operation, create_site).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::EditSite => {
       let edit_site: EditSite = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, edit_site).perform()?;
+      let res = Oper::new(user_operation, edit_site).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::GetSite => {
       let online: usize = chat.sessions.len();
       let get_site: GetSite = serde_json::from_str(data)?;
-      let mut res = Oper::new(user_operation, get_site).perform()?;
+      let mut res = Oper::new(user_operation, get_site).perform(&conn)?;
       res.online = online;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::Search => {
       let search: Search = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, search).perform()?;
+      let res = Oper::new(user_operation, search).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::TransferCommunity => {
       let transfer_community: TransferCommunity = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, transfer_community).perform()?;
+      let res = Oper::new(user_operation, transfer_community).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::TransferSite => {
       let transfer_site: TransferSite = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, transfer_site).perform()?;
+      let res = Oper::new(user_operation, transfer_site).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::DeleteAccount => {
       let delete_account: DeleteAccount = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, delete_account).perform()?;
+      let res = Oper::new(user_operation, delete_account).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::PasswordReset => {
       let password_reset: PasswordReset = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, password_reset).perform()?;
+      let res = Oper::new(user_operation, password_reset).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
     UserOperation::PasswordChange => {
       let password_change: PasswordChange = serde_json::from_str(data)?;
-      let res = Oper::new(user_operation, password_change).perform()?;
+      let res = Oper::new(user_operation, password_change).perform(&conn)?;
       Ok(serde_json::to_string(&res)?)
     }
   }