3 diesel_migrations::MigrationHarness,
8 use activitypub_federation::{fetch::object_id::ObjectId, traits::Object};
9 use chrono::NaiveDateTime;
10 use deadpool::Runtime;
15 result::{ConnectionError, ConnectionResult, Error as DieselError, Error::QueryBuilderError},
16 serialize::{Output, ToSql},
21 pg::AsyncPgConnection,
23 deadpool::{Object as PooledConnection, Pool},
24 AsyncDieselConnectionManager,
27 use diesel_migrations::EmbeddedMigrations;
28 use futures_util::{future::BoxFuture, FutureExt};
30 error::{LemmyError, LemmyErrorExt, LemmyErrorType},
31 settings::structs::Settings,
33 use once_cell::sync::Lazy;
36 client::{ServerCertVerified, ServerCertVerifier},
42 ops::{Deref, DerefMut},
44 time::{Duration, SystemTime},
46 use tracing::{error, info};
49 const FETCH_LIMIT_DEFAULT: i64 = 10;
50 pub const FETCH_LIMIT_MAX: i64 = 50;
51 const POOL_TIMEOUT: Option<Duration> = Some(Duration::from_secs(5));
53 pub type ActualDbPool = Pool<AsyncPgConnection>;
55 /// References a pool or connection. Functions must take `&mut DbPool<'_>` to allow implicit reborrowing.
57 /// https://github.com/rust-lang/rfcs/issues/1403
59 Pool(&'a ActualDbPool),
60 Conn(&'a mut AsyncPgConnection),
64 Pool(PooledConnection<AsyncPgConnection>),
65 Conn(&'a mut AsyncPgConnection),
68 pub async fn get_conn<'a, 'b: 'a>(pool: &'a mut DbPool<'b>) -> Result<DbConn<'a>, DieselError> {
70 DbPool::Pool(pool) => DbConn::Pool(pool.get().await.map_err(|e| QueryBuilderError(e.into()))?),
71 DbPool::Conn(conn) => DbConn::Conn(conn),
75 impl<'a> Deref for DbConn<'a> {
76 type Target = AsyncPgConnection;
78 fn deref(&self) -> &Self::Target {
80 DbConn::Pool(conn) => conn.deref(),
81 DbConn::Conn(conn) => conn.deref(),
86 impl<'a> DerefMut for DbConn<'a> {
87 fn deref_mut(&mut self) -> &mut Self::Target {
89 DbConn::Pool(conn) => conn.deref_mut(),
90 DbConn::Conn(conn) => conn.deref_mut(),
95 // Allows functions that take `DbPool<'_>` to be called in a transaction by passing `&mut conn.into()`
96 impl<'a> From<&'a mut AsyncPgConnection> for DbPool<'a> {
97 fn from(value: &'a mut AsyncPgConnection) -> Self {
102 impl<'a, 'b: 'a> From<&'a mut DbConn<'b>> for DbPool<'a> {
103 fn from(value: &'a mut DbConn<'b>) -> Self {
104 DbPool::Conn(value.deref_mut())
108 impl<'a> From<&'a ActualDbPool> for DbPool<'a> {
109 fn from(value: &'a ActualDbPool) -> Self {
114 /// Runs multiple async functions that take `&mut DbPool<'_>` as input and return `Result`. Only works when the `futures` crate is listed in `Cargo.toml`.
116 /// `$pool` is the value given to each function.
118 /// A `Result` is returned (not in a `Future`, so don't use `.await`). The `Ok` variant contains a tuple with the values returned by the given functions.
120 /// The functions run concurrently if `$pool` has the `DbPool::Pool` variant.
122 macro_rules! try_join_with_pool {
123 ($pool:ident => ($($func:expr),+)) => {{
125 let _: &mut $crate::utils::DbPool<'_> = $pool;
128 // Run concurrently with `try_join`
129 $crate::utils::DbPool::Pool(__pool) => ::futures::try_join!(
131 let mut __dbpool = $crate::utils::DbPool::Pool(__pool);
132 ($func)(&mut __dbpool).await
136 $crate::utils::DbPool::Conn(__conn) => async {
138 let mut __dbpool = $crate::utils::DbPool::Conn(__conn);
139 // `?` prevents the error type from being inferred in an `async` block, so `match` is used instead
140 match ($func)(&mut __dbpool).await {
141 ::core::result::Result::Ok(__v) => __v,
142 ::core::result::Result::Err(__v) => return ::core::result::Result::Err(__v),
150 pub fn get_database_url_from_env() -> Result<String, VarError> {
151 env::var("LEMMY_DATABASE_URL")
154 pub fn fuzzy_search(q: &str) -> String {
155 let replaced = q.replace('%', "\\%").replace('_', "\\_").replace(' ', "%");
156 format!("%{replaced}%")
159 pub fn limit_and_offset(
162 ) -> Result<(i64, i64), diesel::result::Error> {
163 let page = match page {
166 return Err(QueryBuilderError("Page is < 1".into()));
173 let limit = match limit {
175 if !(1..=FETCH_LIMIT_MAX).contains(&limit) {
176 return Err(QueryBuilderError(
177 format!("Fetch limit is > {FETCH_LIMIT_MAX}").into(),
183 None => FETCH_LIMIT_DEFAULT,
185 let offset = limit * (page - 1);
189 pub fn limit_and_offset_unlimited(page: Option<i64>, limit: Option<i64>) -> (i64, i64) {
190 let limit = limit.unwrap_or(FETCH_LIMIT_DEFAULT);
191 let offset = limit * (page.unwrap_or(1) - 1);
195 pub fn is_email_regex(test: &str) -> bool {
196 EMAIL_REGEX.is_match(test)
199 pub fn diesel_option_overwrite(opt: &Option<String>) -> Option<Option<String>> {
201 // An empty string is an erase
203 if !unwrapped.eq("") {
204 Some(Some(unwrapped.clone()))
213 pub fn diesel_option_overwrite_to_url(
214 opt: &Option<String>,
215 ) -> Result<Option<Option<DbUrl>>, LemmyError> {
216 match opt.as_ref().map(String::as_str) {
217 // An empty string is an erase
218 Some("") => Ok(Some(None)),
219 Some(str_url) => Url::parse(str_url)
220 .map(|u| Some(Some(u.into())))
221 .with_lemmy_type(LemmyErrorType::InvalidUrl),
226 pub fn diesel_option_overwrite_to_url_create(
227 opt: &Option<String>,
228 ) -> Result<Option<DbUrl>, LemmyError> {
229 match opt.as_ref().map(String::as_str) {
230 // An empty string is nothing
231 Some("") => Ok(None),
232 Some(str_url) => Url::parse(str_url)
233 .map(|u| Some(u.into()))
234 .with_lemmy_type(LemmyErrorType::InvalidUrl),
239 async fn build_db_pool_settings_opt(
240 settings: Option<&Settings>,
241 ) -> Result<ActualDbPool, LemmyError> {
242 let db_url = get_database_url(settings);
243 let pool_size = settings.map(|s| s.database.pool_size).unwrap_or(5);
244 // We only support TLS with sslmode=require currently
245 let tls_enabled = db_url.contains("sslmode=require");
246 let manager = if tls_enabled {
247 // diesel-async does not support any TLS connections out of the box, so we need to manually
248 // provide a setup function which handles creating the connection
249 AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_setup(&db_url, establish_connection)
251 AsyncDieselConnectionManager::<AsyncPgConnection>::new(&db_url)
253 let pool = Pool::builder(manager)
255 .wait_timeout(POOL_TIMEOUT)
256 .create_timeout(POOL_TIMEOUT)
257 .recycle_timeout(POOL_TIMEOUT)
258 .runtime(Runtime::Tokio1)
261 // If there's no settings, that means its a unit test, and migrations need to be run
262 if settings.is_none() {
263 run_migrations(&db_url);
269 fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
271 let rustls_config = rustls::ClientConfig::builder()
272 .with_safe_defaults()
273 .with_custom_certificate_verifier(Arc::new(NoCertVerifier {}))
274 .with_no_client_auth();
276 let tls = tokio_postgres_rustls::MakeRustlsConnect::new(rustls_config);
277 let (client, conn) = tokio_postgres::connect(config, tls)
279 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
280 tokio::spawn(async move {
281 if let Err(e) = conn.await {
282 error!("Database connection failed: {e}");
285 AsyncPgConnection::try_from(client).await
290 struct NoCertVerifier {}
292 impl ServerCertVerifier for NoCertVerifier {
293 fn verify_server_cert(
295 _end_entity: &rustls::Certificate,
296 _intermediates: &[rustls::Certificate],
297 _server_name: &ServerName,
298 _scts: &mut dyn Iterator<Item = &[u8]>,
299 _ocsp_response: &[u8],
301 ) -> Result<ServerCertVerified, rustls::Error> {
302 // Will verify all (even invalid) certs without any checks (sslmode=require)
303 Ok(ServerCertVerified::assertion())
307 pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
309 pub fn run_migrations(db_url: &str) {
310 // Needs to be a sync connection
312 PgConnection::establish(db_url).unwrap_or_else(|e| panic!("Error connecting to {db_url}: {e}"));
313 info!("Running Database migrations (This may take a long time)...");
315 .run_pending_migrations(MIGRATIONS)
316 .unwrap_or_else(|e| panic!("Couldn't run DB Migrations: {e}"));
317 info!("Database migrations complete.");
320 pub async fn build_db_pool(settings: &Settings) -> Result<ActualDbPool, LemmyError> {
321 build_db_pool_settings_opt(Some(settings)).await
324 pub async fn build_db_pool_for_tests() -> ActualDbPool {
325 build_db_pool_settings_opt(None)
327 .expect("db pool missing")
330 pub fn get_database_url(settings: Option<&Settings>) -> String {
331 // The env var should override anything in the settings config
332 match get_database_url_from_env() {
334 Err(e) => match settings {
335 Some(settings) => settings.get_database_url(),
336 None => panic!("Failed to read database URL from env var LEMMY_DATABASE_URL: {e}"),
341 pub fn naive_now() -> NaiveDateTime {
342 chrono::prelude::Utc::now().naive_utc()
345 pub fn post_to_comment_sort_type(sort: SortType) -> CommentSortType {
347 SortType::Active | SortType::Hot => CommentSortType::Hot,
348 SortType::New | SortType::NewComments | SortType::MostComments => CommentSortType::New,
349 SortType::Old => CommentSortType::Old,
351 | SortType::TopSixHour
352 | SortType::TopTwelveHour
358 | SortType::TopThreeMonths
359 | SortType::TopSixMonths
360 | SortType::TopNineMonths => CommentSortType::Top,
364 static EMAIL_REGEX: Lazy<Regex> = Lazy::new(|| {
365 Regex::new(r"^[a-zA-Z0-9.!#$%&’*+/=?^_`{|}~-]+@[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*$")
366 .expect("compile email regex")
370 use diesel::sql_types::{BigInt, Text, Timestamp};
373 fn hot_rank(score: BigInt, time: Timestamp) -> Integer;
376 sql_function!(fn lower(x: Text) -> Text);
379 pub const DELETED_REPLACEMENT_TEXT: &str = "*Permanently Deleted*";
381 impl ToSql<Text, Pg> for DbUrl {
382 fn to_sql(&self, out: &mut Output<Pg>) -> diesel::serialize::Result {
383 <std::string::String as ToSql<Text, Pg>>::to_sql(&self.0.to_string(), &mut out.reborrow())
387 impl<DB: Backend> FromSql<Text, DB> for DbUrl
389 String: FromSql<Text, DB>,
391 fn from_sql(value: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
392 let str = String::from_sql(value)?;
393 Ok(DbUrl(Box::new(Url::parse(&str)?)))
397 impl<Kind> From<ObjectId<Kind>> for DbUrl
399 Kind: Object + Send + 'static,
400 for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
402 fn from(id: ObjectId<Kind>) -> Self {
403 DbUrl(Box::new(id.into()))
409 #![allow(clippy::unwrap_used)]
410 #![allow(clippy::indexing_slicing)]
412 use super::{fuzzy_search, *};
413 use crate::utils::is_email_regex;
416 fn test_fuzzy_search() {
417 let test = "This %is% _a_ fuzzy search";
420 "%This%\\%is\\%%\\_a\\_%fuzzy%search%".to_string()
426 assert!(is_email_regex("gush@gmail.com"));
427 assert!(!is_email_regex("nada_neutho"));
431 fn test_diesel_option_overwrite() {
432 assert_eq!(diesel_option_overwrite(&None), None);
433 assert_eq!(diesel_option_overwrite(&Some(String::new())), Some(None));
435 diesel_option_overwrite(&Some("test".to_string())),
436 Some(Some("test".to_string()))
441 fn test_diesel_option_overwrite_to_url() {
442 assert!(matches!(diesel_option_overwrite_to_url(&None), Ok(None)));
444 diesel_option_overwrite_to_url(&Some(String::new())),
447 assert!(diesel_option_overwrite_to_url(&Some("invalid_url".to_string())).is_err());
448 let example_url = "https://example.com";
450 diesel_option_overwrite_to_url(&Some(example_url.to_string())),
451 Ok(Some(Some(url))) if url == Url::parse(example_url).unwrap().into()