]> Untitled Git - lemmy.git/blobdiff - crates/db_schema/src/utils.rs
Use same table join code for both read and list functions (#3663)
[lemmy.git] / crates / db_schema / src / utils.rs
index cd2005ad072f11bc13517921980ea62ab82a723d..dc26bedfcde44e6181dd9fdf32bf494a719c1115 100644 (file)
@@ -2,6 +2,7 @@ use crate::{
   diesel::Connection,
   diesel_migrations::MigrationHarness,
   newtypes::DbUrl,
+  traits::JoinView,
   CommentSortType,
   PersonSortType,
   SortType,
@@ -26,7 +27,7 @@ use diesel_async::{
   },
 };
 use diesel_migrations::EmbeddedMigrations;
-use futures_util::{future::BoxFuture, FutureExt};
+use futures_util::{future::BoxFuture, Future, FutureExt};
 use lemmy_utils::{
   error::{LemmyError, LemmyErrorExt, LemmyErrorType},
   settings::structs::Settings,
@@ -197,12 +198,12 @@ pub fn is_email_regex(test: &str) -> bool {
   EMAIL_REGEX.is_match(test)
 }
 
-pub fn diesel_option_overwrite(opt: &Option<String>) -> Option<Option<String>> {
+pub fn diesel_option_overwrite(opt: Option<String>) -> Option<Option<String>> {
   match opt {
     // An empty string is an erase
     Some(unwrapped) => {
       if !unwrapped.eq("") {
-        Some(Some(unwrapped.clone()))
+        Some(Some(unwrapped))
       } else {
         Some(None)
       }
@@ -420,6 +421,94 @@ where
   }
 }
 
+pub type ResultFuture<'a, T> = BoxFuture<'a, Result<T, DieselError>>;
+
+pub trait ReadFn<'a, T: JoinView, Args>:
+  Fn(DbConn<'a>, Args) -> ResultFuture<'a, <T as JoinView>::JoinTuple>
+{
+}
+
+impl<
+    'a,
+    T: JoinView,
+    Args,
+    F: Fn(DbConn<'a>, Args) -> ResultFuture<'a, <T as JoinView>::JoinTuple>,
+  > ReadFn<'a, T, Args> for F
+{
+}
+
+pub trait ListFn<'a, T: JoinView, Args>:
+  Fn(DbConn<'a>, Args) -> ResultFuture<'a, Vec<<T as JoinView>::JoinTuple>>
+{
+}
+
+impl<
+    'a,
+    T: JoinView,
+    Args,
+    F: Fn(DbConn<'a>, Args) -> ResultFuture<'a, Vec<<T as JoinView>::JoinTuple>>,
+  > ListFn<'a, T, Args> for F
+{
+}
+
+/// Allows read and list functions to capture a shared closure that has an inferred return type, which is useful for join logic
+pub struct Queries<RF, LF> {
+  pub read_fn: RF,
+  pub list_fn: LF,
+}
+
+// `()` is used to prevent type inference error
+impl Queries<(), ()> {
+  pub fn new<'a, RFut, LFut, RT, LT, RA, LA, RF2, LF2>(
+    read_fn: RF2,
+    list_fn: LF2,
+  ) -> Queries<impl ReadFn<'a, RT, RA>, impl ListFn<'a, LT, LA>>
+  where
+    RFut: Future<Output = Result<<RT as JoinView>::JoinTuple, DieselError>> + Sized + Send + 'a,
+    LFut:
+      Future<Output = Result<Vec<<LT as JoinView>::JoinTuple>, DieselError>> + Sized + Send + 'a,
+    RT: JoinView,
+    LT: JoinView,
+    RF2: Fn(DbConn<'a>, RA) -> RFut,
+    LF2: Fn(DbConn<'a>, LA) -> LFut,
+  {
+    Queries {
+      read_fn: move |conn, args| read_fn(conn, args).boxed(),
+      list_fn: move |conn, args| list_fn(conn, args).boxed(),
+    }
+  }
+}
+
+impl<RF, LF> Queries<RF, LF> {
+  pub async fn read<'a, T, Args>(
+    self,
+    pool: &'a mut DbPool<'_>,
+    args: Args,
+  ) -> Result<T, DieselError>
+  where
+    T: JoinView,
+    RF: ReadFn<'a, T, Args>,
+  {
+    let conn = get_conn(pool).await?;
+    let res = (self.read_fn)(conn, args).await?;
+    Ok(T::from_tuple(res))
+  }
+
+  pub async fn list<'a, T, Args>(
+    self,
+    pool: &'a mut DbPool<'_>,
+    args: Args,
+  ) -> Result<Vec<T>, DieselError>
+  where
+    T: JoinView,
+    LF: ListFn<'a, T, Args>,
+  {
+    let conn = get_conn(pool).await?;
+    let res = (self.list_fn)(conn, args).await?;
+    Ok(res.into_iter().map(T::from_tuple).collect())
+  }
+}
+
 #[cfg(test)]
 mod tests {
   #![allow(clippy::unwrap_used)]
@@ -445,10 +534,10 @@ mod tests {
 
   #[test]
   fn test_diesel_option_overwrite() {
-    assert_eq!(diesel_option_overwrite(&None), None);
-    assert_eq!(diesel_option_overwrite(&Some(String::new())), Some(None));
+    assert_eq!(diesel_option_overwrite(None), None);
+    assert_eq!(diesel_option_overwrite(Some(String::new())), Some(None));
     assert_eq!(
-      diesel_option_overwrite(&Some("test".to_string())),
+      diesel_option_overwrite(Some("test".to_string())),
       Some(Some("test".to_string()))
     );
   }