1use anyhow::Result;
2use spin_core::wasmtime::component::Resource;
3use spin_world::spin::postgres::postgres::{self as v3};
4use spin_world::v1::postgres as v1;
5use spin_world::v1::rdbms_types as v1_types;
6use spin_world::v2::postgres::{self as v2};
7use spin_world::v2::rdbms_types as v2_types;
8use tracing::field::Empty;
9use tracing::instrument;
10use tracing::Level;
11
12use crate::client::Client;
13use crate::InstanceState;
14
15impl<C: Client> InstanceState<C> {
16 async fn open_connection<Conn: 'static>(
17 &mut self,
18 address: &str,
19 ) -> Result<Resource<Conn>, v3::Error> {
20 self.connections
21 .push(
22 C::build_client(address)
23 .await
24 .map_err(|e| v3::Error::ConnectionFailed(format!("{e:?}")))?,
25 )
26 .map_err(|_| v3::Error::ConnectionFailed("too many connections".into()))
27 .map(Resource::new_own)
28 }
29
30 async fn get_client<Conn: 'static>(
31 &mut self,
32 connection: Resource<Conn>,
33 ) -> Result<&C, v3::Error> {
34 self.connections
35 .get(connection.rep())
36 .ok_or_else(|| v3::Error::ConnectionFailed("no connection found".into()))
37 }
38
39 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
40 let Ok(config) = address.parse::<tokio_postgres::Config>() else {
41 return Ok(false);
42 };
43 for (i, host) in config.get_hosts().iter().enumerate() {
44 match host {
45 tokio_postgres::config::Host::Tcp(address) => {
46 let ports = config.get_ports();
47 let port =
51 ports
52 .get(i)
53 .or_else(|| if ports.len() == 1 { ports.get(1) } else { None });
54 let port_str = port.map(|p| format!(":{}", p)).unwrap_or_default();
55 let url = format!("{address}{port_str}");
56 if !self.allowed_hosts.check_url(&url, "postgres").await? {
57 return Ok(false);
58 }
59 }
60 #[cfg(unix)]
61 tokio_postgres::config::Host::Unix(_) => return Ok(false),
62 }
63 }
64 Ok(true)
65 }
66}
67
68fn v2_params_to_v3(
69 params: Vec<v2_types::ParameterValue>,
70) -> Result<Vec<v3::ParameterValue>, v2::Error> {
71 params.into_iter().map(|p| p.try_into()).collect()
72}
73
74impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnection
75 for InstanceState<C>
76{
77 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
78 async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
79 spin_factor_outbound_networking::record_address_fields(&address);
80
81 if !self
82 .is_address_allowed(&address)
83 .await
84 .map_err(|e| v3::Error::Other(e.to_string()))?
85 {
86 return Err(v3::Error::ConnectionFailed(format!(
87 "address {address} is not permitted"
88 )));
89 }
90 self.open_connection(&address).await
91 }
92
93 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
94 async fn execute(
95 &mut self,
96 connection: Resource<v3::Connection>,
97 statement: String,
98 params: Vec<v3::ParameterValue>,
99 ) -> Result<u64, v3::Error> {
100 self.get_client(connection)
101 .await?
102 .execute(statement, params)
103 .await
104 }
105
106 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
107 async fn query(
108 &mut self,
109 connection: Resource<v3::Connection>,
110 statement: String,
111 params: Vec<v3::ParameterValue>,
112 ) -> Result<v3::RowSet, v3::Error> {
113 self.get_client(connection)
114 .await?
115 .query(statement, params)
116 .await
117 }
118
119 async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
120 self.connections.remove(connection.rep());
121 Ok(())
122 }
123}
124
125impl<C: Send> v2_types::Host for InstanceState<C> {
126 fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
127 Ok(error)
128 }
129}
130
131impl<C: Send + Sync + Client> v3::Host for InstanceState<C> {
132 fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
133 Ok(error)
134 }
135}
136
137macro_rules! delegate {
139 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
140 if !$self.is_address_allowed(&$address).await.map_err(|e| v3::Error::Other(e.to_string()))? {
141 return Err(v1::PgError::ConnectionFailed(format!(
142 "address {} is not permitted", $address
143 )));
144 }
145 let connection = match $self.open_connection(&$address).await {
146 Ok(c) => c,
147 Err(e) => return Err(e.into()),
148 };
149 <Self as v3::HostConnection>::$name($self, connection, $($arg),*)
150 .await
151 .map_err(|e| e.into())
152 }};
153}
154
155impl<C: Send + Sync + Client> v2::Host for InstanceState<C> {}
156
157impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
158 #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
159 async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
160 spin_factor_outbound_networking::record_address_fields(&address);
161
162 if !self
163 .is_address_allowed(&address)
164 .await
165 .map_err(|e| v2::Error::Other(e.to_string()))?
166 {
167 return Err(v2::Error::ConnectionFailed(format!(
168 "address {address} is not permitted"
169 )));
170 }
171 Ok(self.open_connection(&address).await?)
172 }
173
174 #[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
175 async fn execute(
176 &mut self,
177 connection: Resource<v2::Connection>,
178 statement: String,
179 params: Vec<v2_types::ParameterValue>,
180 ) -> Result<u64, v2::Error> {
181 Ok(self
182 .get_client(connection)
183 .await?
184 .execute(statement, v2_params_to_v3(params)?)
185 .await?)
186 }
187
188 #[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
189 async fn query(
190 &mut self,
191 connection: Resource<v2::Connection>,
192 statement: String,
193 params: Vec<v2_types::ParameterValue>,
194 ) -> Result<v2_types::RowSet, v2::Error> {
195 Ok(self
196 .get_client(connection)
197 .await?
198 .query(statement, v2_params_to_v3(params)?)
199 .await?
200 .into())
201 }
202
203 async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
204 self.connections.remove(connection.rep());
205 Ok(())
206 }
207}
208
209impl<C: Send + Sync + Client> v1::Host for InstanceState<C> {
210 async fn execute(
211 &mut self,
212 address: String,
213 statement: String,
214 params: Vec<v1_types::ParameterValue>,
215 ) -> Result<u64, v1::PgError> {
216 delegate!(self.execute(
217 address,
218 statement,
219 params
220 .into_iter()
221 .map(TryInto::try_into)
222 .collect::<Result<Vec<_>, _>>()?
223 ))
224 }
225
226 async fn query(
227 &mut self,
228 address: String,
229 statement: String,
230 params: Vec<v1_types::ParameterValue>,
231 ) -> Result<v1_types::RowSet, v1::PgError> {
232 delegate!(self.query(
233 address,
234 statement,
235 params
236 .into_iter()
237 .map(TryInto::try_into)
238 .collect::<Result<Vec<_>, _>>()?
239 ))
240 .map(Into::into)
241 }
242
243 fn convert_pg_error(&mut self, error: v1::PgError) -> Result<v1::PgError> {
244 Ok(error)
245 }
246}