spin_factor_outbound_pg/
client.rs1use anyhow::{Context, Result};
2use native_tls::TlsConnector;
3use postgres_native_tls::MakeTlsConnector;
4use spin_world::async_trait;
5use spin_world::spin::postgres4_0_0::postgres::{
6 self as v4, Column, DbValue, ParameterValue, RowSet,
7};
8use tokio_postgres::types::ToSql;
9use tokio_postgres::{config::SslMode, NoTls, Row};
10
11use crate::types::{convert_data_type, convert_entry, to_sql_parameter};
12
13const CONNECTION_POOL_SIZE: usize = 64;
15const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16;
17
18#[async_trait]
21pub trait ClientFactory: Default + Send + Sync + 'static {
22 type Client: Client;
24 async fn get_client(&self, address: &str) -> Result<Self::Client>;
26}
27
28pub struct PooledTokioClientFactory {
30 pools: moka::sync::Cache<String, deadpool_postgres::Pool>,
31}
32
33impl Default for PooledTokioClientFactory {
34 fn default() -> Self {
35 Self {
36 pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY),
37 }
38 }
39}
40
41#[async_trait]
42impl ClientFactory for PooledTokioClientFactory {
43 type Client = deadpool_postgres::Object;
44
45 async fn get_client(&self, address: &str) -> Result<Self::Client> {
46 let pool = self
47 .pools
48 .try_get_with_by_ref(address, || create_connection_pool(address))
49 .map_err(ArcError)
50 .context("establishing PostgreSQL connection pool")?;
51
52 Ok(pool.get().await?)
53 }
54}
55
56fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
58 let config = address
59 .parse::<tokio_postgres::Config>()
60 .context("parsing Postgres connection string")?;
61
62 tracing::debug!("Build new connection: {}", address);
63
64 let mgr_config = deadpool_postgres::ManagerConfig {
65 recycling_method: deadpool_postgres::RecyclingMethod::Clean,
66 };
67
68 let mgr = if config.get_ssl_mode() == SslMode::Disable {
69 deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
70 } else {
71 let builder = TlsConnector::builder();
72 let connector = MakeTlsConnector::new(builder.build()?);
73 deadpool_postgres::Manager::from_config(config, connector, mgr_config)
74 };
75
76 let pool = deadpool_postgres::Pool::builder(mgr)
80 .max_size(CONNECTION_POOL_SIZE)
81 .build()
82 .context("building Postgres connection pool")?;
83
84 Ok(pool)
85}
86
87#[async_trait]
88pub trait Client: Send + Sync + 'static {
89 async fn execute(
90 &self,
91 statement: String,
92 params: Vec<ParameterValue>,
93 ) -> Result<u64, v4::Error>;
94
95 async fn query(
96 &self,
97 statement: String,
98 params: Vec<ParameterValue>,
99 ) -> Result<RowSet, v4::Error>;
100}
101
102fn pg_extras(dbe: &tokio_postgres::error::DbError) -> Vec<(String, String)> {
104 let mut extras = vec![];
105
106 macro_rules! pg_extra {
107 ( $n:ident ) => {
108 if let Some(value) = dbe.$n() {
109 extras.push((stringify!($n).to_owned(), value.to_string()));
110 }
111 };
112 }
113
114 pg_extra!(column);
115 pg_extra!(constraint);
116 pg_extra!(routine);
117 pg_extra!(hint);
118 pg_extra!(table);
119 pg_extra!(datatype);
120 pg_extra!(schema);
121 pg_extra!(file);
122 pg_extra!(line);
123 pg_extra!(where_);
124
125 extras
126}
127
128fn query_failed(e: tokio_postgres::error::Error) -> v4::Error {
129 let flattened = format!("{e:?}");
130 let query_error = match e.as_db_error() {
131 None => v4::QueryError::Text(flattened),
132 Some(dbe) => v4::QueryError::DbError(v4::DbError {
133 as_text: flattened,
134 severity: dbe.severity().to_owned(),
135 code: dbe.code().code().to_owned(),
136 message: dbe.message().to_owned(),
137 detail: dbe.detail().map(|s| s.to_owned()),
138 extras: pg_extras(dbe),
139 }),
140 };
141 v4::Error::QueryFailed(query_error)
142}
143
144#[async_trait]
145impl Client for deadpool_postgres::Object {
146 async fn execute(
147 &self,
148 statement: String,
149 params: Vec<ParameterValue>,
150 ) -> Result<u64, v4::Error> {
151 let params = params
152 .iter()
153 .map(to_sql_parameter)
154 .collect::<Result<Vec<_>>>()
155 .map_err(|e| v4::Error::ValueConversionFailed(format!("{e:?}")))?;
156
157 let params_refs: Vec<&(dyn ToSql + Sync)> = params
158 .iter()
159 .map(|b| b.as_ref() as &(dyn ToSql + Sync))
160 .collect();
161
162 self.as_ref()
163 .execute(&statement, params_refs.as_slice())
164 .await
165 .map_err(query_failed)
166 }
167
168 async fn query(
169 &self,
170 statement: String,
171 params: Vec<ParameterValue>,
172 ) -> Result<RowSet, v4::Error> {
173 let params = params
174 .iter()
175 .map(to_sql_parameter)
176 .collect::<Result<Vec<_>>>()
177 .map_err(|e| v4::Error::BadParameter(format!("{e:?}")))?;
178
179 let params_refs: Vec<&(dyn ToSql + Sync)> = params
180 .iter()
181 .map(|b| b.as_ref() as &(dyn ToSql + Sync))
182 .collect();
183
184 let results = self
185 .as_ref()
186 .query(&statement, params_refs.as_slice())
187 .await
188 .map_err(query_failed)?;
189
190 if results.is_empty() {
191 return Ok(RowSet {
192 columns: vec![],
193 rows: vec![],
194 });
195 }
196
197 let columns = infer_columns(&results[0]);
198 let rows = results
199 .iter()
200 .map(convert_row)
201 .collect::<Result<Vec<_>, _>>()
202 .map_err(|e| v4::Error::QueryFailed(v4::QueryError::Text(format!("{e:?}"))))?;
203
204 Ok(RowSet { columns, rows })
205 }
206}
207
208fn infer_columns(row: &Row) -> Vec<Column> {
209 let mut result = Vec::with_capacity(row.len());
210 for index in 0..row.len() {
211 result.push(infer_column(row, index));
212 }
213 result
214}
215
216fn infer_column(row: &Row, index: usize) -> Column {
217 let column = &row.columns()[index];
218 let name = column.name().to_owned();
219 let data_type = convert_data_type(column.type_());
220 Column { name, data_type }
221}
222
223fn convert_row(row: &Row) -> anyhow::Result<Vec<DbValue>> {
224 let mut result = Vec::with_capacity(row.len());
225 for index in 0..row.len() {
226 result.push(convert_entry(row, index)?);
227 }
228 Ok(result)
229}
230
231struct ArcError(std::sync::Arc<anyhow::Error>);
234
235impl std::error::Error for ArcError {
236 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
237 self.0.source()
238 }
239}
240
241impl std::fmt::Debug for ArcError {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 std::fmt::Debug::fmt(&self.0, f)
244 }
245}
246
247impl std::fmt::Display for ArcError {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 std::fmt::Display::fmt(&self.0, f)
250 }
251}