3 diesel_migrations::MigrationHarness,
10 use activitypub_federation::{fetch::object_id::ObjectId, traits::Object};
11 use chrono::NaiveDateTime;
12 use deadpool::Runtime;
17 result::{ConnectionError, ConnectionResult, Error as DieselError, Error::QueryBuilderError},
18 serialize::{Output, ToSql},
23 pg::AsyncPgConnection,
25 deadpool::{Object as PooledConnection, Pool},
26 AsyncDieselConnectionManager,
29 use diesel_migrations::EmbeddedMigrations;
30 use futures_util::{future::BoxFuture, Future, FutureExt};
32 error::{LemmyError, LemmyErrorExt, LemmyErrorType},
33 settings::structs::Settings,
35 use once_cell::sync::Lazy;
38 client::{ServerCertVerified, ServerCertVerifier},
44 ops::{Deref, DerefMut},
46 time::{Duration, SystemTime},
48 use tracing::{error, info};
51 const FETCH_LIMIT_DEFAULT: i64 = 10;
52 pub const FETCH_LIMIT_MAX: i64 = 50;
53 const POOL_TIMEOUT: Option<Duration> = Some(Duration::from_secs(5));
55 pub type ActualDbPool = Pool<AsyncPgConnection>;
57 /// References a pool or connection. Functions must take `&mut DbPool<'_>` to allow implicit reborrowing.
59 /// https://github.com/rust-lang/rfcs/issues/1403
61 Pool(&'a ActualDbPool),
62 Conn(&'a mut AsyncPgConnection),
66 Pool(PooledConnection<AsyncPgConnection>),
67 Conn(&'a mut AsyncPgConnection),
70 pub async fn get_conn<'a, 'b: 'a>(pool: &'a mut DbPool<'b>) -> Result<DbConn<'a>, DieselError> {
72 DbPool::Pool(pool) => DbConn::Pool(pool.get().await.map_err(|e| QueryBuilderError(e.into()))?),
73 DbPool::Conn(conn) => DbConn::Conn(conn),
77 impl<'a> Deref for DbConn<'a> {
78 type Target = AsyncPgConnection;
80 fn deref(&self) -> &Self::Target {
82 DbConn::Pool(conn) => conn.deref(),
83 DbConn::Conn(conn) => conn.deref(),
88 impl<'a> DerefMut for DbConn<'a> {
89 fn deref_mut(&mut self) -> &mut Self::Target {
91 DbConn::Pool(conn) => conn.deref_mut(),
92 DbConn::Conn(conn) => conn.deref_mut(),
97 // Allows functions that take `DbPool<'_>` to be called in a transaction by passing `&mut conn.into()`
98 impl<'a> From<&'a mut AsyncPgConnection> for DbPool<'a> {
99 fn from(value: &'a mut AsyncPgConnection) -> Self {
104 impl<'a, 'b: 'a> From<&'a mut DbConn<'b>> for DbPool<'a> {
105 fn from(value: &'a mut DbConn<'b>) -> Self {
106 DbPool::Conn(value.deref_mut())
110 impl<'a> From<&'a ActualDbPool> for DbPool<'a> {
111 fn from(value: &'a ActualDbPool) -> Self {
116 /// Runs multiple async functions that take `&mut DbPool<'_>` as input and return `Result`. Only works when the `futures` crate is listed in `Cargo.toml`.
118 /// `$pool` is the value given to each function.
120 /// A `Result` is returned (not in a `Future`, so don't use `.await`). The `Ok` variant contains a tuple with the values returned by the given functions.
122 /// The functions run concurrently if `$pool` has the `DbPool::Pool` variant.
124 macro_rules! try_join_with_pool {
125 ($pool:ident => ($($func:expr),+)) => {{
127 let _: &mut $crate::utils::DbPool<'_> = $pool;
130 // Run concurrently with `try_join`
131 $crate::utils::DbPool::Pool(__pool) => ::futures::try_join!(
133 let mut __dbpool = $crate::utils::DbPool::Pool(__pool);
134 ($func)(&mut __dbpool).await
138 $crate::utils::DbPool::Conn(__conn) => async {
140 let mut __dbpool = $crate::utils::DbPool::Conn(__conn);
141 // `?` prevents the error type from being inferred in an `async` block, so `match` is used instead
142 match ($func)(&mut __dbpool).await {
143 ::core::result::Result::Ok(__v) => __v,
144 ::core::result::Result::Err(__v) => return ::core::result::Result::Err(__v),
152 pub fn get_database_url_from_env() -> Result<String, VarError> {
153 env::var("LEMMY_DATABASE_URL")
156 pub fn fuzzy_search(q: &str) -> String {
157 let replaced = q.replace('%', "\\%").replace('_', "\\_").replace(' ', "%");
158 format!("%{replaced}%")
161 pub fn limit_and_offset(
164 ) -> Result<(i64, i64), diesel::result::Error> {
165 let page = match page {
168 return Err(QueryBuilderError("Page is < 1".into()));
175 let limit = match limit {
177 if !(1..=FETCH_LIMIT_MAX).contains(&limit) {
178 return Err(QueryBuilderError(
179 format!("Fetch limit is > {FETCH_LIMIT_MAX}").into(),
185 None => FETCH_LIMIT_DEFAULT,
187 let offset = limit * (page - 1);
191 pub fn limit_and_offset_unlimited(page: Option<i64>, limit: Option<i64>) -> (i64, i64) {
192 let limit = limit.unwrap_or(FETCH_LIMIT_DEFAULT);
193 let offset = limit * (page.unwrap_or(1) - 1);
197 pub fn is_email_regex(test: &str) -> bool {
198 EMAIL_REGEX.is_match(test)
201 pub fn diesel_option_overwrite(opt: Option<String>) -> Option<Option<String>> {
203 // An empty string is an erase
205 if !unwrapped.eq("") {
206 Some(Some(unwrapped))
215 pub fn diesel_option_overwrite_to_url(
216 opt: &Option<String>,
217 ) -> Result<Option<Option<DbUrl>>, LemmyError> {
218 match opt.as_ref().map(String::as_str) {
219 // An empty string is an erase
220 Some("") => Ok(Some(None)),
221 Some(str_url) => Url::parse(str_url)
222 .map(|u| Some(Some(u.into())))
223 .with_lemmy_type(LemmyErrorType::InvalidUrl),
228 pub fn diesel_option_overwrite_to_url_create(
229 opt: &Option<String>,
230 ) -> Result<Option<DbUrl>, LemmyError> {
231 match opt.as_ref().map(String::as_str) {
232 // An empty string is nothing
233 Some("") => Ok(None),
234 Some(str_url) => Url::parse(str_url)
235 .map(|u| Some(u.into()))
236 .with_lemmy_type(LemmyErrorType::InvalidUrl),
241 async fn build_db_pool_settings_opt(
242 settings: Option<&Settings>,
243 ) -> Result<ActualDbPool, LemmyError> {
244 let db_url = get_database_url(settings);
245 let pool_size = settings.map(|s| s.database.pool_size).unwrap_or(5);
246 // We only support TLS with sslmode=require currently
247 let tls_enabled = db_url.contains("sslmode=require");
248 let manager = if tls_enabled {
249 // diesel-async does not support any TLS connections out of the box, so we need to manually
250 // provide a setup function which handles creating the connection
251 AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_setup(&db_url, establish_connection)
253 AsyncDieselConnectionManager::<AsyncPgConnection>::new(&db_url)
255 let pool = Pool::builder(manager)
257 .wait_timeout(POOL_TIMEOUT)
258 .create_timeout(POOL_TIMEOUT)
259 .recycle_timeout(POOL_TIMEOUT)
260 .runtime(Runtime::Tokio1)
263 // If there's no settings, that means its a unit test, and migrations need to be run
264 if settings.is_none() {
265 run_migrations(&db_url);
271 fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
273 let rustls_config = rustls::ClientConfig::builder()
274 .with_safe_defaults()
275 .with_custom_certificate_verifier(Arc::new(NoCertVerifier {}))
276 .with_no_client_auth();
278 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(rustls_config);
279 let (client, conn) = tokio_postgres::connect(config, tls)
281 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
282 tokio::spawn(async move {
283 if let Err(e) = conn.await {
284 error!("Database connection failed: {e}");
287 AsyncPgConnection::try_from(client).await
292 struct NoCertVerifier {}
294 impl ServerCertVerifier for NoCertVerifier {
295 fn verify_server_cert(
297 _end_entity: &rustls::Certificate,
298 _intermediates: &[rustls::Certificate],
299 _server_name: &ServerName,
300 _scts: &mut dyn Iterator<Item = &[u8]>,
301 _ocsp_response: &[u8],
303 ) -> Result<ServerCertVerified, rustls::Error> {
304 // Will verify all (even invalid) certs without any checks (sslmode=require)
305 Ok(ServerCertVerified::assertion())
309 pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
311 pub fn run_migrations(db_url: &str) {
312 // Needs to be a sync connection
314 PgConnection::establish(db_url).unwrap_or_else(|e| panic!("Error connecting to {db_url}: {e}"));
315 info!("Running Database migrations (This may take a long time)...");
317 .run_pending_migrations(MIGRATIONS)
318 .unwrap_or_else(|e| panic!("Couldn't run DB Migrations: {e}"));
319 info!("Database migrations complete.");
322 pub async fn build_db_pool(settings: &Settings) -> Result<ActualDbPool, LemmyError> {
323 build_db_pool_settings_opt(Some(settings)).await
326 pub async fn build_db_pool_for_tests() -> ActualDbPool {
327 build_db_pool_settings_opt(None)
329 .expect("db pool missing")
332 pub fn get_database_url(settings: Option<&Settings>) -> String {
333 // The env var should override anything in the settings config
334 match get_database_url_from_env() {
336 Err(e) => match settings {
337 Some(settings) => settings.get_database_url(),
338 None => panic!("Failed to read database URL from env var LEMMY_DATABASE_URL: {e}"),
343 pub fn naive_now() -> NaiveDateTime {
344 chrono::prelude::Utc::now().naive_utc()
347 pub fn post_to_comment_sort_type(sort: SortType) -> CommentSortType {
349 SortType::Active | SortType::Hot => CommentSortType::Hot,
350 SortType::New | SortType::NewComments | SortType::MostComments => CommentSortType::New,
351 SortType::Old => CommentSortType::Old,
352 SortType::Controversial => CommentSortType::Controversial,
354 | SortType::TopSixHour
355 | SortType::TopTwelveHour
361 | SortType::TopThreeMonths
362 | SortType::TopSixMonths
363 | SortType::TopNineMonths => CommentSortType::Top,
367 pub fn post_to_person_sort_type(sort: SortType) -> PersonSortType {
369 SortType::Active | SortType::Hot | SortType::Controversial => PersonSortType::CommentScore,
370 SortType::New | SortType::NewComments => PersonSortType::New,
371 SortType::MostComments => PersonSortType::MostComments,
372 SortType::Old => PersonSortType::Old,
373 _ => PersonSortType::CommentScore,
377 static EMAIL_REGEX: Lazy<Regex> = Lazy::new(|| {
378 Regex::new(r"^[a-zA-Z0-9.!#$%&’*+/=?^_`{|}~-]+@[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*$")
379 .expect("compile email regex")
383 use diesel::sql_types::{BigInt, Text, Timestamp};
386 fn hot_rank(score: BigInt, time: Timestamp) -> Integer;
390 fn controversy_rank(upvotes: BigInt, downvotes: BigInt, score: BigInt) -> Double;
393 sql_function!(fn lower(x: Text) -> Text);
396 pub const DELETED_REPLACEMENT_TEXT: &str = "*Permanently Deleted*";
398 impl ToSql<Text, Pg> for DbUrl {
399 fn to_sql(&self, out: &mut Output<Pg>) -> diesel::serialize::Result {
400 <std::string::String as ToSql<Text, Pg>>::to_sql(&self.0.to_string(), &mut out.reborrow())
404 impl<DB: Backend> FromSql<Text, DB> for DbUrl
406 String: FromSql<Text, DB>,
408 fn from_sql(value: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
409 let str = String::from_sql(value)?;
410 Ok(DbUrl(Box::new(Url::parse(&str)?)))
414 impl<Kind> From<ObjectId<Kind>> for DbUrl
416 Kind: Object + Send + 'static,
417 for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
419 fn from(id: ObjectId<Kind>) -> Self {
420 DbUrl(Box::new(id.into()))
424 pub type ResultFuture<'a, T> = BoxFuture<'a, Result<T, DieselError>>;
426 pub trait ReadFn<'a, T: JoinView, Args>:
427 Fn(DbConn<'a>, Args) -> ResultFuture<'a, <T as JoinView>::JoinTuple>
435 F: Fn(DbConn<'a>, Args) -> ResultFuture<'a, <T as JoinView>::JoinTuple>,
436 > ReadFn<'a, T, Args> for F
440 pub trait ListFn<'a, T: JoinView, Args>:
441 Fn(DbConn<'a>, Args) -> ResultFuture<'a, Vec<<T as JoinView>::JoinTuple>>
449 F: Fn(DbConn<'a>, Args) -> ResultFuture<'a, Vec<<T as JoinView>::JoinTuple>>,
450 > ListFn<'a, T, Args> for F
454 /// Allows read and list functions to capture a shared closure that has an inferred return type, which is useful for join logic
455 pub struct Queries<RF, LF> {
460 // `()` is used to prevent type inference error
461 impl Queries<(), ()> {
462 pub fn new<'a, RFut, LFut, RT, LT, RA, LA, RF2, LF2>(
465 ) -> Queries<impl ReadFn<'a, RT, RA>, impl ListFn<'a, LT, LA>>
467 RFut: Future<Output = Result<<RT as JoinView>::JoinTuple, DieselError>> + Sized + Send + 'a,
469 Future<Output = Result<Vec<<LT as JoinView>::JoinTuple>, DieselError>> + Sized + Send + 'a,
472 RF2: Fn(DbConn<'a>, RA) -> RFut,
473 LF2: Fn(DbConn<'a>, LA) -> LFut,
476 read_fn: move |conn, args| read_fn(conn, args).boxed(),
477 list_fn: move |conn, args| list_fn(conn, args).boxed(),
482 impl<RF, LF> Queries<RF, LF> {
483 pub async fn read<'a, T, Args>(
485 pool: &'a mut DbPool<'_>,
487 ) -> Result<T, DieselError>
490 RF: ReadFn<'a, T, Args>,
492 let conn = get_conn(pool).await?;
493 let res = (self.read_fn)(conn, args).await?;
494 Ok(T::from_tuple(res))
497 pub async fn list<'a, T, Args>(
499 pool: &'a mut DbPool<'_>,
501 ) -> Result<Vec<T>, DieselError>
504 LF: ListFn<'a, T, Args>,
506 let conn = get_conn(pool).await?;
507 let res = (self.list_fn)(conn, args).await?;
508 Ok(res.into_iter().map(T::from_tuple).collect())
514 #![allow(clippy::unwrap_used)]
515 #![allow(clippy::indexing_slicing)]
517 use super::{fuzzy_search, *};
518 use crate::utils::is_email_regex;
521 fn test_fuzzy_search() {
522 let test = "This %is% _a_ fuzzy search";
525 "%This%\\%is\\%%\\_a\\_%fuzzy%search%".to_string()
531 assert!(is_email_regex("gush@gmail.com"));
532 assert!(!is_email_regex("nada_neutho"));
536 fn test_diesel_option_overwrite() {
537 assert_eq!(diesel_option_overwrite(None), None);
538 assert_eq!(diesel_option_overwrite(Some(String::new())), Some(None));
540 diesel_option_overwrite(Some("test".to_string())),
541 Some(Some("test".to_string()))
546 fn test_diesel_option_overwrite_to_url() {
547 assert!(matches!(diesel_option_overwrite_to_url(&None), Ok(None)));
549 diesel_option_overwrite_to_url(&Some(String::new())),
552 assert!(diesel_option_overwrite_to_url(&Some("invalid_url".to_string())).is_err());
553 let example_url = "https://example.com";
555 diesel_option_overwrite_to_url(&Some(example_url.to_string())),
556 Ok(Some(Some(url))) if url == Url::parse(example_url).unwrap().into()