Skip to main content

spin_factor_outbound_pg/
host.rs

1#![allow(clippy::result_large_err)]
2
3use anyhow::Result;
4use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader};
5use spin_telemetry::traces::{self, Blame};
6use spin_world::MAX_HOST_BUFFERED_BYTES;
7use spin_world::spin::postgres3_0_0::postgres::{self as v3};
8use spin_world::spin::postgres4_2_0::postgres::{self as v4};
9use spin_world::v1::postgres as v1;
10use spin_world::v1::rdbms_types as v1_types;
11use spin_world::v2::postgres::{self as v2};
12use spin_world::v2::rdbms_types as v2_types;
13use tracing::Level;
14use tracing::field::Empty;
15use tracing::instrument;
16
17use crate::InstanceState;
18use crate::allowed_hosts::AllowedHostChecker;
19use crate::client::{Client, ClientFactory, HashableCertificate, QueryAsyncResult};
20
21impl<CF: ClientFactory> InstanceState<CF> {
22    async fn open_connection<Conn: 'static>(
23        &mut self,
24        address: &str,
25        root_ca: Option<HashableCertificate>,
26    ) -> Result<Resource<Conn>, v4::Error> {
27        let permit = self.semaphore.acquire().await.map_err(|_| {
28            let err = v4::Error::ConnectionFailed("too many connections".into());
29            traces::mark_as_error(&err, Some(Blame::Guest));
30            err
31        })?;
32        let client = self
33            .client_factory
34            .get_client(address, root_ca)
35            .await
36            .map_err(|e| {
37                // The guest supplies the address and credentials; connection
38                // failures (wrong password, TLS error, unreachable host, etc.)
39                // are the guest's problem.
40                let err = v4::Error::ConnectionFailed(format!("{e:?}"));
41                traces::mark_as_error(&err, Some(Blame::Guest));
42                err
43            })?;
44        self.connections
45            .push((client, permit))
46            .map_err(|_| {
47                // The guest exceeded the host-imposed connection limit.
48                let err = v4::Error::ConnectionFailed("too many connections".into());
49                traces::mark_as_error(&err, Some(Blame::Guest));
50                err
51            })
52            .map(Resource::new_own)
53    }
54
55    async fn get_client<Conn: 'static>(
56        &self,
57        connection: Resource<Conn>,
58    ) -> Result<&CF::Client, v4::Error> {
59        self.connections
60            .get(connection.rep())
61            .map(|(client, _permit)| client)
62            .ok_or_else(|| {
63                // The connection table is managed entirely by the host, so a
64                // missing handle indicates a host-side bug, not a guest mistake.
65                let err = v4::Error::ConnectionFailed("no connection found".into());
66                traces::mark_as_error(&err, Some(Blame::Host));
67                err
68            })
69    }
70
71    fn allowed_host_checker(&self) -> AllowedHostChecker {
72        self.allowed_host_checker.clone()
73    }
74
75    #[allow(clippy::result_large_err)]
76    async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> {
77        self.allowed_host_checker
78            .ensure_address_allowed(address)
79            .await
80    }
81}
82
83fn v2_params_to_v3(
84    params: Vec<v2_types::ParameterValue>,
85) -> Result<Vec<v4::ParameterValue>, v2::Error> {
86    params.into_iter().map(|p| p.try_into()).collect()
87}
88
89fn v3_params_to_v4(params: Vec<v3::ParameterValue>) -> Vec<v4::ParameterValue> {
90    params.into_iter().map(|p| p.into()).collect()
91}
92
93impl<CF: ClientFactory> v3::HostConnection for InstanceState<CF> {
94    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
95    async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
96        spin_factor_outbound_networking::record_address_fields(&address);
97
98        self.ensure_address_allowed(&address)
99            .await
100            .map_err(v3::Error::from)
101            .map_err(track_address_check_error_v3)?;
102
103        self.open_connection(&address, None)
104            .await
105            .map_err(v3::Error::from)
106    }
107
108    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
109    async fn execute(
110        &mut self,
111        connection: Resource<v3::Connection>,
112        statement: String,
113        params: Vec<v3::ParameterValue>,
114    ) -> Result<u64, v3::Error> {
115        self.get_client(connection)
116            .await
117            .map_err(v3::Error::from)?
118            .execute(statement, v3_params_to_v4(params))
119            .await
120            .map_err(v3::Error::from)
121            .map_err(track_db_error_on_span_v3)
122    }
123
124    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
125    async fn query(
126        &mut self,
127        connection: Resource<v3::Connection>,
128        statement: String,
129        params: Vec<v3::ParameterValue>,
130    ) -> Result<v3::RowSet, v3::Error> {
131        let rowset = self
132            .get_client(connection)
133            .await
134            .map_err(v3::Error::from)?
135            .query(statement, v3_params_to_v4(params), MAX_HOST_BUFFERED_BYTES)
136            .await
137            .map_err(v3::Error::from)
138            .map_err(track_db_error_on_span_v3)?;
139        Ok(rowset.into())
140    }
141
142    async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
143        self.connections.remove(connection.rep());
144        Ok(())
145    }
146}
147
148pub(crate) struct ConnectionBuilder {
149    address: String,
150    root_ca: Option<HashableCertificate>,
151}
152
153impl<CF: ClientFactory> v4::HostConnectionBuilder for InstanceState<CF> {
154    async fn new(&mut self, address: String) -> Result<Resource<v4::ConnectionBuilder>> {
155        let builder = ConnectionBuilder {
156            address,
157            root_ca: None,
158        };
159        let rep = self
160            .builders
161            .push(builder)
162            .map_err(|_| anyhow::anyhow!("out of builder table space"))?;
163        let rsrc = Resource::new_own(rep);
164        Ok(rsrc)
165    }
166
167    async fn set_ca_root(
168        &mut self,
169        self_: Resource<v4::ConnectionBuilder>,
170        certificate: String,
171    ) -> Result<(), v4::Error> {
172        let root_ca = HashableCertificate::from_pem(&certificate).map_err(|e| {
173            let err = v4::Error::Other(format!("invalid root certificate: {e}"));
174            traces::mark_as_error(&err, Some(Blame::Guest));
175            err
176        })?;
177        let builder = self.builders.get_mut(self_.rep()).ok_or_else(|| {
178            let err = v4::Error::ConnectionFailed("no builder found".into());
179            traces::mark_as_error(&err, Some(Blame::Host));
180            err
181        })?;
182        builder.root_ca = Some(root_ca);
183        Ok(())
184    }
185
186    async fn build(
187        &mut self,
188        self_: Resource<v4::ConnectionBuilder>,
189    ) -> Result<Resource<v4::Connection>, v4::Error> {
190        let (address, root_ca) = self.get_builder_info(self_.rep())?;
191        self.open_connection(&address, root_ca).await
192    }
193
194    async fn drop(&mut self, builder: Resource<v4::ConnectionBuilder>) -> Result<()> {
195        self.builders.remove(builder.rep());
196        Ok(())
197    }
198}
199
200impl<CF: ClientFactory> v4::HostConnection for InstanceState<CF> {
201    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
202    async fn open(&mut self, address: String) -> Result<Resource<v4::Connection>, v4::Error> {
203        spin_factor_outbound_networking::record_address_fields(&address);
204
205        self.ensure_address_allowed(&address)
206            .await
207            .map_err(track_address_check_error_v4)?;
208
209        self.open_connection(&address, None).await
210    }
211
212    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
213    async fn execute(
214        &mut self,
215        connection: Resource<v4::Connection>,
216        statement: String,
217        params: Vec<v4::ParameterValue>,
218    ) -> Result<u64, v4::Error> {
219        self.get_client(connection)
220            .await?
221            .execute(statement, params)
222            .await
223            .map_err(track_db_error_on_span_v4)
224    }
225
226    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
227    async fn query(
228        &mut self,
229        connection: Resource<v4::Connection>,
230        statement: String,
231        params: Vec<v4::ParameterValue>,
232    ) -> Result<v4::RowSet, v4::Error> {
233        self.get_client(connection)
234            .await?
235            .query(statement, params, MAX_HOST_BUFFERED_BYTES)
236            .await
237            .map_err(track_db_error_on_span_v4)
238    }
239
240    async fn drop(&mut self, connection: Resource<v4::Connection>) -> anyhow::Result<()> {
241        self.connections.remove(connection.rep());
242        Ok(())
243    }
244}
245
246impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionWithStore
247    for crate::PgFactorData<CF>
248{
249    #[instrument(name = "spin_outbound_pg.open_async", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
250    async fn open_async<T>(
251        accessor: &Accessor<T, Self>,
252        address: String,
253    ) -> Result<Resource<v4::Connection>, v4::Error> {
254        spin_factor_outbound_networking::record_address_fields(&address);
255
256        Self::ensure_address_allowed_async(accessor, &address)
257            .await
258            .map_err(track_address_check_error_v4)?;
259        Self::open_connection_async(accessor, &address, None).await
260    }
261
262    #[instrument(name = "spin_outbound_pg.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
263    async fn execute_async<T>(
264        accessor: &Accessor<T, Self>,
265        connection: Resource<v4::Connection>,
266        statement: String,
267        params: Vec<v4::ParameterValue>,
268    ) -> Result<u64, v4::Error> {
269        let client = accessor.with(|mut access| {
270            let host = access.get();
271            host.connections
272                .get(connection.rep())
273                .map(|(client, _permit)| client.clone())
274                .unwrap()
275        });
276
277        client
278            .execute(statement, params)
279            .await
280            .map_err(track_db_error_on_span_v4)
281    }
282
283    #[allow(clippy::type_complexity)] // blame bindgen, clippy, blame bindgen
284    #[instrument(name = "spin_outbound_pg.query_async", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
285    async fn query_async<T>(
286        accessor: &Accessor<T, Self>,
287        connection: Resource<v4::Connection>,
288        statement: String,
289        params: Vec<v4::ParameterValue>,
290    ) -> Result<
291        (
292            Vec<v4::Column>,
293            StreamReader<v4::Row>,
294            FutureReader<Result<(), v4::Error>>,
295        ),
296        v4::Error,
297    > {
298        let client = accessor.with(|mut access| {
299            let host = access.get();
300            host.connections
301                .get(connection.rep())
302                .map(|(client, _permit)| client.clone())
303                .unwrap()
304        });
305
306        let QueryAsyncResult {
307            columns,
308            rows,
309            error,
310        } = client
311            .query_async(statement, params, MAX_HOST_BUFFERED_BYTES)
312            .await
313            .map_err(track_db_error_on_span_v4)?;
314
315        let row_producer = spin_wasi_async::stream::producer(rows);
316
317        let (sr, efr) = accessor
318            .with(|mut access| {
319                let sr = StreamReader::new(&mut access, row_producer)?;
320                let efr = FutureReader::new(&mut access, error)?;
321                anyhow::Ok((sr, efr))
322            })
323            .map_err(|e| {
324                // Setting up the async stream/future channels is a host
325                // implementation detail; if it fails, that's a host bug.
326                let err = v4::Error::Other(e.to_string());
327                traces::mark_as_error(&err, Some(Blame::Host));
328                err
329            })?;
330
331        Ok((columns, sr, efr))
332    }
333}
334
335impl<CF: ClientFactory> InstanceState<CF> {
336    #[allow(clippy::result_large_err)]
337    fn get_builder_info(
338        &mut self,
339        builder_rep: u32,
340    ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
341        let builder = self.builders.get_mut(builder_rep).ok_or_else(|| {
342            let err = v4::Error::ConnectionFailed("no builder found".into());
343            traces::mark_as_error(&err, Some(Blame::Host));
344            err
345        })?;
346
347        let address = builder.address.clone();
348        let root_ca = builder.root_ca.clone();
349
350        Ok((address, root_ca))
351    }
352}
353
354impl<CF: ClientFactory> crate::PgFactorData<CF> {
355    #[allow(clippy::result_large_err)]
356    fn get_builder_info<T>(
357        accessor: &Accessor<T, Self>,
358        builder: Resource<v4::ConnectionBuilder>,
359    ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
360        let builder_rep = builder.rep();
361        accessor.with(|mut access| {
362            let host = access.get();
363            host.get_builder_info(builder_rep)
364        })
365    }
366
367    async fn ensure_address_allowed_async<T>(
368        accessor: &Accessor<T, Self>,
369        address: &str,
370    ) -> Result<(), v4::Error> {
371        // A merry dance to avoid doing the async allow check under the accessor
372        let allowed_host_checker = accessor.with(|mut access| {
373            let host = access.get();
374            host.allowed_host_checker()
375        });
376
377        allowed_host_checker.ensure_address_allowed(address).await
378    }
379
380    async fn open_connection_async<T>(
381        accessor: &Accessor<T, Self>,
382        address: &str,
383        root_ca: Option<HashableCertificate>,
384    ) -> Result<Resource<v4::Connection>, v4::Error> {
385        let (cf, semaphore) = accessor.with(|mut access| {
386            let host = access.get();
387            (host.client_factory.clone(), host.semaphore.clone())
388        });
389
390        let permit = semaphore.acquire().await.map_err(|_| {
391            let err = v4::Error::ConnectionFailed("too many connections".into());
392            traces::mark_as_error(&err, Some(Blame::Guest));
393            err
394        })?;
395
396        let client = cf.get_client(address, root_ca).await.map_err(|e| {
397            let err = v4::Error::ConnectionFailed(format!("{e:?}"));
398            traces::mark_as_error(&err, Some(Blame::Guest));
399            err
400        })?;
401
402        accessor.with(|mut access| {
403            let host = access.get();
404            host.connections
405                .push((client, permit))
406                .map_err(|_| {
407                    let err = v4::Error::ConnectionFailed("too many connections".into());
408                    traces::mark_as_error(&err, Some(Blame::Guest));
409                    err
410                })
411                .map(Resource::new_own)
412        })
413    }
414}
415
416impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionBuilderWithStore
417    for crate::PgFactorData<CF>
418{
419    async fn build_async<T>(
420        accessor: &Accessor<T, Self>,
421        builder: Resource<v4::ConnectionBuilder>,
422    ) -> Result<Resource<v4::Connection>, v4::Error> {
423        let (address, root_ca) = Self::get_builder_info(accessor, builder)?;
424
425        spin_factor_outbound_networking::record_address_fields(&address);
426
427        Self::ensure_address_allowed_async(accessor, &address)
428            .await
429            .map_err(track_address_check_error_v4)?;
430        Self::open_connection_async(accessor, &address, root_ca).await
431    }
432}
433
434impl<CF: ClientFactory> v2_types::Host for InstanceState<CF> {
435    fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
436        Ok(error)
437    }
438}
439
440impl<CF: ClientFactory> v3::Host for InstanceState<CF> {
441    fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
442        Ok(error)
443    }
444}
445
446impl<CF: ClientFactory> v4::Host for InstanceState<CF> {
447    fn convert_error(&mut self, error: v4::Error) -> Result<v4::Error> {
448        Ok(error)
449    }
450}
451
452/// Delegate a function call to the v4::HostConnection implementation
453macro_rules! delegate {
454    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
455        $self.ensure_address_allowed(&$address).await?;
456        let connection = match $self.open_connection(&$address, None).await {
457            Ok(c) => c,
458            Err(e) => return Err(e.into()),
459        };
460        // v1 has no persistent connections, so remove the table entry immediately
461        // after the call to release the semaphore permit.
462        let rep = connection.rep();
463        let result = <Self as v4::HostConnection>::$name($self, connection, $($arg),*)
464            .await
465            .map_err(|e| e.into());
466        $self.connections.remove(rep);
467        result
468    }};
469}
470
471impl<CF: ClientFactory> v2::Host for InstanceState<CF> {}
472
473impl<CF: ClientFactory> v2::HostConnection for InstanceState<CF> {
474    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
475    async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
476        self.otel.reparent_tracing_span();
477        spin_factor_outbound_networking::record_address_fields(&address);
478
479        self.ensure_address_allowed(&address)
480            .await
481            .map_err(v2::Error::from)
482            .map_err(track_address_check_error_v2)?;
483        self.open_connection(&address, None)
484            .await
485            .map_err(v2::Error::from)
486    }
487
488    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
489    async fn execute(
490        &mut self,
491        connection: Resource<v2::Connection>,
492        statement: String,
493        params: Vec<v2_types::ParameterValue>,
494    ) -> Result<u64, v2::Error> {
495        self.otel.reparent_tracing_span();
496        let params = v2_params_to_v3(params).inspect_err(|e| {
497            traces::mark_as_error(e, Some(Blame::Guest));
498        })?;
499        self.get_client(connection)
500            .await
501            .map_err(v2::Error::from)?
502            .execute(statement, params)
503            .await
504            .map_err(v2::Error::from)
505            .map_err(track_db_error_on_span_v2)
506    }
507
508    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
509    async fn query(
510        &mut self,
511        connection: Resource<v2::Connection>,
512        statement: String,
513        params: Vec<v2_types::ParameterValue>,
514    ) -> Result<v2_types::RowSet, v2::Error> {
515        self.otel.reparent_tracing_span();
516        let params = v2_params_to_v3(params).inspect_err(|e| {
517            traces::mark_as_error(e, Some(Blame::Guest));
518        })?;
519        Ok(self
520            .get_client(connection)
521            .await
522            .map_err(v2::Error::from)?
523            .query(statement, params, MAX_HOST_BUFFERED_BYTES)
524            .await
525            .map_err(v2::Error::from)
526            .map_err(track_db_error_on_span_v2)?
527            .into())
528    }
529
530    async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
531        self.connections.remove(connection.rep());
532        Ok(())
533    }
534}
535
536impl<CF: ClientFactory> v1::Host for InstanceState<CF> {
537    async fn execute(
538        &mut self,
539        address: String,
540        statement: String,
541        params: Vec<v1_types::ParameterValue>,
542    ) -> Result<u64, v1::PgError> {
543        delegate!(
544            self.execute(
545                address,
546                statement,
547                params
548                    .into_iter()
549                    .map(TryInto::try_into)
550                    .collect::<Result<Vec<_>, _>>()?
551            )
552        )
553    }
554
555    async fn query(
556        &mut self,
557        address: String,
558        statement: String,
559        params: Vec<v1_types::ParameterValue>,
560    ) -> Result<v1_types::RowSet, v1::PgError> {
561        delegate!(
562            self.query(
563                address,
564                statement,
565                params
566                    .into_iter()
567                    .map(TryInto::try_into)
568                    .collect::<Result<Vec<_>, _>>()?
569            )
570        )
571        .map(Into::into)
572    }
573
574    fn convert_pg_error(&mut self, error: v1::PgError) -> Result<v1::PgError> {
575        Ok(error)
576    }
577}
578
579/// Mark errors from `ensure_address_allowed` on the current span.
580///
581/// Address check errors where the check infrastructure itself fails (`Other`) are Host-blamed.
582/// All other errors (address not permitted, malformed address, unsupported socket type) are
583/// Guest-blamed since the guest supplied the address.
584fn track_address_check_error_v4(err: v4::Error) -> v4::Error {
585    let blame = match &err {
586        v4::Error::Other(_) => Blame::Host,
587        _ => Blame::Guest,
588    };
589    traces::mark_as_error(&err, Some(blame));
590    err
591}
592
593fn track_address_check_error_v3(err: v3::Error) -> v3::Error {
594    let blame = match &err {
595        v3::Error::Other(_) => Blame::Host,
596        _ => Blame::Guest,
597    };
598    traces::mark_as_error(&err, Some(blame));
599    err
600}
601
602fn track_address_check_error_v2(err: v2::Error) -> v2::Error {
603    let blame = match &err {
604        v2::Error::Other(_) => Blame::Host,
605        _ => Blame::Guest,
606    };
607    traces::mark_as_error(&err, Some(blame));
608    err
609}
610
611/// Mark errors from actual DB client calls (execute/query) on the current span.
612fn track_db_error_on_span_v4(err: v4::Error) -> v4::Error {
613    let blame = match &err {
614        // The guest brings their own database, so connection failures during
615        // execution (dropped connection, auth rejected mid-session, etc.) are
616        // the guest's problem, not the host's.
617        v4::Error::ConnectionFailed(_) => Blame::Guest,
618        v4::Error::BadParameter(_) => Blame::Guest,
619        v4::Error::QueryFailed(_) => Blame::Guest,
620        // The host is responsible for mapping DB wire types to WIT types;
621        // a conversion failure is a host-side limitation or bug.
622        v4::Error::ValueConversionFailed(_) => Blame::Host,
623        v4::Error::Other(_) => Blame::Host,
624    };
625    traces::mark_as_error(&err, Some(blame));
626    err
627}
628
629fn track_db_error_on_span_v3(err: v3::Error) -> v3::Error {
630    let blame = match &err {
631        v3::Error::ConnectionFailed(_) => Blame::Guest,
632        v3::Error::BadParameter(_) => Blame::Guest,
633        v3::Error::QueryFailed(_) => Blame::Guest,
634        v3::Error::ValueConversionFailed(_) => Blame::Host,
635        v3::Error::Other(_) => Blame::Host,
636    };
637    traces::mark_as_error(&err, Some(blame));
638    err
639}
640
641fn track_db_error_on_span_v2(err: v2::Error) -> v2::Error {
642    let blame = match &err {
643        v2::Error::ConnectionFailed(_) => Blame::Guest,
644        v2::Error::BadParameter(_) => Blame::Guest,
645        v2::Error::QueryFailed(_) => Blame::Guest,
646        v2::Error::ValueConversionFailed(_) => Blame::Host,
647        v2::Error::Other(_) => Blame::Host,
648    };
649    traces::mark_as_error(&err, Some(blame));
650    err
651}