Skip to main content

spin_factor_outbound_pg/
host.rs

1#![allow(clippy::result_large_err)]
2
3use anyhow::Result;
4use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader};
5use spin_world::MAX_HOST_BUFFERED_BYTES;
6use spin_world::spin::postgres3_0_0::postgres::{self as v3};
7use spin_world::spin::postgres4_2_0::postgres::{self as v4};
8use spin_world::v1::postgres as v1;
9use spin_world::v1::rdbms_types as v1_types;
10use spin_world::v2::postgres::{self as v2};
11use spin_world::v2::rdbms_types as v2_types;
12use tracing::Level;
13use tracing::field::Empty;
14use tracing::instrument;
15
16use crate::InstanceState;
17use crate::allowed_hosts::AllowedHostChecker;
18use crate::client::{Client, ClientFactory, HashableCertificate, QueryAsyncResult};
19
20impl<CF: ClientFactory> InstanceState<CF> {
21    async fn open_connection<Conn: 'static>(
22        &mut self,
23        address: &str,
24        root_ca: Option<HashableCertificate>,
25    ) -> Result<Resource<Conn>, v4::Error> {
26        self.connections
27            .push(
28                self.client_factory
29                    .get_client(address, root_ca)
30                    .await
31                    .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?,
32            )
33            .map_err(|_| v4::Error::ConnectionFailed("too many connections".into()))
34            .map(Resource::new_own)
35    }
36
37    async fn get_client<Conn: 'static>(
38        &self,
39        connection: Resource<Conn>,
40    ) -> Result<&CF::Client, v4::Error> {
41        self.connections
42            .get(connection.rep())
43            .ok_or_else(|| v4::Error::ConnectionFailed("no connection found".into()))
44    }
45
46    fn allowed_host_checker(&self) -> AllowedHostChecker {
47        self.allowed_host_checker.clone()
48    }
49
50    #[allow(clippy::result_large_err)]
51    async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> {
52        self.allowed_host_checker
53            .ensure_address_allowed(address)
54            .await
55    }
56}
57
58fn v2_params_to_v3(
59    params: Vec<v2_types::ParameterValue>,
60) -> Result<Vec<v4::ParameterValue>, v2::Error> {
61    params.into_iter().map(|p| p.try_into()).collect()
62}
63
64fn v3_params_to_v4(params: Vec<v3::ParameterValue>) -> Vec<v4::ParameterValue> {
65    params.into_iter().map(|p| p.into()).collect()
66}
67
68impl<CF: ClientFactory> v3::HostConnection for InstanceState<CF> {
69    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
70    async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
71        spin_factor_outbound_networking::record_address_fields(&address);
72
73        self.ensure_address_allowed(&address).await?;
74
75        Ok(self.open_connection(&address, None).await?)
76    }
77
78    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
79    async fn execute(
80        &mut self,
81        connection: Resource<v3::Connection>,
82        statement: String,
83        params: Vec<v3::ParameterValue>,
84    ) -> Result<u64, v3::Error> {
85        Ok(self
86            .get_client(connection)
87            .await?
88            .execute(statement, v3_params_to_v4(params))
89            .await?)
90    }
91
92    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
93    async fn query(
94        &mut self,
95        connection: Resource<v3::Connection>,
96        statement: String,
97        params: Vec<v3::ParameterValue>,
98    ) -> Result<v3::RowSet, v3::Error> {
99        Ok(self
100            .get_client(connection)
101            .await?
102            .query(statement, v3_params_to_v4(params), MAX_HOST_BUFFERED_BYTES)
103            .await?
104            .into())
105    }
106
107    async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
108        self.connections.remove(connection.rep());
109        Ok(())
110    }
111}
112
113pub(crate) struct ConnectionBuilder {
114    address: String,
115    root_ca: Option<HashableCertificate>,
116}
117
118impl<CF: ClientFactory> v4::HostConnectionBuilder for InstanceState<CF> {
119    async fn new(&mut self, address: String) -> Result<Resource<v4::ConnectionBuilder>> {
120        let builder = ConnectionBuilder {
121            address,
122            root_ca: None,
123        };
124        let rep = self
125            .builders
126            .push(builder)
127            .map_err(|_| anyhow::anyhow!("out of builder table space"))?;
128        let rsrc = Resource::new_own(rep);
129        Ok(rsrc)
130    }
131
132    async fn set_ca_root(
133        &mut self,
134        self_: Resource<v4::ConnectionBuilder>,
135        certificate: String,
136    ) -> Result<(), v4::Error> {
137        let root_ca = HashableCertificate::from_pem(&certificate)
138            .map_err(|e| v4::Error::Other(format!("invalid root certificate: {e}")))?;
139        let builder = self
140            .builders
141            .get_mut(self_.rep())
142            .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?;
143        builder.root_ca = Some(root_ca);
144        Ok(())
145    }
146
147    async fn build(
148        &mut self,
149        self_: Resource<v4::ConnectionBuilder>,
150    ) -> Result<Resource<v4::Connection>, v4::Error> {
151        let (address, root_ca) = self.get_builder_info(self_.rep())?;
152        self.open_connection(&address, root_ca).await
153    }
154
155    async fn drop(&mut self, builder: Resource<v4::ConnectionBuilder>) -> Result<()> {
156        self.builders.remove(builder.rep());
157        Ok(())
158    }
159}
160
161impl<CF: ClientFactory> v4::HostConnection for InstanceState<CF> {
162    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
163    async fn open(&mut self, address: String) -> Result<Resource<v4::Connection>, v4::Error> {
164        spin_factor_outbound_networking::record_address_fields(&address);
165
166        self.ensure_address_allowed(&address).await?;
167
168        self.open_connection(&address, None).await
169    }
170
171    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
172    async fn execute(
173        &mut self,
174        connection: Resource<v4::Connection>,
175        statement: String,
176        params: Vec<v4::ParameterValue>,
177    ) -> Result<u64, v4::Error> {
178        self.get_client(connection)
179            .await?
180            .execute(statement, params)
181            .await
182    }
183
184    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
185    async fn query(
186        &mut self,
187        connection: Resource<v4::Connection>,
188        statement: String,
189        params: Vec<v4::ParameterValue>,
190    ) -> Result<v4::RowSet, v4::Error> {
191        self.get_client(connection)
192            .await?
193            .query(statement, params, MAX_HOST_BUFFERED_BYTES)
194            .await
195    }
196
197    async fn drop(&mut self, connection: Resource<v4::Connection>) -> anyhow::Result<()> {
198        self.connections.remove(connection.rep());
199        Ok(())
200    }
201}
202
203impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionWithStore
204    for crate::PgFactorData<CF>
205{
206    #[instrument(name = "spin_outbound_pg.open_async", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
207    async fn open_async<T>(
208        accessor: &Accessor<T, Self>,
209        address: String,
210    ) -> Result<Resource<v4::Connection>, v4::Error> {
211        spin_factor_outbound_networking::record_address_fields(&address);
212
213        Self::ensure_address_allowed_async(accessor, &address).await?;
214        Self::open_connection_async(accessor, &address, None).await
215    }
216
217    #[instrument(name = "spin_outbound_pg.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
218    async fn execute_async<T>(
219        accessor: &Accessor<T, Self>,
220        connection: Resource<v4::Connection>,
221        statement: String,
222        params: Vec<v4::ParameterValue>,
223    ) -> Result<u64, v4::Error> {
224        let client = accessor.with(|mut access| {
225            let host = access.get();
226            host.connections.get(connection.rep()).unwrap().clone()
227        });
228
229        client.execute(statement, params).await
230    }
231
232    #[allow(clippy::type_complexity)] // blame bindgen, clippy, blame bindgen
233    #[instrument(name = "spin_outbound_pg.query_async", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
234    async fn query_async<T>(
235        accessor: &Accessor<T, Self>,
236        connection: Resource<v4::Connection>,
237        statement: String,
238        params: Vec<v4::ParameterValue>,
239    ) -> Result<
240        (
241            Vec<v4::Column>,
242            StreamReader<v4::Row>,
243            FutureReader<Result<(), v4::Error>>,
244        ),
245        v4::Error,
246    > {
247        let client = accessor.with(|mut access| {
248            let host = access.get();
249            host.connections.get(connection.rep()).unwrap().clone()
250        });
251
252        let QueryAsyncResult {
253            columns,
254            rows,
255            error,
256        } = client
257            .query_async(statement, params, MAX_HOST_BUFFERED_BYTES)
258            .await?;
259
260        let row_producer = spin_wasi_async::stream::producer(rows);
261
262        let (sr, efr) = accessor
263            .with(|mut access| {
264                let sr = StreamReader::new(&mut access, row_producer)?;
265                let efr = FutureReader::new(&mut access, error)?;
266                anyhow::Ok((sr, efr))
267            })
268            .map_err(|e| v4::Error::Other(e.to_string()))?;
269
270        Ok((columns, sr, efr))
271    }
272}
273
274impl<CF: ClientFactory> InstanceState<CF> {
275    #[allow(clippy::result_large_err)]
276    fn get_builder_info(
277        &mut self,
278        builder_rep: u32,
279    ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
280        let builder = self
281            .builders
282            .get_mut(builder_rep)
283            .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?;
284
285        let address = builder.address.clone();
286        let root_ca = builder.root_ca.clone();
287
288        Ok((address, root_ca))
289    }
290}
291
292impl<CF: ClientFactory> crate::PgFactorData<CF> {
293    #[allow(clippy::result_large_err)]
294    fn get_builder_info<T>(
295        accessor: &Accessor<T, Self>,
296        builder: Resource<v4::ConnectionBuilder>,
297    ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
298        let builder_rep = builder.rep();
299        accessor.with(|mut access| {
300            let host = access.get();
301            host.get_builder_info(builder_rep)
302        })
303    }
304
305    async fn ensure_address_allowed_async<T>(
306        accessor: &Accessor<T, Self>,
307        address: &str,
308    ) -> Result<(), v4::Error> {
309        // A merry dance to avoid doing the async allow check under the accessor
310        let allowed_host_checker = accessor.with(|mut access| {
311            let host = access.get();
312            host.allowed_host_checker()
313        });
314
315        allowed_host_checker.ensure_address_allowed(address).await
316    }
317
318    async fn open_connection_async<T>(
319        accessor: &Accessor<T, Self>,
320        address: &str,
321        root_ca: Option<HashableCertificate>,
322    ) -> Result<Resource<v4::Connection>, v4::Error> {
323        let cf = accessor.with(|mut access| {
324            let host = access.get();
325            host.client_factory.clone()
326        });
327
328        let client = cf
329            .get_client(address, root_ca)
330            .await
331            .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?;
332
333        accessor.with(|mut access| {
334            let host = access.get();
335            host.connections
336                .push(client)
337                .map_err(|_| v4::Error::ConnectionFailed("too many connections".into()))
338                .map(Resource::new_own)
339        })
340    }
341}
342
343impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionBuilderWithStore
344    for crate::PgFactorData<CF>
345{
346    async fn build_async<T>(
347        accessor: &Accessor<T, Self>,
348        builder: Resource<v4::ConnectionBuilder>,
349    ) -> Result<Resource<v4::Connection>, v4::Error> {
350        let (address, root_ca) = Self::get_builder_info(accessor, builder)?;
351
352        spin_factor_outbound_networking::record_address_fields(&address);
353
354        Self::ensure_address_allowed_async(accessor, &address).await?;
355        Self::open_connection_async(accessor, &address, root_ca).await
356    }
357}
358
359impl<CF: ClientFactory> v2_types::Host for InstanceState<CF> {
360    fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
361        Ok(error)
362    }
363}
364
365impl<CF: ClientFactory> v3::Host for InstanceState<CF> {
366    fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
367        Ok(error)
368    }
369}
370
371impl<CF: ClientFactory> v4::Host for InstanceState<CF> {
372    fn convert_error(&mut self, error: v4::Error) -> Result<v4::Error> {
373        Ok(error)
374    }
375}
376
377/// Delegate a function call to the v3::HostConnection implementation
378macro_rules! delegate {
379    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
380        $self.ensure_address_allowed(&$address).await?;
381        let connection = match $self.open_connection(&$address, None).await {
382            Ok(c) => c,
383            Err(e) => return Err(e.into()),
384        };
385        <Self as v4::HostConnection>::$name($self, connection, $($arg),*)
386            .await
387            .map_err(|e| e.into())
388    }};
389}
390
391impl<CF: ClientFactory> v2::Host for InstanceState<CF> {}
392
393impl<CF: ClientFactory> v2::HostConnection for InstanceState<CF> {
394    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
395    async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
396        self.otel.reparent_tracing_span();
397        spin_factor_outbound_networking::record_address_fields(&address);
398
399        self.ensure_address_allowed(&address).await?;
400        Ok(self.open_connection(&address, None).await?)
401    }
402
403    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
404    async fn execute(
405        &mut self,
406        connection: Resource<v2::Connection>,
407        statement: String,
408        params: Vec<v2_types::ParameterValue>,
409    ) -> Result<u64, v2::Error> {
410        self.otel.reparent_tracing_span();
411        Ok(self
412            .get_client(connection)
413            .await?
414            .execute(statement, v2_params_to_v3(params)?)
415            .await?)
416    }
417
418    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
419    async fn query(
420        &mut self,
421        connection: Resource<v2::Connection>,
422        statement: String,
423        params: Vec<v2_types::ParameterValue>,
424    ) -> Result<v2_types::RowSet, v2::Error> {
425        self.otel.reparent_tracing_span();
426        Ok(self
427            .get_client(connection)
428            .await?
429            .query(statement, v2_params_to_v3(params)?, MAX_HOST_BUFFERED_BYTES)
430            .await?
431            .into())
432    }
433
434    async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
435        self.connections.remove(connection.rep());
436        Ok(())
437    }
438}
439
440impl<CF: ClientFactory> v1::Host for InstanceState<CF> {
441    async fn execute(
442        &mut self,
443        address: String,
444        statement: String,
445        params: Vec<v1_types::ParameterValue>,
446    ) -> Result<u64, v1::PgError> {
447        delegate!(
448            self.execute(
449                address,
450                statement,
451                params
452                    .into_iter()
453                    .map(TryInto::try_into)
454                    .collect::<Result<Vec<_>, _>>()?
455            )
456        )
457    }
458
459    async fn query(
460        &mut self,
461        address: String,
462        statement: String,
463        params: Vec<v1_types::ParameterValue>,
464    ) -> Result<v1_types::RowSet, v1::PgError> {
465        delegate!(
466            self.query(
467                address,
468                statement,
469                params
470                    .into_iter()
471                    .map(TryInto::try_into)
472                    .collect::<Result<Vec<_>, _>>()?
473            )
474        )
475        .map(Into::into)
476    }
477
478    fn convert_pg_error(&mut self, error: v1::PgError) -> Result<v1::PgError> {
479        Ok(error)
480    }
481}