1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use futures::stream::TryStreamExt as _;
5use native_tls::TlsConnector;
6use postgres_native_tls::MakeTlsConnector;
7use spin_world::async_trait;
8use spin_world::spin::postgres4_2_0::postgres::{
9 self as v4, Column, DbValue, ParameterValue, RowSet,
10};
11use tokio_postgres::config::SslMode;
12use tokio_postgres::types::ToSql;
13use tokio_postgres::{NoTls, Row};
14
15use crate::types::{convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters};
16
17const CONNECTION_POOL_SIZE: usize = 64;
19const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16;
21
22#[async_trait]
25pub trait ClientFactory: Default + Send + Sync + 'static {
26 type Client: Client;
28 async fn get_client(
30 &self,
31 address: &str,
32 root_ca: Option<HashableCertificate>,
33 ) -> Result<Self::Client>;
34}
35
36#[derive(Clone)]
37pub struct HashableCertificate {
38 certificate: native_tls::Certificate,
39 hash: String,
40}
41
42impl HashableCertificate {
43 pub fn from_pem(text: &str) -> anyhow::Result<Self> {
44 let cert_bytes = text.as_bytes();
45 let hash = spin_common::sha256::hex_digest_from_bytes(cert_bytes);
46 let certificate =
47 native_tls::Certificate::from_pem(cert_bytes).context("invalid root certificate")?;
48 Ok(Self { certificate, hash })
49 }
50}
51
52pub struct PooledTokioClientFactory {
54 pools: moka::sync::Cache<(String, Option<String>), deadpool_postgres::Pool>,
55}
56
57impl Default for PooledTokioClientFactory {
58 fn default() -> Self {
59 Self {
60 pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY),
61 }
62 }
63}
64
65#[derive(Clone)]
66pub struct PooledTokioClient(Arc<deadpool_postgres::Object>);
67
68impl AsRef<deadpool_postgres::Object> for PooledTokioClient {
69 fn as_ref(&self) -> &deadpool_postgres::Object {
70 self.0.as_ref()
71 }
72}
73
74#[async_trait]
75impl ClientFactory for PooledTokioClientFactory {
76 type Client = PooledTokioClient;
77
78 async fn get_client(
79 &self,
80 address: &str,
81 root_ca: Option<HashableCertificate>,
82 ) -> Result<Self::Client> {
83 let (root_ca, root_ca_hash) = match root_ca {
84 None => (None, None),
85 Some(HashableCertificate { certificate, hash }) => (Some(certificate), Some(hash)),
86 };
87 let pool_key = (address.to_string(), root_ca_hash);
88 let pool = self
89 .pools
90 .try_get_with_by_ref(&pool_key, || create_connection_pool(address, root_ca))
91 .map_err(ArcError)
92 .context("establishing PostgreSQL connection pool")?;
93
94 Ok(PooledTokioClient(Arc::new(pool.get().await?)))
95 }
96}
97
98fn create_connection_pool(
100 address: &str,
101 root_ca: Option<native_tls::Certificate>,
102) -> Result<deadpool_postgres::Pool> {
103 let config = address
104 .parse::<tokio_postgres::Config>()
105 .context("parsing Postgres connection string")?;
106
107 tracing::debug!("Build new connection: {}", address);
108
109 let mgr_config = deadpool_postgres::ManagerConfig {
110 recycling_method: deadpool_postgres::RecyclingMethod::Clean,
111 };
112
113 let mgr = if config.get_ssl_mode() == SslMode::Disable {
114 deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
115 } else {
116 let mut builder = TlsConnector::builder();
117 if let Some(root_ca) = root_ca {
118 builder.add_root_certificate(root_ca);
119 }
120 let connector = MakeTlsConnector::new(builder.build()?);
121 deadpool_postgres::Manager::from_config(config, connector, mgr_config)
122 };
123
124 let pool = deadpool_postgres::Pool::builder(mgr)
128 .max_size(CONNECTION_POOL_SIZE)
129 .build()
130 .context("building Postgres connection pool")?;
131
132 Ok(pool)
133}
134
135#[async_trait]
136pub trait Client: Clone + Send + Sync + 'static {
137 async fn execute(
138 &self,
139 statement: String,
140 params: Vec<ParameterValue>,
141 ) -> Result<u64, v4::Error>;
142
143 async fn query(
144 &self,
145 statement: String,
146 params: Vec<ParameterValue>,
147 max_result_bytes: usize,
148 ) -> Result<RowSet, v4::Error>;
149
150 async fn query_async(
151 &self,
152 statement: String,
153 params: Vec<ParameterValue>,
154 max_result_bytes: usize,
155 ) -> Result<QueryAsyncResult, v4::Error>;
156}
157
158pub struct QueryAsyncResult {
159 pub columns: Vec<v4::Column>,
160 pub rows: tokio::sync::mpsc::Receiver<v4::Row>,
161 pub error: tokio::sync::oneshot::Receiver<Result<(), v4::Error>>,
162}
163
164fn pg_extras(dbe: &tokio_postgres::error::DbError) -> Vec<(String, String)> {
166 let mut extras = vec![];
167
168 macro_rules! pg_extra {
169 ( $n:ident ) => {
170 if let Some(value) = dbe.$n() {
171 extras.push((stringify!($n).to_owned(), value.to_string()));
172 }
173 };
174 }
175
176 pg_extra!(column);
177 pg_extra!(constraint);
178 pg_extra!(routine);
179 pg_extra!(hint);
180 pg_extra!(table);
181 pg_extra!(datatype);
182 pg_extra!(schema);
183 pg_extra!(file);
184 pg_extra!(line);
185 pg_extra!(where_);
186
187 extras
188}
189
190fn query_failed(e: tokio_postgres::error::Error) -> v4::Error {
191 let flattened = format!("{e:?}");
192 let query_error = match e.as_db_error() {
193 None => v4::QueryError::Text(flattened),
194 Some(dbe) => v4::QueryError::DbError(v4::DbError {
195 as_text: flattened,
196 severity: dbe.severity().to_owned(),
197 code: dbe.code().code().to_owned(),
198 message: dbe.message().to_owned(),
199 detail: dbe.detail().map(|s| s.to_owned()),
200 extras: pg_extras(dbe),
201 }),
202 };
203 v4::Error::QueryFailed(query_error)
204}
205
206fn query_failed_anyhow(e: anyhow::Error) -> v4::Error {
207 let text = format!("{e:?}");
208 v4::Error::QueryFailed(v4::QueryError::Text(text))
209}
210
211#[async_trait]
212impl Client for PooledTokioClient {
213 async fn execute(
214 &self,
215 statement: String,
216 params: Vec<ParameterValue>,
217 ) -> Result<u64, v4::Error> {
218 let params = params
219 .iter()
220 .map(to_sql_parameter)
221 .collect::<Result<Vec<_>>>()
222 .map_err(|e| v4::Error::ValueConversionFailed(format!("{e:?}")))?;
223
224 let params_refs: Vec<&(dyn ToSql + Sync)> = params
225 .iter()
226 .map(|b| b.as_ref() as &(dyn ToSql + Sync))
227 .collect();
228
229 self.as_ref()
230 .execute(&statement, params_refs.as_slice())
231 .await
232 .map_err(query_failed)
233 }
234
235 async fn query(
236 &self,
237 statement: String,
238 params: Vec<ParameterValue>,
239 max_result_bytes: usize,
240 ) -> Result<RowSet, v4::Error> {
241 let (cols_fut, mut results) = self.query_stream(statement, params).await?;
242
243 let mut columns = None;
244 let mut byte_count = std::mem::size_of::<RowSet>();
245 let mut rows = Vec::new();
246
247 async {
248 while let Some(row) = results.try_next().await? {
249 byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
250 if byte_count > max_result_bytes {
251 anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
252 }
253 rows.push(row);
254 }
255 columns = Some(cols_fut.await);
256 Ok(())
257 }
258 .await
259 .map_err(|e| v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))))?;
260
261 Ok(RowSet {
262 columns: columns.unwrap_or_default(),
263 rows,
264 })
265 }
266
267 async fn query_async(
268 &self,
269 statement: String,
270 params: Vec<ParameterValue>,
271 max_result_bytes: usize,
272 ) -> Result<QueryAsyncResult, v4::Error> {
273 use futures::StreamExt;
274
275 let (cols_fut, mut rows) = self.query_stream(statement, params).await?;
276
277 let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
278 let (err_tx, err_rx) = tokio::sync::oneshot::channel();
279
280 tokio::spawn(async move {
281 loop {
282 let Some(row) = rows.next().await else {
283 _ = err_tx.send(Ok(()));
284 return;
285 };
286 match row {
287 Ok(row) => {
288 let byte_count = row.iter().map(|v| v.memory_size()).sum::<usize>();
289 if byte_count > max_result_bytes {
290 _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
291 format!("query result exceeds limit of {max_result_bytes} bytes"),
292 ))));
293 return;
294 }
295
296 if let Err(e) = rows_tx.send(row).await {
297 _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
298 format!("async error: {e}"),
299 ))));
300 return;
301 }
302 }
303 Err(e) => {
304 _ = err_tx.send(Err(e));
305 return;
306 }
307 }
308 }
309 });
310
311 let cols = cols_fut.await;
312
313 Ok(QueryAsyncResult {
314 columns: cols,
315 rows: rows_rx,
316 error: err_rx,
317 })
318 }
319}
320
321impl PooledTokioClient {
322 async fn query_stream(
323 &self,
324 statement: String,
325 params: Vec<ParameterValue>,
326 ) -> Result<
327 (
328 impl std::future::Future<Output = Vec<v4::Column>>,
329 impl futures::Stream<Item = Result<Vec<DbValue>, v4::Error>>,
330 ),
331 v4::Error,
332 > {
333 use futures::{FutureExt, StreamExt};
334
335 let params = to_sql_parameters(params)?;
336
337 let results = Box::pin(
338 self.as_ref()
339 .query_raw(&statement, params)
340 .await
341 .map_err(query_failed)?,
342 );
343
344 let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
345 let mut cols_tx_opt = Some(cols_tx);
346
347 let row_stm = results.enumerate().map(move |(index, row_res)| {
348 let row_res = row_res.map_err(query_failed);
349 row_res.and_then(|r| {
350 if index == 0 {
351 if let Some(cols_tx) = cols_tx_opt.take() {
352 let cols = infer_columns(&r);
353 _ = cols_tx.send(cols);
354 }
355 }
356 convert_row(&r).map_err(query_failed_anyhow)
357 })
358 });
359
360 let cols_rx = cols_rx.map(|result| result.unwrap_or_default());
361
362 Ok((cols_rx, Box::pin(row_stm)))
363 }
364}
365
366fn infer_columns(row: &Row) -> Vec<Column> {
367 let mut result = Vec::with_capacity(row.len());
368 for index in 0..row.len() {
369 result.push(infer_column(row, index));
370 }
371 result
372}
373
374fn infer_column(row: &Row, index: usize) -> Column {
375 let column = &row.columns()[index];
376 let name = column.name().to_owned();
377 let data_type = convert_data_type(column.type_());
378 Column { name, data_type }
379}
380
381fn convert_row(row: &Row) -> anyhow::Result<Vec<DbValue>> {
382 let mut result = Vec::with_capacity(row.len());
383 for index in 0..row.len() {
384 result.push(convert_entry(row, index)?);
385 }
386 Ok(result)
387}
388
389struct ArcError(std::sync::Arc<anyhow::Error>);
392
393impl std::error::Error for ArcError {
394 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
395 self.0.source()
396 }
397}
398
399impl std::fmt::Debug for ArcError {
400 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401 std::fmt::Debug::fmt(&self.0, f)
402 }
403}
404
405impl std::fmt::Display for ArcError {
406 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407 std::fmt::Display::fmt(&self.0, f)
408 }
409}