spin_factor_outbound_mysql/
host.rs1use anyhow::Result;
2use spin_core::wasmtime::component::Resource;
3use spin_world::v1::mysql as v1;
4use spin_world::v2::mysql::{self as v2, Connection};
5use spin_world::v2::rdbms_types as v2_types;
6use spin_world::v2::rdbms_types::ParameterValue;
7use spin_world::MAX_HOST_BUFFERED_BYTES;
8use tracing::field::Empty;
9use tracing::{instrument, Level};
10
11use crate::client::Client;
12use crate::InstanceState;
13
14impl<C: Client> InstanceState<C> {
15 async fn open_connection(&mut self, address: &str) -> Result<Resource<Connection>, v2::Error> {
16 self.connections
17 .push(
18 C::build_client(address)
19 .await
20 .map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
21 )
22 .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))
23 .map(Resource::new_own)
24 }
25
26 async fn get_client(&mut self, connection: Resource<Connection>) -> Result<&mut C, v2::Error> {
27 self.connections
28 .get_mut(connection.rep())
29 .ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into()))
30 }
31
32 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
33 self.allowed_hosts.check_url(address, "mysql").await
34 }
35}
36
37impl<C: Client> v2::Host for InstanceState<C> {}
38
39impl<C: Client> v2::HostConnection for InstanceState<C> {
40 #[instrument(name = "spin_outbound_mysql.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
41 async fn open(&mut self, address: String) -> Result<Resource<Connection>, v2::Error> {
42 self.otel.reparent_tracing_span();
43 spin_factor_outbound_networking::record_address_fields(&address);
44
45 if !self
46 .is_address_allowed(&address)
47 .await
48 .map_err(|e| v2::Error::Other(e.to_string()))?
49 {
50 return Err(v2::Error::ConnectionFailed(format!(
51 "address {address} is not permitted"
52 )));
53 }
54 self.open_connection(&address).await
55 }
56
57 #[instrument(name = "spin_outbound_mysql.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
58 async fn execute(
59 &mut self,
60 connection: Resource<Connection>,
61 statement: String,
62 params: Vec<ParameterValue>,
63 ) -> Result<(), v2::Error> {
64 self.otel.reparent_tracing_span();
65 self.get_client(connection)
66 .await?
67 .execute(statement, params)
68 .await
69 }
70
71 #[instrument(name = "spin_outbound_mysql.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))]
72 async fn query(
73 &mut self,
74 connection: Resource<Connection>,
75 statement: String,
76 params: Vec<ParameterValue>,
77 ) -> Result<v2_types::RowSet, v2::Error> {
78 self.otel.reparent_tracing_span();
79 self.get_client(connection)
80 .await?
81 .query(statement, params, MAX_HOST_BUFFERED_BYTES)
82 .await
83 }
84
85 async fn drop(&mut self, connection: Resource<Connection>) -> Result<()> {
86 self.connections.remove(connection.rep());
87 Ok(())
88 }
89}
90
91impl<C: Send> v2_types::Host for InstanceState<C> {
92 fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
93 Ok(error)
94 }
95}
96
97macro_rules! delegate {
99 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
100 if !$self.is_address_allowed(&$address).await.map_err(|e| v2::Error::Other(e.to_string()))? {
101 return Err(v1::MysqlError::ConnectionFailed(format!(
102 "address {} is not permitted", $address
103 )));
104 }
105 let connection = match $self.open_connection(&$address).await {
106 Ok(c) => c,
107 Err(e) => return Err(e.into()),
108 };
109 <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
110 .await
111 .map_err(Into::into)
112 }};
113}
114
115impl<C: Client> v1::Host for InstanceState<C> {
116 async fn execute(
117 &mut self,
118 address: String,
119 statement: String,
120 params: Vec<v1::ParameterValue>,
121 ) -> Result<(), v1::MysqlError> {
122 delegate!(self.execute(
123 address,
124 statement,
125 params.into_iter().map(Into::into).collect()
126 ))
127 }
128
129 async fn query(
130 &mut self,
131 address: String,
132 statement: String,
133 params: Vec<v1::ParameterValue>,
134 ) -> Result<v1::RowSet, v1::MysqlError> {
135 delegate!(self.query(
136 address,
137 statement,
138 params.into_iter().map(Into::into).collect()
139 ))
140 .map(Into::into)
141 }
142
143 fn convert_mysql_error(&mut self, error: v1::MysqlError) -> Result<v1::MysqlError> {
144 Ok(error)
145 }
146}