]> Untitled Git - lemmy.git/blob - crates/db_schema/src/utils.rs
Cache & Optimize Woodpecker CI (#3450)
[lemmy.git] / crates / db_schema / src / utils.rs
1 use crate::{
2   diesel::Connection,
3   diesel_migrations::MigrationHarness,
4   newtypes::DbUrl,
5   CommentSortType,
6   SortType,
7 };
8 use activitypub_federation::{fetch::object_id::ObjectId, traits::Object};
9 use chrono::NaiveDateTime;
10 use deadpool::Runtime;
11 use diesel::{
12   backend::Backend,
13   deserialize::FromSql,
14   pg::Pg,
15   result::{ConnectionError, ConnectionResult, Error as DieselError, Error::QueryBuilderError},
16   serialize::{Output, ToSql},
17   sql_types::Text,
18   PgConnection,
19 };
20 use diesel_async::{
21   pg::AsyncPgConnection,
22   pooled_connection::{
23     deadpool::{Object as PooledConnection, Pool},
24     AsyncDieselConnectionManager,
25   },
26 };
27 use diesel_migrations::EmbeddedMigrations;
28 use futures_util::{future::BoxFuture, FutureExt};
29 use lemmy_utils::{
30   error::{LemmyError, LemmyErrorExt, LemmyErrorType},
31   settings::structs::Settings,
32 };
33 use once_cell::sync::Lazy;
34 use regex::Regex;
35 use rustls::{
36   client::{ServerCertVerified, ServerCertVerifier},
37   ServerName,
38 };
39 use std::{
40   env,
41   env::VarError,
42   ops::{Deref, DerefMut},
43   sync::Arc,
44   time::{Duration, SystemTime},
45 };
46 use tracing::{error, info};
47 use url::Url;
48
49 const FETCH_LIMIT_DEFAULT: i64 = 10;
50 pub const FETCH_LIMIT_MAX: i64 = 50;
51 const POOL_TIMEOUT: Option<Duration> = Some(Duration::from_secs(5));
52
53 pub type ActualDbPool = Pool<AsyncPgConnection>;
54
55 /// References a pool or connection. Functions must take `&mut DbPool<'_>` to allow implicit reborrowing.
56 ///
57 /// https://github.com/rust-lang/rfcs/issues/1403
58 pub enum DbPool<'a> {
59   Pool(&'a ActualDbPool),
60   Conn(&'a mut AsyncPgConnection),
61 }
62
63 pub enum DbConn<'a> {
64   Pool(PooledConnection<AsyncPgConnection>),
65   Conn(&'a mut AsyncPgConnection),
66 }
67
68 pub async fn get_conn<'a, 'b: 'a>(pool: &'a mut DbPool<'b>) -> Result<DbConn<'a>, DieselError> {
69   Ok(match pool {
70     DbPool::Pool(pool) => DbConn::Pool(pool.get().await.map_err(|e| QueryBuilderError(e.into()))?),
71     DbPool::Conn(conn) => DbConn::Conn(conn),
72   })
73 }
74
75 impl<'a> Deref for DbConn<'a> {
76   type Target = AsyncPgConnection;
77
78   fn deref(&self) -> &Self::Target {
79     match self {
80       DbConn::Pool(conn) => conn.deref(),
81       DbConn::Conn(conn) => conn.deref(),
82     }
83   }
84 }
85
86 impl<'a> DerefMut for DbConn<'a> {
87   fn deref_mut(&mut self) -> &mut Self::Target {
88     match self {
89       DbConn::Pool(conn) => conn.deref_mut(),
90       DbConn::Conn(conn) => conn.deref_mut(),
91     }
92   }
93 }
94
95 // Allows functions that take `DbPool<'_>` to be called in a transaction by passing `&mut conn.into()`
96 impl<'a> From<&'a mut AsyncPgConnection> for DbPool<'a> {
97   fn from(value: &'a mut AsyncPgConnection) -> Self {
98     DbPool::Conn(value)
99   }
100 }
101
102 impl<'a, 'b: 'a> From<&'a mut DbConn<'b>> for DbPool<'a> {
103   fn from(value: &'a mut DbConn<'b>) -> Self {
104     DbPool::Conn(value.deref_mut())
105   }
106 }
107
108 impl<'a> From<&'a ActualDbPool> for DbPool<'a> {
109   fn from(value: &'a ActualDbPool) -> Self {
110     DbPool::Pool(value)
111   }
112 }
113
114 /// Runs multiple async functions that take `&mut DbPool<'_>` as input and return `Result`. Only works when the  `futures` crate is listed in `Cargo.toml`.
115 ///
116 /// `$pool` is the value given to each function.
117 ///
118 /// A `Result` is returned (not in a `Future`, so don't use `.await`). The `Ok` variant contains a tuple with the values returned by the given functions.
119 ///
120 /// The functions run concurrently if `$pool` has the `DbPool::Pool` variant.
121 #[macro_export]
122 macro_rules! try_join_with_pool {
123   ($pool:ident => ($($func:expr),+)) => {{
124     // Check type
125     let _: &mut $crate::utils::DbPool<'_> = $pool;
126
127     match $pool {
128       // Run concurrently with `try_join`
129       $crate::utils::DbPool::Pool(__pool) => ::futures::try_join!(
130         $(async {
131           let mut __dbpool = $crate::utils::DbPool::Pool(__pool);
132           ($func)(&mut __dbpool).await
133         }),+
134       ),
135       // Run sequentially
136       $crate::utils::DbPool::Conn(__conn) => async {
137         Ok(($({
138           let mut __dbpool = $crate::utils::DbPool::Conn(__conn);
139           // `?` prevents the error type from being inferred in an `async` block, so `match` is used instead
140           match ($func)(&mut __dbpool).await {
141             ::core::result::Result::Ok(__v) => __v,
142             ::core::result::Result::Err(__v) => return ::core::result::Result::Err(__v),
143           }
144         }),+))
145       }.await,
146     }
147   }};
148 }
149
150 pub fn get_database_url_from_env() -> Result<String, VarError> {
151   env::var("LEMMY_DATABASE_URL")
152 }
153
154 pub fn fuzzy_search(q: &str) -> String {
155   let replaced = q.replace('%', "\\%").replace('_', "\\_").replace(' ', "%");
156   format!("%{replaced}%")
157 }
158
159 pub fn limit_and_offset(
160   page: Option<i64>,
161   limit: Option<i64>,
162 ) -> Result<(i64, i64), diesel::result::Error> {
163   let page = match page {
164     Some(page) => {
165       if page < 1 {
166         return Err(QueryBuilderError("Page is < 1".into()));
167       } else {
168         page
169       }
170     }
171     None => 1,
172   };
173   let limit = match limit {
174     Some(limit) => {
175       if !(1..=FETCH_LIMIT_MAX).contains(&limit) {
176         return Err(QueryBuilderError(
177           format!("Fetch limit is > {FETCH_LIMIT_MAX}").into(),
178         ));
179       } else {
180         limit
181       }
182     }
183     None => FETCH_LIMIT_DEFAULT,
184   };
185   let offset = limit * (page - 1);
186   Ok((limit, offset))
187 }
188
189 pub fn limit_and_offset_unlimited(page: Option<i64>, limit: Option<i64>) -> (i64, i64) {
190   let limit = limit.unwrap_or(FETCH_LIMIT_DEFAULT);
191   let offset = limit * (page.unwrap_or(1) - 1);
192   (limit, offset)
193 }
194
195 pub fn is_email_regex(test: &str) -> bool {
196   EMAIL_REGEX.is_match(test)
197 }
198
199 pub fn diesel_option_overwrite(opt: &Option<String>) -> Option<Option<String>> {
200   match opt {
201     // An empty string is an erase
202     Some(unwrapped) => {
203       if !unwrapped.eq("") {
204         Some(Some(unwrapped.clone()))
205       } else {
206         Some(None)
207       }
208     }
209     None => None,
210   }
211 }
212
213 pub fn diesel_option_overwrite_to_url(
214   opt: &Option<String>,
215 ) -> Result<Option<Option<DbUrl>>, LemmyError> {
216   match opt.as_ref().map(String::as_str) {
217     // An empty string is an erase
218     Some("") => Ok(Some(None)),
219     Some(str_url) => Url::parse(str_url)
220       .map(|u| Some(Some(u.into())))
221       .with_lemmy_type(LemmyErrorType::InvalidUrl),
222     None => Ok(None),
223   }
224 }
225
226 pub fn diesel_option_overwrite_to_url_create(
227   opt: &Option<String>,
228 ) -> Result<Option<DbUrl>, LemmyError> {
229   match opt.as_ref().map(String::as_str) {
230     // An empty string is nothing
231     Some("") => Ok(None),
232     Some(str_url) => Url::parse(str_url)
233       .map(|u| Some(u.into()))
234       .with_lemmy_type(LemmyErrorType::InvalidUrl),
235     None => Ok(None),
236   }
237 }
238
239 async fn build_db_pool_settings_opt(
240   settings: Option<&Settings>,
241 ) -> Result<ActualDbPool, LemmyError> {
242   let db_url = get_database_url(settings);
243   let pool_size = settings.map(|s| s.database.pool_size).unwrap_or(5);
244   // We only support TLS with sslmode=require currently
245   let tls_enabled = db_url.contains("sslmode=require");
246   let manager = if tls_enabled {
247     // diesel-async does not support any TLS connections out of the box, so we need to manually
248     // provide a setup function which handles creating the connection
249     AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_setup(&db_url, establish_connection)
250   } else {
251     AsyncDieselConnectionManager::<AsyncPgConnection>::new(&db_url)
252   };
253   let pool = Pool::builder(manager)
254     .max_size(pool_size)
255     .wait_timeout(POOL_TIMEOUT)
256     .create_timeout(POOL_TIMEOUT)
257     .recycle_timeout(POOL_TIMEOUT)
258     .runtime(Runtime::Tokio1)
259     .build()?;
260
261   // If there's no settings, that means its a unit test, and migrations need to be run
262   if settings.is_none() {
263     run_migrations(&db_url);
264   }
265
266   Ok(pool)
267 }
268
269 fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
270   let fut = async {
271     let rustls_config = rustls::ClientConfig::builder()
272       .with_safe_defaults()
273       .with_custom_certificate_verifier(Arc::new(NoCertVerifier {}))
274       .with_no_client_auth();
275
276     let tls = tokio_postgres_rustls::MakeRustlsConnect::new(rustls_config);
277     let (client, conn) = tokio_postgres::connect(config, tls)
278       .await
279       .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
280     tokio::spawn(async move {
281       if let Err(e) = conn.await {
282         error!("Database connection failed: {e}");
283       }
284     });
285     AsyncPgConnection::try_from(client).await
286   };
287   fut.boxed()
288 }
289
290 struct NoCertVerifier {}
291
292 impl ServerCertVerifier for NoCertVerifier {
293   fn verify_server_cert(
294     &self,
295     _end_entity: &rustls::Certificate,
296     _intermediates: &[rustls::Certificate],
297     _server_name: &ServerName,
298     _scts: &mut dyn Iterator<Item = &[u8]>,
299     _ocsp_response: &[u8],
300     _now: SystemTime,
301   ) -> Result<ServerCertVerified, rustls::Error> {
302     // Will verify all (even invalid) certs without any checks (sslmode=require)
303     Ok(ServerCertVerified::assertion())
304   }
305 }
306
307 pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
308
309 pub fn run_migrations(db_url: &str) {
310   // Needs to be a sync connection
311   let mut conn =
312     PgConnection::establish(db_url).unwrap_or_else(|e| panic!("Error connecting to {db_url}: {e}"));
313   info!("Running Database migrations (This may take a long time)...");
314   let _ = &mut conn
315     .run_pending_migrations(MIGRATIONS)
316     .unwrap_or_else(|e| panic!("Couldn't run DB Migrations: {e}"));
317   info!("Database migrations complete.");
318 }
319
320 pub async fn build_db_pool(settings: &Settings) -> Result<ActualDbPool, LemmyError> {
321   build_db_pool_settings_opt(Some(settings)).await
322 }
323
324 pub async fn build_db_pool_for_tests() -> ActualDbPool {
325   build_db_pool_settings_opt(None)
326     .await
327     .expect("db pool missing")
328 }
329
330 pub fn get_database_url(settings: Option<&Settings>) -> String {
331   // The env var should override anything in the settings config
332   match get_database_url_from_env() {
333     Ok(url) => url,
334     Err(e) => match settings {
335       Some(settings) => settings.get_database_url(),
336       None => panic!("Failed to read database URL from env var LEMMY_DATABASE_URL: {e}"),
337     },
338   }
339 }
340
341 pub fn naive_now() -> NaiveDateTime {
342   chrono::prelude::Utc::now().naive_utc()
343 }
344
345 pub fn post_to_comment_sort_type(sort: SortType) -> CommentSortType {
346   match sort {
347     SortType::Active | SortType::Hot => CommentSortType::Hot,
348     SortType::New | SortType::NewComments | SortType::MostComments => CommentSortType::New,
349     SortType::Old => CommentSortType::Old,
350     SortType::TopHour
351     | SortType::TopSixHour
352     | SortType::TopTwelveHour
353     | SortType::TopDay
354     | SortType::TopAll
355     | SortType::TopWeek
356     | SortType::TopYear
357     | SortType::TopMonth
358     | SortType::TopThreeMonths
359     | SortType::TopSixMonths
360     | SortType::TopNineMonths => CommentSortType::Top,
361   }
362 }
363
364 static EMAIL_REGEX: Lazy<Regex> = Lazy::new(|| {
365   Regex::new(r"^[a-zA-Z0-9.!#$%&’*+/=?^_`{|}~-]+@[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*$")
366     .expect("compile email regex")
367 });
368
369 pub mod functions {
370   use diesel::sql_types::{BigInt, Text, Timestamp};
371
372   sql_function! {
373     fn hot_rank(score: BigInt, time: Timestamp) -> Integer;
374   }
375
376   sql_function!(fn lower(x: Text) -> Text);
377 }
378
379 pub const DELETED_REPLACEMENT_TEXT: &str = "*Permanently Deleted*";
380
381 impl ToSql<Text, Pg> for DbUrl {
382   fn to_sql(&self, out: &mut Output<Pg>) -> diesel::serialize::Result {
383     <std::string::String as ToSql<Text, Pg>>::to_sql(&self.0.to_string(), &mut out.reborrow())
384   }
385 }
386
387 impl<DB: Backend> FromSql<Text, DB> for DbUrl
388 where
389   String: FromSql<Text, DB>,
390 {
391   fn from_sql(value: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
392     let str = String::from_sql(value)?;
393     Ok(DbUrl(Box::new(Url::parse(&str)?)))
394   }
395 }
396
397 impl<Kind> From<ObjectId<Kind>> for DbUrl
398 where
399   Kind: Object + Send + 'static,
400   for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
401 {
402   fn from(id: ObjectId<Kind>) -> Self {
403     DbUrl(Box::new(id.into()))
404   }
405 }
406
407 #[cfg(test)]
408 mod tests {
409   #![allow(clippy::unwrap_used)]
410   #![allow(clippy::indexing_slicing)]
411
412   use super::{fuzzy_search, *};
413   use crate::utils::is_email_regex;
414
415   #[test]
416   fn test_fuzzy_search() {
417     let test = "This %is% _a_ fuzzy search";
418     assert_eq!(
419       fuzzy_search(test),
420       "%This%\\%is\\%%\\_a\\_%fuzzy%search%".to_string()
421     );
422   }
423
424   #[test]
425   fn test_email() {
426     assert!(is_email_regex("gush@gmail.com"));
427     assert!(!is_email_regex("nada_neutho"));
428   }
429
430   #[test]
431   fn test_diesel_option_overwrite() {
432     assert_eq!(diesel_option_overwrite(&None), None);
433     assert_eq!(diesel_option_overwrite(&Some(String::new())), Some(None));
434     assert_eq!(
435       diesel_option_overwrite(&Some("test".to_string())),
436       Some(Some("test".to_string()))
437     );
438   }
439
440   #[test]
441   fn test_diesel_option_overwrite_to_url() {
442     assert!(matches!(diesel_option_overwrite_to_url(&None), Ok(None)));
443     assert!(matches!(
444       diesel_option_overwrite_to_url(&Some(String::new())),
445       Ok(Some(None))
446     ));
447     assert!(diesel_option_overwrite_to_url(&Some("invalid_url".to_string())).is_err());
448     let example_url = "https://example.com";
449     assert!(matches!(
450       diesel_option_overwrite_to_url(&Some(example_url.to_string())),
451       Ok(Some(Some(url))) if url == Url::parse(example_url).unwrap().into()
452     ));
453   }
454 }