From 6d67f88603482f62bea74c8ef7ceae4ca4b77d59 Mon Sep 17 00:00:00 2001
From: Sander Saarend <sander@saarend.com>
Date: Mon, 26 Jun 2023 11:25:38 +0300
Subject: [PATCH] Add support for sslmode=require for diesel-async DB
 connections (#3189)

---
 Cargo.lock                    | 92 +++++++++++++++++++++++++++++------
 Cargo.toml                    |  9 ++++
 crates/db_schema/Cargo.toml   |  6 ++-
 crates/db_schema/src/utils.rs | 64 ++++++++++++++++++++++--
 4 files changed, 152 insertions(+), 19 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 9d575f57..99d94948 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -215,7 +215,7 @@ dependencies = [
  "futures-util",
  "mio",
  "num_cpus",
- "socket2",
+ "socket2 0.4.9",
  "tokio",
  "tracing",
 ]
@@ -245,7 +245,7 @@ dependencies = [
  "http",
  "log",
  "pin-project-lite",
- "tokio-rustls",
+ "tokio-rustls 0.23.4",
  "tokio-util 0.7.4",
  "webpki-roots",
 ]
@@ -297,7 +297,7 @@ dependencies = [
  "serde_json",
  "serde_urlencoded",
  "smallvec",
- "socket2",
+ "socket2 0.4.9",
  "time 0.3.15",
  "url",
 ]
@@ -496,7 +496,7 @@ dependencies = [
  "percent-encoding",
  "pin-project-lite",
  "rand 0.8.5",
- "rustls",
+ "rustls 0.20.7",
  "serde",
  "serde_json",
  "serde_urlencoded",
@@ -2262,7 +2262,7 @@ dependencies = [
  "httpdate",
  "itoa",
  "pin-project-lite",
- "socket2",
+ "socket2 0.4.9",
  "tokio",
  "tower-service",
  "tracing",
@@ -2640,9 +2640,11 @@ dependencies = [
  "diesel-derive-newtype",
  "diesel_ltree",
  "diesel_migrations",
+ "futures-util",
  "lemmy_utils",
  "once_cell",
  "regex",
+ "rustls 0.21.2",
  "serde",
  "serde_json",
  "serde_with",
@@ -2651,6 +2653,8 @@ dependencies = [
  "strum",
  "strum_macros",
  "tokio",
+ "tokio-postgres",
+ "tokio-postgres-rustls",
  "tracing",
  "ts-rs",
  "typed-builder",
@@ -2736,6 +2740,7 @@ dependencies = [
  "diesel",
  "diesel-async",
  "doku",
+ "futures-util",
  "lemmy_api",
  "lemmy_api_common",
  "lemmy_api_crud",
@@ -2749,9 +2754,12 @@ dependencies = [
  "reqwest",
  "reqwest-middleware",
  "reqwest-tracing",
+ "rustls 0.21.2",
  "serde",
  "serde_json",
  "tokio",
+ "tokio-postgres",
+ "tokio-postgres-rustls",
  "tracing",
  "tracing-actix-web 0.6.2",
  "tracing-error",
@@ -2820,7 +2828,7 @@ dependencies = [
  "nom 7.1.1",
  "once_cell",
  "quoted_printable",
- "socket2",
+ "socket2 0.4.9",
 ]
 
 [[package]]
@@ -3932,11 +3940,11 @@ dependencies = [
 
 [[package]]
 name = "postgres-protocol"
-version = "0.6.4"
+version = "0.6.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "878c6cbf956e03af9aa8204b407b9cbf47c072164800aa918c516cd4b056c50c"
+checksum = "78b7fa9f396f51dffd61546fd8573ee20592287996568e6175ceb0f8699ad75d"
 dependencies = [
- "base64 0.13.1",
+ "base64 0.21.2",
  "byteorder",
  "bytes",
  "fallible-iterator",
@@ -4495,6 +4503,28 @@ dependencies = [
  "webpki",
 ]
 
+[[package]]
+name = "rustls"
+version = "0.21.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e32ca28af694bc1bbf399c33a516dbdf1c90090b8ab23c2bc24f834aa2247f5f"
+dependencies = [
+ "log",
+ "ring",
+ "rustls-webpki",
+ "sct",
+]
+
+[[package]]
+name = "rustls-webpki"
+version = "0.100.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b"
+dependencies = [
+ "ring",
+ "untrusted",
+]
+
 [[package]]
 name = "rustversion"
 version = "1.0.9"
@@ -4859,6 +4889,16 @@ dependencies = [
  "winapi",
 ]
 
+[[package]]
+name = "socket2"
+version = "0.5.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877"
+dependencies = [
+ "libc",
+ "windows-sys 0.48.0",
+]
+
 [[package]]
 name = "spin"
 version = "0.5.2"
@@ -5304,7 +5344,7 @@ dependencies = [
  "parking_lot 0.12.1",
  "pin-project-lite",
  "signal-hook-registry",
- "socket2",
+ "socket2 0.4.9",
  "tokio-macros",
  "tracing",
  "windows-sys 0.48.0",
@@ -5343,9 +5383,9 @@ dependencies = [
 
 [[package]]
 name = "tokio-postgres"
-version = "0.7.7"
+version = "0.7.8"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "29a12c1b3e0704ae7dfc25562629798b29c72e6b1d0a681b6f29ab4ae5e7f7bf"
+checksum = "6e89f6234aa8fd43779746012fcf53603cdb91fdd8399aa0de868c2d56b6dde1"
 dependencies = [
  "async-trait",
  "byteorder",
@@ -5360,22 +5400,46 @@ dependencies = [
  "pin-project-lite",
  "postgres-protocol",
  "postgres-types",
- "socket2",
+ "socket2 0.5.3",
  "tokio",
  "tokio-util 0.7.4",
 ]
 
+[[package]]
+name = "tokio-postgres-rustls"
+version = "0.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dd5831152cb0d3f79ef5523b357319ba154795d64c7078b2daa95a803b54057f"
+dependencies = [
+ "futures",
+ "ring",
+ "rustls 0.21.2",
+ "tokio",
+ "tokio-postgres",
+ "tokio-rustls 0.24.1",
+]
+
 [[package]]
 name = "tokio-rustls"
 version = "0.23.4"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59"
 dependencies = [
- "rustls",
+ "rustls 0.20.7",
  "tokio",
  "webpki",
 ]
 
+[[package]]
+name = "tokio-rustls"
+version = "0.24.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
+dependencies = [
+ "rustls 0.21.2",
+ "tokio",
+]
+
 [[package]]
 name = "tokio-stream"
 version = "0.1.11"
diff --git a/Cargo.toml b/Cargo.toml
index 07e41ab3..2ee5a530 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -107,6 +107,10 @@ rand = "0.8.5"
 opentelemetry = { version = "0.17.0", features = ["rt-tokio"] }
 tracing-opentelemetry = { version = "0.17.4" }
 ts-rs = { version = "6.2", features = ["serde-compat", "format", "chrono-impl"] }
+rustls = { version ="0.21.2", features = ["dangerous_configuration"]}
+futures-util = "0.3.28"
+tokio-postgres = "0.7.8"
+tokio-postgres-rustls = "0.10.0"
 
 [dependencies]
 lemmy_api = { workspace = true }
@@ -140,3 +144,8 @@ opentelemetry-otlp = { version = "0.10.0", optional = true }
 pict-rs = { version = "0.4.0-rc.3", optional = true }
 tokio.workspace = true
 actix-cors = "0.6.4"
+rustls = { workspace = true }
+futures-util = { workspace = true }
+tokio-postgres = { workspace = true }
+tokio-postgres-rustls = { workspace = true }
+
diff --git a/crates/db_schema/Cargo.toml b/crates/db_schema/Cargo.toml
index aa26382c..6c89ab9f 100644
--- a/crates/db_schema/Cargo.toml
+++ b/crates/db_schema/Cargo.toml
@@ -43,7 +43,11 @@ async-trait = { workspace = true }
 tokio = { workspace = true }
 tracing = { workspace = true }
 deadpool = { version = "0.9.5", features = ["rt_tokio_1"], optional = true }
-ts-rs = { workspace = true, optional = true } 
+ts-rs = { workspace = true, optional = true }
+rustls = { workspace = true }
+futures-util = { workspace = true }
+tokio-postgres = { workspace = true }
+tokio-postgres-rustls = { workspace = true }
 
 [dev-dependencies]
 serial_test = { workspace = true }
diff --git a/crates/db_schema/src/utils.rs b/crates/db_schema/src/utils.rs
index 98d3952a..1319a62f 100644
--- a/crates/db_schema/src/utils.rs
+++ b/crates/db_schema/src/utils.rs
@@ -12,7 +12,7 @@ use diesel::{
   backend::Backend,
   deserialize::FromSql,
   pg::Pg,
-  result::{Error as DieselError, Error::QueryBuilderError},
+  result::{ConnectionError, ConnectionResult, Error as DieselError, Error::QueryBuilderError},
   serialize::{Output, ToSql},
   sql_types::Text,
   PgConnection,
@@ -25,11 +25,21 @@ use diesel_async::{
   },
 };
 use diesel_migrations::EmbeddedMigrations;
+use futures_util::{future::BoxFuture, FutureExt};
 use lemmy_utils::{error::LemmyError, settings::structs::Settings};
 use once_cell::sync::Lazy;
 use regex::Regex;
-use std::{env, env::VarError, time::Duration};
-use tracing::info;
+use rustls::{
+  client::{ServerCertVerified, ServerCertVerifier},
+  ServerName,
+};
+use std::{
+  env,
+  env::VarError,
+  sync::Arc,
+  time::{Duration, SystemTime},
+};
+use tracing::{error, info};
 use url::Url;
 
 const FETCH_LIMIT_DEFAULT: i64 = 10;
@@ -136,7 +146,15 @@ pub fn diesel_option_overwrite_to_url_create(
 async fn build_db_pool_settings_opt(settings: Option<&Settings>) -> Result<DbPool, LemmyError> {
   let db_url = get_database_url(settings);
   let pool_size = settings.map(|s| s.database.pool_size).unwrap_or(5);
-  let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(&db_url);
+  // We only support TLS with sslmode=require currently
+  let tls_enabled = db_url.contains("sslmode=require");
+  let manager = if tls_enabled {
+    // diesel-async does not support any TLS connections out of the box, so we need to manually
+    // provide a setup function which handles creating the connection
+    AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_setup(&db_url, establish_connection)
+  } else {
+    AsyncDieselConnectionManager::<AsyncPgConnection>::new(&db_url)
+  };
   let pool = Pool::builder(manager)
     .max_size(pool_size)
     .wait_timeout(POOL_TIMEOUT)
@@ -153,6 +171,44 @@ async fn build_db_pool_settings_opt(settings: Option<&Settings>) -> Result<DbPoo
   Ok(pool)
 }
 
+fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
+  let fut = async {
+    let rustls_config = rustls::ClientConfig::builder()
+      .with_safe_defaults()
+      .with_custom_certificate_verifier(Arc::new(NoCertVerifier {}))
+      .with_no_client_auth();
+
+    let tls = tokio_postgres_rustls::MakeRustlsConnect::new(rustls_config);
+    let (client, conn) = tokio_postgres::connect(config, tls)
+      .await
+      .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
+    tokio::spawn(async move {
+      if let Err(e) = conn.await {
+        error!("Database connection failed: {e}");
+      }
+    });
+    AsyncPgConnection::try_from(client).await
+  };
+  fut.boxed()
+}
+
+struct NoCertVerifier {}
+
+impl ServerCertVerifier for NoCertVerifier {
+  fn verify_server_cert(
+    &self,
+    _end_entity: &rustls::Certificate,
+    _intermediates: &[rustls::Certificate],
+    _server_name: &ServerName,
+    _scts: &mut dyn Iterator<Item = &[u8]>,
+    _ocsp_response: &[u8],
+    _now: SystemTime,
+  ) -> Result<ServerCertVerified, rustls::Error> {
+    // Will verify all (even invalid) certs without any checks (sslmode=require)
+    Ok(ServerCertVerified::assertion())
+  }
+}
+
 pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
 
 pub fn run_migrations(db_url: &str) {
-- 
2.44.1