1#![allow(clippy::result_large_err)]
2
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use futures::stream::TryStreamExt as _;
7use native_tls::TlsConnector;
8use postgres_native_tls::MakeTlsConnector;
9use spin_world::async_trait;
10use spin_world::spin::postgres4_2_0::postgres::{
11 self as v4, Column, DbValue, ParameterValue, RowSet,
12};
13use tokio_postgres::config::SslMode;
14use tokio_postgres::types::ToSql;
15use tokio_postgres::{NoTls, Row};
16
17use crate::types::{convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters};
18
19const CONNECTION_POOL_SIZE: usize = 64;
21const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16;
23
24#[async_trait]
27pub trait ClientFactory: Default + Send + Sync + 'static {
28 type Client: Client;
30 async fn get_client(
32 &self,
33 address: &str,
34 root_ca: Option<HashableCertificate>,
35 ) -> Result<Self::Client>;
36}
37
38#[derive(Clone)]
39pub struct HashableCertificate {
40 certificate: native_tls::Certificate,
41 hash: String,
42}
43
44impl HashableCertificate {
45 pub fn from_pem(text: &str) -> anyhow::Result<Self> {
46 let cert_bytes = text.as_bytes();
47 let hash = spin_common::sha256::hex_digest_from_bytes(cert_bytes);
48 let certificate =
49 native_tls::Certificate::from_pem(cert_bytes).context("invalid root certificate")?;
50 Ok(Self { certificate, hash })
51 }
52}
53
54pub struct PooledTokioClientFactory {
56 pools: moka::sync::Cache<(String, Option<String>), deadpool_postgres::Pool>,
57}
58
59impl Default for PooledTokioClientFactory {
60 fn default() -> Self {
61 Self {
62 pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY),
63 }
64 }
65}
66
67#[derive(Clone)]
68pub struct PooledTokioClient(Arc<deadpool_postgres::Object>);
69
70impl AsRef<deadpool_postgres::Object> for PooledTokioClient {
71 fn as_ref(&self) -> &deadpool_postgres::Object {
72 self.0.as_ref()
73 }
74}
75
76#[async_trait]
77impl ClientFactory for PooledTokioClientFactory {
78 type Client = PooledTokioClient;
79
80 async fn get_client(
81 &self,
82 address: &str,
83 root_ca: Option<HashableCertificate>,
84 ) -> Result<Self::Client> {
85 let (root_ca, root_ca_hash) = match root_ca {
86 None => (None, None),
87 Some(HashableCertificate { certificate, hash }) => (Some(certificate), Some(hash)),
88 };
89 let pool_key = (address.to_string(), root_ca_hash);
90 let pool = self
91 .pools
92 .try_get_with_by_ref(&pool_key, || create_connection_pool(address, root_ca))
93 .map_err(ArcError)
94 .context("establishing PostgreSQL connection pool")?;
95
96 Ok(PooledTokioClient(Arc::new(pool.get().await?)))
97 }
98}
99
100fn create_connection_pool(
102 address: &str,
103 root_ca: Option<native_tls::Certificate>,
104) -> Result<deadpool_postgres::Pool> {
105 let config = address
106 .parse::<tokio_postgres::Config>()
107 .context("parsing Postgres connection string")?;
108
109 tracing::debug!("Build new connection: {config:?}");
112
113 let mgr_config = deadpool_postgres::ManagerConfig {
114 recycling_method: deadpool_postgres::RecyclingMethod::Clean,
115 };
116
117 let mgr = if config.get_ssl_mode() == SslMode::Disable {
118 deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
119 } else {
120 let mut builder = TlsConnector::builder();
121 if let Some(root_ca) = root_ca {
122 builder.add_root_certificate(root_ca);
123 }
124 let connector = MakeTlsConnector::new(builder.build()?);
125 deadpool_postgres::Manager::from_config(config, connector, mgr_config)
126 };
127
128 let pool = deadpool_postgres::Pool::builder(mgr)
132 .max_size(CONNECTION_POOL_SIZE)
133 .build()
134 .context("building Postgres connection pool")?;
135
136 Ok(pool)
137}
138
139#[async_trait]
140pub trait Client: Clone + Send + Sync + 'static {
141 async fn execute(
142 &self,
143 statement: String,
144 params: Vec<ParameterValue>,
145 ) -> Result<u64, v4::Error>;
146
147 async fn query(
148 &self,
149 statement: String,
150 params: Vec<ParameterValue>,
151 max_result_bytes: usize,
152 ) -> Result<RowSet, v4::Error>;
153
154 async fn query_async(
155 &self,
156 statement: String,
157 params: Vec<ParameterValue>,
158 max_result_bytes: usize,
159 ) -> Result<QueryAsyncResult, v4::Error>;
160}
161
162pub struct QueryAsyncResult {
163 pub columns: Vec<v4::Column>,
164 pub rows: tokio::sync::mpsc::Receiver<v4::Row>,
165 pub error: tokio::sync::oneshot::Receiver<Result<(), v4::Error>>,
166}
167
168fn pg_extras(dbe: &tokio_postgres::error::DbError) -> Vec<(String, String)> {
170 let mut extras = vec![];
171
172 macro_rules! pg_extra {
173 ( $n:ident ) => {
174 if let Some(value) = dbe.$n() {
175 extras.push((stringify!($n).to_owned(), value.to_string()));
176 }
177 };
178 }
179
180 pg_extra!(column);
181 pg_extra!(constraint);
182 pg_extra!(routine);
183 pg_extra!(hint);
184 pg_extra!(table);
185 pg_extra!(datatype);
186 pg_extra!(schema);
187 pg_extra!(file);
188 pg_extra!(line);
189 pg_extra!(where_);
190
191 extras
192}
193
194fn query_failed(e: tokio_postgres::error::Error) -> v4::Error {
195 let flattened = format!("{e:?}");
196 let query_error = match e.as_db_error() {
197 None => v4::QueryError::Text(flattened),
198 Some(dbe) => v4::QueryError::DbError(v4::DbError {
199 as_text: flattened,
200 severity: dbe.severity().to_owned(),
201 code: dbe.code().code().to_owned(),
202 message: dbe.message().to_owned(),
203 detail: dbe.detail().map(|s| s.to_owned()),
204 extras: pg_extras(dbe),
205 }),
206 };
207 v4::Error::QueryFailed(query_error)
208}
209
210fn query_failed_anyhow(e: anyhow::Error) -> v4::Error {
211 let text = format!("{e:?}");
212 v4::Error::QueryFailed(v4::QueryError::Text(text))
213}
214
215#[async_trait]
216impl Client for PooledTokioClient {
217 async fn execute(
218 &self,
219 statement: String,
220 params: Vec<ParameterValue>,
221 ) -> Result<u64, v4::Error> {
222 let params = params
223 .iter()
224 .map(to_sql_parameter)
225 .collect::<Result<Vec<_>>>()
226 .map_err(|e| v4::Error::ValueConversionFailed(format!("{e:?}")))?;
227
228 let params_refs: Vec<&(dyn ToSql + Sync)> = params
229 .iter()
230 .map(|b| b.as_ref() as &(dyn ToSql + Sync))
231 .collect();
232
233 self.as_ref()
234 .execute(&statement, params_refs.as_slice())
235 .await
236 .map_err(query_failed)
237 }
238
239 async fn query(
240 &self,
241 statement: String,
242 params: Vec<ParameterValue>,
243 max_result_bytes: usize,
244 ) -> Result<RowSet, v4::Error> {
245 let (cols_fut, mut results) = self.query_stream(statement, params).await?;
246
247 let mut columns = None;
248 let mut byte_count = std::mem::size_of::<RowSet>();
249 let mut rows = Vec::new();
250
251 async {
252 while let Some(row) = results.try_next().await? {
253 byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
254 if byte_count > max_result_bytes {
255 anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
256 }
257 rows.push(row);
258 }
259 columns = Some(cols_fut.await);
260 Ok(())
261 }
262 .await
263 .map_err(|e| v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))))?;
264
265 Ok(RowSet {
266 columns: columns.unwrap_or_default(),
267 rows,
268 })
269 }
270
271 async fn query_async(
272 &self,
273 statement: String,
274 params: Vec<ParameterValue>,
275 max_result_bytes: usize,
276 ) -> Result<QueryAsyncResult, v4::Error> {
277 use futures::StreamExt;
278
279 let (cols_fut, mut rows) = self.query_stream(statement, params).await?;
280
281 let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
282 let (err_tx, err_rx) = tokio::sync::oneshot::channel();
283
284 tokio::spawn(async move {
285 loop {
286 tokio::select! {
287 biased;
288 _ = rows_tx.closed() => {
291 _ = err_tx.send(Ok(()));
292 return;
293 }
294 row = rows.next() => {
295 let Some(row) = row else {
296 _ = err_tx.send(Ok(()));
297 return;
298 };
299 match row {
300 Ok(row) => {
301 let byte_count = row.iter().map(|v| v.memory_size()).sum::<usize>();
302 if byte_count > max_result_bytes {
303 _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
304 format!("query result exceeds limit of {max_result_bytes} bytes"),
305 ))));
306 return;
307 }
308
309 if let Err(e) = rows_tx.send(row).await {
310 _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
311 format!("async error: {e}"),
312 ))));
313 return;
314 }
315 }
316 Err(e) => {
317 _ = err_tx.send(Err(e));
318 return;
319 }
320 }
321 }
322 }
323 }
324 });
325
326 let cols = cols_fut.await;
327
328 Ok(QueryAsyncResult {
329 columns: cols,
330 rows: rows_rx,
331 error: err_rx,
332 })
333 }
334}
335
336impl PooledTokioClient {
337 async fn query_stream(
338 &self,
339 statement: String,
340 params: Vec<ParameterValue>,
341 ) -> Result<
342 (
343 impl std::future::Future<Output = Vec<v4::Column>>,
344 impl futures::Stream<Item = Result<Vec<DbValue>, v4::Error>> + 'static,
345 ),
346 v4::Error,
347 > {
348 use futures::{FutureExt, StreamExt};
349
350 let params = to_sql_parameters(params)?;
351
352 let results = Box::pin(
353 self.as_ref()
354 .query_raw(&statement, params)
355 .await
356 .map_err(query_failed)?,
357 );
358
359 let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
360 let mut cols_tx_opt = Some(cols_tx);
361
362 let row_stm = results.enumerate().map(move |(index, row_res)| {
363 let row_res = row_res.map_err(query_failed);
364 row_res.and_then(|r| {
365 if index == 0
366 && let Some(cols_tx) = cols_tx_opt.take()
367 {
368 let cols = infer_columns(&r);
369 _ = cols_tx.send(cols);
370 }
371 convert_row(&r).map_err(query_failed_anyhow)
372 })
373 });
374
375 let cols_rx = cols_rx.map(|result| result.unwrap_or_default());
376
377 Ok((cols_rx, Box::pin(row_stm)))
378 }
379}
380
381fn infer_columns(row: &Row) -> Vec<Column> {
382 let mut result = Vec::with_capacity(row.len());
383 for index in 0..row.len() {
384 result.push(infer_column(row, index));
385 }
386 result
387}
388
389fn infer_column(row: &Row, index: usize) -> Column {
390 let column = &row.columns()[index];
391 let name = column.name().to_owned();
392 let data_type = convert_data_type(column.type_());
393 Column { name, data_type }
394}
395
396fn convert_row(row: &Row) -> anyhow::Result<Vec<DbValue>> {
397 let mut result = Vec::with_capacity(row.len());
398 for index in 0..row.len() {
399 result.push(convert_entry(row, index)?);
400 }
401 Ok(result)
402}
403
404struct ArcError(std::sync::Arc<anyhow::Error>);
407
408impl std::error::Error for ArcError {
409 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
410 self.0.source()
411 }
412}
413
414impl std::fmt::Debug for ArcError {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 std::fmt::Debug::fmt(&self.0, f)
417 }
418}
419
420impl std::fmt::Display for ArcError {
421 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422 std::fmt::Display::fmt(&self.0, f)
423 }
424}