1use anyhow::Result;
2use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader};
3use spin_world::spin::postgres3_0_0::postgres::{self as v3};
4use spin_world::spin::postgres4_2_0::postgres::{self as v4};
5use spin_world::v1::postgres as v1;
6use spin_world::v1::rdbms_types as v1_types;
7use spin_world::v2::postgres::{self as v2};
8use spin_world::v2::rdbms_types as v2_types;
9use spin_world::MAX_HOST_BUFFERED_BYTES;
10use tracing::field::Empty;
11use tracing::instrument;
12use tracing::Level;
13
14use crate::allowed_hosts::AllowedHostChecker;
15use crate::client::{Client, ClientFactory, HashableCertificate, QueryAsyncResult};
16use crate::InstanceState;
17
18impl<CF: ClientFactory> InstanceState<CF> {
19 async fn open_connection<Conn: 'static>(
20 &mut self,
21 address: &str,
22 root_ca: Option<HashableCertificate>,
23 ) -> Result<Resource<Conn>, v4::Error> {
24 self.connections
25 .push(
26 self.client_factory
27 .get_client(address, root_ca)
28 .await
29 .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?,
30 )
31 .map_err(|_| v4::Error::ConnectionFailed("too many connections".into()))
32 .map(Resource::new_own)
33 }
34
35 async fn get_client<Conn: 'static>(
36 &self,
37 connection: Resource<Conn>,
38 ) -> Result<&CF::Client, v4::Error> {
39 self.connections
40 .get(connection.rep())
41 .ok_or_else(|| v4::Error::ConnectionFailed("no connection found".into()))
42 }
43
44 fn allowed_host_checker(&self) -> AllowedHostChecker {
45 self.allowed_host_checker.clone()
46 }
47
48 #[allow(clippy::result_large_err)]
49 async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> {
50 self.allowed_host_checker
51 .ensure_address_allowed(address)
52 .await
53 }
54}
55
56fn v2_params_to_v3(
57 params: Vec<v2_types::ParameterValue>,
58) -> Result<Vec<v4::ParameterValue>, v2::Error> {
59 params.into_iter().map(|p| p.try_into()).collect()
60}
61
62fn v3_params_to_v4(params: Vec<v3::ParameterValue>) -> Vec<v4::ParameterValue> {
63 params.into_iter().map(|p| p.into()).collect()
64}
65
66impl<CF: ClientFactory> v3::HostConnection for InstanceState<CF> {
67 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
68 async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
69 spin_factor_outbound_networking::record_address_fields(&address);
70
71 self.ensure_address_allowed(&address).await?;
72
73 Ok(self.open_connection(&address, None).await?)
74 }
75
76 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
77 async fn execute(
78 &mut self,
79 connection: Resource<v3::Connection>,
80 statement: String,
81 params: Vec<v3::ParameterValue>,
82 ) -> Result<u64, v3::Error> {
83 Ok(self
84 .get_client(connection)
85 .await?
86 .execute(statement, v3_params_to_v4(params))
87 .await?)
88 }
89
90 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
91 async fn query(
92 &mut self,
93 connection: Resource<v3::Connection>,
94 statement: String,
95 params: Vec<v3::ParameterValue>,
96 ) -> Result<v3::RowSet, v3::Error> {
97 Ok(self
98 .get_client(connection)
99 .await?
100 .query(statement, v3_params_to_v4(params), MAX_HOST_BUFFERED_BYTES)
101 .await?
102 .into())
103 }
104
105 async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
106 self.connections.remove(connection.rep());
107 Ok(())
108 }
109}
110
111pub(crate) struct ConnectionBuilder {
112 address: String,
113 root_ca: Option<HashableCertificate>,
114}
115
116impl<CF: ClientFactory> v4::HostConnectionBuilder for InstanceState<CF> {
117 async fn new(&mut self, address: String) -> Result<Resource<v4::ConnectionBuilder>> {
118 let builder = ConnectionBuilder {
119 address,
120 root_ca: None,
121 };
122 let rep = self
123 .builders
124 .push(builder)
125 .map_err(|_| anyhow::anyhow!("out of builder table space"))?;
126 let rsrc = Resource::new_own(rep);
127 Ok(rsrc)
128 }
129
130 async fn set_ca_root(
131 &mut self,
132 self_: Resource<v4::ConnectionBuilder>,
133 certificate: String,
134 ) -> Result<(), v4::Error> {
135 let root_ca = HashableCertificate::from_pem(&certificate)
136 .map_err(|e| v4::Error::Other(format!("invalid root certificate: {e}")))?;
137 let builder = self
138 .builders
139 .get_mut(self_.rep())
140 .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?;
141 builder.root_ca = Some(root_ca);
142 Ok(())
143 }
144
145 async fn build(
146 &mut self,
147 self_: Resource<v4::ConnectionBuilder>,
148 ) -> Result<Resource<v4::Connection>, v4::Error> {
149 let (address, root_ca) = self.get_builder_info(self_.rep())?;
150 self.open_connection(&address, root_ca).await
151 }
152
153 async fn drop(&mut self, builder: Resource<v4::ConnectionBuilder>) -> Result<()> {
154 self.builders.remove(builder.rep());
155 Ok(())
156 }
157}
158
159impl<CF: ClientFactory> v4::HostConnection for InstanceState<CF> {
160 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
161 async fn open(&mut self, address: String) -> Result<Resource<v4::Connection>, v4::Error> {
162 spin_factor_outbound_networking::record_address_fields(&address);
163
164 self.ensure_address_allowed(&address).await?;
165
166 self.open_connection(&address, None).await
167 }
168
169 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
170 async fn execute(
171 &mut self,
172 connection: Resource<v4::Connection>,
173 statement: String,
174 params: Vec<v4::ParameterValue>,
175 ) -> Result<u64, v4::Error> {
176 self.get_client(connection)
177 .await?
178 .execute(statement, params)
179 .await
180 }
181
182 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
183 async fn query(
184 &mut self,
185 connection: Resource<v4::Connection>,
186 statement: String,
187 params: Vec<v4::ParameterValue>,
188 ) -> Result<v4::RowSet, v4::Error> {
189 self.get_client(connection)
190 .await?
191 .query(statement, params, MAX_HOST_BUFFERED_BYTES)
192 .await
193 }
194
195 async fn drop(&mut self, connection: Resource<v4::Connection>) -> anyhow::Result<()> {
196 self.connections.remove(connection.rep());
197 Ok(())
198 }
199}
200
201impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionWithStore
202 for crate::PgFactorData<CF>
203{
204 #[instrument(name = "spin_outbound_pg.open_async", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
205 async fn open_async<T>(
206 accessor: &Accessor<T, Self>,
207 address: String,
208 ) -> Result<Resource<v4::Connection>, v4::Error> {
209 spin_factor_outbound_networking::record_address_fields(&address);
210
211 Self::ensure_address_allowed_async(accessor, &address).await?;
212 Self::open_connection_async(accessor, &address, None).await
213 }
214
215 #[instrument(name = "spin_outbound_pg.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
216 async fn execute_async<T>(
217 accessor: &Accessor<T, Self>,
218 connection: Resource<v4::Connection>,
219 statement: String,
220 params: Vec<v4::ParameterValue>,
221 ) -> Result<u64, v4::Error> {
222 let client = accessor.with(|mut access| {
223 let host = access.get();
224 host.connections.get(connection.rep()).unwrap().clone()
225 });
226
227 client.execute(statement, params).await
228 }
229
230 #[allow(clippy::type_complexity)] #[instrument(name = "spin_outbound_pg.query_async", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
232 async fn query_async<T>(
233 accessor: &Accessor<T, Self>,
234 connection: Resource<v4::Connection>,
235 statement: String,
236 params: Vec<v4::ParameterValue>,
237 ) -> Result<
238 (
239 Vec<v4::Column>,
240 StreamReader<v4::Row>,
241 FutureReader<Result<(), v4::Error>>,
242 ),
243 v4::Error,
244 > {
245 let client = accessor.with(|mut access| {
246 let host = access.get();
247 host.connections.get(connection.rep()).unwrap().clone()
248 });
249
250 let QueryAsyncResult {
251 columns,
252 rows,
253 error,
254 } = client
255 .query_async(statement, params, MAX_HOST_BUFFERED_BYTES)
256 .await?;
257
258 let row_producer = spin_wasi_async::stream::producer(rows);
259
260 let (sr, efr) = accessor.with(|mut access| {
261 let sr = StreamReader::new(&mut access, row_producer);
262 let efr = FutureReader::new(&mut access, error);
263 (sr, efr)
264 });
265
266 Ok((columns, sr, efr))
267 }
268}
269
270impl<CF: ClientFactory> InstanceState<CF> {
271 #[allow(clippy::result_large_err)]
272 fn get_builder_info(
273 &mut self,
274 builder_rep: u32,
275 ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
276 let builder = self
277 .builders
278 .get_mut(builder_rep)
279 .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?;
280
281 let address = builder.address.clone();
282 let root_ca = builder.root_ca.clone();
283
284 Ok((address, root_ca))
285 }
286}
287
288impl<CF: ClientFactory> crate::PgFactorData<CF> {
289 #[allow(clippy::result_large_err)]
290 fn get_builder_info<T>(
291 accessor: &Accessor<T, Self>,
292 builder: Resource<v4::ConnectionBuilder>,
293 ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
294 let builder_rep = builder.rep();
295 accessor.with(|mut access| {
296 let host = access.get();
297 host.get_builder_info(builder_rep)
298 })
299 }
300
301 async fn ensure_address_allowed_async<T>(
302 accessor: &Accessor<T, Self>,
303 address: &str,
304 ) -> Result<(), v4::Error> {
305 let allowed_host_checker = accessor.with(|mut access| {
307 let host = access.get();
308 host.allowed_host_checker()
309 });
310
311 allowed_host_checker.ensure_address_allowed(address).await
312 }
313
314 async fn open_connection_async<T>(
315 accessor: &Accessor<T, Self>,
316 address: &str,
317 root_ca: Option<HashableCertificate>,
318 ) -> Result<Resource<v4::Connection>, v4::Error> {
319 let cf = accessor.with(|mut access| {
320 let host = access.get();
321 host.client_factory.clone()
322 });
323
324 let client = cf
325 .get_client(address, root_ca)
326 .await
327 .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?;
328
329 let rsrc = accessor.with(|mut access| {
330 let host = access.get();
331 host.connections
332 .push(client)
333 .map_err(|_| v4::Error::ConnectionFailed("too many connections".into()))
334 .map(Resource::new_own)
335 });
336
337 rsrc
338 }
339}
340
341impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionBuilderWithStore
342 for crate::PgFactorData<CF>
343{
344 async fn build_async<T>(
345 accessor: &Accessor<T, Self>,
346 builder: Resource<v4::ConnectionBuilder>,
347 ) -> Result<Resource<v4::Connection>, v4::Error> {
348 let (address, root_ca) = Self::get_builder_info(accessor, builder)?;
349
350 spin_factor_outbound_networking::record_address_fields(&address);
351
352 Self::ensure_address_allowed_async(accessor, &address).await?;
353 Self::open_connection_async(accessor, &address, root_ca).await
354 }
355}
356
357impl<CF: ClientFactory> v2_types::Host for InstanceState<CF> {
358 fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
359 Ok(error)
360 }
361}
362
363impl<CF: ClientFactory> v3::Host for InstanceState<CF> {
364 fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
365 Ok(error)
366 }
367}
368
369impl<CF: ClientFactory> v4::Host for InstanceState<CF> {
370 fn convert_error(&mut self, error: v4::Error) -> Result<v4::Error> {
371 Ok(error)
372 }
373}
374
375macro_rules! delegate {
377 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
378 $self.ensure_address_allowed(&$address).await?;
379 let connection = match $self.open_connection(&$address, None).await {
380 Ok(c) => c,
381 Err(e) => return Err(e.into()),
382 };
383 <Self as v4::HostConnection>::$name($self, connection, $($arg),*)
384 .await
385 .map_err(|e| e.into())
386 }};
387}
388
389impl<CF: ClientFactory> v2::Host for InstanceState<CF> {}
390
391impl<CF: ClientFactory> v2::HostConnection for InstanceState<CF> {
392 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
393 async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
394 self.otel.reparent_tracing_span();
395 spin_factor_outbound_networking::record_address_fields(&address);
396
397 self.ensure_address_allowed(&address).await?;
398 Ok(self.open_connection(&address, None).await?)
399 }
400
401 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
402 async fn execute(
403 &mut self,
404 connection: Resource<v2::Connection>,
405 statement: String,
406 params: Vec<v2_types::ParameterValue>,
407 ) -> Result<u64, v2::Error> {
408 self.otel.reparent_tracing_span();
409 Ok(self
410 .get_client(connection)
411 .await?
412 .execute(statement, v2_params_to_v3(params)?)
413 .await?)
414 }
415
416 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
417 async fn query(
418 &mut self,
419 connection: Resource<v2::Connection>,
420 statement: String,
421 params: Vec<v2_types::ParameterValue>,
422 ) -> Result<v2_types::RowSet, v2::Error> {
423 self.otel.reparent_tracing_span();
424 Ok(self
425 .get_client(connection)
426 .await?
427 .query(statement, v2_params_to_v3(params)?, MAX_HOST_BUFFERED_BYTES)
428 .await?
429 .into())
430 }
431
432 async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
433 self.connections.remove(connection.rep());
434 Ok(())
435 }
436}
437
438impl<CF: ClientFactory> v1::Host for InstanceState<CF> {
439 async fn execute(
440 &mut self,
441 address: String,
442 statement: String,
443 params: Vec<v1_types::ParameterValue>,
444 ) -> Result<u64, v1::PgError> {
445 delegate!(self.execute(
446 address,
447 statement,
448 params
449 .into_iter()
450 .map(TryInto::try_into)
451 .collect::<Result<Vec<_>, _>>()?
452 ))
453 }
454
455 async fn query(
456 &mut self,
457 address: String,
458 statement: String,
459 params: Vec<v1_types::ParameterValue>,
460 ) -> Result<v1_types::RowSet, v1::PgError> {
461 delegate!(self.query(
462 address,
463 statement,
464 params
465 .into_iter()
466 .map(TryInto::try_into)
467 .collect::<Result<Vec<_>, _>>()?
468 ))
469 .map(Into::into)
470 }
471
472 fn convert_pg_error(&mut self, error: v1::PgError) -> Result<v1::PgError> {
473 Ok(error)
474 }
475}