spin_factor_outbound_pg/
client.rs

1use anyhow::{anyhow, Result};
2use native_tls::TlsConnector;
3use postgres_native_tls::MakeTlsConnector;
4use spin_world::async_trait;
5use spin_world::spin::postgres::postgres::{
6    self as v3, Column, DbDataType, DbValue, ParameterValue, RowSet,
7};
8use tokio_postgres::types::Type;
9use tokio_postgres::{config::SslMode, types::ToSql, Row};
10use tokio_postgres::{Client as TokioClient, NoTls, Socket};
11
12#[async_trait]
13pub trait Client {
14    async fn build_client(address: &str) -> Result<Self>
15    where
16        Self: Sized;
17
18    async fn execute(
19        &self,
20        statement: String,
21        params: Vec<ParameterValue>,
22    ) -> Result<u64, v3::Error>;
23
24    async fn query(
25        &self,
26        statement: String,
27        params: Vec<ParameterValue>,
28    ) -> Result<RowSet, v3::Error>;
29}
30
31#[async_trait]
32impl Client for TokioClient {
33    async fn build_client(address: &str) -> Result<Self>
34    where
35        Self: Sized,
36    {
37        let config = address.parse::<tokio_postgres::Config>()?;
38
39        tracing::debug!("Build new connection: {}", address);
40
41        if config.get_ssl_mode() == SslMode::Disable {
42            let (client, connection) = config.connect(NoTls).await?;
43            spawn_connection(connection);
44            Ok(client)
45        } else {
46            let builder = TlsConnector::builder();
47            let connector = MakeTlsConnector::new(builder.build()?);
48            let (client, connection) = config.connect(connector).await?;
49            spawn_connection(connection);
50            Ok(client)
51        }
52    }
53
54    async fn execute(
55        &self,
56        statement: String,
57        params: Vec<ParameterValue>,
58    ) -> Result<u64, v3::Error> {
59        let params = params
60            .iter()
61            .map(to_sql_parameter)
62            .collect::<Result<Vec<_>>>()
63            .map_err(|e| v3::Error::ValueConversionFailed(format!("{:?}", e)))?;
64
65        let params_refs: Vec<&(dyn ToSql + Sync)> = params
66            .iter()
67            .map(|b| b.as_ref() as &(dyn ToSql + Sync))
68            .collect();
69
70        self.execute(&statement, params_refs.as_slice())
71            .await
72            .map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))
73    }
74
75    async fn query(
76        &self,
77        statement: String,
78        params: Vec<ParameterValue>,
79    ) -> Result<RowSet, v3::Error> {
80        let params = params
81            .iter()
82            .map(to_sql_parameter)
83            .collect::<Result<Vec<_>>>()
84            .map_err(|e| v3::Error::BadParameter(format!("{:?}", e)))?;
85
86        let params_refs: Vec<&(dyn ToSql + Sync)> = params
87            .iter()
88            .map(|b| b.as_ref() as &(dyn ToSql + Sync))
89            .collect();
90
91        let results = self
92            .query(&statement, params_refs.as_slice())
93            .await
94            .map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))?;
95
96        if results.is_empty() {
97            return Ok(RowSet {
98                columns: vec![],
99                rows: vec![],
100            });
101        }
102
103        let columns = infer_columns(&results[0]);
104        let rows = results
105            .iter()
106            .map(convert_row)
107            .collect::<Result<Vec<_>, _>>()
108            .map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))?;
109
110        Ok(RowSet { columns, rows })
111    }
112}
113
114fn spawn_connection<T>(connection: tokio_postgres::Connection<Socket, T>)
115where
116    T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static,
117{
118    tokio::spawn(async move {
119        if let Err(e) = connection.await {
120            tracing::error!("Postgres connection error: {}", e);
121        }
122    });
123}
124
125fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Sync>> {
126    match value {
127        ParameterValue::Boolean(v) => Ok(Box::new(*v)),
128        ParameterValue::Int32(v) => Ok(Box::new(*v)),
129        ParameterValue::Int64(v) => Ok(Box::new(*v)),
130        ParameterValue::Int8(v) => Ok(Box::new(*v)),
131        ParameterValue::Int16(v) => Ok(Box::new(*v)),
132        ParameterValue::Floating32(v) => Ok(Box::new(*v)),
133        ParameterValue::Floating64(v) => Ok(Box::new(*v)),
134        ParameterValue::Str(v) => Ok(Box::new(v.clone())),
135        ParameterValue::Binary(v) => Ok(Box::new(v.clone())),
136        ParameterValue::Date((y, mon, d)) => {
137            let naive_date = chrono::NaiveDate::from_ymd_opt(*y, (*mon).into(), (*d).into())
138                .ok_or_else(|| anyhow!("invalid date y={y}, m={mon}, d={d}"))?;
139            Ok(Box::new(naive_date))
140        }
141        ParameterValue::Time((h, min, s, ns)) => {
142            let naive_time =
143                chrono::NaiveTime::from_hms_nano_opt((*h).into(), (*min).into(), (*s).into(), *ns)
144                    .ok_or_else(|| anyhow!("invalid time {h}:{min}:{s}:{ns}"))?;
145            Ok(Box::new(naive_time))
146        }
147        ParameterValue::Datetime((y, mon, d, h, min, s, ns)) => {
148            let naive_date = chrono::NaiveDate::from_ymd_opt(*y, (*mon).into(), (*d).into())
149                .ok_or_else(|| anyhow!("invalid date y={y}, m={mon}, d={d}"))?;
150            let naive_time =
151                chrono::NaiveTime::from_hms_nano_opt((*h).into(), (*min).into(), (*s).into(), *ns)
152                    .ok_or_else(|| anyhow!("invalid time {h}:{min}:{s}:{ns}"))?;
153            let dt = chrono::NaiveDateTime::new(naive_date, naive_time);
154            Ok(Box::new(dt))
155        }
156        ParameterValue::Timestamp(v) => {
157            let ts = chrono::DateTime::<chrono::Utc>::from_timestamp(*v, 0)
158                .ok_or_else(|| anyhow!("invalid epoch timestamp {v}"))?;
159            Ok(Box::new(ts))
160        }
161        ParameterValue::DbNull => Ok(Box::new(PgNull)),
162    }
163}
164
165fn infer_columns(row: &Row) -> Vec<Column> {
166    let mut result = Vec::with_capacity(row.len());
167    for index in 0..row.len() {
168        result.push(infer_column(row, index));
169    }
170    result
171}
172
173fn infer_column(row: &Row, index: usize) -> Column {
174    let column = &row.columns()[index];
175    let name = column.name().to_owned();
176    let data_type = convert_data_type(column.type_());
177    Column { name, data_type }
178}
179
180fn convert_data_type(pg_type: &Type) -> DbDataType {
181    match *pg_type {
182        Type::BOOL => DbDataType::Boolean,
183        Type::BYTEA => DbDataType::Binary,
184        Type::FLOAT4 => DbDataType::Floating32,
185        Type::FLOAT8 => DbDataType::Floating64,
186        Type::INT2 => DbDataType::Int16,
187        Type::INT4 => DbDataType::Int32,
188        Type::INT8 => DbDataType::Int64,
189        Type::TEXT | Type::VARCHAR | Type::BPCHAR => DbDataType::Str,
190        Type::TIMESTAMP | Type::TIMESTAMPTZ => DbDataType::Timestamp,
191        Type::DATE => DbDataType::Date,
192        Type::TIME => DbDataType::Time,
193        _ => {
194            tracing::debug!("Couldn't convert Postgres type {} to WIT", pg_type.name(),);
195            DbDataType::Other
196        }
197    }
198}
199
200fn convert_row(row: &Row) -> anyhow::Result<Vec<DbValue>> {
201    let mut result = Vec::with_capacity(row.len());
202    for index in 0..row.len() {
203        result.push(convert_entry(row, index)?);
204    }
205    Ok(result)
206}
207
208fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
209    let column = &row.columns()[index];
210    let value = match column.type_() {
211        &Type::BOOL => {
212            let value: Option<bool> = row.try_get(index)?;
213            match value {
214                Some(v) => DbValue::Boolean(v),
215                None => DbValue::DbNull,
216            }
217        }
218        &Type::BYTEA => {
219            let value: Option<Vec<u8>> = row.try_get(index)?;
220            match value {
221                Some(v) => DbValue::Binary(v),
222                None => DbValue::DbNull,
223            }
224        }
225        &Type::FLOAT4 => {
226            let value: Option<f32> = row.try_get(index)?;
227            match value {
228                Some(v) => DbValue::Floating32(v),
229                None => DbValue::DbNull,
230            }
231        }
232        &Type::FLOAT8 => {
233            let value: Option<f64> = row.try_get(index)?;
234            match value {
235                Some(v) => DbValue::Floating64(v),
236                None => DbValue::DbNull,
237            }
238        }
239        &Type::INT2 => {
240            let value: Option<i16> = row.try_get(index)?;
241            match value {
242                Some(v) => DbValue::Int16(v),
243                None => DbValue::DbNull,
244            }
245        }
246        &Type::INT4 => {
247            let value: Option<i32> = row.try_get(index)?;
248            match value {
249                Some(v) => DbValue::Int32(v),
250                None => DbValue::DbNull,
251            }
252        }
253        &Type::INT8 => {
254            let value: Option<i64> = row.try_get(index)?;
255            match value {
256                Some(v) => DbValue::Int64(v),
257                None => DbValue::DbNull,
258            }
259        }
260        &Type::TEXT | &Type::VARCHAR | &Type::BPCHAR => {
261            let value: Option<String> = row.try_get(index)?;
262            match value {
263                Some(v) => DbValue::Str(v),
264                None => DbValue::DbNull,
265            }
266        }
267        &Type::TIMESTAMP | &Type::TIMESTAMPTZ => {
268            let value: Option<chrono::NaiveDateTime> = row.try_get(index)?;
269            match value {
270                Some(v) => DbValue::Datetime(tuplify_date_time(v)?),
271                None => DbValue::DbNull,
272            }
273        }
274        &Type::DATE => {
275            let value: Option<chrono::NaiveDate> = row.try_get(index)?;
276            match value {
277                Some(v) => DbValue::Date(tuplify_date(v)?),
278                None => DbValue::DbNull,
279            }
280        }
281        &Type::TIME => {
282            let value: Option<chrono::NaiveTime> = row.try_get(index)?;
283            match value {
284                Some(v) => DbValue::Time(tuplify_time(v)?),
285                None => DbValue::DbNull,
286            }
287        }
288        t => {
289            tracing::debug!(
290                "Couldn't convert Postgres type {} in column {}",
291                t.name(),
292                column.name()
293            );
294            DbValue::Unsupported
295        }
296    };
297    Ok(value)
298}
299
300// Functions to convert from the chrono types to the WIT interface tuples
301fn tuplify_date_time(
302    value: chrono::NaiveDateTime,
303) -> anyhow::Result<(i32, u8, u8, u8, u8, u8, u32)> {
304    use chrono::{Datelike, Timelike};
305    Ok((
306        value.year(),
307        value.month().try_into()?,
308        value.day().try_into()?,
309        value.hour().try_into()?,
310        value.minute().try_into()?,
311        value.second().try_into()?,
312        value.nanosecond(),
313    ))
314}
315
316fn tuplify_date(value: chrono::NaiveDate) -> anyhow::Result<(i32, u8, u8)> {
317    use chrono::Datelike;
318    Ok((
319        value.year(),
320        value.month().try_into()?,
321        value.day().try_into()?,
322    ))
323}
324
325fn tuplify_time(value: chrono::NaiveTime) -> anyhow::Result<(u8, u8, u8, u32)> {
326    use chrono::Timelike;
327    Ok((
328        value.hour().try_into()?,
329        value.minute().try_into()?,
330        value.second().try_into()?,
331        value.nanosecond(),
332    ))
333}
334
335/// Although the Postgres crate converts Rust Option::None to Postgres NULL,
336/// it enforces the type of the Option as it does so. (For example, trying to
337/// pass an Option::<i32>::None to a VARCHAR column fails conversion.) As we
338/// do not know expected column types, we instead use a "neutral" custom type
339/// which allows conversion to any type but always tells the Postgres crate to
340/// treat it as a SQL NULL.
341struct PgNull;
342
343impl ToSql for PgNull {
344    fn to_sql(
345        &self,
346        _ty: &Type,
347        _out: &mut tokio_postgres::types::private::BytesMut,
348    ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>>
349    where
350        Self: Sized,
351    {
352        Ok(tokio_postgres::types::IsNull::Yes)
353    }
354
355    fn accepts(_ty: &Type) -> bool
356    where
357        Self: Sized,
358    {
359        true
360    }
361
362    fn to_sql_checked(
363        &self,
364        _ty: &Type,
365        _out: &mut tokio_postgres::types::private::BytesMut,
366    ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
367        Ok(tokio_postgres::types::IsNull::Yes)
368    }
369}
370
371impl std::fmt::Debug for PgNull {
372    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        f.debug_struct("NULL").finish()
374    }
375}