1use anyhow::Result;
2use spin_core::wasmtime::component::Resource;
3use spin_world::spin::postgres3_0_0::postgres::{self as v3};
4use spin_world::spin::postgres4_0_0::postgres::{self as v4};
5use spin_world::v1::postgres as v1;
6use spin_world::v1::rdbms_types as v1_types;
7use spin_world::v2::postgres::{self as v2};
8use spin_world::v2::rdbms_types as v2_types;
9use tracing::field::Empty;
10use tracing::instrument;
11use tracing::Level;
12
13use crate::client::{Client, ClientFactory};
14use crate::InstanceState;
15
16impl<CF: ClientFactory> InstanceState<CF> {
17 async fn open_connection<Conn: 'static>(
18 &mut self,
19 address: &str,
20 ) -> Result<Resource<Conn>, v4::Error> {
21 self.connections
22 .push(
23 self.client_factory
24 .get_client(address)
25 .await
26 .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?,
27 )
28 .map_err(|_| v4::Error::ConnectionFailed("too many connections".into()))
29 .map(Resource::new_own)
30 }
31
32 async fn get_client<Conn: 'static>(
33 &self,
34 connection: Resource<Conn>,
35 ) -> Result<&CF::Client, v4::Error> {
36 self.connections
37 .get(connection.rep())
38 .ok_or_else(|| v4::Error::ConnectionFailed("no connection found".into()))
39 }
40
41 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
42 let Ok(config) = address.parse::<tokio_postgres::Config>() else {
43 return Ok(false);
44 };
45 for (i, host) in config.get_hosts().iter().enumerate() {
46 match host {
47 tokio_postgres::config::Host::Tcp(address) => {
48 let ports = config.get_ports();
49 let port =
53 ports
54 .get(i)
55 .or_else(|| if ports.len() == 1 { ports.get(1) } else { None });
56 let port_str = port.map(|p| format!(":{p}")).unwrap_or_default();
57 let url = format!("{address}{port_str}");
58 if !self.allowed_hosts.check_url(&url, "postgres").await? {
59 return Ok(false);
60 }
61 }
62 #[cfg(unix)]
63 tokio_postgres::config::Host::Unix(_) => return Ok(false),
64 }
65 }
66 Ok(true)
67 }
68}
69
70fn v2_params_to_v3(
71 params: Vec<v2_types::ParameterValue>,
72) -> Result<Vec<v4::ParameterValue>, v2::Error> {
73 params.into_iter().map(|p| p.try_into()).collect()
74}
75
76fn v3_params_to_v4(params: Vec<v3::ParameterValue>) -> Vec<v4::ParameterValue> {
77 params.into_iter().map(|p| p.into()).collect()
78}
79
80impl<CF: ClientFactory> v3::HostConnection for InstanceState<CF> {
81 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
82 async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
83 spin_factor_outbound_networking::record_address_fields(&address);
84
85 if !self
86 .is_address_allowed(&address)
87 .await
88 .map_err(|e| v3::Error::Other(e.to_string()))?
89 {
90 return Err(v3::Error::ConnectionFailed(format!(
91 "address {address} is not permitted"
92 )));
93 }
94 Ok(self.open_connection(&address).await?)
95 }
96
97 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
98 async fn execute(
99 &mut self,
100 connection: Resource<v3::Connection>,
101 statement: String,
102 params: Vec<v3::ParameterValue>,
103 ) -> Result<u64, v3::Error> {
104 Ok(self
105 .get_client(connection)
106 .await?
107 .execute(statement, v3_params_to_v4(params))
108 .await?)
109 }
110
111 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
112 async fn query(
113 &mut self,
114 connection: Resource<v3::Connection>,
115 statement: String,
116 params: Vec<v3::ParameterValue>,
117 ) -> Result<v3::RowSet, v3::Error> {
118 Ok(self
119 .get_client(connection)
120 .await?
121 .query(statement, v3_params_to_v4(params))
122 .await?
123 .into())
124 }
125
126 async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
127 self.connections.remove(connection.rep());
128 Ok(())
129 }
130}
131
132impl<CF: ClientFactory> v4::HostConnection for InstanceState<CF> {
133 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
134 async fn open(&mut self, address: String) -> Result<Resource<v4::Connection>, v4::Error> {
135 spin_factor_outbound_networking::record_address_fields(&address);
136
137 if !self
138 .is_address_allowed(&address)
139 .await
140 .map_err(|e| v4::Error::Other(e.to_string()))?
141 {
142 return Err(v4::Error::ConnectionFailed(format!(
143 "address {address} is not permitted"
144 )));
145 }
146 self.open_connection(&address).await
147 }
148
149 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
150 async fn execute(
151 &mut self,
152 connection: Resource<v4::Connection>,
153 statement: String,
154 params: Vec<v4::ParameterValue>,
155 ) -> Result<u64, v4::Error> {
156 self.get_client(connection)
157 .await?
158 .execute(statement, params)
159 .await
160 }
161
162 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
163 async fn query(
164 &mut self,
165 connection: Resource<v4::Connection>,
166 statement: String,
167 params: Vec<v4::ParameterValue>,
168 ) -> Result<v4::RowSet, v4::Error> {
169 self.get_client(connection)
170 .await?
171 .query(statement, params)
172 .await
173 }
174
175 async fn drop(&mut self, connection: Resource<v4::Connection>) -> anyhow::Result<()> {
176 self.connections.remove(connection.rep());
177 Ok(())
178 }
179}
180
181impl<CF: ClientFactory> v2_types::Host for InstanceState<CF> {
182 fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
183 Ok(error)
184 }
185}
186
187impl<CF: ClientFactory> v3::Host for InstanceState<CF> {
188 fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
189 Ok(error)
190 }
191}
192
193impl<CF: ClientFactory> v4::Host for InstanceState<CF> {
194 fn convert_error(&mut self, error: v4::Error) -> Result<v4::Error> {
195 Ok(error)
196 }
197}
198
199macro_rules! delegate {
201 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
202 if !$self.is_address_allowed(&$address).await.map_err(|e| v4::Error::Other(e.to_string()))? {
203 return Err(v1::PgError::ConnectionFailed(format!(
204 "address {} is not permitted", $address
205 )));
206 }
207 let connection = match $self.open_connection(&$address).await {
208 Ok(c) => c,
209 Err(e) => return Err(e.into()),
210 };
211 <Self as v4::HostConnection>::$name($self, connection, $($arg),*)
212 .await
213 .map_err(|e| e.into())
214 }};
215}
216
217impl<CF: ClientFactory> v2::Host for InstanceState<CF> {}
218
219impl<CF: ClientFactory> v2::HostConnection for InstanceState<CF> {
220 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
221 async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
222 spin_factor_outbound_networking::record_address_fields(&address);
223
224 if !self
225 .is_address_allowed(&address)
226 .await
227 .map_err(|e| v2::Error::Other(e.to_string()))?
228 {
229 return Err(v2::Error::ConnectionFailed(format!(
230 "address {address} is not permitted"
231 )));
232 }
233 Ok(self.open_connection(&address).await?)
234 }
235
236 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
237 async fn execute(
238 &mut self,
239 connection: Resource<v2::Connection>,
240 statement: String,
241 params: Vec<v2_types::ParameterValue>,
242 ) -> Result<u64, v2::Error> {
243 Ok(self
244 .get_client(connection)
245 .await?
246 .execute(statement, v2_params_to_v3(params)?)
247 .await?)
248 }
249
250 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
251 async fn query(
252 &mut self,
253 connection: Resource<v2::Connection>,
254 statement: String,
255 params: Vec<v2_types::ParameterValue>,
256 ) -> Result<v2_types::RowSet, v2::Error> {
257 Ok(self
258 .get_client(connection)
259 .await?
260 .query(statement, v2_params_to_v3(params)?)
261 .await?
262 .into())
263 }
264
265 async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
266 self.connections.remove(connection.rep());
267 Ok(())
268 }
269}
270
271impl<CF: ClientFactory> v1::Host for InstanceState<CF> {
272 async fn execute(
273 &mut self,
274 address: String,
275 statement: String,
276 params: Vec<v1_types::ParameterValue>,
277 ) -> Result<u64, v1::PgError> {
278 delegate!(self.execute(
279 address,
280 statement,
281 params
282 .into_iter()
283 .map(TryInto::try_into)
284 .collect::<Result<Vec<_>, _>>()?
285 ))
286 }
287
288 async fn query(
289 &mut self,
290 address: String,
291 statement: String,
292 params: Vec<v1_types::ParameterValue>,
293 ) -> Result<v1_types::RowSet, v1::PgError> {
294 delegate!(self.query(
295 address,
296 statement,
297 params
298 .into_iter()
299 .map(TryInto::try_into)
300 .collect::<Result<Vec<_>, _>>()?
301 ))
302 .map(Into::into)
303 }
304
305 fn convert_pg_error(&mut self, error: v1::PgError) -> Result<v1::PgError> {
306 Ok(error)
307 }
308}