]> Untitled Git - lemmy.git/blobdiff - crates/db_schema/src/impls/federation_allowlist.rs
Make functions work with both connection and pool (#3420)
[lemmy.git] / crates / db_schema / src / impls / federation_allowlist.rs
index 406b0e2ca866732ce33e8119b40def6a9d5b1269..d4aed484655b26a6c03c058b5e13b90f3b06ad90 100644 (file)
@@ -4,78 +4,91 @@ use crate::{
     federation_allowlist::{FederationAllowList, FederationAllowListForm},
     instance::Instance,
   },
+  utils::{get_conn, DbPool},
 };
-use diesel::{dsl::*, result::Error, *};
+use diesel::{dsl::insert_into, result::Error};
+use diesel_async::{AsyncPgConnection, RunQueryDsl};
 
 impl FederationAllowList {
-  pub fn replace(conn: &mut PgConnection, list_opt: Option<Vec<String>>) -> Result<(), Error> {
-    conn.build_transaction().read_write().run(|conn| {
-      if let Some(list) = list_opt {
-        Self::clear(conn)?;
+  pub async fn replace(pool: &mut DbPool<'_>, list_opt: Option<Vec<String>>) -> Result<(), Error> {
+    let conn = &mut get_conn(pool).await?;
+    conn
+      .build_transaction()
+      .run(|conn| {
+        Box::pin(async move {
+          if let Some(list) = list_opt {
+            Self::clear(conn).await?;
 
-        for domain in list {
-          // Upsert all of these as instances
-          let instance = Instance::create(conn, &domain)?;
+            for domain in list {
+              // Upsert all of these as instances
+              let instance = Instance::read_or_create(&mut conn.into(), domain).await?;
 
-          let form = FederationAllowListForm {
-            instance_id: instance.id,
-            updated: None,
-          };
-          insert_into(federation_allowlist::table)
-            .values(form)
-            .get_result::<Self>(conn)?;
-        }
-        Ok(())
-      } else {
-        Ok(())
-      }
-    })
+              let form = FederationAllowListForm {
+                instance_id: instance.id,
+                updated: None,
+              };
+              insert_into(federation_allowlist::table)
+                .values(form)
+                .get_result::<Self>(conn)
+                .await?;
+            }
+            Ok(())
+          } else {
+            Ok(())
+          }
+        }) as _
+      })
+      .await
   }
 
-  pub fn clear(conn: &mut PgConnection) -> Result<usize, Error> {
-    diesel::delete(federation_allowlist::table).execute(conn)
+  async fn clear(conn: &mut AsyncPgConnection) -> Result<usize, Error> {
+    diesel::delete(federation_allowlist::table)
+      .execute(conn)
+      .await
   }
 }
 #[cfg(test)]
 mod tests {
   use crate::{
     source::{federation_allowlist::FederationAllowList, instance::Instance},
-    utils::establish_unpooled_connection,
+    utils::build_db_pool_for_tests,
   };
   use serial_test::serial;
 
-  #[test]
+  #[tokio::test]
   #[serial]
-  fn test_allowlist_insert_and_clear() {
-    let conn = &mut establish_unpooled_connection();
-    let allowed = Some(vec![
+  async fn test_allowlist_insert_and_clear() {
+    let pool = &build_db_pool_for_tests().await;
+    let pool = &mut pool.into();
+    let domains = vec![
       "tld1.xyz".to_string(),
       "tld2.xyz".to_string(),
       "tld3.xyz".to_string(),
-    ]);
+    ];
 
-    FederationAllowList::replace(conn, allowed).unwrap();
+    let allowed = Some(domains.clone());
 
-    let allows = Instance::allowlist(conn).unwrap();
+    FederationAllowList::replace(pool, allowed).await.unwrap();
+
+    let allows = Instance::allowlist(pool).await.unwrap();
+    let allows_domains = allows
+      .iter()
+      .map(|i| i.domain.clone())
+      .collect::<Vec<String>>();
 
     assert_eq!(3, allows.len());
-    assert_eq!(
-      vec![
-        "tld1.xyz".to_string(),
-        "tld2.xyz".to_string(),
-        "tld3.xyz".to_string()
-      ],
-      allows
-    );
+    assert_eq!(domains, allows_domains);
 
     // Now test clearing them via Some(empty vec)
     let clear_allows = Some(Vec::new());
 
-    FederationAllowList::replace(conn, clear_allows).unwrap();
-    let allows = Instance::allowlist(conn).unwrap();
+    FederationAllowList::replace(pool, clear_allows)
+      .await
+      .unwrap();
+    let allows = Instance::allowlist(pool).await.unwrap();
 
     assert_eq!(0, allows.len());
 
-    Instance::delete_all(conn).unwrap();
+    Instance::delete_all(pool).await.unwrap();
   }
 }