spin_factor_outbound_mysql/
host.rs

1use anyhow::Result;
2use spin_core::wasmtime::component::Resource;
3use spin_world::v1::mysql as v1;
4use spin_world::v2::mysql::{self as v2, Connection};
5use spin_world::v2::rdbms_types as v2_types;
6use spin_world::v2::rdbms_types::ParameterValue;
7use tracing::field::Empty;
8use tracing::{instrument, Level};
9
10use crate::client::Client;
11use crate::InstanceState;
12
13impl<C: Client> InstanceState<C> {
14    async fn open_connection(&mut self, address: &str) -> Result<Resource<Connection>, v2::Error> {
15        self.connections
16            .push(
17                C::build_client(address)
18                    .await
19                    .map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
20            )
21            .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))
22            .map(Resource::new_own)
23    }
24
25    async fn get_client(&mut self, connection: Resource<Connection>) -> Result<&mut C, v2::Error> {
26        self.connections
27            .get_mut(connection.rep())
28            .ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into()))
29    }
30
31    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
32        self.allowed_hosts.check_url(address, "mysql").await
33    }
34}
35
36impl<C: Client> v2::Host for InstanceState<C> {}
37
38impl<C: Client> v2::HostConnection for InstanceState<C> {
39    #[instrument(name = "spin_outbound_mysql.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
40    async fn open(&mut self, address: String) -> Result<Resource<Connection>, v2::Error> {
41        spin_factor_outbound_networking::record_address_fields(&address);
42
43        if !self
44            .is_address_allowed(&address)
45            .await
46            .map_err(|e| v2::Error::Other(e.to_string()))?
47        {
48            return Err(v2::Error::ConnectionFailed(format!(
49                "address {address} is not permitted"
50            )));
51        }
52        self.open_connection(&address).await
53    }
54
55    #[instrument(name = "spin_outbound_mysql.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
56    async fn execute(
57        &mut self,
58        connection: Resource<Connection>,
59        statement: String,
60        params: Vec<ParameterValue>,
61    ) -> Result<(), v2::Error> {
62        self.get_client(connection)
63            .await?
64            .execute(statement, params)
65            .await
66    }
67
68    #[instrument(name = "spin_outbound_mysql.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
69    async fn query(
70        &mut self,
71        connection: Resource<Connection>,
72        statement: String,
73        params: Vec<ParameterValue>,
74    ) -> Result<v2_types::RowSet, v2::Error> {
75        self.get_client(connection)
76            .await?
77            .query(statement, params)
78            .await
79    }
80
81    async fn drop(&mut self, connection: Resource<Connection>) -> Result<()> {
82        self.connections.remove(connection.rep());
83        Ok(())
84    }
85}
86
87impl<C: Send> v2_types::Host for InstanceState<C> {
88    fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
89        Ok(error)
90    }
91}
92
93/// Delegate a function call to the v2::HostConnection implementation
94macro_rules! delegate {
95    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
96        if !$self.is_address_allowed(&$address).await.map_err(|e| v2::Error::Other(e.to_string()))? {
97            return Err(v1::MysqlError::ConnectionFailed(format!(
98                "address {} is not permitted", $address
99            )));
100        }
101        let connection = match $self.open_connection(&$address).await {
102            Ok(c) => c,
103            Err(e) => return Err(e.into()),
104        };
105        <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
106            .await
107            .map_err(Into::into)
108    }};
109}
110
111impl<C: Client> v1::Host for InstanceState<C> {
112    async fn execute(
113        &mut self,
114        address: String,
115        statement: String,
116        params: Vec<v1::ParameterValue>,
117    ) -> Result<(), v1::MysqlError> {
118        delegate!(self.execute(
119            address,
120            statement,
121            params.into_iter().map(Into::into).collect()
122        ))
123    }
124
125    async fn query(
126        &mut self,
127        address: String,
128        statement: String,
129        params: Vec<v1::ParameterValue>,
130    ) -> Result<v1::RowSet, v1::MysqlError> {
131        delegate!(self.query(
132            address,
133            statement,
134            params.into_iter().map(Into::into).collect()
135        ))
136        .map(Into::into)
137    }
138
139    fn convert_mysql_error(&mut self, error: v1::MysqlError) -> Result<v1::MysqlError> {
140        Ok(error)
141    }
142}