spin_factor_outbound_mysql/
host.rs1use std::sync::Arc;
2
3use anyhow::Result;
4use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader};
5use spin_factor_outbound_networking::ConnectionPermit;
6use spin_telemetry::traces::{self, Blame};
7use spin_world::MAX_HOST_BUFFERED_BYTES;
8use spin_world::spin::mysql::mysql as v3;
9use spin_world::v1::mysql as v1;
10use spin_world::v2::mysql as v2;
11use spin_world::v2::rdbms_types as v2_types;
12use tokio::sync::Mutex;
13use tracing::field::Empty;
14use tracing::{Level, instrument};
15
16use crate::client::Client;
17use crate::{InstanceState, InstanceStateInner, MysqlFactorData};
18
19impl<C: Client> InstanceStateInner<C> {
20 async fn open_connection(
21 &mut self,
22 address: &str,
23 permit: ConnectionPermit,
24 ) -> Result<u32, v2::Error> {
25 spin_factor_outbound_networking::record_address_fields(address);
26
27 if !self.is_address_allowed(address).await.map_err(|e| {
28 let err = v2::Error::Other(e.to_string());
31 traces::mark_as_error(&err, Some(Blame::Host));
32 err
33 })? {
34 let err = v2::Error::ConnectionFailed(format!("address {address} is not permitted"));
37 traces::mark_as_error(&err, Some(Blame::Guest));
38 return Err(err);
39 }
40 let client = C::build_client(address).await.map_err(|e| {
41 let err = v2::Error::ConnectionFailed(format!("{e:?}"));
45 traces::mark_as_error(&err, Some(Blame::Guest));
46 err
47 })?;
48 self.connections
49 .push((Arc::new(Mutex::new(client)), permit))
50 .map_err(|_| {
51 let err = v2::Error::ConnectionFailed("too many connections".into());
53 traces::mark_as_error(&err, Some(Blame::Guest));
54 err
55 })
56 }
57
58 fn get_client(&mut self, connection: u32) -> Result<Arc<Mutex<C>>, v2::Error> {
59 self.connections
60 .get(connection)
61 .map(|(conn, _permit)| conn.clone())
62 .ok_or_else(|| {
63 let err = v2::Error::ConnectionFailed("no connection found".into());
66 traces::mark_as_error(&err, Some(Blame::Host));
67 err
68 })
69 }
70
71 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
72 self.allowed_hosts.check_url(address, "mysql").await
73 }
74}
75
76impl<C: Client> v3::Host for InstanceState<C> {
77 fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
78 Ok(error)
79 }
80}
81
82impl<C: Client> v3::HostConnection for InstanceState<C> {
83 async fn drop(&mut self, connection: Resource<v3::Connection>) -> Result<()> {
84 let mut state = self.inner.lock().await;
85 state.connections.remove(connection.rep());
86 Ok(())
87 }
88}
89
90type QueryTuple = (
91 Vec<v3::Column>,
92 StreamReader<v3::Row>,
93 FutureReader<Result<(), v3::Error>>,
94);
95
96impl<C: Client> v3::HostConnectionWithStore for MysqlFactorData<C> {
97 #[instrument(name = "spin_outbound_mysql.open", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
98 async fn open<T>(
99 accessor: &Accessor<T, Self>,
100 address: String,
101 ) -> Result<Resource<v3::Connection>, v3::Error> {
102 let (state_arc, semaphore) = accessor.with(|mut access| {
103 let host = access.get();
104 (host.inner.clone(), host.semaphore.clone())
105 });
106 let permit = semaphore
107 .acquire()
108 .await
109 .map_err(|_| v3::Error::ConnectionFailed("too many connections".into()))?;
110 let mut state = state_arc.lock().await;
111 state.otel.reparent_tracing_span();
112 Ok(Resource::new_own(
113 state.open_connection(&address, permit).await?,
114 ))
115 }
116
117 #[instrument(name = "spin_outbound_mysql.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
118 async fn execute<T>(
119 accessor: &Accessor<T, Self>,
120 connection: Resource<v3::Connection>,
121 statement: String,
122 params: Vec<v3::ParameterValue>,
123 ) -> Result<(), v3::Error> {
124 let state = accessor.with(|mut access| access.get().inner.clone());
125 let client = {
126 let mut state = state.lock().await;
127 state.otel.reparent_tracing_span();
128 state.get_client(connection.rep())?
129 };
130 client
131 .lock()
132 .await
133 .execute(statement, params.into_iter().map(Into::into).collect())
134 .await
135 .map_err(track_db_error_on_span)?;
136 Ok(())
137 }
138
139 #[instrument(name = "spin_outbound_mysql.query", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
140 async fn query<T>(
141 accessor: &Accessor<T, Self>,
142 connection: Resource<v3::Connection>,
143 statement: String,
144 params: Vec<v3::ParameterValue>,
145 ) -> Result<QueryTuple, v3::Error> {
146 let state = accessor.with(|mut access| access.get().inner.clone());
147 let client = {
148 let mut state = state.lock().await;
149 state.otel.reparent_tracing_span();
150 state.get_client(connection.rep())?
151 };
152
153 let (columns, stream, future) =
154 C::query_async(client, statement, params, MAX_HOST_BUFFERED_BYTES)
155 .await
156 .map_err(|v| v3::Error::from(track_db_error_on_span(v2::Error::from(v))))?;
157
158 let (stream, future) = accessor
159 .with(|mut access| {
160 anyhow::Ok((
161 StreamReader::new(&mut access, spin_wasi_async::stream::producer(stream))?,
162 FutureReader::new(&mut access, future)?,
163 ))
164 })
165 .map_err(|e| {
166 let err = v3::Error::Other(e.to_string());
169 traces::mark_as_error(&err, Some(Blame::Host));
170 err
171 })?;
172
173 Ok((columns, stream, future))
174 }
175}
176
177impl<C: Client> v2::Host for InstanceState<C> {}
178
179impl<C: Client> v2::HostConnection for InstanceState<C> {
180 #[instrument(name = "spin_outbound_mysql.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
181 async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
182 let permit = self
183 .semaphore
184 .acquire()
185 .await
186 .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?;
187 let mut state = self.inner.lock().await;
188 state.otel.reparent_tracing_span();
189 state
190 .open_connection(&address, permit)
191 .await
192 .map(Resource::new_own)
193 }
194
195 #[instrument(name = "spin_outbound_mysql.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
196 async fn execute(
197 &mut self,
198 connection: Resource<v2::Connection>,
199 statement: String,
200 params: Vec<v2_types::ParameterValue>,
201 ) -> Result<(), v2::Error> {
202 let mut state = self.inner.lock().await;
203 state.otel.reparent_tracing_span();
204 state
205 .get_client(connection.rep())?
206 .lock()
207 .await
208 .execute(statement, params)
209 .await
210 .map_err(track_db_error_on_span)
211 }
212
213 #[instrument(name = "spin_outbound_mysql.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
214 async fn query(
215 &mut self,
216 connection: Resource<v2::Connection>,
217 statement: String,
218 params: Vec<v2_types::ParameterValue>,
219 ) -> Result<v2_types::RowSet, v2::Error> {
220 let mut state = self.inner.lock().await;
221 state.otel.reparent_tracing_span();
222 state
223 .get_client(connection.rep())?
224 .lock()
225 .await
226 .query(statement, params, MAX_HOST_BUFFERED_BYTES)
227 .await
228 .map_err(track_db_error_on_span)
229 }
230
231 async fn drop(&mut self, connection: Resource<v2::Connection>) -> Result<()> {
232 let mut state = self.inner.lock().await;
233 state.connections.remove(connection.rep());
234 Ok(())
235 }
236}
237
238impl<C: Send> v2_types::Host for InstanceState<C> {
239 fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
240 Ok(error)
241 }
242}
243
244macro_rules! delegate {
246 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
247 let permit = $self
248 .semaphore
249 .acquire()
250 .await
251 .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?;
252 let connection = {
253 let mut state = $self.inner.lock().await;
254 Resource::new_own(state.open_connection(&$address, permit).await?)
255 };
256 let rep = connection.rep();
259 let result = <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
260 .await
261 .map_err(Into::into);
262 $self.inner.lock().await.connections.remove(rep);
263 result
264 }};
265}
266
267impl<C: Client> v1::Host for InstanceState<C> {
268 async fn execute(
269 &mut self,
270 address: String,
271 statement: String,
272 params: Vec<v1::ParameterValue>,
273 ) -> Result<(), v1::MysqlError> {
274 delegate!(self.execute(
275 address,
276 statement,
277 params.into_iter().map(Into::into).collect()
278 ))
279 }
280
281 async fn query(
282 &mut self,
283 address: String,
284 statement: String,
285 params: Vec<v1::ParameterValue>,
286 ) -> Result<v1::RowSet, v1::MysqlError> {
287 delegate!(self.query(
288 address,
289 statement,
290 params.into_iter().map(Into::into).collect()
291 ))
292 .map(Into::into)
293 }
294
295 fn convert_mysql_error(&mut self, error: v1::MysqlError) -> Result<v1::MysqlError> {
296 Ok(error)
297 }
298}
299
300fn track_db_error_on_span(err: v2::Error) -> v2::Error {
303 let blame = match &err {
304 v2::Error::ConnectionFailed(_) => Blame::Guest,
308 v2::Error::BadParameter(_) => Blame::Guest,
309 v2::Error::QueryFailed(_) => Blame::Guest,
310 v2::Error::ValueConversionFailed(_) => Blame::Host,
313 v2::Error::Other(_) => Blame::Host,
314 };
315 traces::mark_as_error(&err, Some(blame));
316 err
317}