1#![allow(clippy::result_large_err)]
2
3use anyhow::Result;
4use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader};
5use spin_world::MAX_HOST_BUFFERED_BYTES;
6use spin_world::spin::postgres3_0_0::postgres::{self as v3};
7use spin_world::spin::postgres4_2_0::postgres::{self as v4};
8use spin_world::v1::postgres as v1;
9use spin_world::v1::rdbms_types as v1_types;
10use spin_world::v2::postgres::{self as v2};
11use spin_world::v2::rdbms_types as v2_types;
12use tracing::Level;
13use tracing::field::Empty;
14use tracing::instrument;
15
16use crate::InstanceState;
17use crate::allowed_hosts::AllowedHostChecker;
18use crate::client::{Client, ClientFactory, HashableCertificate, QueryAsyncResult};
19
20impl<CF: ClientFactory> InstanceState<CF> {
21 async fn open_connection<Conn: 'static>(
22 &mut self,
23 address: &str,
24 root_ca: Option<HashableCertificate>,
25 ) -> Result<Resource<Conn>, v4::Error> {
26 self.connections
27 .push(
28 self.client_factory
29 .get_client(address, root_ca)
30 .await
31 .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?,
32 )
33 .map_err(|_| v4::Error::ConnectionFailed("too many connections".into()))
34 .map(Resource::new_own)
35 }
36
37 async fn get_client<Conn: 'static>(
38 &self,
39 connection: Resource<Conn>,
40 ) -> Result<&CF::Client, v4::Error> {
41 self.connections
42 .get(connection.rep())
43 .ok_or_else(|| v4::Error::ConnectionFailed("no connection found".into()))
44 }
45
46 fn allowed_host_checker(&self) -> AllowedHostChecker {
47 self.allowed_host_checker.clone()
48 }
49
50 #[allow(clippy::result_large_err)]
51 async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> {
52 self.allowed_host_checker
53 .ensure_address_allowed(address)
54 .await
55 }
56}
57
58fn v2_params_to_v3(
59 params: Vec<v2_types::ParameterValue>,
60) -> Result<Vec<v4::ParameterValue>, v2::Error> {
61 params.into_iter().map(|p| p.try_into()).collect()
62}
63
64fn v3_params_to_v4(params: Vec<v3::ParameterValue>) -> Vec<v4::ParameterValue> {
65 params.into_iter().map(|p| p.into()).collect()
66}
67
68impl<CF: ClientFactory> v3::HostConnection for InstanceState<CF> {
69 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
70 async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
71 spin_factor_outbound_networking::record_address_fields(&address);
72
73 self.ensure_address_allowed(&address).await?;
74
75 Ok(self.open_connection(&address, None).await?)
76 }
77
78 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
79 async fn execute(
80 &mut self,
81 connection: Resource<v3::Connection>,
82 statement: String,
83 params: Vec<v3::ParameterValue>,
84 ) -> Result<u64, v3::Error> {
85 Ok(self
86 .get_client(connection)
87 .await?
88 .execute(statement, v3_params_to_v4(params))
89 .await?)
90 }
91
92 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
93 async fn query(
94 &mut self,
95 connection: Resource<v3::Connection>,
96 statement: String,
97 params: Vec<v3::ParameterValue>,
98 ) -> Result<v3::RowSet, v3::Error> {
99 Ok(self
100 .get_client(connection)
101 .await?
102 .query(statement, v3_params_to_v4(params), MAX_HOST_BUFFERED_BYTES)
103 .await?
104 .into())
105 }
106
107 async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
108 self.connections.remove(connection.rep());
109 Ok(())
110 }
111}
112
113pub(crate) struct ConnectionBuilder {
114 address: String,
115 root_ca: Option<HashableCertificate>,
116}
117
118impl<CF: ClientFactory> v4::HostConnectionBuilder for InstanceState<CF> {
119 async fn new(&mut self, address: String) -> Result<Resource<v4::ConnectionBuilder>> {
120 let builder = ConnectionBuilder {
121 address,
122 root_ca: None,
123 };
124 let rep = self
125 .builders
126 .push(builder)
127 .map_err(|_| anyhow::anyhow!("out of builder table space"))?;
128 let rsrc = Resource::new_own(rep);
129 Ok(rsrc)
130 }
131
132 async fn set_ca_root(
133 &mut self,
134 self_: Resource<v4::ConnectionBuilder>,
135 certificate: String,
136 ) -> Result<(), v4::Error> {
137 let root_ca = HashableCertificate::from_pem(&certificate)
138 .map_err(|e| v4::Error::Other(format!("invalid root certificate: {e}")))?;
139 let builder = self
140 .builders
141 .get_mut(self_.rep())
142 .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?;
143 builder.root_ca = Some(root_ca);
144 Ok(())
145 }
146
147 async fn build(
148 &mut self,
149 self_: Resource<v4::ConnectionBuilder>,
150 ) -> Result<Resource<v4::Connection>, v4::Error> {
151 let (address, root_ca) = self.get_builder_info(self_.rep())?;
152 self.open_connection(&address, root_ca).await
153 }
154
155 async fn drop(&mut self, builder: Resource<v4::ConnectionBuilder>) -> Result<()> {
156 self.builders.remove(builder.rep());
157 Ok(())
158 }
159}
160
161impl<CF: ClientFactory> v4::HostConnection for InstanceState<CF> {
162 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
163 async fn open(&mut self, address: String) -> Result<Resource<v4::Connection>, v4::Error> {
164 spin_factor_outbound_networking::record_address_fields(&address);
165
166 self.ensure_address_allowed(&address).await?;
167
168 self.open_connection(&address, None).await
169 }
170
171 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
172 async fn execute(
173 &mut self,
174 connection: Resource<v4::Connection>,
175 statement: String,
176 params: Vec<v4::ParameterValue>,
177 ) -> Result<u64, v4::Error> {
178 self.get_client(connection)
179 .await?
180 .execute(statement, params)
181 .await
182 }
183
184 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
185 async fn query(
186 &mut self,
187 connection: Resource<v4::Connection>,
188 statement: String,
189 params: Vec<v4::ParameterValue>,
190 ) -> Result<v4::RowSet, v4::Error> {
191 self.get_client(connection)
192 .await?
193 .query(statement, params, MAX_HOST_BUFFERED_BYTES)
194 .await
195 }
196
197 async fn drop(&mut self, connection: Resource<v4::Connection>) -> anyhow::Result<()> {
198 self.connections.remove(connection.rep());
199 Ok(())
200 }
201}
202
203impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionWithStore
204 for crate::PgFactorData<CF>
205{
206 #[instrument(name = "spin_outbound_pg.open_async", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
207 async fn open_async<T>(
208 accessor: &Accessor<T, Self>,
209 address: String,
210 ) -> Result<Resource<v4::Connection>, v4::Error> {
211 spin_factor_outbound_networking::record_address_fields(&address);
212
213 Self::ensure_address_allowed_async(accessor, &address).await?;
214 Self::open_connection_async(accessor, &address, None).await
215 }
216
217 #[instrument(name = "spin_outbound_pg.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
218 async fn execute_async<T>(
219 accessor: &Accessor<T, Self>,
220 connection: Resource<v4::Connection>,
221 statement: String,
222 params: Vec<v4::ParameterValue>,
223 ) -> Result<u64, v4::Error> {
224 let client = accessor.with(|mut access| {
225 let host = access.get();
226 host.connections.get(connection.rep()).unwrap().clone()
227 });
228
229 client.execute(statement, params).await
230 }
231
232 #[allow(clippy::type_complexity)] #[instrument(name = "spin_outbound_pg.query_async", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
234 async fn query_async<T>(
235 accessor: &Accessor<T, Self>,
236 connection: Resource<v4::Connection>,
237 statement: String,
238 params: Vec<v4::ParameterValue>,
239 ) -> Result<
240 (
241 Vec<v4::Column>,
242 StreamReader<v4::Row>,
243 FutureReader<Result<(), v4::Error>>,
244 ),
245 v4::Error,
246 > {
247 let client = accessor.with(|mut access| {
248 let host = access.get();
249 host.connections.get(connection.rep()).unwrap().clone()
250 });
251
252 let QueryAsyncResult {
253 columns,
254 rows,
255 error,
256 } = client
257 .query_async(statement, params, MAX_HOST_BUFFERED_BYTES)
258 .await?;
259
260 let row_producer = spin_wasi_async::stream::producer(rows);
261
262 let (sr, efr) = accessor
263 .with(|mut access| {
264 let sr = StreamReader::new(&mut access, row_producer)?;
265 let efr = FutureReader::new(&mut access, error)?;
266 anyhow::Ok((sr, efr))
267 })
268 .map_err(|e| v4::Error::Other(e.to_string()))?;
269
270 Ok((columns, sr, efr))
271 }
272}
273
274impl<CF: ClientFactory> InstanceState<CF> {
275 #[allow(clippy::result_large_err)]
276 fn get_builder_info(
277 &mut self,
278 builder_rep: u32,
279 ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
280 let builder = self
281 .builders
282 .get_mut(builder_rep)
283 .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?;
284
285 let address = builder.address.clone();
286 let root_ca = builder.root_ca.clone();
287
288 Ok((address, root_ca))
289 }
290}
291
292impl<CF: ClientFactory> crate::PgFactorData<CF> {
293 #[allow(clippy::result_large_err)]
294 fn get_builder_info<T>(
295 accessor: &Accessor<T, Self>,
296 builder: Resource<v4::ConnectionBuilder>,
297 ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
298 let builder_rep = builder.rep();
299 accessor.with(|mut access| {
300 let host = access.get();
301 host.get_builder_info(builder_rep)
302 })
303 }
304
305 async fn ensure_address_allowed_async<T>(
306 accessor: &Accessor<T, Self>,
307 address: &str,
308 ) -> Result<(), v4::Error> {
309 let allowed_host_checker = accessor.with(|mut access| {
311 let host = access.get();
312 host.allowed_host_checker()
313 });
314
315 allowed_host_checker.ensure_address_allowed(address).await
316 }
317
318 async fn open_connection_async<T>(
319 accessor: &Accessor<T, Self>,
320 address: &str,
321 root_ca: Option<HashableCertificate>,
322 ) -> Result<Resource<v4::Connection>, v4::Error> {
323 let cf = accessor.with(|mut access| {
324 let host = access.get();
325 host.client_factory.clone()
326 });
327
328 let client = cf
329 .get_client(address, root_ca)
330 .await
331 .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?;
332
333 accessor.with(|mut access| {
334 let host = access.get();
335 host.connections
336 .push(client)
337 .map_err(|_| v4::Error::ConnectionFailed("too many connections".into()))
338 .map(Resource::new_own)
339 })
340 }
341}
342
343impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionBuilderWithStore
344 for crate::PgFactorData<CF>
345{
346 async fn build_async<T>(
347 accessor: &Accessor<T, Self>,
348 builder: Resource<v4::ConnectionBuilder>,
349 ) -> Result<Resource<v4::Connection>, v4::Error> {
350 let (address, root_ca) = Self::get_builder_info(accessor, builder)?;
351
352 spin_factor_outbound_networking::record_address_fields(&address);
353
354 Self::ensure_address_allowed_async(accessor, &address).await?;
355 Self::open_connection_async(accessor, &address, root_ca).await
356 }
357}
358
359impl<CF: ClientFactory> v2_types::Host for InstanceState<CF> {
360 fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
361 Ok(error)
362 }
363}
364
365impl<CF: ClientFactory> v3::Host for InstanceState<CF> {
366 fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
367 Ok(error)
368 }
369}
370
371impl<CF: ClientFactory> v4::Host for InstanceState<CF> {
372 fn convert_error(&mut self, error: v4::Error) -> Result<v4::Error> {
373 Ok(error)
374 }
375}
376
377macro_rules! delegate {
379 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
380 $self.ensure_address_allowed(&$address).await?;
381 let connection = match $self.open_connection(&$address, None).await {
382 Ok(c) => c,
383 Err(e) => return Err(e.into()),
384 };
385 <Self as v4::HostConnection>::$name($self, connection, $($arg),*)
386 .await
387 .map_err(|e| e.into())
388 }};
389}
390
391impl<CF: ClientFactory> v2::Host for InstanceState<CF> {}
392
393impl<CF: ClientFactory> v2::HostConnection for InstanceState<CF> {
394 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
395 async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
396 self.otel.reparent_tracing_span();
397 spin_factor_outbound_networking::record_address_fields(&address);
398
399 self.ensure_address_allowed(&address).await?;
400 Ok(self.open_connection(&address, None).await?)
401 }
402
403 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
404 async fn execute(
405 &mut self,
406 connection: Resource<v2::Connection>,
407 statement: String,
408 params: Vec<v2_types::ParameterValue>,
409 ) -> Result<u64, v2::Error> {
410 self.otel.reparent_tracing_span();
411 Ok(self
412 .get_client(connection)
413 .await?
414 .execute(statement, v2_params_to_v3(params)?)
415 .await?)
416 }
417
418 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
419 async fn query(
420 &mut self,
421 connection: Resource<v2::Connection>,
422 statement: String,
423 params: Vec<v2_types::ParameterValue>,
424 ) -> Result<v2_types::RowSet, v2::Error> {
425 self.otel.reparent_tracing_span();
426 Ok(self
427 .get_client(connection)
428 .await?
429 .query(statement, v2_params_to_v3(params)?, MAX_HOST_BUFFERED_BYTES)
430 .await?
431 .into())
432 }
433
434 async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
435 self.connections.remove(connection.rep());
436 Ok(())
437 }
438}
439
440impl<CF: ClientFactory> v1::Host for InstanceState<CF> {
441 async fn execute(
442 &mut self,
443 address: String,
444 statement: String,
445 params: Vec<v1_types::ParameterValue>,
446 ) -> Result<u64, v1::PgError> {
447 delegate!(
448 self.execute(
449 address,
450 statement,
451 params
452 .into_iter()
453 .map(TryInto::try_into)
454 .collect::<Result<Vec<_>, _>>()?
455 )
456 )
457 }
458
459 async fn query(
460 &mut self,
461 address: String,
462 statement: String,
463 params: Vec<v1_types::ParameterValue>,
464 ) -> Result<v1_types::RowSet, v1::PgError> {
465 delegate!(
466 self.query(
467 address,
468 statement,
469 params
470 .into_iter()
471 .map(TryInto::try_into)
472 .collect::<Result<Vec<_>, _>>()?
473 )
474 )
475 .map(Into::into)
476 }
477
478 fn convert_pg_error(&mut self, error: v1::PgError) -> Result<v1::PgError> {
479 Ok(error)
480 }
481}