1#![allow(clippy::result_large_err)]
2
3use anyhow::Result;
4use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader};
5use spin_telemetry::traces::{self, Blame};
6use spin_world::MAX_HOST_BUFFERED_BYTES;
7use spin_world::spin::postgres3_0_0::postgres::{self as v3};
8use spin_world::spin::postgres4_2_0::postgres::{self as v4};
9use spin_world::v1::postgres as v1;
10use spin_world::v1::rdbms_types as v1_types;
11use spin_world::v2::postgres::{self as v2};
12use spin_world::v2::rdbms_types as v2_types;
13use tracing::Level;
14use tracing::field::Empty;
15use tracing::instrument;
16
17use crate::InstanceState;
18use crate::allowed_hosts::AllowedHostChecker;
19use crate::client::{Client, ClientFactory, HashableCertificate, QueryAsyncResult};
20
21impl<CF: ClientFactory> InstanceState<CF> {
22 async fn open_connection<Conn: 'static>(
23 &mut self,
24 address: &str,
25 root_ca: Option<HashableCertificate>,
26 ) -> Result<Resource<Conn>, v4::Error> {
27 let permit = self.semaphore.acquire().await.map_err(|_| {
28 let err = v4::Error::ConnectionFailed("too many connections".into());
29 traces::mark_as_error(&err, Some(Blame::Guest));
30 err
31 })?;
32 let client = self
33 .client_factory
34 .get_client(address, root_ca)
35 .await
36 .map_err(|e| {
37 let err = v4::Error::ConnectionFailed(format!("{e:?}"));
41 traces::mark_as_error(&err, Some(Blame::Guest));
42 err
43 })?;
44 self.connections
45 .push((client, permit))
46 .map_err(|_| {
47 let err = v4::Error::ConnectionFailed("too many connections".into());
49 traces::mark_as_error(&err, Some(Blame::Guest));
50 err
51 })
52 .map(Resource::new_own)
53 }
54
55 async fn get_client<Conn: 'static>(
56 &self,
57 connection: Resource<Conn>,
58 ) -> Result<&CF::Client, v4::Error> {
59 self.connections
60 .get(connection.rep())
61 .map(|(client, _permit)| client)
62 .ok_or_else(|| {
63 let err = v4::Error::ConnectionFailed("no connection found".into());
66 traces::mark_as_error(&err, Some(Blame::Host));
67 err
68 })
69 }
70
71 fn allowed_host_checker(&self) -> AllowedHostChecker {
72 self.allowed_host_checker.clone()
73 }
74
75 #[allow(clippy::result_large_err)]
76 async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> {
77 self.allowed_host_checker
78 .ensure_address_allowed(address)
79 .await
80 }
81}
82
83fn v2_params_to_v3(
84 params: Vec<v2_types::ParameterValue>,
85) -> Result<Vec<v4::ParameterValue>, v2::Error> {
86 params.into_iter().map(|p| p.try_into()).collect()
87}
88
89fn v3_params_to_v4(params: Vec<v3::ParameterValue>) -> Vec<v4::ParameterValue> {
90 params.into_iter().map(|p| p.into()).collect()
91}
92
93impl<CF: ClientFactory> v3::HostConnection for InstanceState<CF> {
94 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
95 async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
96 spin_factor_outbound_networking::record_address_fields(&address);
97
98 self.ensure_address_allowed(&address)
99 .await
100 .map_err(v3::Error::from)
101 .map_err(track_address_check_error_v3)?;
102
103 self.open_connection(&address, None)
104 .await
105 .map_err(v3::Error::from)
106 }
107
108 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
109 async fn execute(
110 &mut self,
111 connection: Resource<v3::Connection>,
112 statement: String,
113 params: Vec<v3::ParameterValue>,
114 ) -> Result<u64, v3::Error> {
115 self.get_client(connection)
116 .await
117 .map_err(v3::Error::from)?
118 .execute(statement, v3_params_to_v4(params))
119 .await
120 .map_err(v3::Error::from)
121 .map_err(track_db_error_on_span_v3)
122 }
123
124 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
125 async fn query(
126 &mut self,
127 connection: Resource<v3::Connection>,
128 statement: String,
129 params: Vec<v3::ParameterValue>,
130 ) -> Result<v3::RowSet, v3::Error> {
131 let rowset = self
132 .get_client(connection)
133 .await
134 .map_err(v3::Error::from)?
135 .query(statement, v3_params_to_v4(params), MAX_HOST_BUFFERED_BYTES)
136 .await
137 .map_err(v3::Error::from)
138 .map_err(track_db_error_on_span_v3)?;
139 Ok(rowset.into())
140 }
141
142 async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
143 self.connections.remove(connection.rep());
144 Ok(())
145 }
146}
147
148pub(crate) struct ConnectionBuilder {
149 address: String,
150 root_ca: Option<HashableCertificate>,
151}
152
153impl<CF: ClientFactory> v4::HostConnectionBuilder for InstanceState<CF> {
154 async fn new(&mut self, address: String) -> Result<Resource<v4::ConnectionBuilder>> {
155 let builder = ConnectionBuilder {
156 address,
157 root_ca: None,
158 };
159 let rep = self
160 .builders
161 .push(builder)
162 .map_err(|_| anyhow::anyhow!("out of builder table space"))?;
163 let rsrc = Resource::new_own(rep);
164 Ok(rsrc)
165 }
166
167 async fn set_ca_root(
168 &mut self,
169 self_: Resource<v4::ConnectionBuilder>,
170 certificate: String,
171 ) -> Result<(), v4::Error> {
172 let root_ca = HashableCertificate::from_pem(&certificate).map_err(|e| {
173 let err = v4::Error::Other(format!("invalid root certificate: {e}"));
174 traces::mark_as_error(&err, Some(Blame::Guest));
175 err
176 })?;
177 let builder = self.builders.get_mut(self_.rep()).ok_or_else(|| {
178 let err = v4::Error::ConnectionFailed("no builder found".into());
179 traces::mark_as_error(&err, Some(Blame::Host));
180 err
181 })?;
182 builder.root_ca = Some(root_ca);
183 Ok(())
184 }
185
186 async fn build(
187 &mut self,
188 self_: Resource<v4::ConnectionBuilder>,
189 ) -> Result<Resource<v4::Connection>, v4::Error> {
190 let (address, root_ca) = self.get_builder_info(self_.rep())?;
191 self.open_connection(&address, root_ca).await
192 }
193
194 async fn drop(&mut self, builder: Resource<v4::ConnectionBuilder>) -> Result<()> {
195 self.builders.remove(builder.rep());
196 Ok(())
197 }
198}
199
200impl<CF: ClientFactory> v4::HostConnection for InstanceState<CF> {
201 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
202 async fn open(&mut self, address: String) -> Result<Resource<v4::Connection>, v4::Error> {
203 spin_factor_outbound_networking::record_address_fields(&address);
204
205 self.ensure_address_allowed(&address)
206 .await
207 .map_err(track_address_check_error_v4)?;
208
209 self.open_connection(&address, None).await
210 }
211
212 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
213 async fn execute(
214 &mut self,
215 connection: Resource<v4::Connection>,
216 statement: String,
217 params: Vec<v4::ParameterValue>,
218 ) -> Result<u64, v4::Error> {
219 self.get_client(connection)
220 .await?
221 .execute(statement, params)
222 .await
223 .map_err(track_db_error_on_span_v4)
224 }
225
226 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
227 async fn query(
228 &mut self,
229 connection: Resource<v4::Connection>,
230 statement: String,
231 params: Vec<v4::ParameterValue>,
232 ) -> Result<v4::RowSet, v4::Error> {
233 self.get_client(connection)
234 .await?
235 .query(statement, params, MAX_HOST_BUFFERED_BYTES)
236 .await
237 .map_err(track_db_error_on_span_v4)
238 }
239
240 async fn drop(&mut self, connection: Resource<v4::Connection>) -> anyhow::Result<()> {
241 self.connections.remove(connection.rep());
242 Ok(())
243 }
244}
245
246impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionWithStore
247 for crate::PgFactorData<CF>
248{
249 #[instrument(name = "spin_outbound_pg.open_async", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
250 async fn open_async<T>(
251 accessor: &Accessor<T, Self>,
252 address: String,
253 ) -> Result<Resource<v4::Connection>, v4::Error> {
254 spin_factor_outbound_networking::record_address_fields(&address);
255
256 Self::ensure_address_allowed_async(accessor, &address)
257 .await
258 .map_err(track_address_check_error_v4)?;
259 Self::open_connection_async(accessor, &address, None).await
260 }
261
262 #[instrument(name = "spin_outbound_pg.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
263 async fn execute_async<T>(
264 accessor: &Accessor<T, Self>,
265 connection: Resource<v4::Connection>,
266 statement: String,
267 params: Vec<v4::ParameterValue>,
268 ) -> Result<u64, v4::Error> {
269 let client = accessor.with(|mut access| {
270 let host = access.get();
271 host.connections
272 .get(connection.rep())
273 .map(|(client, _permit)| client.clone())
274 .unwrap()
275 });
276
277 client
278 .execute(statement, params)
279 .await
280 .map_err(track_db_error_on_span_v4)
281 }
282
283 #[allow(clippy::type_complexity)] #[instrument(name = "spin_outbound_pg.query_async", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
285 async fn query_async<T>(
286 accessor: &Accessor<T, Self>,
287 connection: Resource<v4::Connection>,
288 statement: String,
289 params: Vec<v4::ParameterValue>,
290 ) -> Result<
291 (
292 Vec<v4::Column>,
293 StreamReader<v4::Row>,
294 FutureReader<Result<(), v4::Error>>,
295 ),
296 v4::Error,
297 > {
298 let client = accessor.with(|mut access| {
299 let host = access.get();
300 host.connections
301 .get(connection.rep())
302 .map(|(client, _permit)| client.clone())
303 .unwrap()
304 });
305
306 let QueryAsyncResult {
307 columns,
308 rows,
309 error,
310 } = client
311 .query_async(statement, params, MAX_HOST_BUFFERED_BYTES)
312 .await
313 .map_err(track_db_error_on_span_v4)?;
314
315 let row_producer = spin_wasi_async::stream::producer(rows);
316
317 let (sr, efr) = accessor
318 .with(|mut access| {
319 let sr = StreamReader::new(&mut access, row_producer)?;
320 let efr = FutureReader::new(&mut access, error)?;
321 anyhow::Ok((sr, efr))
322 })
323 .map_err(|e| {
324 let err = v4::Error::Other(e.to_string());
327 traces::mark_as_error(&err, Some(Blame::Host));
328 err
329 })?;
330
331 Ok((columns, sr, efr))
332 }
333}
334
335impl<CF: ClientFactory> InstanceState<CF> {
336 #[allow(clippy::result_large_err)]
337 fn get_builder_info(
338 &mut self,
339 builder_rep: u32,
340 ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
341 let builder = self.builders.get_mut(builder_rep).ok_or_else(|| {
342 let err = v4::Error::ConnectionFailed("no builder found".into());
343 traces::mark_as_error(&err, Some(Blame::Host));
344 err
345 })?;
346
347 let address = builder.address.clone();
348 let root_ca = builder.root_ca.clone();
349
350 Ok((address, root_ca))
351 }
352}
353
354impl<CF: ClientFactory> crate::PgFactorData<CF> {
355 #[allow(clippy::result_large_err)]
356 fn get_builder_info<T>(
357 accessor: &Accessor<T, Self>,
358 builder: Resource<v4::ConnectionBuilder>,
359 ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
360 let builder_rep = builder.rep();
361 accessor.with(|mut access| {
362 let host = access.get();
363 host.get_builder_info(builder_rep)
364 })
365 }
366
367 async fn ensure_address_allowed_async<T>(
368 accessor: &Accessor<T, Self>,
369 address: &str,
370 ) -> Result<(), v4::Error> {
371 let allowed_host_checker = accessor.with(|mut access| {
373 let host = access.get();
374 host.allowed_host_checker()
375 });
376
377 allowed_host_checker.ensure_address_allowed(address).await
378 }
379
380 async fn open_connection_async<T>(
381 accessor: &Accessor<T, Self>,
382 address: &str,
383 root_ca: Option<HashableCertificate>,
384 ) -> Result<Resource<v4::Connection>, v4::Error> {
385 let (cf, semaphore) = accessor.with(|mut access| {
386 let host = access.get();
387 (host.client_factory.clone(), host.semaphore.clone())
388 });
389
390 let permit = semaphore.acquire().await.map_err(|_| {
391 let err = v4::Error::ConnectionFailed("too many connections".into());
392 traces::mark_as_error(&err, Some(Blame::Guest));
393 err
394 })?;
395
396 let client = cf.get_client(address, root_ca).await.map_err(|e| {
397 let err = v4::Error::ConnectionFailed(format!("{e:?}"));
398 traces::mark_as_error(&err, Some(Blame::Guest));
399 err
400 })?;
401
402 accessor.with(|mut access| {
403 let host = access.get();
404 host.connections
405 .push((client, permit))
406 .map_err(|_| {
407 let err = v4::Error::ConnectionFailed("too many connections".into());
408 traces::mark_as_error(&err, Some(Blame::Guest));
409 err
410 })
411 .map(Resource::new_own)
412 })
413 }
414}
415
416impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionBuilderWithStore
417 for crate::PgFactorData<CF>
418{
419 async fn build_async<T>(
420 accessor: &Accessor<T, Self>,
421 builder: Resource<v4::ConnectionBuilder>,
422 ) -> Result<Resource<v4::Connection>, v4::Error> {
423 let (address, root_ca) = Self::get_builder_info(accessor, builder)?;
424
425 spin_factor_outbound_networking::record_address_fields(&address);
426
427 Self::ensure_address_allowed_async(accessor, &address)
428 .await
429 .map_err(track_address_check_error_v4)?;
430 Self::open_connection_async(accessor, &address, root_ca).await
431 }
432}
433
434impl<CF: ClientFactory> v2_types::Host for InstanceState<CF> {
435 fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
436 Ok(error)
437 }
438}
439
440impl<CF: ClientFactory> v3::Host for InstanceState<CF> {
441 fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
442 Ok(error)
443 }
444}
445
446impl<CF: ClientFactory> v4::Host for InstanceState<CF> {
447 fn convert_error(&mut self, error: v4::Error) -> Result<v4::Error> {
448 Ok(error)
449 }
450}
451
452macro_rules! delegate {
454 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
455 $self.ensure_address_allowed(&$address).await?;
456 let connection = match $self.open_connection(&$address, None).await {
457 Ok(c) => c,
458 Err(e) => return Err(e.into()),
459 };
460 let rep = connection.rep();
463 let result = <Self as v4::HostConnection>::$name($self, connection, $($arg),*)
464 .await
465 .map_err(|e| e.into());
466 $self.connections.remove(rep);
467 result
468 }};
469}
470
471impl<CF: ClientFactory> v2::Host for InstanceState<CF> {}
472
473impl<CF: ClientFactory> v2::HostConnection for InstanceState<CF> {
474 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
475 async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
476 self.otel.reparent_tracing_span();
477 spin_factor_outbound_networking::record_address_fields(&address);
478
479 self.ensure_address_allowed(&address)
480 .await
481 .map_err(v2::Error::from)
482 .map_err(track_address_check_error_v2)?;
483 self.open_connection(&address, None)
484 .await
485 .map_err(v2::Error::from)
486 }
487
488 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
489 async fn execute(
490 &mut self,
491 connection: Resource<v2::Connection>,
492 statement: String,
493 params: Vec<v2_types::ParameterValue>,
494 ) -> Result<u64, v2::Error> {
495 self.otel.reparent_tracing_span();
496 let params = v2_params_to_v3(params).inspect_err(|e| {
497 traces::mark_as_error(e, Some(Blame::Guest));
498 })?;
499 self.get_client(connection)
500 .await
501 .map_err(v2::Error::from)?
502 .execute(statement, params)
503 .await
504 .map_err(v2::Error::from)
505 .map_err(track_db_error_on_span_v2)
506 }
507
508 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
509 async fn query(
510 &mut self,
511 connection: Resource<v2::Connection>,
512 statement: String,
513 params: Vec<v2_types::ParameterValue>,
514 ) -> Result<v2_types::RowSet, v2::Error> {
515 self.otel.reparent_tracing_span();
516 let params = v2_params_to_v3(params).inspect_err(|e| {
517 traces::mark_as_error(e, Some(Blame::Guest));
518 })?;
519 Ok(self
520 .get_client(connection)
521 .await
522 .map_err(v2::Error::from)?
523 .query(statement, params, MAX_HOST_BUFFERED_BYTES)
524 .await
525 .map_err(v2::Error::from)
526 .map_err(track_db_error_on_span_v2)?
527 .into())
528 }
529
530 async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
531 self.connections.remove(connection.rep());
532 Ok(())
533 }
534}
535
536impl<CF: ClientFactory> v1::Host for InstanceState<CF> {
537 async fn execute(
538 &mut self,
539 address: String,
540 statement: String,
541 params: Vec<v1_types::ParameterValue>,
542 ) -> Result<u64, v1::PgError> {
543 delegate!(
544 self.execute(
545 address,
546 statement,
547 params
548 .into_iter()
549 .map(TryInto::try_into)
550 .collect::<Result<Vec<_>, _>>()?
551 )
552 )
553 }
554
555 async fn query(
556 &mut self,
557 address: String,
558 statement: String,
559 params: Vec<v1_types::ParameterValue>,
560 ) -> Result<v1_types::RowSet, v1::PgError> {
561 delegate!(
562 self.query(
563 address,
564 statement,
565 params
566 .into_iter()
567 .map(TryInto::try_into)
568 .collect::<Result<Vec<_>, _>>()?
569 )
570 )
571 .map(Into::into)
572 }
573
574 fn convert_pg_error(&mut self, error: v1::PgError) -> Result<v1::PgError> {
575 Ok(error)
576 }
577}
578
579fn track_address_check_error_v4(err: v4::Error) -> v4::Error {
585 let blame = match &err {
586 v4::Error::Other(_) => Blame::Host,
587 _ => Blame::Guest,
588 };
589 traces::mark_as_error(&err, Some(blame));
590 err
591}
592
593fn track_address_check_error_v3(err: v3::Error) -> v3::Error {
594 let blame = match &err {
595 v3::Error::Other(_) => Blame::Host,
596 _ => Blame::Guest,
597 };
598 traces::mark_as_error(&err, Some(blame));
599 err
600}
601
602fn track_address_check_error_v2(err: v2::Error) -> v2::Error {
603 let blame = match &err {
604 v2::Error::Other(_) => Blame::Host,
605 _ => Blame::Guest,
606 };
607 traces::mark_as_error(&err, Some(blame));
608 err
609}
610
611fn track_db_error_on_span_v4(err: v4::Error) -> v4::Error {
613 let blame = match &err {
614 v4::Error::ConnectionFailed(_) => Blame::Guest,
618 v4::Error::BadParameter(_) => Blame::Guest,
619 v4::Error::QueryFailed(_) => Blame::Guest,
620 v4::Error::ValueConversionFailed(_) => Blame::Host,
623 v4::Error::Other(_) => Blame::Host,
624 };
625 traces::mark_as_error(&err, Some(blame));
626 err
627}
628
629fn track_db_error_on_span_v3(err: v3::Error) -> v3::Error {
630 let blame = match &err {
631 v3::Error::ConnectionFailed(_) => Blame::Guest,
632 v3::Error::BadParameter(_) => Blame::Guest,
633 v3::Error::QueryFailed(_) => Blame::Guest,
634 v3::Error::ValueConversionFailed(_) => Blame::Host,
635 v3::Error::Other(_) => Blame::Host,
636 };
637 traces::mark_as_error(&err, Some(blame));
638 err
639}
640
641fn track_db_error_on_span_v2(err: v2::Error) -> v2::Error {
642 let blame = match &err {
643 v2::Error::ConnectionFailed(_) => Blame::Guest,
644 v2::Error::BadParameter(_) => Blame::Guest,
645 v2::Error::QueryFailed(_) => Blame::Guest,
646 v2::Error::ValueConversionFailed(_) => Blame::Host,
647 v2::Error::Other(_) => Blame::Host,
648 };
649 traces::mark_as_error(&err, Some(blame));
650 err
651}