X-Git-Url: http://these/git/?a=blobdiff_plain;f=crates%2Fdb_schema%2Fsrc%2Futils.rs;h=dc26bedfcde44e6181dd9fdf32bf494a719c1115;hb=9a5a13c734a1792511e1bfef7b9ac4121e0e7371;hp=cd2005ad072f11bc13517921980ea62ab82a723d;hpb=c890797b370417b9479186acd86bc065693f8691;p=lemmy.git diff --git a/crates/db_schema/src/utils.rs b/crates/db_schema/src/utils.rs index cd2005ad..dc26bedf 100644 --- a/crates/db_schema/src/utils.rs +++ b/crates/db_schema/src/utils.rs @@ -2,6 +2,7 @@ use crate::{ diesel::Connection, diesel_migrations::MigrationHarness, newtypes::DbUrl, + traits::JoinView, CommentSortType, PersonSortType, SortType, @@ -26,7 +27,7 @@ use diesel_async::{ }, }; use diesel_migrations::EmbeddedMigrations; -use futures_util::{future::BoxFuture, FutureExt}; +use futures_util::{future::BoxFuture, Future, FutureExt}; use lemmy_utils::{ error::{LemmyError, LemmyErrorExt, LemmyErrorType}, settings::structs::Settings, @@ -197,12 +198,12 @@ pub fn is_email_regex(test: &str) -> bool { EMAIL_REGEX.is_match(test) } -pub fn diesel_option_overwrite(opt: &Option) -> Option> { +pub fn diesel_option_overwrite(opt: Option) -> Option> { match opt { // An empty string is an erase Some(unwrapped) => { if !unwrapped.eq("") { - Some(Some(unwrapped.clone())) + Some(Some(unwrapped)) } else { Some(None) } @@ -420,6 +421,94 @@ where } } +pub type ResultFuture<'a, T> = BoxFuture<'a, Result>; + +pub trait ReadFn<'a, T: JoinView, Args>: + Fn(DbConn<'a>, Args) -> ResultFuture<'a, ::JoinTuple> +{ +} + +impl< + 'a, + T: JoinView, + Args, + F: Fn(DbConn<'a>, Args) -> ResultFuture<'a, ::JoinTuple>, + > ReadFn<'a, T, Args> for F +{ +} + +pub trait ListFn<'a, T: JoinView, Args>: + Fn(DbConn<'a>, Args) -> ResultFuture<'a, Vec<::JoinTuple>> +{ +} + +impl< + 'a, + T: JoinView, + Args, + F: Fn(DbConn<'a>, Args) -> ResultFuture<'a, Vec<::JoinTuple>>, + > ListFn<'a, T, Args> for F +{ +} + +/// Allows read and list functions to capture a shared closure that has an inferred return type, which is useful for join logic +pub struct Queries { + pub read_fn: RF, + pub list_fn: LF, +} + +// `()` is used to prevent type inference error +impl Queries<(), ()> { + pub fn new<'a, RFut, LFut, RT, LT, RA, LA, RF2, LF2>( + read_fn: RF2, + list_fn: LF2, + ) -> Queries, impl ListFn<'a, LT, LA>> + where + RFut: Future::JoinTuple, DieselError>> + Sized + Send + 'a, + LFut: + Future::JoinTuple>, DieselError>> + Sized + Send + 'a, + RT: JoinView, + LT: JoinView, + RF2: Fn(DbConn<'a>, RA) -> RFut, + LF2: Fn(DbConn<'a>, LA) -> LFut, + { + Queries { + read_fn: move |conn, args| read_fn(conn, args).boxed(), + list_fn: move |conn, args| list_fn(conn, args).boxed(), + } + } +} + +impl Queries { + pub async fn read<'a, T, Args>( + self, + pool: &'a mut DbPool<'_>, + args: Args, + ) -> Result + where + T: JoinView, + RF: ReadFn<'a, T, Args>, + { + let conn = get_conn(pool).await?; + let res = (self.read_fn)(conn, args).await?; + Ok(T::from_tuple(res)) + } + + pub async fn list<'a, T, Args>( + self, + pool: &'a mut DbPool<'_>, + args: Args, + ) -> Result, DieselError> + where + T: JoinView, + LF: ListFn<'a, T, Args>, + { + let conn = get_conn(pool).await?; + let res = (self.list_fn)(conn, args).await?; + Ok(res.into_iter().map(T::from_tuple).collect()) + } +} + #[cfg(test)] mod tests { #![allow(clippy::unwrap_used)] @@ -445,10 +534,10 @@ mod tests { #[test] fn test_diesel_option_overwrite() { - assert_eq!(diesel_option_overwrite(&None), None); - assert_eq!(diesel_option_overwrite(&Some(String::new())), Some(None)); + assert_eq!(diesel_option_overwrite(None), None); + assert_eq!(diesel_option_overwrite(Some(String::new())), Some(None)); assert_eq!( - diesel_option_overwrite(&Some("test".to_string())), + diesel_option_overwrite(Some("test".to_string())), Some(Some("test".to_string())) ); }