spin_factor_outbound_pg/
host.rs

1use anyhow::Result;
2use spin_core::wasmtime::component::Resource;
3use spin_world::spin::postgres::postgres::{self as v3};
4use spin_world::v1::postgres as v1;
5use spin_world::v1::rdbms_types as v1_types;
6use spin_world::v2::postgres::{self as v2};
7use spin_world::v2::rdbms_types as v2_types;
8use tracing::field::Empty;
9use tracing::instrument;
10use tracing::Level;
11
12use crate::client::Client;
13use crate::InstanceState;
14
15impl<C: Client> InstanceState<C> {
16    async fn open_connection<Conn: 'static>(
17        &mut self,
18        address: &str,
19    ) -> Result<Resource<Conn>, v3::Error> {
20        self.connections
21            .push(
22                C::build_client(address)
23                    .await
24                    .map_err(|e| v3::Error::ConnectionFailed(format!("{e:?}")))?,
25            )
26            .map_err(|_| v3::Error::ConnectionFailed("too many connections".into()))
27            .map(Resource::new_own)
28    }
29
30    async fn get_client<Conn: 'static>(
31        &mut self,
32        connection: Resource<Conn>,
33    ) -> Result<&C, v3::Error> {
34        self.connections
35            .get(connection.rep())
36            .ok_or_else(|| v3::Error::ConnectionFailed("no connection found".into()))
37    }
38
39    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
40        let Ok(config) = address.parse::<tokio_postgres::Config>() else {
41            return Ok(false);
42        };
43        for (i, host) in config.get_hosts().iter().enumerate() {
44            match host {
45                tokio_postgres::config::Host::Tcp(address) => {
46                    let ports = config.get_ports();
47                    // The port we use is either:
48                    // * The port at the same index as the host
49                    // * The first port if there is only one port
50                    let port =
51                        ports
52                            .get(i)
53                            .or_else(|| if ports.len() == 1 { ports.get(1) } else { None });
54                    let port_str = port.map(|p| format!(":{}", p)).unwrap_or_default();
55                    let url = format!("{address}{port_str}");
56                    if !self.allowed_hosts.check_url(&url, "postgres").await? {
57                        return Ok(false);
58                    }
59                }
60                #[cfg(unix)]
61                tokio_postgres::config::Host::Unix(_) => return Ok(false),
62            }
63        }
64        Ok(true)
65    }
66}
67
68fn v2_params_to_v3(
69    params: Vec<v2_types::ParameterValue>,
70) -> Result<Vec<v3::ParameterValue>, v2::Error> {
71    params.into_iter().map(|p| p.try_into()).collect()
72}
73
74impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnection
75    for InstanceState<C>
76{
77    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
78    async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
79        spin_factor_outbound_networking::record_address_fields(&address);
80
81        if !self
82            .is_address_allowed(&address)
83            .await
84            .map_err(|e| v3::Error::Other(e.to_string()))?
85        {
86            return Err(v3::Error::ConnectionFailed(format!(
87                "address {address} is not permitted"
88            )));
89        }
90        self.open_connection(&address).await
91    }
92
93    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
94    async fn execute(
95        &mut self,
96        connection: Resource<v3::Connection>,
97        statement: String,
98        params: Vec<v3::ParameterValue>,
99    ) -> Result<u64, v3::Error> {
100        self.get_client(connection)
101            .await?
102            .execute(statement, params)
103            .await
104    }
105
106    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
107    async fn query(
108        &mut self,
109        connection: Resource<v3::Connection>,
110        statement: String,
111        params: Vec<v3::ParameterValue>,
112    ) -> Result<v3::RowSet, v3::Error> {
113        self.get_client(connection)
114            .await?
115            .query(statement, params)
116            .await
117    }
118
119    async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
120        self.connections.remove(connection.rep());
121        Ok(())
122    }
123}
124
125impl<C: Send> v2_types::Host for InstanceState<C> {
126    fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
127        Ok(error)
128    }
129}
130
131impl<C: Send + Sync + Client> v3::Host for InstanceState<C> {
132    fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
133        Ok(error)
134    }
135}
136
137/// Delegate a function call to the v3::HostConnection implementation
138macro_rules! delegate {
139    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
140        if !$self.is_address_allowed(&$address).await.map_err(|e| v3::Error::Other(e.to_string()))? {
141            return Err(v1::PgError::ConnectionFailed(format!(
142                "address {} is not permitted", $address
143            )));
144        }
145        let connection = match $self.open_connection(&$address).await {
146            Ok(c) => c,
147            Err(e) => return Err(e.into()),
148        };
149        <Self as v3::HostConnection>::$name($self, connection, $($arg),*)
150            .await
151            .map_err(|e| e.into())
152    }};
153}
154
155impl<C: Send + Sync + Client> v2::Host for InstanceState<C> {}
156
157impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
158    #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
159    async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
160        spin_factor_outbound_networking::record_address_fields(&address);
161
162        if !self
163            .is_address_allowed(&address)
164            .await
165            .map_err(|e| v2::Error::Other(e.to_string()))?
166        {
167            return Err(v2::Error::ConnectionFailed(format!(
168                "address {address} is not permitted"
169            )));
170        }
171        Ok(self.open_connection(&address).await?)
172    }
173
174    #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
175    async fn execute(
176        &mut self,
177        connection: Resource<v2::Connection>,
178        statement: String,
179        params: Vec<v2_types::ParameterValue>,
180    ) -> Result<u64, v2::Error> {
181        Ok(self
182            .get_client(connection)
183            .await?
184            .execute(statement, v2_params_to_v3(params)?)
185            .await?)
186    }
187
188    #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
189    async fn query(
190        &mut self,
191        connection: Resource<v2::Connection>,
192        statement: String,
193        params: Vec<v2_types::ParameterValue>,
194    ) -> Result<v2_types::RowSet, v2::Error> {
195        Ok(self
196            .get_client(connection)
197            .await?
198            .query(statement, v2_params_to_v3(params)?)
199            .await?
200            .into())
201    }
202
203    async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
204        self.connections.remove(connection.rep());
205        Ok(())
206    }
207}
208
209impl<C: Send + Sync + Client> v1::Host for InstanceState<C> {
210    async fn execute(
211        &mut self,
212        address: String,
213        statement: String,
214        params: Vec<v1_types::ParameterValue>,
215    ) -> Result<u64, v1::PgError> {
216        delegate!(self.execute(
217            address,
218            statement,
219            params
220                .into_iter()
221                .map(TryInto::try_into)
222                .collect::<Result<Vec<_>, _>>()?
223        ))
224    }
225
226    async fn query(
227        &mut self,
228        address: String,
229        statement: String,
230        params: Vec<v1_types::ParameterValue>,
231    ) -> Result<v1_types::RowSet, v1::PgError> {
232        delegate!(self.query(
233            address,
234            statement,
235            params
236                .into_iter()
237                .map(TryInto::try_into)
238                .collect::<Result<Vec<_>, _>>()?
239        ))
240        .map(Into::into)
241    }
242
243    fn convert_pg_error(&mut self, error: v1::PgError) -> Result<v1::PgError> {
244        Ok(error)
245    }
246}