Skip to main content

spin_factor_outbound_mysql/
host.rs

1use anyhow::Result;
2use spin_core::wasmtime::component::Resource;
3use spin_world::v1::mysql as v1;
4use spin_world::v2::mysql::{self as v2, Connection};
5use spin_world::v2::rdbms_types as v2_types;
6use spin_world::v2::rdbms_types::ParameterValue;
7use spin_world::MAX_HOST_BUFFERED_BYTES;
8use tracing::field::Empty;
9use tracing::{instrument, Level};
10
11use crate::client::Client;
12use crate::InstanceState;
13
14impl<C: Client> InstanceState<C> {
15    async fn open_connection(&mut self, address: &str) -> Result<Resource<Connection>, v2::Error> {
16        self.connections
17            .push(
18                C::build_client(address)
19                    .await
20                    .map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
21            )
22            .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))
23            .map(Resource::new_own)
24    }
25
26    async fn get_client(&mut self, connection: Resource<Connection>) -> Result<&mut C, v2::Error> {
27        self.connections
28            .get_mut(connection.rep())
29            .ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into()))
30    }
31
32    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
33        self.allowed_hosts.check_url(address, "mysql").await
34    }
35}
36
37impl<C: Client> v2::Host for InstanceState<C> {}
38
39impl<C: Client> v2::HostConnection for InstanceState<C> {
40    #[instrument(name = "spin_outbound_mysql.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
41    async fn open(&mut self, address: String) -> Result<Resource<Connection>, v2::Error> {
42        self.otel.reparent_tracing_span();
43        spin_factor_outbound_networking::record_address_fields(&address);
44
45        if !self
46            .is_address_allowed(&address)
47            .await
48            .map_err(|e| v2::Error::Other(e.to_string()))?
49        {
50            return Err(v2::Error::ConnectionFailed(format!(
51                "address {address} is not permitted"
52            )));
53        }
54        self.open_connection(&address).await
55    }
56
57    #[instrument(name = "spin_outbound_mysql.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
58    async fn execute(
59        &mut self,
60        connection: Resource<Connection>,
61        statement: String,
62        params: Vec<ParameterValue>,
63    ) -> Result<(), v2::Error> {
64        self.otel.reparent_tracing_span();
65        self.get_client(connection)
66            .await?
67            .execute(statement, params)
68            .await
69    }
70
71    #[instrument(name = "spin_outbound_mysql.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
72    async fn query(
73        &mut self,
74        connection: Resource<Connection>,
75        statement: String,
76        params: Vec<ParameterValue>,
77    ) -> Result<v2_types::RowSet, v2::Error> {
78        self.otel.reparent_tracing_span();
79        self.get_client(connection)
80            .await?
81            .query(statement, params, MAX_HOST_BUFFERED_BYTES)
82            .await
83    }
84
85    async fn drop(&mut self, connection: Resource<Connection>) -> Result<()> {
86        self.connections.remove(connection.rep());
87        Ok(())
88    }
89}
90
91impl<C: Send> v2_types::Host for InstanceState<C> {
92    fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
93        Ok(error)
94    }
95}
96
97/// Delegate a function call to the v2::HostConnection implementation
98macro_rules! delegate {
99    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
100        if !$self.is_address_allowed(&$address).await.map_err(|e| v2::Error::Other(e.to_string()))? {
101            return Err(v1::MysqlError::ConnectionFailed(format!(
102                "address {} is not permitted", $address
103            )));
104        }
105        let connection = match $self.open_connection(&$address).await {
106            Ok(c) => c,
107            Err(e) => return Err(e.into()),
108        };
109        <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
110            .await
111            .map_err(Into::into)
112    }};
113}
114
115impl<C: Client> v1::Host for InstanceState<C> {
116    async fn execute(
117        &mut self,
118        address: String,
119        statement: String,
120        params: Vec<v1::ParameterValue>,
121    ) -> Result<(), v1::MysqlError> {
122        delegate!(self.execute(
123            address,
124            statement,
125            params.into_iter().map(Into::into).collect()
126        ))
127    }
128
129    async fn query(
130        &mut self,
131        address: String,
132        statement: String,
133        params: Vec<v1::ParameterValue>,
134    ) -> Result<v1::RowSet, v1::MysqlError> {
135        delegate!(self.query(
136            address,
137            statement,
138            params.into_iter().map(Into::into).collect()
139        ))
140        .map(Into::into)
141    }
142
143    fn convert_mysql_error(&mut self, error: v1::MysqlError) -> Result<v1::MysqlError> {
144        Ok(error)
145    }
146}