Skip to main content

spin_factor_outbound_pg/
client.rs

1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use futures::stream::TryStreamExt as _;
5use native_tls::TlsConnector;
6use postgres_native_tls::MakeTlsConnector;
7use spin_world::async_trait;
8use spin_world::spin::postgres4_2_0::postgres::{
9    self as v4, Column, DbValue, ParameterValue, RowSet,
10};
11use tokio_postgres::config::SslMode;
12use tokio_postgres::types::ToSql;
13use tokio_postgres::{NoTls, Row};
14
15use crate::types::{convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters};
16
17/// Max connections in a given address' connection pool
18const CONNECTION_POOL_SIZE: usize = 64;
19/// Max addresses for which to keep pools in cache.
20const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16;
21
22/// A factory object for Postgres clients. This abstracts
23/// details of client creation such as pooling.
24#[async_trait]
25pub trait ClientFactory: Default + Send + Sync + 'static {
26    /// The type of client produced by `get_client`.
27    type Client: Client;
28    /// Gets a client from the factory.
29    async fn get_client(
30        &self,
31        address: &str,
32        root_ca: Option<HashableCertificate>,
33    ) -> Result<Self::Client>;
34}
35
36#[derive(Clone)]
37pub struct HashableCertificate {
38    certificate: native_tls::Certificate,
39    hash: String,
40}
41
42impl HashableCertificate {
43    pub fn from_pem(text: &str) -> anyhow::Result<Self> {
44        let cert_bytes = text.as_bytes();
45        let hash = spin_common::sha256::hex_digest_from_bytes(cert_bytes);
46        let certificate =
47            native_tls::Certificate::from_pem(cert_bytes).context("invalid root certificate")?;
48        Ok(Self { certificate, hash })
49    }
50}
51
52/// A `ClientFactory` that uses a connection pool per address.
53pub struct PooledTokioClientFactory {
54    pools: moka::sync::Cache<(String, Option<String>), deadpool_postgres::Pool>,
55}
56
57impl Default for PooledTokioClientFactory {
58    fn default() -> Self {
59        Self {
60            pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY),
61        }
62    }
63}
64
65#[derive(Clone)]
66pub struct PooledTokioClient(Arc<deadpool_postgres::Object>);
67
68impl AsRef<deadpool_postgres::Object> for PooledTokioClient {
69    fn as_ref(&self) -> &deadpool_postgres::Object {
70        self.0.as_ref()
71    }
72}
73
74#[async_trait]
75impl ClientFactory for PooledTokioClientFactory {
76    type Client = PooledTokioClient;
77
78    async fn get_client(
79        &self,
80        address: &str,
81        root_ca: Option<HashableCertificate>,
82    ) -> Result<Self::Client> {
83        let (root_ca, root_ca_hash) = match root_ca {
84            None => (None, None),
85            Some(HashableCertificate { certificate, hash }) => (Some(certificate), Some(hash)),
86        };
87        let pool_key = (address.to_string(), root_ca_hash);
88        let pool = self
89            .pools
90            .try_get_with_by_ref(&pool_key, || create_connection_pool(address, root_ca))
91            .map_err(ArcError)
92            .context("establishing PostgreSQL connection pool")?;
93
94        Ok(PooledTokioClient(Arc::new(pool.get().await?)))
95    }
96}
97
98/// Creates a Postgres connection pool for the given address.
99fn create_connection_pool(
100    address: &str,
101    root_ca: Option<native_tls::Certificate>,
102) -> Result<deadpool_postgres::Pool> {
103    let config = address
104        .parse::<tokio_postgres::Config>()
105        .context("parsing Postgres connection string")?;
106
107    tracing::debug!("Build new connection: {}", address);
108
109    let mgr_config = deadpool_postgres::ManagerConfig {
110        recycling_method: deadpool_postgres::RecyclingMethod::Clean,
111    };
112
113    let mgr = if config.get_ssl_mode() == SslMode::Disable {
114        deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
115    } else {
116        let mut builder = TlsConnector::builder();
117        if let Some(root_ca) = root_ca {
118            builder.add_root_certificate(root_ca);
119        }
120        let connector = MakeTlsConnector::new(builder.build()?);
121        deadpool_postgres::Manager::from_config(config, connector, mgr_config)
122    };
123
124    // TODO: what is our max size heuristic?  Should this be passed in so that different
125    // hosts can manage it according to their needs?  Will a plain number suffice for
126    // sophisticated hosts anyway?
127    let pool = deadpool_postgres::Pool::builder(mgr)
128        .max_size(CONNECTION_POOL_SIZE)
129        .build()
130        .context("building Postgres connection pool")?;
131
132    Ok(pool)
133}
134
135#[async_trait]
136pub trait Client: Clone + Send + Sync + 'static {
137    async fn execute(
138        &self,
139        statement: String,
140        params: Vec<ParameterValue>,
141    ) -> Result<u64, v4::Error>;
142
143    async fn query(
144        &self,
145        statement: String,
146        params: Vec<ParameterValue>,
147        max_result_bytes: usize,
148    ) -> Result<RowSet, v4::Error>;
149
150    async fn query_async(
151        &self,
152        statement: String,
153        params: Vec<ParameterValue>,
154        max_result_bytes: usize,
155    ) -> Result<QueryAsyncResult, v4::Error>;
156}
157
158pub struct QueryAsyncResult {
159    pub columns: Vec<v4::Column>,
160    pub rows: tokio::sync::mpsc::Receiver<v4::Row>,
161    pub error: tokio::sync::oneshot::Receiver<Result<(), v4::Error>>,
162}
163
164/// Extract weak-typed error data for WIT purposes
165fn pg_extras(dbe: &tokio_postgres::error::DbError) -> Vec<(String, String)> {
166    let mut extras = vec![];
167
168    macro_rules! pg_extra {
169        ( $n:ident ) => {
170            if let Some(value) = dbe.$n() {
171                extras.push((stringify!($n).to_owned(), value.to_string()));
172            }
173        };
174    }
175
176    pg_extra!(column);
177    pg_extra!(constraint);
178    pg_extra!(routine);
179    pg_extra!(hint);
180    pg_extra!(table);
181    pg_extra!(datatype);
182    pg_extra!(schema);
183    pg_extra!(file);
184    pg_extra!(line);
185    pg_extra!(where_);
186
187    extras
188}
189
190fn query_failed(e: tokio_postgres::error::Error) -> v4::Error {
191    let flattened = format!("{e:?}");
192    let query_error = match e.as_db_error() {
193        None => v4::QueryError::Text(flattened),
194        Some(dbe) => v4::QueryError::DbError(v4::DbError {
195            as_text: flattened,
196            severity: dbe.severity().to_owned(),
197            code: dbe.code().code().to_owned(),
198            message: dbe.message().to_owned(),
199            detail: dbe.detail().map(|s| s.to_owned()),
200            extras: pg_extras(dbe),
201        }),
202    };
203    v4::Error::QueryFailed(query_error)
204}
205
206fn query_failed_anyhow(e: anyhow::Error) -> v4::Error {
207    let text = format!("{e:?}");
208    v4::Error::QueryFailed(v4::QueryError::Text(text))
209}
210
211#[async_trait]
212impl Client for PooledTokioClient {
213    async fn execute(
214        &self,
215        statement: String,
216        params: Vec<ParameterValue>,
217    ) -> Result<u64, v4::Error> {
218        let params = params
219            .iter()
220            .map(to_sql_parameter)
221            .collect::<Result<Vec<_>>>()
222            .map_err(|e| v4::Error::ValueConversionFailed(format!("{e:?}")))?;
223
224        let params_refs: Vec<&(dyn ToSql + Sync)> = params
225            .iter()
226            .map(|b| b.as_ref() as &(dyn ToSql + Sync))
227            .collect();
228
229        self.as_ref()
230            .execute(&statement, params_refs.as_slice())
231            .await
232            .map_err(query_failed)
233    }
234
235    async fn query(
236        &self,
237        statement: String,
238        params: Vec<ParameterValue>,
239        max_result_bytes: usize,
240    ) -> Result<RowSet, v4::Error> {
241        let (cols_fut, mut results) = self.query_stream(statement, params).await?;
242
243        let mut columns = None;
244        let mut byte_count = std::mem::size_of::<RowSet>();
245        let mut rows = Vec::new();
246
247        async {
248            while let Some(row) = results.try_next().await? {
249                byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
250                if byte_count > max_result_bytes {
251                    anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
252                }
253                rows.push(row);
254            }
255            columns = Some(cols_fut.await);
256            Ok(())
257        }
258        .await
259        .map_err(|e| v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))))?;
260
261        Ok(RowSet {
262            columns: columns.unwrap_or_default(),
263            rows,
264        })
265    }
266
267    async fn query_async(
268        &self,
269        statement: String,
270        params: Vec<ParameterValue>,
271        max_result_bytes: usize,
272    ) -> Result<QueryAsyncResult, v4::Error> {
273        use futures::StreamExt;
274
275        let (cols_fut, mut rows) = self.query_stream(statement, params).await?;
276
277        let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
278        let (err_tx, err_rx) = tokio::sync::oneshot::channel();
279
280        tokio::spawn(async move {
281            loop {
282                let Some(row) = rows.next().await else {
283                    _ = err_tx.send(Ok(()));
284                    return;
285                };
286                match row {
287                    Ok(row) => {
288                        let byte_count = row.iter().map(|v| v.memory_size()).sum::<usize>();
289                        if byte_count > max_result_bytes {
290                            _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
291                                format!("query result exceeds limit of {max_result_bytes} bytes"),
292                            ))));
293                            return;
294                        }
295
296                        if let Err(e) = rows_tx.send(row).await {
297                            _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
298                                format!("async error: {e}"),
299                            ))));
300                            return;
301                        }
302                    }
303                    Err(e) => {
304                        _ = err_tx.send(Err(e));
305                        return;
306                    }
307                }
308            }
309        });
310
311        let cols = cols_fut.await;
312
313        Ok(QueryAsyncResult {
314            columns: cols,
315            rows: rows_rx,
316            error: err_rx,
317        })
318    }
319}
320
321impl PooledTokioClient {
322    async fn query_stream(
323        &self,
324        statement: String,
325        params: Vec<ParameterValue>,
326    ) -> Result<
327        (
328            impl std::future::Future<Output = Vec<v4::Column>>,
329            impl futures::Stream<Item = Result<Vec<DbValue>, v4::Error>>,
330        ),
331        v4::Error,
332    > {
333        use futures::{FutureExt, StreamExt};
334
335        let params = to_sql_parameters(params)?;
336
337        let results = Box::pin(
338            self.as_ref()
339                .query_raw(&statement, params)
340                .await
341                .map_err(query_failed)?,
342        );
343
344        let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
345        let mut cols_tx_opt = Some(cols_tx);
346
347        let row_stm = results.enumerate().map(move |(index, row_res)| {
348            let row_res = row_res.map_err(query_failed);
349            row_res.and_then(|r| {
350                if index == 0 {
351                    if let Some(cols_tx) = cols_tx_opt.take() {
352                        let cols = infer_columns(&r);
353                        _ = cols_tx.send(cols);
354                    }
355                }
356                convert_row(&r).map_err(query_failed_anyhow)
357            })
358        });
359
360        let cols_rx = cols_rx.map(|result| result.unwrap_or_default());
361
362        Ok((cols_rx, Box::pin(row_stm)))
363    }
364}
365
366fn infer_columns(row: &Row) -> Vec<Column> {
367    let mut result = Vec::with_capacity(row.len());
368    for index in 0..row.len() {
369        result.push(infer_column(row, index));
370    }
371    result
372}
373
374fn infer_column(row: &Row, index: usize) -> Column {
375    let column = &row.columns()[index];
376    let name = column.name().to_owned();
377    let data_type = convert_data_type(column.type_());
378    Column { name, data_type }
379}
380
381fn convert_row(row: &Row) -> anyhow::Result<Vec<DbValue>> {
382    let mut result = Vec::with_capacity(row.len());
383    for index in 0..row.len() {
384        result.push(convert_entry(row, index)?);
385    }
386    Ok(result)
387}
388
389/// Workaround for moka returning Arc<Error> which, although
390/// necessary for concurrency, does not play well with others.
391struct ArcError(std::sync::Arc<anyhow::Error>);
392
393impl std::error::Error for ArcError {
394    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
395        self.0.source()
396    }
397}
398
399impl std::fmt::Debug for ArcError {
400    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401        std::fmt::Debug::fmt(&self.0, f)
402    }
403}
404
405impl std::fmt::Display for ArcError {
406    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407        std::fmt::Display::fmt(&self.0, f)
408    }
409}