spin_factor_outbound_pg/
host.rs

1use anyhow::Result;
2use spin_core::wasmtime::component::Resource;
3use spin_world::spin::postgres3_0_0::postgres::{self as v3};
4use spin_world::spin::postgres4_0_0::postgres::{self as v4};
5use spin_world::v1::postgres as v1;
6use spin_world::v1::rdbms_types as v1_types;
7use spin_world::v2::postgres::{self as v2};
8use spin_world::v2::rdbms_types as v2_types;
9use tracing::field::Empty;
10use tracing::instrument;
11use tracing::Level;
12
13use crate::client::{Client, ClientFactory};
14use crate::InstanceState;
15
16impl<CF: ClientFactory> InstanceState<CF> {
17    async fn open_connection<Conn: 'static>(
18        &mut self,
19        address: &str,
20    ) -> Result<Resource<Conn>, v4::Error> {
21        self.connections
22            .push(
23                self.client_factory
24                    .get_client(address)
25                    .await
26                    .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?,
27            )
28            .map_err(|_| v4::Error::ConnectionFailed("too many connections".into()))
29            .map(Resource::new_own)
30    }
31
32    async fn get_client<Conn: 'static>(
33        &self,
34        connection: Resource<Conn>,
35    ) -> Result<&CF::Client, v4::Error> {
36        self.connections
37            .get(connection.rep())
38            .ok_or_else(|| v4::Error::ConnectionFailed("no connection found".into()))
39    }
40
41    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
42        let Ok(config) = address.parse::<tokio_postgres::Config>() else {
43            return Ok(false);
44        };
45        for (i, host) in config.get_hosts().iter().enumerate() {
46            match host {
47                tokio_postgres::config::Host::Tcp(address) => {
48                    let ports = config.get_ports();
49                    // The port we use is either:
50                    // * The port at the same index as the host
51                    // * The first port if there is only one port
52                    let port =
53                        ports
54                            .get(i)
55                            .or_else(|| if ports.len() == 1 { ports.get(1) } else { None });
56                    let port_str = port.map(|p| format!(":{p}")).unwrap_or_default();
57                    let url = format!("{address}{port_str}");
58                    if !self.allowed_hosts.check_url(&url, "postgres").await? {
59                        return Ok(false);
60                    }
61                }
62                #[cfg(unix)]
63                tokio_postgres::config::Host::Unix(_) => return Ok(false),
64            }
65        }
66        Ok(true)
67    }
68}
69
70fn v2_params_to_v3(
71    params: Vec<v2_types::ParameterValue>,
72) -> Result<Vec<v4::ParameterValue>, v2::Error> {
73    params.into_iter().map(|p| p.try_into()).collect()
74}
75
76fn v3_params_to_v4(params: Vec<v3::ParameterValue>) -> Vec<v4::ParameterValue> {
77    params.into_iter().map(|p| p.into()).collect()
78}
79
80impl<CF: ClientFactory> v3::HostConnection for InstanceState<CF> {
81    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
82    async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
83        spin_factor_outbound_networking::record_address_fields(&address);
84
85        if !self
86            .is_address_allowed(&address)
87            .await
88            .map_err(|e| v3::Error::Other(e.to_string()))?
89        {
90            return Err(v3::Error::ConnectionFailed(format!(
91                "address {address} is not permitted"
92            )));
93        }
94        Ok(self.open_connection(&address).await?)
95    }
96
97    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
98    async fn execute(
99        &mut self,
100        connection: Resource<v3::Connection>,
101        statement: String,
102        params: Vec<v3::ParameterValue>,
103    ) -> Result<u64, v3::Error> {
104        Ok(self
105            .get_client(connection)
106            .await?
107            .execute(statement, v3_params_to_v4(params))
108            .await?)
109    }
110
111    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
112    async fn query(
113        &mut self,
114        connection: Resource<v3::Connection>,
115        statement: String,
116        params: Vec<v3::ParameterValue>,
117    ) -> Result<v3::RowSet, v3::Error> {
118        Ok(self
119            .get_client(connection)
120            .await?
121            .query(statement, v3_params_to_v4(params))
122            .await?
123            .into())
124    }
125
126    async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
127        self.connections.remove(connection.rep());
128        Ok(())
129    }
130}
131
132impl<CF: ClientFactory> v4::HostConnection for InstanceState<CF> {
133    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
134    async fn open(&mut self, address: String) -> Result<Resource<v4::Connection>, v4::Error> {
135        spin_factor_outbound_networking::record_address_fields(&address);
136
137        if !self
138            .is_address_allowed(&address)
139            .await
140            .map_err(|e| v4::Error::Other(e.to_string()))?
141        {
142            return Err(v4::Error::ConnectionFailed(format!(
143                "address {address} is not permitted"
144            )));
145        }
146        self.open_connection(&address).await
147    }
148
149    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
150    async fn execute(
151        &mut self,
152        connection: Resource<v4::Connection>,
153        statement: String,
154        params: Vec<v4::ParameterValue>,
155    ) -> Result<u64, v4::Error> {
156        self.get_client(connection)
157            .await?
158            .execute(statement, params)
159            .await
160    }
161
162    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
163    async fn query(
164        &mut self,
165        connection: Resource<v4::Connection>,
166        statement: String,
167        params: Vec<v4::ParameterValue>,
168    ) -> Result<v4::RowSet, v4::Error> {
169        self.get_client(connection)
170            .await?
171            .query(statement, params)
172            .await
173    }
174
175    async fn drop(&mut self, connection: Resource<v4::Connection>) -> anyhow::Result<()> {
176        self.connections.remove(connection.rep());
177        Ok(())
178    }
179}
180
181impl<CF: ClientFactory> v2_types::Host for InstanceState<CF> {
182    fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
183        Ok(error)
184    }
185}
186
187impl<CF: ClientFactory> v3::Host for InstanceState<CF> {
188    fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
189        Ok(error)
190    }
191}
192
193impl<CF: ClientFactory> v4::Host for InstanceState<CF> {
194    fn convert_error(&mut self, error: v4::Error) -> Result<v4::Error> {
195        Ok(error)
196    }
197}
198
199/// Delegate a function call to the v3::HostConnection implementation
200macro_rules! delegate {
201    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
202        if !$self.is_address_allowed(&$address).await.map_err(|e| v4::Error::Other(e.to_string()))? {
203            return Err(v1::PgError::ConnectionFailed(format!(
204                "address {} is not permitted", $address
205            )));
206        }
207        let connection = match $self.open_connection(&$address).await {
208            Ok(c) => c,
209            Err(e) => return Err(e.into()),
210        };
211        <Self as v4::HostConnection>::$name($self, connection, $($arg),*)
212            .await
213            .map_err(|e| e.into())
214    }};
215}
216
217impl<CF: ClientFactory> v2::Host for InstanceState<CF> {}
218
219impl<CF: ClientFactory> v2::HostConnection for InstanceState<CF> {
220    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
221    async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
222        spin_factor_outbound_networking::record_address_fields(&address);
223
224        if !self
225            .is_address_allowed(&address)
226            .await
227            .map_err(|e| v2::Error::Other(e.to_string()))?
228        {
229            return Err(v2::Error::ConnectionFailed(format!(
230                "address {address} is not permitted"
231            )));
232        }
233        Ok(self.open_connection(&address).await?)
234    }
235
236    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
237    async fn execute(
238        &mut self,
239        connection: Resource<v2::Connection>,
240        statement: String,
241        params: Vec<v2_types::ParameterValue>,
242    ) -> Result<u64, v2::Error> {
243        Ok(self
244            .get_client(connection)
245            .await?
246            .execute(statement, v2_params_to_v3(params)?)
247            .await?)
248    }
249
250    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
251    async fn query(
252        &mut self,
253        connection: Resource<v2::Connection>,
254        statement: String,
255        params: Vec<v2_types::ParameterValue>,
256    ) -> Result<v2_types::RowSet, v2::Error> {
257        Ok(self
258            .get_client(connection)
259            .await?
260            .query(statement, v2_params_to_v3(params)?)
261            .await?
262            .into())
263    }
264
265    async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
266        self.connections.remove(connection.rep());
267        Ok(())
268    }
269}
270
271impl<CF: ClientFactory> v1::Host for InstanceState<CF> {
272    async fn execute(
273        &mut self,
274        address: String,
275        statement: String,
276        params: Vec<v1_types::ParameterValue>,
277    ) -> Result<u64, v1::PgError> {
278        delegate!(self.execute(
279            address,
280            statement,
281            params
282                .into_iter()
283                .map(TryInto::try_into)
284                .collect::<Result<Vec<_>, _>>()?
285        ))
286    }
287
288    async fn query(
289        &mut self,
290        address: String,
291        statement: String,
292        params: Vec<v1_types::ParameterValue>,
293    ) -> Result<v1_types::RowSet, v1::PgError> {
294        delegate!(self.query(
295            address,
296            statement,
297            params
298                .into_iter()
299                .map(TryInto::try_into)
300                .collect::<Result<Vec<_>, _>>()?
301        ))
302        .map(Into::into)
303    }
304
305    fn convert_pg_error(&mut self, error: v1::PgError) -> Result<v1::PgError> {
306        Ok(error)
307    }
308}