1#![allow(clippy::result_large_err)]
2
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use futures::stream::TryStreamExt as _;
7use native_tls::TlsConnector;
8use postgres_native_tls::MakeTlsConnector;
9use spin_world::async_trait;
10use spin_world::spin::postgres4_2_0::postgres::{
11 self as v4, Column, DbValue, ParameterValue, RowSet,
12};
13use tokio_postgres::config::SslMode;
14use tokio_postgres::types::ToSql;
15use tokio_postgres::{NoTls, Row};
16
17use crate::types::{convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters};
18
19const CONNECTION_POOL_SIZE: usize = 64;
21const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16;
23
24#[async_trait]
27pub trait ClientFactory: Default + Send + Sync + 'static {
28 type Client: Client;
30 async fn get_client(
32 &self,
33 address: &str,
34 root_ca: Option<HashableCertificate>,
35 ) -> Result<Self::Client>;
36}
37
38#[derive(Clone)]
39pub struct HashableCertificate {
40 certificate: native_tls::Certificate,
41 hash: String,
42}
43
44impl HashableCertificate {
45 pub fn from_pem(text: &str) -> anyhow::Result<Self> {
46 let cert_bytes = text.as_bytes();
47 let hash = spin_common::sha256::hex_digest_from_bytes(cert_bytes);
48 let certificate =
49 native_tls::Certificate::from_pem(cert_bytes).context("invalid root certificate")?;
50 Ok(Self { certificate, hash })
51 }
52}
53
54pub struct PooledTokioClientFactory {
56 pools: moka::sync::Cache<(String, Option<String>), deadpool_postgres::Pool>,
57}
58
59impl Default for PooledTokioClientFactory {
60 fn default() -> Self {
61 Self {
62 pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY),
63 }
64 }
65}
66
67#[derive(Clone)]
68pub struct PooledTokioClient(Arc<deadpool_postgres::Object>);
69
70impl AsRef<deadpool_postgres::Object> for PooledTokioClient {
71 fn as_ref(&self) -> &deadpool_postgres::Object {
72 self.0.as_ref()
73 }
74}
75
76#[async_trait]
77impl ClientFactory for PooledTokioClientFactory {
78 type Client = PooledTokioClient;
79
80 async fn get_client(
81 &self,
82 address: &str,
83 root_ca: Option<HashableCertificate>,
84 ) -> Result<Self::Client> {
85 let (root_ca, root_ca_hash) = match root_ca {
86 None => (None, None),
87 Some(HashableCertificate { certificate, hash }) => (Some(certificate), Some(hash)),
88 };
89 let pool_key = (address.to_string(), root_ca_hash);
90 let pool = self
91 .pools
92 .try_get_with_by_ref(&pool_key, || create_connection_pool(address, root_ca))
93 .map_err(ArcError)
94 .context("establishing PostgreSQL connection pool")?;
95
96 Ok(PooledTokioClient(Arc::new(pool.get().await?)))
97 }
98}
99
100fn create_connection_pool(
102 address: &str,
103 root_ca: Option<native_tls::Certificate>,
104) -> Result<deadpool_postgres::Pool> {
105 let config = address
106 .parse::<tokio_postgres::Config>()
107 .context("parsing Postgres connection string")?;
108
109 tracing::debug!("Build new connection: {config:?}");
112
113 let mgr_config = deadpool_postgres::ManagerConfig {
114 recycling_method: deadpool_postgres::RecyclingMethod::Clean,
115 };
116
117 let mgr = if config.get_ssl_mode() == SslMode::Disable {
118 deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
119 } else {
120 let mut builder = TlsConnector::builder();
121 if let Some(root_ca) = root_ca {
122 builder.add_root_certificate(root_ca);
123 }
124 let connector = MakeTlsConnector::new(builder.build()?);
125 deadpool_postgres::Manager::from_config(config, connector, mgr_config)
126 };
127
128 let pool = deadpool_postgres::Pool::builder(mgr)
132 .max_size(CONNECTION_POOL_SIZE)
133 .build()
134 .context("building Postgres connection pool")?;
135
136 Ok(pool)
137}
138
139#[async_trait]
140pub trait Client: Clone + Send + Sync + 'static {
141 async fn execute(
142 &self,
143 statement: String,
144 params: Vec<ParameterValue>,
145 ) -> Result<u64, v4::Error>;
146
147 async fn query(
148 &self,
149 statement: String,
150 params: Vec<ParameterValue>,
151 max_result_bytes: usize,
152 ) -> Result<RowSet, v4::Error>;
153
154 async fn query_async(
155 &self,
156 statement: String,
157 params: Vec<ParameterValue>,
158 max_result_bytes: usize,
159 ) -> Result<QueryAsyncResult, v4::Error>;
160}
161
162pub struct QueryAsyncResult {
163 pub columns: Vec<v4::Column>,
164 pub rows: tokio::sync::mpsc::Receiver<v4::Row>,
165 pub error: tokio::sync::oneshot::Receiver<Result<(), v4::Error>>,
166}
167
168fn pg_extras(dbe: &tokio_postgres::error::DbError) -> Vec<(String, String)> {
170 let mut extras = vec![];
171
172 macro_rules! pg_extra {
173 ( $n:ident ) => {
174 if let Some(value) = dbe.$n() {
175 extras.push((stringify!($n).to_owned(), value.to_string()));
176 }
177 };
178 }
179
180 pg_extra!(column);
181 pg_extra!(constraint);
182 pg_extra!(routine);
183 pg_extra!(hint);
184 pg_extra!(table);
185 pg_extra!(datatype);
186 pg_extra!(schema);
187 pg_extra!(file);
188 pg_extra!(line);
189 pg_extra!(where_);
190
191 extras
192}
193
194fn query_failed(e: tokio_postgres::error::Error) -> v4::Error {
195 let flattened = format!("{e:?}");
196 let query_error = match e.as_db_error() {
197 None => v4::QueryError::Text(flattened),
198 Some(dbe) => v4::QueryError::DbError(v4::DbError {
199 as_text: flattened,
200 severity: dbe.severity().to_owned(),
201 code: dbe.code().code().to_owned(),
202 message: dbe.message().to_owned(),
203 detail: dbe.detail().map(|s| s.to_owned()),
204 extras: pg_extras(dbe),
205 }),
206 };
207 v4::Error::QueryFailed(query_error)
208}
209
210fn query_failed_anyhow(e: anyhow::Error) -> v4::Error {
211 let text = format!("{e:?}");
212 v4::Error::QueryFailed(v4::QueryError::Text(text))
213}
214
215#[async_trait]
216impl Client for PooledTokioClient {
217 async fn execute(
218 &self,
219 statement: String,
220 params: Vec<ParameterValue>,
221 ) -> Result<u64, v4::Error> {
222 let params = params
223 .iter()
224 .map(to_sql_parameter)
225 .collect::<Result<Vec<_>>>()
226 .map_err(|e| v4::Error::ValueConversionFailed(format!("{e:?}")))?;
227
228 let params_refs: Vec<&(dyn ToSql + Sync)> = params
229 .iter()
230 .map(|b| b.as_ref() as &(dyn ToSql + Sync))
231 .collect();
232
233 self.as_ref()
234 .execute(&statement, params_refs.as_slice())
235 .await
236 .map_err(query_failed)
237 }
238
239 async fn query(
240 &self,
241 statement: String,
242 params: Vec<ParameterValue>,
243 max_result_bytes: usize,
244 ) -> Result<RowSet, v4::Error> {
245 let (cols_fut, mut results) = self.query_stream(statement, params).await?;
246
247 let mut columns = None;
248 let mut byte_count = std::mem::size_of::<RowSet>();
249 let mut rows = Vec::new();
250
251 async {
252 while let Some(row) = results.try_next().await? {
253 byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
254 if byte_count > max_result_bytes {
255 anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
256 }
257 rows.push(row);
258 }
259 columns = Some(cols_fut.await);
260 Ok(())
261 }
262 .await
263 .map_err(|e| v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))))?;
264
265 Ok(RowSet {
266 columns: columns.unwrap_or_default(),
267 rows,
268 })
269 }
270
271 async fn query_async(
272 &self,
273 statement: String,
274 params: Vec<ParameterValue>,
275 max_result_bytes: usize,
276 ) -> Result<QueryAsyncResult, v4::Error> {
277 use futures::StreamExt;
278
279 let (cols_fut, mut rows) = self.query_stream(statement, params).await?;
280
281 let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
282 let (err_tx, err_rx) = tokio::sync::oneshot::channel();
283
284 tokio::spawn(async move {
285 loop {
286 let Some(row) = rows.next().await else {
287 _ = err_tx.send(Ok(()));
288 return;
289 };
290 match row {
291 Ok(row) => {
292 let byte_count = row.iter().map(|v| v.memory_size()).sum::<usize>();
293 if byte_count > max_result_bytes {
294 _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
295 format!("query result exceeds limit of {max_result_bytes} bytes"),
296 ))));
297 return;
298 }
299
300 if let Err(e) = rows_tx.send(row).await {
301 _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
302 format!("async error: {e}"),
303 ))));
304 return;
305 }
306 }
307 Err(e) => {
308 _ = err_tx.send(Err(e));
309 return;
310 }
311 }
312 }
313 });
314
315 let cols = cols_fut.await;
316
317 Ok(QueryAsyncResult {
318 columns: cols,
319 rows: rows_rx,
320 error: err_rx,
321 })
322 }
323}
324
325impl PooledTokioClient {
326 async fn query_stream(
327 &self,
328 statement: String,
329 params: Vec<ParameterValue>,
330 ) -> Result<
331 (
332 impl std::future::Future<Output = Vec<v4::Column>>,
333 impl futures::Stream<Item = Result<Vec<DbValue>, v4::Error>> + 'static,
334 ),
335 v4::Error,
336 > {
337 use futures::{FutureExt, StreamExt};
338
339 let params = to_sql_parameters(params)?;
340
341 let results = Box::pin(
342 self.as_ref()
343 .query_raw(&statement, params)
344 .await
345 .map_err(query_failed)?,
346 );
347
348 let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
349 let mut cols_tx_opt = Some(cols_tx);
350
351 let row_stm = results.enumerate().map(move |(index, row_res)| {
352 let row_res = row_res.map_err(query_failed);
353 row_res.and_then(|r| {
354 if index == 0
355 && let Some(cols_tx) = cols_tx_opt.take()
356 {
357 let cols = infer_columns(&r);
358 _ = cols_tx.send(cols);
359 }
360 convert_row(&r).map_err(query_failed_anyhow)
361 })
362 });
363
364 let cols_rx = cols_rx.map(|result| result.unwrap_or_default());
365
366 Ok((cols_rx, Box::pin(row_stm)))
367 }
368}
369
370fn infer_columns(row: &Row) -> Vec<Column> {
371 let mut result = Vec::with_capacity(row.len());
372 for index in 0..row.len() {
373 result.push(infer_column(row, index));
374 }
375 result
376}
377
378fn infer_column(row: &Row, index: usize) -> Column {
379 let column = &row.columns()[index];
380 let name = column.name().to_owned();
381 let data_type = convert_data_type(column.type_());
382 Column { name, data_type }
383}
384
385fn convert_row(row: &Row) -> anyhow::Result<Vec<DbValue>> {
386 let mut result = Vec::with_capacity(row.len());
387 for index in 0..row.len() {
388 result.push(convert_entry(row, index)?);
389 }
390 Ok(result)
391}
392
393struct ArcError(std::sync::Arc<anyhow::Error>);
396
397impl std::error::Error for ArcError {
398 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
399 self.0.source()
400 }
401}
402
403impl std::fmt::Debug for ArcError {
404 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405 std::fmt::Debug::fmt(&self.0, f)
406 }
407}
408
409impl std::fmt::Display for ArcError {
410 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411 std::fmt::Display::fmt(&self.0, f)
412 }
413}