3 diesel_migrations::MigrationHarness,
9 use activitypub_federation::{fetch::object_id::ObjectId, traits::Object};
10 use chrono::NaiveDateTime;
11 use deadpool::Runtime;
16 result::{ConnectionError, ConnectionResult, Error as DieselError, Error::QueryBuilderError},
17 serialize::{Output, ToSql},
22 pg::AsyncPgConnection,
24 deadpool::{Object as PooledConnection, Pool},
25 AsyncDieselConnectionManager,
28 use diesel_migrations::EmbeddedMigrations;
29 use futures_util::{future::BoxFuture, FutureExt};
31 error::{LemmyError, LemmyErrorExt, LemmyErrorType},
32 settings::structs::Settings,
34 use once_cell::sync::Lazy;
37 client::{ServerCertVerified, ServerCertVerifier},
43 ops::{Deref, DerefMut},
45 time::{Duration, SystemTime},
47 use tracing::{error, info};
50 const FETCH_LIMIT_DEFAULT: i64 = 10;
51 pub const FETCH_LIMIT_MAX: i64 = 50;
52 const POOL_TIMEOUT: Option<Duration> = Some(Duration::from_secs(5));
54 pub type ActualDbPool = Pool<AsyncPgConnection>;
56 /// References a pool or connection. Functions must take `&mut DbPool<'_>` to allow implicit reborrowing.
58 /// https://github.com/rust-lang/rfcs/issues/1403
60 Pool(&'a ActualDbPool),
61 Conn(&'a mut AsyncPgConnection),
65 Pool(PooledConnection<AsyncPgConnection>),
66 Conn(&'a mut AsyncPgConnection),
69 pub async fn get_conn<'a, 'b: 'a>(pool: &'a mut DbPool<'b>) -> Result<DbConn<'a>, DieselError> {
71 DbPool::Pool(pool) => DbConn::Pool(pool.get().await.map_err(|e| QueryBuilderError(e.into()))?),
72 DbPool::Conn(conn) => DbConn::Conn(conn),
76 impl<'a> Deref for DbConn<'a> {
77 type Target = AsyncPgConnection;
79 fn deref(&self) -> &Self::Target {
81 DbConn::Pool(conn) => conn.deref(),
82 DbConn::Conn(conn) => conn.deref(),
87 impl<'a> DerefMut for DbConn<'a> {
88 fn deref_mut(&mut self) -> &mut Self::Target {
90 DbConn::Pool(conn) => conn.deref_mut(),
91 DbConn::Conn(conn) => conn.deref_mut(),
96 // Allows functions that take `DbPool<'_>` to be called in a transaction by passing `&mut conn.into()`
97 impl<'a> From<&'a mut AsyncPgConnection> for DbPool<'a> {
98 fn from(value: &'a mut AsyncPgConnection) -> Self {
103 impl<'a, 'b: 'a> From<&'a mut DbConn<'b>> for DbPool<'a> {
104 fn from(value: &'a mut DbConn<'b>) -> Self {
105 DbPool::Conn(value.deref_mut())
109 impl<'a> From<&'a ActualDbPool> for DbPool<'a> {
110 fn from(value: &'a ActualDbPool) -> Self {
115 /// Runs multiple async functions that take `&mut DbPool<'_>` as input and return `Result`. Only works when the `futures` crate is listed in `Cargo.toml`.
117 /// `$pool` is the value given to each function.
119 /// A `Result` is returned (not in a `Future`, so don't use `.await`). The `Ok` variant contains a tuple with the values returned by the given functions.
121 /// The functions run concurrently if `$pool` has the `DbPool::Pool` variant.
123 macro_rules! try_join_with_pool {
124 ($pool:ident => ($($func:expr),+)) => {{
126 let _: &mut $crate::utils::DbPool<'_> = $pool;
129 // Run concurrently with `try_join`
130 $crate::utils::DbPool::Pool(__pool) => ::futures::try_join!(
132 let mut __dbpool = $crate::utils::DbPool::Pool(__pool);
133 ($func)(&mut __dbpool).await
137 $crate::utils::DbPool::Conn(__conn) => async {
139 let mut __dbpool = $crate::utils::DbPool::Conn(__conn);
140 // `?` prevents the error type from being inferred in an `async` block, so `match` is used instead
141 match ($func)(&mut __dbpool).await {
142 ::core::result::Result::Ok(__v) => __v,
143 ::core::result::Result::Err(__v) => return ::core::result::Result::Err(__v),
151 pub fn get_database_url_from_env() -> Result<String, VarError> {
152 env::var("LEMMY_DATABASE_URL")
155 pub fn fuzzy_search(q: &str) -> String {
156 let replaced = q.replace('%', "\\%").replace('_', "\\_").replace(' ', "%");
157 format!("%{replaced}%")
160 pub fn limit_and_offset(
163 ) -> Result<(i64, i64), diesel::result::Error> {
164 let page = match page {
167 return Err(QueryBuilderError("Page is < 1".into()));
174 let limit = match limit {
176 if !(1..=FETCH_LIMIT_MAX).contains(&limit) {
177 return Err(QueryBuilderError(
178 format!("Fetch limit is > {FETCH_LIMIT_MAX}").into(),
184 None => FETCH_LIMIT_DEFAULT,
186 let offset = limit * (page - 1);
190 pub fn limit_and_offset_unlimited(page: Option<i64>, limit: Option<i64>) -> (i64, i64) {
191 let limit = limit.unwrap_or(FETCH_LIMIT_DEFAULT);
192 let offset = limit * (page.unwrap_or(1) - 1);
196 pub fn is_email_regex(test: &str) -> bool {
197 EMAIL_REGEX.is_match(test)
200 pub fn diesel_option_overwrite(opt: Option<String>) -> Option<Option<String>> {
202 // An empty string is an erase
204 if !unwrapped.eq("") {
205 Some(Some(unwrapped))
214 pub fn diesel_option_overwrite_to_url(
215 opt: &Option<String>,
216 ) -> Result<Option<Option<DbUrl>>, LemmyError> {
217 match opt.as_ref().map(String::as_str) {
218 // An empty string is an erase
219 Some("") => Ok(Some(None)),
220 Some(str_url) => Url::parse(str_url)
221 .map(|u| Some(Some(u.into())))
222 .with_lemmy_type(LemmyErrorType::InvalidUrl),
227 pub fn diesel_option_overwrite_to_url_create(
228 opt: &Option<String>,
229 ) -> Result<Option<DbUrl>, LemmyError> {
230 match opt.as_ref().map(String::as_str) {
231 // An empty string is nothing
232 Some("") => Ok(None),
233 Some(str_url) => Url::parse(str_url)
234 .map(|u| Some(u.into()))
235 .with_lemmy_type(LemmyErrorType::InvalidUrl),
240 async fn build_db_pool_settings_opt(
241 settings: Option<&Settings>,
242 ) -> Result<ActualDbPool, LemmyError> {
243 let db_url = get_database_url(settings);
244 let pool_size = settings.map(|s| s.database.pool_size).unwrap_or(5);
245 // We only support TLS with sslmode=require currently
246 let tls_enabled = db_url.contains("sslmode=require");
247 let manager = if tls_enabled {
248 // diesel-async does not support any TLS connections out of the box, so we need to manually
249 // provide a setup function which handles creating the connection
250 AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_setup(&db_url, establish_connection)
252 AsyncDieselConnectionManager::<AsyncPgConnection>::new(&db_url)
254 let pool = Pool::builder(manager)
256 .wait_timeout(POOL_TIMEOUT)
257 .create_timeout(POOL_TIMEOUT)
258 .recycle_timeout(POOL_TIMEOUT)
259 .runtime(Runtime::Tokio1)
262 // If there's no settings, that means its a unit test, and migrations need to be run
263 if settings.is_none() {
264 run_migrations(&db_url);
270 fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
272 let rustls_config = rustls::ClientConfig::builder()
273 .with_safe_defaults()
274 .with_custom_certificate_verifier(Arc::new(NoCertVerifier {}))
275 .with_no_client_auth();
277 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(rustls_config);
278 let (client, conn) = tokio_postgres::connect(config, tls)
280 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
281 tokio::spawn(async move {
282 if let Err(e) = conn.await {
283 error!("Database connection failed: {e}");
286 AsyncPgConnection::try_from(client).await
291 struct NoCertVerifier {}
293 impl ServerCertVerifier for NoCertVerifier {
294 fn verify_server_cert(
296 _end_entity: &rustls::Certificate,
297 _intermediates: &[rustls::Certificate],
298 _server_name: &ServerName,
299 _scts: &mut dyn Iterator<Item = &[u8]>,
300 _ocsp_response: &[u8],
302 ) -> Result<ServerCertVerified, rustls::Error> {
303 // Will verify all (even invalid) certs without any checks (sslmode=require)
304 Ok(ServerCertVerified::assertion())
308 pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
310 pub fn run_migrations(db_url: &str) {
311 // Needs to be a sync connection
313 PgConnection::establish(db_url).unwrap_or_else(|e| panic!("Error connecting to {db_url}: {e}"));
314 info!("Running Database migrations (This may take a long time)...");
316 .run_pending_migrations(MIGRATIONS)
317 .unwrap_or_else(|e| panic!("Couldn't run DB Migrations: {e}"));
318 info!("Database migrations complete.");
321 pub async fn build_db_pool(settings: &Settings) -> Result<ActualDbPool, LemmyError> {
322 build_db_pool_settings_opt(Some(settings)).await
325 pub async fn build_db_pool_for_tests() -> ActualDbPool {
326 build_db_pool_settings_opt(None)
328 .expect("db pool missing")
331 pub fn get_database_url(settings: Option<&Settings>) -> String {
332 // The env var should override anything in the settings config
333 match get_database_url_from_env() {
335 Err(e) => match settings {
336 Some(settings) => settings.get_database_url(),
337 None => panic!("Failed to read database URL from env var LEMMY_DATABASE_URL: {e}"),
342 pub fn naive_now() -> NaiveDateTime {
343 chrono::prelude::Utc::now().naive_utc()
346 pub fn post_to_comment_sort_type(sort: SortType) -> CommentSortType {
348 SortType::Active | SortType::Hot => CommentSortType::Hot,
349 SortType::New | SortType::NewComments | SortType::MostComments => CommentSortType::New,
350 SortType::Old => CommentSortType::Old,
351 SortType::Controversial => CommentSortType::Controversial,
353 | SortType::TopSixHour
354 | SortType::TopTwelveHour
360 | SortType::TopThreeMonths
361 | SortType::TopSixMonths
362 | SortType::TopNineMonths => CommentSortType::Top,
366 pub fn post_to_person_sort_type(sort: SortType) -> PersonSortType {
368 SortType::Active | SortType::Hot | SortType::Controversial => PersonSortType::CommentScore,
369 SortType::New | SortType::NewComments => PersonSortType::New,
370 SortType::MostComments => PersonSortType::MostComments,
371 SortType::Old => PersonSortType::Old,
372 _ => PersonSortType::CommentScore,
376 static EMAIL_REGEX: Lazy<Regex> = Lazy::new(|| {
377 Regex::new(r"^[a-zA-Z0-9.!#$%&’*+/=?^_`{|}~-]+@[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*$")
378 .expect("compile email regex")
382 use diesel::sql_types::{BigInt, Text, Timestamp};
385 fn hot_rank(score: BigInt, time: Timestamp) -> Integer;
389 fn controversy_rank(upvotes: BigInt, downvotes: BigInt, score: BigInt) -> Double;
392 sql_function!(fn lower(x: Text) -> Text);
395 pub const DELETED_REPLACEMENT_TEXT: &str = "*Permanently Deleted*";
397 impl ToSql<Text, Pg> for DbUrl {
398 fn to_sql(&self, out: &mut Output<Pg>) -> diesel::serialize::Result {
399 <std::string::String as ToSql<Text, Pg>>::to_sql(&self.0.to_string(), &mut out.reborrow())
403 impl<DB: Backend> FromSql<Text, DB> for DbUrl
405 String: FromSql<Text, DB>,
407 fn from_sql(value: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
408 let str = String::from_sql(value)?;
409 Ok(DbUrl(Box::new(Url::parse(&str)?)))
413 impl<Kind> From<ObjectId<Kind>> for DbUrl
415 Kind: Object + Send + 'static,
416 for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
418 fn from(id: ObjectId<Kind>) -> Self {
419 DbUrl(Box::new(id.into()))
425 #![allow(clippy::unwrap_used)]
426 #![allow(clippy::indexing_slicing)]
428 use super::{fuzzy_search, *};
429 use crate::utils::is_email_regex;
432 fn test_fuzzy_search() {
433 let test = "This %is% _a_ fuzzy search";
436 "%This%\\%is\\%%\\_a\\_%fuzzy%search%".to_string()
442 assert!(is_email_regex("gush@gmail.com"));
443 assert!(!is_email_regex("nada_neutho"));
447 fn test_diesel_option_overwrite() {
448 assert_eq!(diesel_option_overwrite(None), None);
449 assert_eq!(diesel_option_overwrite(Some(String::new())), Some(None));
451 diesel_option_overwrite(Some("test".to_string())),
452 Some(Some("test".to_string()))
457 fn test_diesel_option_overwrite_to_url() {
458 assert!(matches!(diesel_option_overwrite_to_url(&None), Ok(None)));
460 diesel_option_overwrite_to_url(&Some(String::new())),
463 assert!(diesel_option_overwrite_to_url(&Some("invalid_url".to_string())).is_err());
464 let example_url = "https://example.com";
466 diesel_option_overwrite_to_url(&Some(example_url.to_string())),
467 Ok(Some(Some(url))) if url == Url::parse(example_url).unwrap().into()