-use crate::newtypes::DbUrl;
+use crate::{
+ diesel::Connection,
+ diesel_migrations::MigrationHarness,
+ newtypes::DbUrl,
+ CommentSortType,
+ SortType,
+};
use activitypub_federation::{core::object_id::ObjectId, traits::ApubObject};
+use bb8::PooledConnection;
use chrono::NaiveDateTime;
use diesel::{
backend::Backend,
deserialize::FromSql,
+ pg::Pg,
+ result::{Error as DieselError, Error::QueryBuilderError},
serialize::{Output, ToSql},
sql_types::Text,
- Connection,
PgConnection,
};
-use lemmy_utils::error::LemmyError;
+use diesel_async::{
+ pg::AsyncPgConnection,
+ pooled_connection::{bb8::Pool, AsyncDieselConnectionManager},
+};
+use diesel_migrations::EmbeddedMigrations;
+use lemmy_utils::{error::LemmyError, settings::structs::Settings};
use once_cell::sync::Lazy;
use regex::Regex;
-use std::{env, env::VarError, io::Write};
+use std::{env, env::VarError};
+use tracing::info;
use url::Url;
-pub type DbPool = diesel::r2d2::Pool<diesel::r2d2::ConnectionManager<diesel::PgConnection>>;
+const FETCH_LIMIT_DEFAULT: i64 = 10;
+pub const FETCH_LIMIT_MAX: i64 = 50;
+
+pub type DbPool = Pool<AsyncPgConnection>;
+
+pub async fn get_conn(
+ pool: &DbPool,
+) -> Result<PooledConnection<AsyncDieselConnectionManager<AsyncPgConnection>>, DieselError> {
+ pool.get().await.map_err(|e| QueryBuilderError(e.into()))
+}
pub fn get_database_url_from_env() -> Result<String, VarError> {
env::var("LEMMY_DATABASE_URL")
format!("%{}%", replaced)
}
-pub fn limit_and_offset(page: Option<i64>, limit: Option<i64>) -> (i64, i64) {
- let page = page.unwrap_or(1);
- let limit = limit.unwrap_or(10);
+pub fn limit_and_offset(
+ page: Option<i64>,
+ limit: Option<i64>,
+) -> Result<(i64, i64), diesel::result::Error> {
+ let page = match page {
+ Some(page) => {
+ if page < 1 {
+ return Err(QueryBuilderError("Page is < 1".into()));
+ } else {
+ page
+ }
+ }
+ None => 1,
+ };
+ let limit = match limit {
+ Some(limit) => {
+ if !(1..=FETCH_LIMIT_MAX).contains(&limit) {
+ return Err(QueryBuilderError(
+ format!("Fetch limit is > {}", FETCH_LIMIT_MAX).into(),
+ ));
+ } else {
+ limit
+ }
+ }
+ None => FETCH_LIMIT_DEFAULT,
+ };
let offset = limit * (page - 1);
+ Ok((limit, offset))
+}
+
+pub fn limit_and_offset_unlimited(page: Option<i64>, limit: Option<i64>) -> (i64, i64) {
+ let limit = limit.unwrap_or(FETCH_LIMIT_DEFAULT);
+ let offset = limit * (page.unwrap_or(1) - 1);
(limit, offset)
}
// An empty string is an erase
Some(unwrapped) => {
if !unwrapped.eq("") {
- Some(Some(unwrapped.to_owned()))
+ Some(Some(unwrapped.clone()))
} else {
Some(None)
}
pub fn diesel_option_overwrite_to_url(
opt: &Option<String>,
) -> Result<Option<Option<DbUrl>>, LemmyError> {
- match opt.as_ref().map(|s| s.as_str()) {
+ match opt.as_ref().map(std::string::String::as_str) {
// An empty string is an erase
Some("") => Ok(Some(None)),
Some(str_url) => match Url::parse(str_url) {
}
}
-embed_migrations!();
+pub fn diesel_option_overwrite_to_url_create(
+ opt: &Option<String>,
+) -> Result<Option<DbUrl>, LemmyError> {
+ match opt.as_ref().map(std::string::String::as_str) {
+ // An empty string is nothing
+ Some("") => Ok(None),
+ Some(str_url) => match Url::parse(str_url) {
+ Ok(url) => Ok(Some(url.into())),
+ Err(e) => Err(LemmyError::from_error_message(e, "invalid_url")),
+ },
+ None => Ok(None),
+ }
+}
+
+async fn build_db_pool_settings_opt(settings: Option<&Settings>) -> Result<DbPool, LemmyError> {
+ let db_url = get_database_url(settings);
+ let pool_size = settings.map(|s| s.database.pool_size).unwrap_or(5);
+ let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(&db_url);
+ let pool = Pool::builder()
+ .max_size(pool_size)
+ .min_idle(Some(1))
+ .build(manager)
+ .await?;
+
+ // If there's no settings, that means its a unit test, and migrations need to be run
+ if settings.is_none() {
+ run_migrations(&db_url);
+ }
+
+ Ok(pool)
+}
+
+pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
+
+pub fn run_migrations(db_url: &str) {
+ // Needs to be a sync connection
+ let mut conn =
+ PgConnection::establish(db_url).unwrap_or_else(|_| panic!("Error connecting to {}", db_url));
+ info!("Running Database migrations (This may take a long time)...");
+ let _ = &mut conn
+ .run_pending_migrations(MIGRATIONS)
+ .unwrap_or_else(|_| panic!("Couldn't run DB Migrations"));
+ info!("Database migrations complete.");
+}
+
+pub async fn build_db_pool(settings: &Settings) -> Result<DbPool, LemmyError> {
+ build_db_pool_settings_opt(Some(settings)).await
+}
-pub fn establish_unpooled_connection() -> PgConnection {
- let db_url = match get_database_url_from_env() {
+pub async fn build_db_pool_for_tests() -> DbPool {
+ build_db_pool_settings_opt(None)
+ .await
+ .expect("db pool missing")
+}
+
+pub fn get_database_url(settings: Option<&Settings>) -> String {
+ // The env var should override anything in the settings config
+ match get_database_url_from_env() {
Ok(url) => url,
- Err(e) => panic!(
- "Failed to read database URL from env var LEMMY_DATABASE_URL: {}",
- e
- ),
- };
- let conn =
- PgConnection::establish(&db_url).unwrap_or_else(|_| panic!("Error connecting to {}", db_url));
- embedded_migrations::run(&conn).expect("load migrations");
- conn
+ Err(e) => match settings {
+ Some(settings) => settings.get_database_url(),
+ None => panic!(
+ "Failed to read database URL from env var LEMMY_DATABASE_URL: {}",
+ e
+ ),
+ },
+ }
}
pub fn naive_now() -> NaiveDateTime {
chrono::prelude::Utc::now().naive_utc()
}
+pub fn post_to_comment_sort_type(sort: SortType) -> CommentSortType {
+ match sort {
+ SortType::Active | SortType::Hot => CommentSortType::Hot,
+ SortType::New | SortType::NewComments | SortType::MostComments => CommentSortType::New,
+ SortType::Old => CommentSortType::Old,
+ SortType::TopDay
+ | SortType::TopAll
+ | SortType::TopWeek
+ | SortType::TopYear
+ | SortType::TopMonth => CommentSortType::Top,
+ }
+}
+
static EMAIL_REGEX: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"^[a-zA-Z0-9.!#$%&’*+/=?^_`{|}~-]+@[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*$")
.expect("compile email regex")
});
pub mod functions {
- use diesel::sql_types::*;
+ use diesel::sql_types::{BigInt, Text, Timestamp};
sql_function! {
fn hot_rank(score: BigInt, time: Timestamp) -> Integer;
sql_function!(fn lower(x: Text) -> Text);
}
-impl<DB: Backend> ToSql<Text, DB> for DbUrl
-where
- String: ToSql<Text, DB>,
-{
- fn to_sql<W: Write>(&self, out: &mut Output<W, DB>) -> diesel::serialize::Result {
- self.0.to_string().to_sql(out)
+impl ToSql<Text, Pg> for DbUrl {
+ fn to_sql(&self, out: &mut Output<Pg>) -> diesel::serialize::Result {
+ <std::string::String as ToSql<Text, Pg>>::to_sql(&self.0.to_string(), &mut out.reborrow())
}
}
where
String: FromSql<Text, DB>,
{
- fn from_sql(bytes: Option<&DB::RawValue>) -> diesel::deserialize::Result<Self> {
- let str = String::from_sql(bytes)?;
+ fn from_sql(value: diesel::backend::RawValue<'_, DB>) -> diesel::deserialize::Result<Self> {
+ let str = String::from_sql(value)?;
Ok(DbUrl(Url::parse(&str)?))
}
}
#[test]
fn test_diesel_option_overwrite() {
assert_eq!(diesel_option_overwrite(&None), None);
- assert_eq!(diesel_option_overwrite(&Some("".to_string())), Some(None));
+ assert_eq!(diesel_option_overwrite(&Some(String::new())), Some(None));
assert_eq!(
diesel_option_overwrite(&Some("test".to_string())),
Some(Some("test".to_string()))
fn test_diesel_option_overwrite_to_url() {
assert!(matches!(diesel_option_overwrite_to_url(&None), Ok(None)));
assert!(matches!(
- diesel_option_overwrite_to_url(&Some("".to_string())),
+ diesel_option_overwrite_to_url(&Some(String::new())),
Ok(Some(None))
));
assert!(matches!(