Skip to main content

spin_factor_outbound_pg/
host.rs

1use anyhow::Result;
2use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader};
3use spin_world::spin::postgres3_0_0::postgres::{self as v3};
4use spin_world::spin::postgres4_2_0::postgres::{self as v4};
5use spin_world::v1::postgres as v1;
6use spin_world::v1::rdbms_types as v1_types;
7use spin_world::v2::postgres::{self as v2};
8use spin_world::v2::rdbms_types as v2_types;
9use spin_world::MAX_HOST_BUFFERED_BYTES;
10use tracing::field::Empty;
11use tracing::instrument;
12use tracing::Level;
13
14use crate::allowed_hosts::AllowedHostChecker;
15use crate::client::{Client, ClientFactory, HashableCertificate, QueryAsyncResult};
16use crate::InstanceState;
17
18impl<CF: ClientFactory> InstanceState<CF> {
19    async fn open_connection<Conn: 'static>(
20        &mut self,
21        address: &str,
22        root_ca: Option<HashableCertificate>,
23    ) -> Result<Resource<Conn>, v4::Error> {
24        self.connections
25            .push(
26                self.client_factory
27                    .get_client(address, root_ca)
28                    .await
29                    .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?,
30            )
31            .map_err(|_| v4::Error::ConnectionFailed("too many connections".into()))
32            .map(Resource::new_own)
33    }
34
35    async fn get_client<Conn: 'static>(
36        &self,
37        connection: Resource<Conn>,
38    ) -> Result<&CF::Client, v4::Error> {
39        self.connections
40            .get(connection.rep())
41            .ok_or_else(|| v4::Error::ConnectionFailed("no connection found".into()))
42    }
43
44    fn allowed_host_checker(&self) -> AllowedHostChecker {
45        self.allowed_host_checker.clone()
46    }
47
48    #[allow(clippy::result_large_err)]
49    async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> {
50        self.allowed_host_checker
51            .ensure_address_allowed(address)
52            .await
53    }
54}
55
56fn v2_params_to_v3(
57    params: Vec<v2_types::ParameterValue>,
58) -> Result<Vec<v4::ParameterValue>, v2::Error> {
59    params.into_iter().map(|p| p.try_into()).collect()
60}
61
62fn v3_params_to_v4(params: Vec<v3::ParameterValue>) -> Vec<v4::ParameterValue> {
63    params.into_iter().map(|p| p.into()).collect()
64}
65
66impl<CF: ClientFactory> v3::HostConnection for InstanceState<CF> {
67    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
68    async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
69        spin_factor_outbound_networking::record_address_fields(&address);
70
71        self.ensure_address_allowed(&address).await?;
72
73        Ok(self.open_connection(&address, None).await?)
74    }
75
76    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
77    async fn execute(
78        &mut self,
79        connection: Resource<v3::Connection>,
80        statement: String,
81        params: Vec<v3::ParameterValue>,
82    ) -> Result<u64, v3::Error> {
83        Ok(self
84            .get_client(connection)
85            .await?
86            .execute(statement, v3_params_to_v4(params))
87            .await?)
88    }
89
90    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
91    async fn query(
92        &mut self,
93        connection: Resource<v3::Connection>,
94        statement: String,
95        params: Vec<v3::ParameterValue>,
96    ) -> Result<v3::RowSet, v3::Error> {
97        Ok(self
98            .get_client(connection)
99            .await?
100            .query(statement, v3_params_to_v4(params), MAX_HOST_BUFFERED_BYTES)
101            .await?
102            .into())
103    }
104
105    async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
106        self.connections.remove(connection.rep());
107        Ok(())
108    }
109}
110
111pub(crate) struct ConnectionBuilder {
112    address: String,
113    root_ca: Option<HashableCertificate>,
114}
115
116impl<CF: ClientFactory> v4::HostConnectionBuilder for InstanceState<CF> {
117    async fn new(&mut self, address: String) -> Result<Resource<v4::ConnectionBuilder>> {
118        let builder = ConnectionBuilder {
119            address,
120            root_ca: None,
121        };
122        let rep = self
123            .builders
124            .push(builder)
125            .map_err(|_| anyhow::anyhow!("out of builder table space"))?;
126        let rsrc = Resource::new_own(rep);
127        Ok(rsrc)
128    }
129
130    async fn set_ca_root(
131        &mut self,
132        self_: Resource<v4::ConnectionBuilder>,
133        certificate: String,
134    ) -> Result<(), v4::Error> {
135        let root_ca = HashableCertificate::from_pem(&certificate)
136            .map_err(|e| v4::Error::Other(format!("invalid root certificate: {e}")))?;
137        let builder = self
138            .builders
139            .get_mut(self_.rep())
140            .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?;
141        builder.root_ca = Some(root_ca);
142        Ok(())
143    }
144
145    async fn build(
146        &mut self,
147        self_: Resource<v4::ConnectionBuilder>,
148    ) -> Result<Resource<v4::Connection>, v4::Error> {
149        let (address, root_ca) = self.get_builder_info(self_.rep())?;
150        self.open_connection(&address, root_ca).await
151    }
152
153    async fn drop(&mut self, builder: Resource<v4::ConnectionBuilder>) -> Result<()> {
154        self.builders.remove(builder.rep());
155        Ok(())
156    }
157}
158
159impl<CF: ClientFactory> v4::HostConnection for InstanceState<CF> {
160    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
161    async fn open(&mut self, address: String) -> Result<Resource<v4::Connection>, v4::Error> {
162        spin_factor_outbound_networking::record_address_fields(&address);
163
164        self.ensure_address_allowed(&address).await?;
165
166        self.open_connection(&address, None).await
167    }
168
169    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
170    async fn execute(
171        &mut self,
172        connection: Resource<v4::Connection>,
173        statement: String,
174        params: Vec<v4::ParameterValue>,
175    ) -> Result<u64, v4::Error> {
176        self.get_client(connection)
177            .await?
178            .execute(statement, params)
179            .await
180    }
181
182    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
183    async fn query(
184        &mut self,
185        connection: Resource<v4::Connection>,
186        statement: String,
187        params: Vec<v4::ParameterValue>,
188    ) -> Result<v4::RowSet, v4::Error> {
189        self.get_client(connection)
190            .await?
191            .query(statement, params, MAX_HOST_BUFFERED_BYTES)
192            .await
193    }
194
195    async fn drop(&mut self, connection: Resource<v4::Connection>) -> anyhow::Result<()> {
196        self.connections.remove(connection.rep());
197        Ok(())
198    }
199}
200
201impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionWithStore
202    for crate::PgFactorData<CF>
203{
204    #[instrument(name = "spin_outbound_pg.open_async", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
205    async fn open_async<T>(
206        accessor: &Accessor<T, Self>,
207        address: String,
208    ) -> Result<Resource<v4::Connection>, v4::Error> {
209        spin_factor_outbound_networking::record_address_fields(&address);
210
211        Self::ensure_address_allowed_async(accessor, &address).await?;
212        Self::open_connection_async(accessor, &address, None).await
213    }
214
215    #[instrument(name = "spin_outbound_pg.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
216    async fn execute_async<T>(
217        accessor: &Accessor<T, Self>,
218        connection: Resource<v4::Connection>,
219        statement: String,
220        params: Vec<v4::ParameterValue>,
221    ) -> Result<u64, v4::Error> {
222        let client = accessor.with(|mut access| {
223            let host = access.get();
224            host.connections.get(connection.rep()).unwrap().clone()
225        });
226
227        client.execute(statement, params).await
228    }
229
230    #[allow(clippy::type_complexity)] // blame bindgen, clippy, blame bindgen
231    #[instrument(name = "spin_outbound_pg.query_async", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
232    async fn query_async<T>(
233        accessor: &Accessor<T, Self>,
234        connection: Resource<v4::Connection>,
235        statement: String,
236        params: Vec<v4::ParameterValue>,
237    ) -> Result<
238        (
239            Vec<v4::Column>,
240            StreamReader<v4::Row>,
241            FutureReader<Result<(), v4::Error>>,
242        ),
243        v4::Error,
244    > {
245        let client = accessor.with(|mut access| {
246            let host = access.get();
247            host.connections.get(connection.rep()).unwrap().clone()
248        });
249
250        let QueryAsyncResult {
251            columns,
252            rows,
253            error,
254        } = client
255            .query_async(statement, params, MAX_HOST_BUFFERED_BYTES)
256            .await?;
257
258        let row_producer = spin_wasi_async::stream::producer(rows);
259
260        let (sr, efr) = accessor.with(|mut access| {
261            let sr = StreamReader::new(&mut access, row_producer);
262            let efr = FutureReader::new(&mut access, error);
263            (sr, efr)
264        });
265
266        Ok((columns, sr, efr))
267    }
268}
269
270impl<CF: ClientFactory> InstanceState<CF> {
271    #[allow(clippy::result_large_err)]
272    fn get_builder_info(
273        &mut self,
274        builder_rep: u32,
275    ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
276        let builder = self
277            .builders
278            .get_mut(builder_rep)
279            .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?;
280
281        let address = builder.address.clone();
282        let root_ca = builder.root_ca.clone();
283
284        Ok((address, root_ca))
285    }
286}
287
288impl<CF: ClientFactory> crate::PgFactorData<CF> {
289    #[allow(clippy::result_large_err)]
290    fn get_builder_info<T>(
291        accessor: &Accessor<T, Self>,
292        builder: Resource<v4::ConnectionBuilder>,
293    ) -> Result<(String, Option<HashableCertificate>), v4::Error> {
294        let builder_rep = builder.rep();
295        accessor.with(|mut access| {
296            let host = access.get();
297            host.get_builder_info(builder_rep)
298        })
299    }
300
301    async fn ensure_address_allowed_async<T>(
302        accessor: &Accessor<T, Self>,
303        address: &str,
304    ) -> Result<(), v4::Error> {
305        // A merry dance to avoid doing the async allow check under the accessor
306        let allowed_host_checker = accessor.with(|mut access| {
307            let host = access.get();
308            host.allowed_host_checker()
309        });
310
311        allowed_host_checker.ensure_address_allowed(address).await
312    }
313
314    async fn open_connection_async<T>(
315        accessor: &Accessor<T, Self>,
316        address: &str,
317        root_ca: Option<HashableCertificate>,
318    ) -> Result<Resource<v4::Connection>, v4::Error> {
319        let cf = accessor.with(|mut access| {
320            let host = access.get();
321            host.client_factory.clone()
322        });
323
324        let client = cf
325            .get_client(address, root_ca)
326            .await
327            .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?;
328
329        let rsrc = accessor.with(|mut access| {
330            let host = access.get();
331            host.connections
332                .push(client)
333                .map_err(|_| v4::Error::ConnectionFailed("too many connections".into()))
334                .map(Resource::new_own)
335        });
336
337        rsrc
338    }
339}
340
341impl<CF: ClientFactory> spin_world::spin::postgres4_2_0::postgres::HostConnectionBuilderWithStore
342    for crate::PgFactorData<CF>
343{
344    async fn build_async<T>(
345        accessor: &Accessor<T, Self>,
346        builder: Resource<v4::ConnectionBuilder>,
347    ) -> Result<Resource<v4::Connection>, v4::Error> {
348        let (address, root_ca) = Self::get_builder_info(accessor, builder)?;
349
350        spin_factor_outbound_networking::record_address_fields(&address);
351
352        Self::ensure_address_allowed_async(accessor, &address).await?;
353        Self::open_connection_async(accessor, &address, root_ca).await
354    }
355}
356
357impl<CF: ClientFactory> v2_types::Host for InstanceState<CF> {
358    fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
359        Ok(error)
360    }
361}
362
363impl<CF: ClientFactory> v3::Host for InstanceState<CF> {
364    fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
365        Ok(error)
366    }
367}
368
369impl<CF: ClientFactory> v4::Host for InstanceState<CF> {
370    fn convert_error(&mut self, error: v4::Error) -> Result<v4::Error> {
371        Ok(error)
372    }
373}
374
375/// Delegate a function call to the v3::HostConnection implementation
376macro_rules! delegate {
377    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
378        $self.ensure_address_allowed(&$address).await?;
379        let connection = match $self.open_connection(&$address, None).await {
380            Ok(c) => c,
381            Err(e) => return Err(e.into()),
382        };
383        <Self as v4::HostConnection>::$name($self, connection, $($arg),*)
384            .await
385            .map_err(|e| e.into())
386    }};
387}
388
389impl<CF: ClientFactory> v2::Host for InstanceState<CF> {}
390
391impl<CF: ClientFactory> v2::HostConnection for InstanceState<CF> {
392    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
393    async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
394        self.otel.reparent_tracing_span();
395        spin_factor_outbound_networking::record_address_fields(&address);
396
397        self.ensure_address_allowed(&address).await?;
398        Ok(self.open_connection(&address, None).await?)
399    }
400
401    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
402    async fn execute(
403        &mut self,
404        connection: Resource<v2::Connection>,
405        statement: String,
406        params: Vec<v2_types::ParameterValue>,
407    ) -> Result<u64, v2::Error> {
408        self.otel.reparent_tracing_span();
409        Ok(self
410            .get_client(connection)
411            .await?
412            .execute(statement, v2_params_to_v3(params)?)
413            .await?)
414    }
415
416    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
417    async fn query(
418        &mut self,
419        connection: Resource<v2::Connection>,
420        statement: String,
421        params: Vec<v2_types::ParameterValue>,
422    ) -> Result<v2_types::RowSet, v2::Error> {
423        self.otel.reparent_tracing_span();
424        Ok(self
425            .get_client(connection)
426            .await?
427            .query(statement, v2_params_to_v3(params)?, MAX_HOST_BUFFERED_BYTES)
428            .await?
429            .into())
430    }
431
432    async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
433        self.connections.remove(connection.rep());
434        Ok(())
435    }
436}
437
438impl<CF: ClientFactory> v1::Host for InstanceState<CF> {
439    async fn execute(
440        &mut self,
441        address: String,
442        statement: String,
443        params: Vec<v1_types::ParameterValue>,
444    ) -> Result<u64, v1::PgError> {
445        delegate!(self.execute(
446            address,
447            statement,
448            params
449                .into_iter()
450                .map(TryInto::try_into)
451                .collect::<Result<Vec<_>, _>>()?
452        ))
453    }
454
455    async fn query(
456        &mut self,
457        address: String,
458        statement: String,
459        params: Vec<v1_types::ParameterValue>,
460    ) -> Result<v1_types::RowSet, v1::PgError> {
461        delegate!(self.query(
462            address,
463            statement,
464            params
465                .into_iter()
466                .map(TryInto::try_into)
467                .collect::<Result<Vec<_>, _>>()?
468        ))
469        .map(Into::into)
470    }
471
472    fn convert_pg_error(&mut self, error: v1::PgError) -> Result<v1::PgError> {
473        Ok(error)
474    }
475}