Skip to main content

spin_factor_outbound_pg/
client.rs

1#![allow(clippy::result_large_err)]
2
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use futures::stream::TryStreamExt as _;
7use native_tls::TlsConnector;
8use postgres_native_tls::MakeTlsConnector;
9use spin_world::async_trait;
10use spin_world::spin::postgres4_2_0::postgres::{
11    self as v4, Column, DbValue, ParameterValue, RowSet,
12};
13use tokio_postgres::config::SslMode;
14use tokio_postgres::types::ToSql;
15use tokio_postgres::{NoTls, Row};
16
17use crate::types::{convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters};
18
19/// Max connections in a given address' connection pool
20const CONNECTION_POOL_SIZE: usize = 64;
21/// Max addresses for which to keep pools in cache.
22const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16;
23
24/// A factory object for Postgres clients. This abstracts
25/// details of client creation such as pooling.
26#[async_trait]
27pub trait ClientFactory: Default + Send + Sync + 'static {
28    /// The type of client produced by `get_client`.
29    type Client: Client;
30    /// Gets a client from the factory.
31    async fn get_client(
32        &self,
33        address: &str,
34        root_ca: Option<HashableCertificate>,
35    ) -> Result<Self::Client>;
36}
37
38#[derive(Clone)]
39pub struct HashableCertificate {
40    certificate: native_tls::Certificate,
41    hash: String,
42}
43
44impl HashableCertificate {
45    pub fn from_pem(text: &str) -> anyhow::Result<Self> {
46        let cert_bytes = text.as_bytes();
47        let hash = spin_common::sha256::hex_digest_from_bytes(cert_bytes);
48        let certificate =
49            native_tls::Certificate::from_pem(cert_bytes).context("invalid root certificate")?;
50        Ok(Self { certificate, hash })
51    }
52}
53
54/// A `ClientFactory` that uses a connection pool per address.
55pub struct PooledTokioClientFactory {
56    pools: moka::sync::Cache<(String, Option<String>), deadpool_postgres::Pool>,
57}
58
59impl Default for PooledTokioClientFactory {
60    fn default() -> Self {
61        Self {
62            pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY),
63        }
64    }
65}
66
67#[derive(Clone)]
68pub struct PooledTokioClient(Arc<deadpool_postgres::Object>);
69
70impl AsRef<deadpool_postgres::Object> for PooledTokioClient {
71    fn as_ref(&self) -> &deadpool_postgres::Object {
72        self.0.as_ref()
73    }
74}
75
76#[async_trait]
77impl ClientFactory for PooledTokioClientFactory {
78    type Client = PooledTokioClient;
79
80    async fn get_client(
81        &self,
82        address: &str,
83        root_ca: Option<HashableCertificate>,
84    ) -> Result<Self::Client> {
85        let (root_ca, root_ca_hash) = match root_ca {
86            None => (None, None),
87            Some(HashableCertificate { certificate, hash }) => (Some(certificate), Some(hash)),
88        };
89        let pool_key = (address.to_string(), root_ca_hash);
90        let pool = self
91            .pools
92            .try_get_with_by_ref(&pool_key, || create_connection_pool(address, root_ca))
93            .map_err(ArcError)
94            .context("establishing PostgreSQL connection pool")?;
95
96        Ok(PooledTokioClient(Arc::new(pool.get().await?)))
97    }
98}
99
100/// Creates a Postgres connection pool for the given address.
101fn create_connection_pool(
102    address: &str,
103    root_ca: Option<native_tls::Certificate>,
104) -> Result<deadpool_postgres::Pool> {
105    let config = address
106        .parse::<tokio_postgres::Config>()
107        .context("parsing Postgres connection string")?;
108
109    // The debug form isn't as easy to read as logging the address,
110    // but it redacts the password!
111    tracing::debug!("Build new connection: {config:?}");
112
113    let mgr_config = deadpool_postgres::ManagerConfig {
114        recycling_method: deadpool_postgres::RecyclingMethod::Clean,
115    };
116
117    let mgr = if config.get_ssl_mode() == SslMode::Disable {
118        deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
119    } else {
120        let mut builder = TlsConnector::builder();
121        if let Some(root_ca) = root_ca {
122            builder.add_root_certificate(root_ca);
123        }
124        let connector = MakeTlsConnector::new(builder.build()?);
125        deadpool_postgres::Manager::from_config(config, connector, mgr_config)
126    };
127
128    // TODO: what is our max size heuristic?  Should this be passed in so that different
129    // hosts can manage it according to their needs?  Will a plain number suffice for
130    // sophisticated hosts anyway?
131    let pool = deadpool_postgres::Pool::builder(mgr)
132        .max_size(CONNECTION_POOL_SIZE)
133        .build()
134        .context("building Postgres connection pool")?;
135
136    Ok(pool)
137}
138
139#[async_trait]
140pub trait Client: Clone + Send + Sync + 'static {
141    async fn execute(
142        &self,
143        statement: String,
144        params: Vec<ParameterValue>,
145    ) -> Result<u64, v4::Error>;
146
147    async fn query(
148        &self,
149        statement: String,
150        params: Vec<ParameterValue>,
151        max_result_bytes: usize,
152    ) -> Result<RowSet, v4::Error>;
153
154    async fn query_async(
155        &self,
156        statement: String,
157        params: Vec<ParameterValue>,
158        max_result_bytes: usize,
159    ) -> Result<QueryAsyncResult, v4::Error>;
160}
161
162pub struct QueryAsyncResult {
163    pub columns: Vec<v4::Column>,
164    pub rows: tokio::sync::mpsc::Receiver<v4::Row>,
165    pub error: tokio::sync::oneshot::Receiver<Result<(), v4::Error>>,
166}
167
168/// Extract weak-typed error data for WIT purposes
169fn pg_extras(dbe: &tokio_postgres::error::DbError) -> Vec<(String, String)> {
170    let mut extras = vec![];
171
172    macro_rules! pg_extra {
173        ( $n:ident ) => {
174            if let Some(value) = dbe.$n() {
175                extras.push((stringify!($n).to_owned(), value.to_string()));
176            }
177        };
178    }
179
180    pg_extra!(column);
181    pg_extra!(constraint);
182    pg_extra!(routine);
183    pg_extra!(hint);
184    pg_extra!(table);
185    pg_extra!(datatype);
186    pg_extra!(schema);
187    pg_extra!(file);
188    pg_extra!(line);
189    pg_extra!(where_);
190
191    extras
192}
193
194fn query_failed(e: tokio_postgres::error::Error) -> v4::Error {
195    let flattened = format!("{e:?}");
196    let query_error = match e.as_db_error() {
197        None => v4::QueryError::Text(flattened),
198        Some(dbe) => v4::QueryError::DbError(v4::DbError {
199            as_text: flattened,
200            severity: dbe.severity().to_owned(),
201            code: dbe.code().code().to_owned(),
202            message: dbe.message().to_owned(),
203            detail: dbe.detail().map(|s| s.to_owned()),
204            extras: pg_extras(dbe),
205        }),
206    };
207    v4::Error::QueryFailed(query_error)
208}
209
210fn query_failed_anyhow(e: anyhow::Error) -> v4::Error {
211    let text = format!("{e:?}");
212    v4::Error::QueryFailed(v4::QueryError::Text(text))
213}
214
215#[async_trait]
216impl Client for PooledTokioClient {
217    async fn execute(
218        &self,
219        statement: String,
220        params: Vec<ParameterValue>,
221    ) -> Result<u64, v4::Error> {
222        let params = params
223            .iter()
224            .map(to_sql_parameter)
225            .collect::<Result<Vec<_>>>()
226            .map_err(|e| v4::Error::ValueConversionFailed(format!("{e:?}")))?;
227
228        let params_refs: Vec<&(dyn ToSql + Sync)> = params
229            .iter()
230            .map(|b| b.as_ref() as &(dyn ToSql + Sync))
231            .collect();
232
233        self.as_ref()
234            .execute(&statement, params_refs.as_slice())
235            .await
236            .map_err(query_failed)
237    }
238
239    async fn query(
240        &self,
241        statement: String,
242        params: Vec<ParameterValue>,
243        max_result_bytes: usize,
244    ) -> Result<RowSet, v4::Error> {
245        let (cols_fut, mut results) = self.query_stream(statement, params).await?;
246
247        let mut columns = None;
248        let mut byte_count = std::mem::size_of::<RowSet>();
249        let mut rows = Vec::new();
250
251        async {
252            while let Some(row) = results.try_next().await? {
253                byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
254                if byte_count > max_result_bytes {
255                    anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
256                }
257                rows.push(row);
258            }
259            columns = Some(cols_fut.await);
260            Ok(())
261        }
262        .await
263        .map_err(|e| v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))))?;
264
265        Ok(RowSet {
266            columns: columns.unwrap_or_default(),
267            rows,
268        })
269    }
270
271    async fn query_async(
272        &self,
273        statement: String,
274        params: Vec<ParameterValue>,
275        max_result_bytes: usize,
276    ) -> Result<QueryAsyncResult, v4::Error> {
277        use futures::StreamExt;
278
279        let (cols_fut, mut rows) = self.query_stream(statement, params).await?;
280
281        let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
282        let (err_tx, err_rx) = tokio::sync::oneshot::channel();
283
284        tokio::spawn(async move {
285            loop {
286                let Some(row) = rows.next().await else {
287                    _ = err_tx.send(Ok(()));
288                    return;
289                };
290                match row {
291                    Ok(row) => {
292                        let byte_count = row.iter().map(|v| v.memory_size()).sum::<usize>();
293                        if byte_count > max_result_bytes {
294                            _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
295                                format!("query result exceeds limit of {max_result_bytes} bytes"),
296                            ))));
297                            return;
298                        }
299
300                        if let Err(e) = rows_tx.send(row).await {
301                            _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
302                                format!("async error: {e}"),
303                            ))));
304                            return;
305                        }
306                    }
307                    Err(e) => {
308                        _ = err_tx.send(Err(e));
309                        return;
310                    }
311                }
312            }
313        });
314
315        let cols = cols_fut.await;
316
317        Ok(QueryAsyncResult {
318            columns: cols,
319            rows: rows_rx,
320            error: err_rx,
321        })
322    }
323}
324
325impl PooledTokioClient {
326    async fn query_stream(
327        &self,
328        statement: String,
329        params: Vec<ParameterValue>,
330    ) -> Result<
331        (
332            impl std::future::Future<Output = Vec<v4::Column>>,
333            impl futures::Stream<Item = Result<Vec<DbValue>, v4::Error>> + 'static,
334        ),
335        v4::Error,
336    > {
337        use futures::{FutureExt, StreamExt};
338
339        let params = to_sql_parameters(params)?;
340
341        let results = Box::pin(
342            self.as_ref()
343                .query_raw(&statement, params)
344                .await
345                .map_err(query_failed)?,
346        );
347
348        let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
349        let mut cols_tx_opt = Some(cols_tx);
350
351        let row_stm = results.enumerate().map(move |(index, row_res)| {
352            let row_res = row_res.map_err(query_failed);
353            row_res.and_then(|r| {
354                if index == 0
355                    && let Some(cols_tx) = cols_tx_opt.take()
356                {
357                    let cols = infer_columns(&r);
358                    _ = cols_tx.send(cols);
359                }
360                convert_row(&r).map_err(query_failed_anyhow)
361            })
362        });
363
364        let cols_rx = cols_rx.map(|result| result.unwrap_or_default());
365
366        Ok((cols_rx, Box::pin(row_stm)))
367    }
368}
369
370fn infer_columns(row: &Row) -> Vec<Column> {
371    let mut result = Vec::with_capacity(row.len());
372    for index in 0..row.len() {
373        result.push(infer_column(row, index));
374    }
375    result
376}
377
378fn infer_column(row: &Row, index: usize) -> Column {
379    let column = &row.columns()[index];
380    let name = column.name().to_owned();
381    let data_type = convert_data_type(column.type_());
382    Column { name, data_type }
383}
384
385fn convert_row(row: &Row) -> anyhow::Result<Vec<DbValue>> {
386    let mut result = Vec::with_capacity(row.len());
387    for index in 0..row.len() {
388        result.push(convert_entry(row, index)?);
389    }
390    Ok(result)
391}
392
393/// Workaround for moka returning Arc<Error> which, although
394/// necessary for concurrency, does not play well with others.
395struct ArcError(std::sync::Arc<anyhow::Error>);
396
397impl std::error::Error for ArcError {
398    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
399        self.0.source()
400    }
401}
402
403impl std::fmt::Debug for ArcError {
404    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405        std::fmt::Debug::fmt(&self.0, f)
406    }
407}
408
409impl std::fmt::Display for ArcError {
410    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411        std::fmt::Display::fmt(&self.0, f)
412    }
413}