1use anyhow::{anyhow, Result};
2use native_tls::TlsConnector;
3use postgres_native_tls::MakeTlsConnector;
4use spin_world::async_trait;
5use spin_world::spin::postgres::postgres::{
6 self as v3, Column, DbDataType, DbValue, ParameterValue, RowSet,
7};
8use tokio_postgres::types::Type;
9use tokio_postgres::{config::SslMode, types::ToSql, Row};
10use tokio_postgres::{Client as TokioClient, NoTls, Socket};
11
12#[async_trait]
13pub trait Client {
14 async fn build_client(address: &str) -> Result<Self>
15 where
16 Self: Sized;
17
18 async fn execute(
19 &self,
20 statement: String,
21 params: Vec<ParameterValue>,
22 ) -> Result<u64, v3::Error>;
23
24 async fn query(
25 &self,
26 statement: String,
27 params: Vec<ParameterValue>,
28 ) -> Result<RowSet, v3::Error>;
29}
30
31#[async_trait]
32impl Client for TokioClient {
33 async fn build_client(address: &str) -> Result<Self>
34 where
35 Self: Sized,
36 {
37 let config = address.parse::<tokio_postgres::Config>()?;
38
39 tracing::debug!("Build new connection: {}", address);
40
41 if config.get_ssl_mode() == SslMode::Disable {
42 let (client, connection) = config.connect(NoTls).await?;
43 spawn_connection(connection);
44 Ok(client)
45 } else {
46 let builder = TlsConnector::builder();
47 let connector = MakeTlsConnector::new(builder.build()?);
48 let (client, connection) = config.connect(connector).await?;
49 spawn_connection(connection);
50 Ok(client)
51 }
52 }
53
54 async fn execute(
55 &self,
56 statement: String,
57 params: Vec<ParameterValue>,
58 ) -> Result<u64, v3::Error> {
59 let params = params
60 .iter()
61 .map(to_sql_parameter)
62 .collect::<Result<Vec<_>>>()
63 .map_err(|e| v3::Error::ValueConversionFailed(format!("{:?}", e)))?;
64
65 let params_refs: Vec<&(dyn ToSql + Sync)> = params
66 .iter()
67 .map(|b| b.as_ref() as &(dyn ToSql + Sync))
68 .collect();
69
70 self.execute(&statement, params_refs.as_slice())
71 .await
72 .map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))
73 }
74
75 async fn query(
76 &self,
77 statement: String,
78 params: Vec<ParameterValue>,
79 ) -> Result<RowSet, v3::Error> {
80 let params = params
81 .iter()
82 .map(to_sql_parameter)
83 .collect::<Result<Vec<_>>>()
84 .map_err(|e| v3::Error::BadParameter(format!("{:?}", e)))?;
85
86 let params_refs: Vec<&(dyn ToSql + Sync)> = params
87 .iter()
88 .map(|b| b.as_ref() as &(dyn ToSql + Sync))
89 .collect();
90
91 let results = self
92 .query(&statement, params_refs.as_slice())
93 .await
94 .map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))?;
95
96 if results.is_empty() {
97 return Ok(RowSet {
98 columns: vec![],
99 rows: vec![],
100 });
101 }
102
103 let columns = infer_columns(&results[0]);
104 let rows = results
105 .iter()
106 .map(convert_row)
107 .collect::<Result<Vec<_>, _>>()
108 .map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))?;
109
110 Ok(RowSet { columns, rows })
111 }
112}
113
114fn spawn_connection<T>(connection: tokio_postgres::Connection<Socket, T>)
115where
116 T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static,
117{
118 tokio::spawn(async move {
119 if let Err(e) = connection.await {
120 tracing::error!("Postgres connection error: {}", e);
121 }
122 });
123}
124
125fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Sync>> {
126 match value {
127 ParameterValue::Boolean(v) => Ok(Box::new(*v)),
128 ParameterValue::Int32(v) => Ok(Box::new(*v)),
129 ParameterValue::Int64(v) => Ok(Box::new(*v)),
130 ParameterValue::Int8(v) => Ok(Box::new(*v)),
131 ParameterValue::Int16(v) => Ok(Box::new(*v)),
132 ParameterValue::Floating32(v) => Ok(Box::new(*v)),
133 ParameterValue::Floating64(v) => Ok(Box::new(*v)),
134 ParameterValue::Str(v) => Ok(Box::new(v.clone())),
135 ParameterValue::Binary(v) => Ok(Box::new(v.clone())),
136 ParameterValue::Date((y, mon, d)) => {
137 let naive_date = chrono::NaiveDate::from_ymd_opt(*y, (*mon).into(), (*d).into())
138 .ok_or_else(|| anyhow!("invalid date y={y}, m={mon}, d={d}"))?;
139 Ok(Box::new(naive_date))
140 }
141 ParameterValue::Time((h, min, s, ns)) => {
142 let naive_time =
143 chrono::NaiveTime::from_hms_nano_opt((*h).into(), (*min).into(), (*s).into(), *ns)
144 .ok_or_else(|| anyhow!("invalid time {h}:{min}:{s}:{ns}"))?;
145 Ok(Box::new(naive_time))
146 }
147 ParameterValue::Datetime((y, mon, d, h, min, s, ns)) => {
148 let naive_date = chrono::NaiveDate::from_ymd_opt(*y, (*mon).into(), (*d).into())
149 .ok_or_else(|| anyhow!("invalid date y={y}, m={mon}, d={d}"))?;
150 let naive_time =
151 chrono::NaiveTime::from_hms_nano_opt((*h).into(), (*min).into(), (*s).into(), *ns)
152 .ok_or_else(|| anyhow!("invalid time {h}:{min}:{s}:{ns}"))?;
153 let dt = chrono::NaiveDateTime::new(naive_date, naive_time);
154 Ok(Box::new(dt))
155 }
156 ParameterValue::Timestamp(v) => {
157 let ts = chrono::DateTime::<chrono::Utc>::from_timestamp(*v, 0)
158 .ok_or_else(|| anyhow!("invalid epoch timestamp {v}"))?;
159 Ok(Box::new(ts))
160 }
161 ParameterValue::DbNull => Ok(Box::new(PgNull)),
162 }
163}
164
165fn infer_columns(row: &Row) -> Vec<Column> {
166 let mut result = Vec::with_capacity(row.len());
167 for index in 0..row.len() {
168 result.push(infer_column(row, index));
169 }
170 result
171}
172
173fn infer_column(row: &Row, index: usize) -> Column {
174 let column = &row.columns()[index];
175 let name = column.name().to_owned();
176 let data_type = convert_data_type(column.type_());
177 Column { name, data_type }
178}
179
180fn convert_data_type(pg_type: &Type) -> DbDataType {
181 match *pg_type {
182 Type::BOOL => DbDataType::Boolean,
183 Type::BYTEA => DbDataType::Binary,
184 Type::FLOAT4 => DbDataType::Floating32,
185 Type::FLOAT8 => DbDataType::Floating64,
186 Type::INT2 => DbDataType::Int16,
187 Type::INT4 => DbDataType::Int32,
188 Type::INT8 => DbDataType::Int64,
189 Type::TEXT | Type::VARCHAR | Type::BPCHAR => DbDataType::Str,
190 Type::TIMESTAMP | Type::TIMESTAMPTZ => DbDataType::Timestamp,
191 Type::DATE => DbDataType::Date,
192 Type::TIME => DbDataType::Time,
193 _ => {
194 tracing::debug!("Couldn't convert Postgres type {} to WIT", pg_type.name(),);
195 DbDataType::Other
196 }
197 }
198}
199
200fn convert_row(row: &Row) -> anyhow::Result<Vec<DbValue>> {
201 let mut result = Vec::with_capacity(row.len());
202 for index in 0..row.len() {
203 result.push(convert_entry(row, index)?);
204 }
205 Ok(result)
206}
207
208fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
209 let column = &row.columns()[index];
210 let value = match column.type_() {
211 &Type::BOOL => {
212 let value: Option<bool> = row.try_get(index)?;
213 match value {
214 Some(v) => DbValue::Boolean(v),
215 None => DbValue::DbNull,
216 }
217 }
218 &Type::BYTEA => {
219 let value: Option<Vec<u8>> = row.try_get(index)?;
220 match value {
221 Some(v) => DbValue::Binary(v),
222 None => DbValue::DbNull,
223 }
224 }
225 &Type::FLOAT4 => {
226 let value: Option<f32> = row.try_get(index)?;
227 match value {
228 Some(v) => DbValue::Floating32(v),
229 None => DbValue::DbNull,
230 }
231 }
232 &Type::FLOAT8 => {
233 let value: Option<f64> = row.try_get(index)?;
234 match value {
235 Some(v) => DbValue::Floating64(v),
236 None => DbValue::DbNull,
237 }
238 }
239 &Type::INT2 => {
240 let value: Option<i16> = row.try_get(index)?;
241 match value {
242 Some(v) => DbValue::Int16(v),
243 None => DbValue::DbNull,
244 }
245 }
246 &Type::INT4 => {
247 let value: Option<i32> = row.try_get(index)?;
248 match value {
249 Some(v) => DbValue::Int32(v),
250 None => DbValue::DbNull,
251 }
252 }
253 &Type::INT8 => {
254 let value: Option<i64> = row.try_get(index)?;
255 match value {
256 Some(v) => DbValue::Int64(v),
257 None => DbValue::DbNull,
258 }
259 }
260 &Type::TEXT | &Type::VARCHAR | &Type::BPCHAR => {
261 let value: Option<String> = row.try_get(index)?;
262 match value {
263 Some(v) => DbValue::Str(v),
264 None => DbValue::DbNull,
265 }
266 }
267 &Type::TIMESTAMP | &Type::TIMESTAMPTZ => {
268 let value: Option<chrono::NaiveDateTime> = row.try_get(index)?;
269 match value {
270 Some(v) => DbValue::Datetime(tuplify_date_time(v)?),
271 None => DbValue::DbNull,
272 }
273 }
274 &Type::DATE => {
275 let value: Option<chrono::NaiveDate> = row.try_get(index)?;
276 match value {
277 Some(v) => DbValue::Date(tuplify_date(v)?),
278 None => DbValue::DbNull,
279 }
280 }
281 &Type::TIME => {
282 let value: Option<chrono::NaiveTime> = row.try_get(index)?;
283 match value {
284 Some(v) => DbValue::Time(tuplify_time(v)?),
285 None => DbValue::DbNull,
286 }
287 }
288 t => {
289 tracing::debug!(
290 "Couldn't convert Postgres type {} in column {}",
291 t.name(),
292 column.name()
293 );
294 DbValue::Unsupported
295 }
296 };
297 Ok(value)
298}
299
300fn tuplify_date_time(
302 value: chrono::NaiveDateTime,
303) -> anyhow::Result<(i32, u8, u8, u8, u8, u8, u32)> {
304 use chrono::{Datelike, Timelike};
305 Ok((
306 value.year(),
307 value.month().try_into()?,
308 value.day().try_into()?,
309 value.hour().try_into()?,
310 value.minute().try_into()?,
311 value.second().try_into()?,
312 value.nanosecond(),
313 ))
314}
315
316fn tuplify_date(value: chrono::NaiveDate) -> anyhow::Result<(i32, u8, u8)> {
317 use chrono::Datelike;
318 Ok((
319 value.year(),
320 value.month().try_into()?,
321 value.day().try_into()?,
322 ))
323}
324
325fn tuplify_time(value: chrono::NaiveTime) -> anyhow::Result<(u8, u8, u8, u32)> {
326 use chrono::Timelike;
327 Ok((
328 value.hour().try_into()?,
329 value.minute().try_into()?,
330 value.second().try_into()?,
331 value.nanosecond(),
332 ))
333}
334
335struct PgNull;
342
343impl ToSql for PgNull {
344 fn to_sql(
345 &self,
346 _ty: &Type,
347 _out: &mut tokio_postgres::types::private::BytesMut,
348 ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>>
349 where
350 Self: Sized,
351 {
352 Ok(tokio_postgres::types::IsNull::Yes)
353 }
354
355 fn accepts(_ty: &Type) -> bool
356 where
357 Self: Sized,
358 {
359 true
360 }
361
362 fn to_sql_checked(
363 &self,
364 _ty: &Type,
365 _out: &mut tokio_postgres::types::private::BytesMut,
366 ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
367 Ok(tokio_postgres::types::IsNull::Yes)
368 }
369}
370
371impl std::fmt::Debug for PgNull {
372 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373 f.debug_struct("NULL").finish()
374 }
375}