Skip to main content

spin_factor_outbound_mysql/
host.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader};
5use spin_factor_outbound_networking::ConnectionPermit;
6use spin_telemetry::traces::{self, Blame};
7use spin_world::MAX_HOST_BUFFERED_BYTES;
8use spin_world::spin::mysql::mysql as v3;
9use spin_world::v1::mysql as v1;
10use spin_world::v2::mysql as v2;
11use spin_world::v2::rdbms_types as v2_types;
12use tokio::sync::Mutex;
13use tracing::field::Empty;
14use tracing::{Level, instrument};
15
16use crate::client::Client;
17use crate::{InstanceState, InstanceStateInner, MysqlFactorData};
18
19impl<C: Client> InstanceStateInner<C> {
20    async fn open_connection(
21        &mut self,
22        address: &str,
23        permit: ConnectionPermit,
24    ) -> Result<u32, v2::Error> {
25        spin_factor_outbound_networking::record_address_fields(address);
26
27        if !self.is_address_allowed(address).await.map_err(|e| {
28            // The allow-list check infrastructure itself failed; that's a
29            // host problem, not anything the guest did wrong.
30            let err = v2::Error::Other(e.to_string());
31            traces::mark_as_error(&err, Some(Blame::Host));
32            err
33        })? {
34            // The check succeeded but returned false: the guest supplied an
35            // address that isn't on the allow list.
36            let err = v2::Error::ConnectionFailed(format!("address {address} is not permitted"));
37            traces::mark_as_error(&err, Some(Blame::Guest));
38            return Err(err);
39        }
40        let client = C::build_client(address).await.map_err(|e| {
41            // The guest supplies the address and credentials; connection
42            // failures (wrong password, TLS error, unreachable host, etc.)
43            // are the guest's problem.
44            let err = v2::Error::ConnectionFailed(format!("{e:?}"));
45            traces::mark_as_error(&err, Some(Blame::Guest));
46            err
47        })?;
48        self.connections
49            .push((Arc::new(Mutex::new(client)), permit))
50            .map_err(|_| {
51                // The guest exceeded the host-imposed connection limit.
52                let err = v2::Error::ConnectionFailed("too many connections".into());
53                traces::mark_as_error(&err, Some(Blame::Guest));
54                err
55            })
56    }
57
58    fn get_client(&mut self, connection: u32) -> Result<Arc<Mutex<C>>, v2::Error> {
59        self.connections
60            .get(connection)
61            .map(|(conn, _permit)| conn.clone())
62            .ok_or_else(|| {
63                // The connection table is managed entirely by the host, so a
64                // missing handle indicates a host-side bug, not a guest mistake.
65                let err = v2::Error::ConnectionFailed("no connection found".into());
66                traces::mark_as_error(&err, Some(Blame::Host));
67                err
68            })
69    }
70
71    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
72        self.allowed_hosts.check_url(address, "mysql").await
73    }
74}
75
76impl<C: Client> v3::Host for InstanceState<C> {
77    fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
78        Ok(error)
79    }
80}
81
82impl<C: Client> v3::HostConnection for InstanceState<C> {
83    async fn drop(&mut self, connection: Resource<v3::Connection>) -> Result<()> {
84        let mut state = self.inner.lock().await;
85        state.connections.remove(connection.rep());
86        Ok(())
87    }
88}
89
90type QueryTuple = (
91    Vec<v3::Column>,
92    StreamReader<v3::Row>,
93    FutureReader<Result<(), v3::Error>>,
94);
95
96impl<C: Client> v3::HostConnectionWithStore for MysqlFactorData<C> {
97    #[instrument(name = "spin_outbound_mysql.open", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
98    async fn open<T>(
99        accessor: &Accessor<T, Self>,
100        address: String,
101    ) -> Result<Resource<v3::Connection>, v3::Error> {
102        let (state_arc, semaphore) = accessor.with(|mut access| {
103            let host = access.get();
104            (host.inner.clone(), host.semaphore.clone())
105        });
106        let permit = semaphore
107            .acquire()
108            .await
109            .map_err(|_| v3::Error::ConnectionFailed("too many connections".into()))?;
110        let mut state = state_arc.lock().await;
111        state.otel.reparent_tracing_span();
112        Ok(Resource::new_own(
113            state.open_connection(&address, permit).await?,
114        ))
115    }
116
117    #[instrument(name = "spin_outbound_mysql.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
118    async fn execute<T>(
119        accessor: &Accessor<T, Self>,
120        connection: Resource<v3::Connection>,
121        statement: String,
122        params: Vec<v3::ParameterValue>,
123    ) -> Result<(), v3::Error> {
124        let state = accessor.with(|mut access| access.get().inner.clone());
125        let client = {
126            let mut state = state.lock().await;
127            state.otel.reparent_tracing_span();
128            state.get_client(connection.rep())?
129        };
130        client
131            .lock()
132            .await
133            .execute(statement, params.into_iter().map(Into::into).collect())
134            .await
135            .map_err(track_db_error_on_span)?;
136        Ok(())
137    }
138
139    #[instrument(name = "spin_outbound_mysql.query", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
140    async fn query<T>(
141        accessor: &Accessor<T, Self>,
142        connection: Resource<v3::Connection>,
143        statement: String,
144        params: Vec<v3::ParameterValue>,
145    ) -> Result<QueryTuple, v3::Error> {
146        let state = accessor.with(|mut access| access.get().inner.clone());
147        let client = {
148            let mut state = state.lock().await;
149            state.otel.reparent_tracing_span();
150            state.get_client(connection.rep())?
151        };
152
153        let (columns, stream, future) =
154            C::query_async(client, statement, params, MAX_HOST_BUFFERED_BYTES)
155                .await
156                .map_err(|v| v3::Error::from(track_db_error_on_span(v2::Error::from(v))))?;
157
158        let (stream, future) = accessor
159            .with(|mut access| {
160                anyhow::Ok((
161                    StreamReader::new(&mut access, spin_wasi_async::stream::producer(stream))?,
162                    FutureReader::new(&mut access, future)?,
163                ))
164            })
165            .map_err(|e| {
166                // Setting up the async stream/future channels is a host
167                // implementation detail; if it fails, that's a host bug.
168                let err = v3::Error::Other(e.to_string());
169                traces::mark_as_error(&err, Some(Blame::Host));
170                err
171            })?;
172
173        Ok((columns, stream, future))
174    }
175}
176
177impl<C: Client> v2::Host for InstanceState<C> {}
178
179impl<C: Client> v2::HostConnection for InstanceState<C> {
180    #[instrument(name = "spin_outbound_mysql.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
181    async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
182        let permit = self
183            .semaphore
184            .acquire()
185            .await
186            .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?;
187        let mut state = self.inner.lock().await;
188        state.otel.reparent_tracing_span();
189        state
190            .open_connection(&address, permit)
191            .await
192            .map(Resource::new_own)
193    }
194
195    #[instrument(name = "spin_outbound_mysql.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
196    async fn execute(
197        &mut self,
198        connection: Resource<v2::Connection>,
199        statement: String,
200        params: Vec<v2_types::ParameterValue>,
201    ) -> Result<(), v2::Error> {
202        let mut state = self.inner.lock().await;
203        state.otel.reparent_tracing_span();
204        state
205            .get_client(connection.rep())?
206            .lock()
207            .await
208            .execute(statement, params)
209            .await
210            .map_err(track_db_error_on_span)
211    }
212
213    #[instrument(name = "spin_outbound_mysql.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
214    async fn query(
215        &mut self,
216        connection: Resource<v2::Connection>,
217        statement: String,
218        params: Vec<v2_types::ParameterValue>,
219    ) -> Result<v2_types::RowSet, v2::Error> {
220        let mut state = self.inner.lock().await;
221        state.otel.reparent_tracing_span();
222        state
223            .get_client(connection.rep())?
224            .lock()
225            .await
226            .query(statement, params, MAX_HOST_BUFFERED_BYTES)
227            .await
228            .map_err(track_db_error_on_span)
229    }
230
231    async fn drop(&mut self, connection: Resource<v2::Connection>) -> Result<()> {
232        let mut state = self.inner.lock().await;
233        state.connections.remove(connection.rep());
234        Ok(())
235    }
236}
237
238impl<C: Send> v2_types::Host for InstanceState<C> {
239    fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
240        Ok(error)
241    }
242}
243
244/// Delegate a function call to the v2::HostConnection implementation
245macro_rules! delegate {
246    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
247        let permit = $self
248            .semaphore
249            .acquire()
250            .await
251            .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?;
252        let connection = {
253            let mut state = $self.inner.lock().await;
254            Resource::new_own(state.open_connection(&$address, permit).await?)
255        };
256        // v1 has no persistent connections, so remove the table entry immediately
257        // after the call to release the semaphore permit.
258        let rep = connection.rep();
259        let result = <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
260            .await
261            .map_err(Into::into);
262        $self.inner.lock().await.connections.remove(rep);
263        result
264    }};
265}
266
267impl<C: Client> v1::Host for InstanceState<C> {
268    async fn execute(
269        &mut self,
270        address: String,
271        statement: String,
272        params: Vec<v1::ParameterValue>,
273    ) -> Result<(), v1::MysqlError> {
274        delegate!(self.execute(
275            address,
276            statement,
277            params.into_iter().map(Into::into).collect()
278        ))
279    }
280
281    async fn query(
282        &mut self,
283        address: String,
284        statement: String,
285        params: Vec<v1::ParameterValue>,
286    ) -> Result<v1::RowSet, v1::MysqlError> {
287        delegate!(self.query(
288            address,
289            statement,
290            params.into_iter().map(Into::into).collect()
291        ))
292        .map(Into::into)
293    }
294
295    fn convert_mysql_error(&mut self, error: v1::MysqlError) -> Result<v1::MysqlError> {
296        Ok(error)
297    }
298}
299
300/// Only for actual DB client calls (execute/query).
301/// Blame is inferred from the error variant returned by the DB driver.
302fn track_db_error_on_span(err: v2::Error) -> v2::Error {
303    let blame = match &err {
304        // The guest brings their own database, so connection failures during
305        // execution (dropped connection, auth rejected mid-session, etc.) are
306        // the guest's problem, not the host's.
307        v2::Error::ConnectionFailed(_) => Blame::Guest,
308        v2::Error::BadParameter(_) => Blame::Guest,
309        v2::Error::QueryFailed(_) => Blame::Guest,
310        // The host is responsible for mapping DB wire types to WIT types;
311        // a conversion failure is a host-side limitation or bug.
312        v2::Error::ValueConversionFailed(_) => Blame::Host,
313        v2::Error::Other(_) => Blame::Host,
314    };
315    traces::mark_as_error(&err, Some(blame));
316    err
317}