spin_factor_outbound_mysql/
host.rs1use anyhow::Result;
2use spin_core::wasmtime::component::Resource;
3use spin_world::v1::mysql as v1;
4use spin_world::v2::mysql::{self as v2, Connection};
5use spin_world::v2::rdbms_types as v2_types;
6use spin_world::v2::rdbms_types::ParameterValue;
7use tracing::field::Empty;
8use tracing::{instrument, Level};
9
10use crate::client::Client;
11use crate::InstanceState;
12
13impl<C: Client> InstanceState<C> {
14 async fn open_connection(&mut self, address: &str) -> Result<Resource<Connection>, v2::Error> {
15 self.connections
16 .push(
17 C::build_client(address)
18 .await
19 .map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
20 )
21 .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))
22 .map(Resource::new_own)
23 }
24
25 async fn get_client(&mut self, connection: Resource<Connection>) -> Result<&mut C, v2::Error> {
26 self.connections
27 .get_mut(connection.rep())
28 .ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into()))
29 }
30
31 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
32 self.allowed_hosts.check_url(address, "mysql").await
33 }
34}
35
36impl<C: Client> v2::Host for InstanceState<C> {}
37
38impl<C: Client> v2::HostConnection for InstanceState<C> {
39 #[instrument(name = "spin_outbound_mysql.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
40 async fn open(&mut self, address: String) -> Result<Resource<Connection>, v2::Error> {
41 spin_factor_outbound_networking::record_address_fields(&address);
42
43 if !self
44 .is_address_allowed(&address)
45 .await
46 .map_err(|e| v2::Error::Other(e.to_string()))?
47 {
48 return Err(v2::Error::ConnectionFailed(format!(
49 "address {address} is not permitted"
50 )));
51 }
52 self.open_connection(&address).await
53 }
54
55 #[instrument(name = "spin_outbound_mysql.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
56 async fn execute(
57 &mut self,
58 connection: Resource<Connection>,
59 statement: String,
60 params: Vec<ParameterValue>,
61 ) -> Result<(), v2::Error> {
62 self.get_client(connection)
63 .await?
64 .execute(statement, params)
65 .await
66 }
67
68 #[instrument(name = "spin_outbound_mysql.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
69 async fn query(
70 &mut self,
71 connection: Resource<Connection>,
72 statement: String,
73 params: Vec<ParameterValue>,
74 ) -> Result<v2_types::RowSet, v2::Error> {
75 self.get_client(connection)
76 .await?
77 .query(statement, params)
78 .await
79 }
80
81 async fn drop(&mut self, connection: Resource<Connection>) -> Result<()> {
82 self.connections.remove(connection.rep());
83 Ok(())
84 }
85}
86
87impl<C: Send> v2_types::Host for InstanceState<C> {
88 fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
89 Ok(error)
90 }
91}
92
93macro_rules! delegate {
95 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
96 if !$self.is_address_allowed(&$address).await.map_err(|e| v2::Error::Other(e.to_string()))? {
97 return Err(v1::MysqlError::ConnectionFailed(format!(
98 "address {} is not permitted", $address
99 )));
100 }
101 let connection = match $self.open_connection(&$address).await {
102 Ok(c) => c,
103 Err(e) => return Err(e.into()),
104 };
105 <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
106 .await
107 .map_err(Into::into)
108 }};
109}
110
111impl<C: Client> v1::Host for InstanceState<C> {
112 async fn execute(
113 &mut self,
114 address: String,
115 statement: String,
116 params: Vec<v1::ParameterValue>,
117 ) -> Result<(), v1::MysqlError> {
118 delegate!(self.execute(
119 address,
120 statement,
121 params.into_iter().map(Into::into).collect()
122 ))
123 }
124
125 async fn query(
126 &mut self,
127 address: String,
128 statement: String,
129 params: Vec<v1::ParameterValue>,
130 ) -> Result<v1::RowSet, v1::MysqlError> {
131 delegate!(self.query(
132 address,
133 statement,
134 params.into_iter().map(Into::into).collect()
135 ))
136 .map(Into::into)
137 }
138
139 fn convert_mysql_error(&mut self, error: v1::MysqlError) -> Result<v1::MysqlError> {
140 Ok(error)
141 }
142}