spin_factor_outbound_pg/
client.rs

1use anyhow::{Context, Result};
2use native_tls::TlsConnector;
3use postgres_native_tls::MakeTlsConnector;
4use spin_world::async_trait;
5use spin_world::spin::postgres4_0_0::postgres::{
6    self as v4, Column, DbValue, ParameterValue, RowSet,
7};
8use tokio_postgres::types::ToSql;
9use tokio_postgres::{config::SslMode, NoTls, Row};
10
11use crate::types::{convert_data_type, convert_entry, to_sql_parameter};
12
13/// Max connections in a given address' connection pool
14const CONNECTION_POOL_SIZE: usize = 64;
15/// Max addresses for which to keep pools in cache.
16const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16;
17
18/// A factory object for Postgres clients. This abstracts
19/// details of client creation such as pooling.
20#[async_trait]
21pub trait ClientFactory: Default + Send + Sync + 'static {
22    /// The type of client produced by `get_client`.
23    type Client: Client;
24    /// Gets a client from the factory.
25    async fn get_client(&self, address: &str) -> Result<Self::Client>;
26}
27
28/// A `ClientFactory` that uses a connection pool per address.
29pub struct PooledTokioClientFactory {
30    pools: moka::sync::Cache<String, deadpool_postgres::Pool>,
31}
32
33impl Default for PooledTokioClientFactory {
34    fn default() -> Self {
35        Self {
36            pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY),
37        }
38    }
39}
40
41#[async_trait]
42impl ClientFactory for PooledTokioClientFactory {
43    type Client = deadpool_postgres::Object;
44
45    async fn get_client(&self, address: &str) -> Result<Self::Client> {
46        let pool = self
47            .pools
48            .try_get_with_by_ref(address, || create_connection_pool(address))
49            .map_err(ArcError)
50            .context("establishing PostgreSQL connection pool")?;
51
52        Ok(pool.get().await?)
53    }
54}
55
56/// Creates a Postgres connection pool for the given address.
57fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
58    let config = address
59        .parse::<tokio_postgres::Config>()
60        .context("parsing Postgres connection string")?;
61
62    tracing::debug!("Build new connection: {}", address);
63
64    let mgr_config = deadpool_postgres::ManagerConfig {
65        recycling_method: deadpool_postgres::RecyclingMethod::Clean,
66    };
67
68    let mgr = if config.get_ssl_mode() == SslMode::Disable {
69        deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
70    } else {
71        let builder = TlsConnector::builder();
72        let connector = MakeTlsConnector::new(builder.build()?);
73        deadpool_postgres::Manager::from_config(config, connector, mgr_config)
74    };
75
76    // TODO: what is our max size heuristic?  Should this be passed in so that different
77    // hosts can manage it according to their needs?  Will a plain number suffice for
78    // sophisticated hosts anyway?
79    let pool = deadpool_postgres::Pool::builder(mgr)
80        .max_size(CONNECTION_POOL_SIZE)
81        .build()
82        .context("building Postgres connection pool")?;
83
84    Ok(pool)
85}
86
87#[async_trait]
88pub trait Client: Send + Sync + 'static {
89    async fn execute(
90        &self,
91        statement: String,
92        params: Vec<ParameterValue>,
93    ) -> Result<u64, v4::Error>;
94
95    async fn query(
96        &self,
97        statement: String,
98        params: Vec<ParameterValue>,
99    ) -> Result<RowSet, v4::Error>;
100}
101
102/// Extract weak-typed error data for WIT purposes
103fn pg_extras(dbe: &tokio_postgres::error::DbError) -> Vec<(String, String)> {
104    let mut extras = vec![];
105
106    macro_rules! pg_extra {
107        ( $n:ident ) => {
108            if let Some(value) = dbe.$n() {
109                extras.push((stringify!($n).to_owned(), value.to_string()));
110            }
111        };
112    }
113
114    pg_extra!(column);
115    pg_extra!(constraint);
116    pg_extra!(routine);
117    pg_extra!(hint);
118    pg_extra!(table);
119    pg_extra!(datatype);
120    pg_extra!(schema);
121    pg_extra!(file);
122    pg_extra!(line);
123    pg_extra!(where_);
124
125    extras
126}
127
128fn query_failed(e: tokio_postgres::error::Error) -> v4::Error {
129    let flattened = format!("{e:?}");
130    let query_error = match e.as_db_error() {
131        None => v4::QueryError::Text(flattened),
132        Some(dbe) => v4::QueryError::DbError(v4::DbError {
133            as_text: flattened,
134            severity: dbe.severity().to_owned(),
135            code: dbe.code().code().to_owned(),
136            message: dbe.message().to_owned(),
137            detail: dbe.detail().map(|s| s.to_owned()),
138            extras: pg_extras(dbe),
139        }),
140    };
141    v4::Error::QueryFailed(query_error)
142}
143
144#[async_trait]
145impl Client for deadpool_postgres::Object {
146    async fn execute(
147        &self,
148        statement: String,
149        params: Vec<ParameterValue>,
150    ) -> Result<u64, v4::Error> {
151        let params = params
152            .iter()
153            .map(to_sql_parameter)
154            .collect::<Result<Vec<_>>>()
155            .map_err(|e| v4::Error::ValueConversionFailed(format!("{e:?}")))?;
156
157        let params_refs: Vec<&(dyn ToSql + Sync)> = params
158            .iter()
159            .map(|b| b.as_ref() as &(dyn ToSql + Sync))
160            .collect();
161
162        self.as_ref()
163            .execute(&statement, params_refs.as_slice())
164            .await
165            .map_err(query_failed)
166    }
167
168    async fn query(
169        &self,
170        statement: String,
171        params: Vec<ParameterValue>,
172    ) -> Result<RowSet, v4::Error> {
173        let params = params
174            .iter()
175            .map(to_sql_parameter)
176            .collect::<Result<Vec<_>>>()
177            .map_err(|e| v4::Error::BadParameter(format!("{e:?}")))?;
178
179        let params_refs: Vec<&(dyn ToSql + Sync)> = params
180            .iter()
181            .map(|b| b.as_ref() as &(dyn ToSql + Sync))
182            .collect();
183
184        let results = self
185            .as_ref()
186            .query(&statement, params_refs.as_slice())
187            .await
188            .map_err(query_failed)?;
189
190        if results.is_empty() {
191            return Ok(RowSet {
192                columns: vec![],
193                rows: vec![],
194            });
195        }
196
197        let columns = infer_columns(&results[0]);
198        let rows = results
199            .iter()
200            .map(convert_row)
201            .collect::<Result<Vec<_>, _>>()
202            .map_err(|e| v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))))?;
203
204        Ok(RowSet { columns, rows })
205    }
206}
207
208fn infer_columns(row: &Row) -> Vec<Column> {
209    let mut result = Vec::with_capacity(row.len());
210    for index in 0..row.len() {
211        result.push(infer_column(row, index));
212    }
213    result
214}
215
216fn infer_column(row: &Row, index: usize) -> Column {
217    let column = &row.columns()[index];
218    let name = column.name().to_owned();
219    let data_type = convert_data_type(column.type_());
220    Column { name, data_type }
221}
222
223fn convert_row(row: &Row) -> anyhow::Result<Vec<DbValue>> {
224    let mut result = Vec::with_capacity(row.len());
225    for index in 0..row.len() {
226        result.push(convert_entry(row, index)?);
227    }
228    Ok(result)
229}
230
231/// Workaround for moka returning Arc<Error> which, although
232/// necessary for concurrency, does not play well with others.
233struct ArcError(std::sync::Arc<anyhow::Error>);
234
235impl std::error::Error for ArcError {
236    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
237        self.0.source()
238    }
239}
240
241impl std::fmt::Debug for ArcError {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        std::fmt::Debug::fmt(&self.0, f)
244    }
245}
246
247impl std::fmt::Display for ArcError {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        std::fmt::Display::fmt(&self.0, f)
250    }
251}