]> Untitled Git - lemmy.git/blob - crates/db_schema/src/utils.rs
Sanitize html (#3708)
[lemmy.git] / crates / db_schema / src / utils.rs
1 use crate::{
2   diesel::Connection,
3   diesel_migrations::MigrationHarness,
4   newtypes::DbUrl,
5   CommentSortType,
6   PersonSortType,
7   SortType,
8 };
9 use activitypub_federation::{fetch::object_id::ObjectId, traits::Object};
10 use chrono::NaiveDateTime;
11 use deadpool::Runtime;
12 use diesel::{
13   backend::Backend,
14   deserialize::FromSql,
15   pg::Pg,
16   result::{ConnectionError, ConnectionResult, Error as DieselError, Error::QueryBuilderError},
17   serialize::{Output, ToSql},
18   sql_types::Text,
19   PgConnection,
20 };
21 use diesel_async::{
22   pg::AsyncPgConnection,
23   pooled_connection::{
24     deadpool::{Object as PooledConnection, Pool},
25     AsyncDieselConnectionManager,
26   },
27 };
28 use diesel_migrations::EmbeddedMigrations;
29 use futures_util::{future::BoxFuture, FutureExt};
30 use lemmy_utils::{
31   error::{LemmyError, LemmyErrorExt, LemmyErrorType},
32   settings::structs::Settings,
33 };
34 use once_cell::sync::Lazy;
35 use regex::Regex;
36 use rustls::{
37   client::{ServerCertVerified, ServerCertVerifier},
38   ServerName,
39 };
40 use std::{
41   env,
42   env::VarError,
43   ops::{Deref, DerefMut},
44   sync::Arc,
45   time::{Duration, SystemTime},
46 };
47 use tracing::{error, info};
48 use url::Url;
49
50 const FETCH_LIMIT_DEFAULT: i64 = 10;
51 pub const FETCH_LIMIT_MAX: i64 = 50;
52 const POOL_TIMEOUT: Option<Duration> = Some(Duration::from_secs(5));
53
54 pub type ActualDbPool = Pool<AsyncPgConnection>;
55
56 /// References a pool or connection. Functions must take `&mut DbPool<'_>` to allow implicit reborrowing.
57 ///
58 /// https://github.com/rust-lang/rfcs/issues/1403
59 pub enum DbPool<'a> {
60   Pool(&'a ActualDbPool),
61   Conn(&'a mut AsyncPgConnection),
62 }
63
64 pub enum DbConn<'a> {
65   Pool(PooledConnection<AsyncPgConnection>),
66   Conn(&'a mut AsyncPgConnection),
67 }
68
69 pub async fn get_conn<'a, 'b: 'a>(pool: &'a mut DbPool<'b>) -> Result<DbConn<'a>, DieselError> {
70   Ok(match pool {
71     DbPool::Pool(pool) => DbConn::Pool(pool.get().await.map_err(|e| QueryBuilderError(e.into()))?),
72     DbPool::Conn(conn) => DbConn::Conn(conn),
73   })
74 }
75
76 impl<'a> Deref for DbConn<'a> {
77   type Target = AsyncPgConnection;
78
79   fn deref(&self) -> &Self::Target {
80     match self {
81       DbConn::Pool(conn) => conn.deref(),
82       DbConn::Conn(conn) => conn.deref(),
83     }
84   }
85 }
86
87 impl<'a> DerefMut for DbConn<'a> {
88   fn deref_mut(&mut self) -> &mut Self::Target {
89     match self {
90       DbConn::Pool(conn) => conn.deref_mut(),
91       DbConn::Conn(conn) => conn.deref_mut(),
92     }
93   }
94 }
95
96 // Allows functions that take `DbPool<'_>` to be called in a transaction by passing `&mut conn.into()`
97 impl<'a> From<&'a mut AsyncPgConnection> for DbPool<'a> {
98   fn from(value: &'a mut AsyncPgConnection) -> Self {
99     DbPool::Conn(value)
100   }
101 }
102
103 impl<'a, 'b: 'a> From<&'a mut DbConn<'b>> for DbPool<'a> {
104   fn from(value: &'a mut DbConn<'b>) -> Self {
105     DbPool::Conn(value.deref_mut())
106   }
107 }
108
109 impl<'a> From<&'a ActualDbPool> for DbPool<'a> {
110   fn from(value: &'a ActualDbPool) -> Self {
111     DbPool::Pool(value)
112   }
113 }
114
115 /// Runs multiple async functions that take `&mut DbPool<'_>` as input and return `Result`. Only works when the  `futures` crate is listed in `Cargo.toml`.
116 ///
117 /// `$pool` is the value given to each function.
118 ///
119 /// A `Result` is returned (not in a `Future`, so don't use `.await`). The `Ok` variant contains a tuple with the values returned by the given functions.
120 ///
121 /// The functions run concurrently if `$pool` has the `DbPool::Pool` variant.
122 #[macro_export]
123 macro_rules! try_join_with_pool {
124   ($pool:ident => ($($func:expr),+)) => {{
125     // Check type
126     let _: &mut $crate::utils::DbPool<'_> = $pool;
127
128     match $pool {
129       // Run concurrently with `try_join`
130       $crate::utils::DbPool::Pool(__pool) => ::futures::try_join!(
131         $(async {
132           let mut __dbpool = $crate::utils::DbPool::Pool(__pool);
133           ($func)(&mut __dbpool).await
134         }),+
135       ),
136       // Run sequentially
137       $crate::utils::DbPool::Conn(__conn) => async {
138         Ok(($({
139           let mut __dbpool = $crate::utils::DbPool::Conn(__conn);
140           // `?` prevents the error type from being inferred in an `async` block, so `match` is used instead
141           match ($func)(&mut __dbpool).await {
142             ::core::result::Result::Ok(__v) => __v,
143             ::core::result::Result::Err(__v) => return ::core::result::Result::Err(__v),
144           }
145         }),+))
146       }.await,
147     }
148   }};
149 }
150
151 pub fn get_database_url_from_env() -> Result<String, VarError> {
152   env::var("LEMMY_DATABASE_URL")
153 }
154
155 pub fn fuzzy_search(q: &str) -> String {
156   let replaced = q.replace('%', "\\%").replace('_', "\\_").replace(' ', "%");
157   format!("%{replaced}%")
158 }
159
160 pub fn limit_and_offset(
161   page: Option<i64>,
162   limit: Option<i64>,
163 ) -> Result<(i64, i64), diesel::result::Error> {
164   let page = match page {
165     Some(page) => {
166       if page < 1 {
167         return Err(QueryBuilderError("Page is < 1".into()));
168       } else {
169         page
170       }
171     }
172     None => 1,
173   };
174   let limit = match limit {
175     Some(limit) => {
176       if !(1..=FETCH_LIMIT_MAX).contains(&limit) {
177         return Err(QueryBuilderError(
178           format!("Fetch limit is > {FETCH_LIMIT_MAX}").into(),
179         ));
180       } else {
181         limit
182       }
183     }
184     None => FETCH_LIMIT_DEFAULT,
185   };
186   let offset = limit * (page - 1);
187   Ok((limit, offset))
188 }
189
190 pub fn limit_and_offset_unlimited(page: Option<i64>, limit: Option<i64>) -> (i64, i64) {
191   let limit = limit.unwrap_or(FETCH_LIMIT_DEFAULT);
192   let offset = limit * (page.unwrap_or(1) - 1);
193   (limit, offset)
194 }
195
196 pub fn is_email_regex(test: &str) -> bool {
197   EMAIL_REGEX.is_match(test)
198 }
199
200 pub fn diesel_option_overwrite(opt: Option<String>) -> Option<Option<String>> {
201   match opt {
202     // An empty string is an erase
203     Some(unwrapped) => {
204       if !unwrapped.eq("") {
205         Some(Some(unwrapped))
206       } else {
207         Some(None)
208       }
209     }
210     None => None,
211   }
212 }
213
214 pub fn diesel_option_overwrite_to_url(
215   opt: &Option<String>,
216 ) -> Result<Option<Option<DbUrl>>, LemmyError> {
217   match opt.as_ref().map(String::as_str) {
218     // An empty string is an erase
219     Some("") => Ok(Some(None)),
220     Some(str_url) => Url::parse(str_url)
221       .map(|u| Some(Some(u.into())))
222       .with_lemmy_type(LemmyErrorType::InvalidUrl),
223     None => Ok(None),
224   }
225 }
226
227 pub fn diesel_option_overwrite_to_url_create(
228   opt: &Option<String>,
229 ) -> Result<Option<DbUrl>, LemmyError> {
230   match opt.as_ref().map(String::as_str) {
231     // An empty string is nothing
232     Some("") => Ok(None),
233     Some(str_url) => Url::parse(str_url)
234       .map(|u| Some(u.into()))
235       .with_lemmy_type(LemmyErrorType::InvalidUrl),
236     None => Ok(None),
237   }
238 }
239
240 async fn build_db_pool_settings_opt(
241   settings: Option<&Settings>,
242 ) -> Result<ActualDbPool, LemmyError> {
243   let db_url = get_database_url(settings);
244   let pool_size = settings.map(|s| s.database.pool_size).unwrap_or(5);
245   // We only support TLS with sslmode=require currently
246   let tls_enabled = db_url.contains("sslmode=require");
247   let manager = if tls_enabled {
248     // diesel-async does not support any TLS connections out of the box, so we need to manually
249     // provide a setup function which handles creating the connection
250     AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_setup(&db_url, establish_connection)
251   } else {
252     AsyncDieselConnectionManager::<AsyncPgConnection>::new(&db_url)
253   };
254   let pool = Pool::builder(manager)
255     .max_size(pool_size)
256     .wait_timeout(POOL_TIMEOUT)
257     .create_timeout(POOL_TIMEOUT)
258     .recycle_timeout(POOL_TIMEOUT)
259     .runtime(Runtime::Tokio1)
260     .build()?;
261
262   // If there's no settings, that means its a unit test, and migrations need to be run
263   if settings.is_none() {
264     run_migrations(&db_url);
265   }
266
267   Ok(pool)
268 }
269
270 fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
271   let fut = async {
272     let rustls_config = rustls::ClientConfig::builder()
273       .with_safe_defaults()
274       .with_custom_certificate_verifier(Arc::new(NoCertVerifier {}))
275       .with_no_client_auth();
276
277     let tls = tokio_postgres_rustls::MakeRustlsConnect::new(rustls_config);
278     let (client, conn) = tokio_postgres::connect(config, tls)
279       .await
280       .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
281     tokio::spawn(async move {
282       if let Err(e) = conn.await {
283         error!("Database connection failed: {e}");
284       }
285     });
286     AsyncPgConnection::try_from(client).await
287   };
288   fut.boxed()
289 }
290
291 struct NoCertVerifier {}
292
293 impl ServerCertVerifier for NoCertVerifier {
294   fn verify_server_cert(
295     &self,
296     _end_entity: &rustls::Certificate,
297     _intermediates: &[rustls::Certificate],
298     _server_name: &ServerName,
299     _scts: &mut dyn Iterator<Item = &[u8]>,
300     _ocsp_response: &[u8],
301     _now: SystemTime,
302   ) -> Result<ServerCertVerified, rustls::Error> {
303     // Will verify all (even invalid) certs without any checks (sslmode=require)
304     Ok(ServerCertVerified::assertion())
305   }
306 }
307
308 pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
309
310 pub fn run_migrations(db_url: &str) {
311   // Needs to be a sync connection
312   let mut conn =
313     PgConnection::establish(db_url).unwrap_or_else(|e| panic!("Error connecting to {db_url}: {e}"));
314   info!("Running Database migrations (This may take a long time)...");
315   let _ = &mut conn
316     .run_pending_migrations(MIGRATIONS)
317     .unwrap_or_else(|e| panic!("Couldn't run DB Migrations: {e}"));
318   info!("Database migrations complete.");
319 }
320
321 pub async fn build_db_pool(settings: &Settings) -> Result<ActualDbPool, LemmyError> {
322   build_db_pool_settings_opt(Some(settings)).await
323 }
324
325 pub async fn build_db_pool_for_tests() -> ActualDbPool {
326   build_db_pool_settings_opt(None)
327     .await
328     .expect("db pool missing")
329 }
330
331 pub fn get_database_url(settings: Option<&Settings>) -> String {
332   // The env var should override anything in the settings config
333   match get_database_url_from_env() {
334     Ok(url) => url,
335     Err(e) => match settings {
336       Some(settings) => settings.get_database_url(),
337       None => panic!("Failed to read database URL from env var LEMMY_DATABASE_URL: {e}"),
338     },
339   }
340 }
341
342 pub fn naive_now() -> NaiveDateTime {
343   chrono::prelude::Utc::now().naive_utc()
344 }
345
346 pub fn post_to_comment_sort_type(sort: SortType) -> CommentSortType {
347   match sort {
348     SortType::Active | SortType::Hot => CommentSortType::Hot,
349     SortType::New | SortType::NewComments | SortType::MostComments => CommentSortType::New,
350     SortType::Old => CommentSortType::Old,
351     SortType::Controversial => CommentSortType::Controversial,
352     SortType::TopHour
353     | SortType::TopSixHour
354     | SortType::TopTwelveHour
355     | SortType::TopDay
356     | SortType::TopAll
357     | SortType::TopWeek
358     | SortType::TopYear
359     | SortType::TopMonth
360     | SortType::TopThreeMonths
361     | SortType::TopSixMonths
362     | SortType::TopNineMonths => CommentSortType::Top,
363   }
364 }
365
366 pub fn post_to_person_sort_type(sort: SortType) -> PersonSortType {
367   match sort {
368     SortType::Active | SortType::Hot | SortType::Controversial => PersonSortType::CommentScore,
369     SortType::New | SortType::NewComments => PersonSortType::New,
370     SortType::MostComments => PersonSortType::MostComments,
371     SortType::Old => PersonSortType::Old,
372     _ => PersonSortType::CommentScore,
373   }
374 }
375
376 static EMAIL_REGEX: Lazy<Regex> = Lazy::new(|| {
377   Regex::new(r"^[a-zA-Z0-9.!#$%&’*+/=?^_`{|}~-]+@[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*$")
378     .expect("compile email regex")
379 });
380
381 pub mod functions {
382   use diesel::sql_types::{BigInt, Text, Timestamp};
383
384   sql_function! {
385     fn hot_rank(score: BigInt, time: Timestamp) -> Integer;
386   }
387
388   sql_function! {
389     fn controversy_rank(upvotes: BigInt, downvotes: BigInt, score: BigInt) -> Double;
390   }
391
392   sql_function!(fn lower(x: Text) -> Text);
393 }
394
395 pub const DELETED_REPLACEMENT_TEXT: &str = "*Permanently Deleted*";
396
397 impl ToSql<Text, Pg> for DbUrl {
398   fn to_sql(&self, out: &mut Output<Pg>) -> diesel::serialize::Result {
399     <std::string::String as ToSql<Text, Pg>>::to_sql(&self.0.to_string(), &mut out.reborrow())
400   }
401 }
402
403 impl<DB: Backend> FromSql<Text, DB> for DbUrl
404 where
405   String: FromSql<Text, DB>,
406 {
407   fn from_sql(value: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
408     let str = String::from_sql(value)?;
409     Ok(DbUrl(Box::new(Url::parse(&str)?)))
410   }
411 }
412
413 impl<Kind> From<ObjectId<Kind>> for DbUrl
414 where
415   Kind: Object + Send + 'static,
416   for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
417 {
418   fn from(id: ObjectId<Kind>) -> Self {
419     DbUrl(Box::new(id.into()))
420   }
421 }
422
423 #[cfg(test)]
424 mod tests {
425   #![allow(clippy::unwrap_used)]
426   #![allow(clippy::indexing_slicing)]
427
428   use super::{fuzzy_search, *};
429   use crate::utils::is_email_regex;
430
431   #[test]
432   fn test_fuzzy_search() {
433     let test = "This %is% _a_ fuzzy search";
434     assert_eq!(
435       fuzzy_search(test),
436       "%This%\\%is\\%%\\_a\\_%fuzzy%search%".to_string()
437     );
438   }
439
440   #[test]
441   fn test_email() {
442     assert!(is_email_regex("gush@gmail.com"));
443     assert!(!is_email_regex("nada_neutho"));
444   }
445
446   #[test]
447   fn test_diesel_option_overwrite() {
448     assert_eq!(diesel_option_overwrite(None), None);
449     assert_eq!(diesel_option_overwrite(Some(String::new())), Some(None));
450     assert_eq!(
451       diesel_option_overwrite(Some("test".to_string())),
452       Some(Some("test".to_string()))
453     );
454   }
455
456   #[test]
457   fn test_diesel_option_overwrite_to_url() {
458     assert!(matches!(diesel_option_overwrite_to_url(&None), Ok(None)));
459     assert!(matches!(
460       diesel_option_overwrite_to_url(&Some(String::new())),
461       Ok(Some(None))
462     ));
463     assert!(diesel_option_overwrite_to_url(&Some("invalid_url".to_string())).is_err());
464     let example_url = "https://example.com";
465     assert!(matches!(
466       diesel_option_overwrite_to_url(&Some(example_url.to_string())),
467       Ok(Some(Some(url))) if url == Url::parse(example_url).unwrap().into()
468     ));
469   }
470 }