spin_factor_outbound_mysql/
client.rs1use std::sync::Arc;
2
3use anyhow::{anyhow, Result};
4use futures::stream::TryStreamExt as _;
5use mysql_async::consts::ColumnType;
6use mysql_async::prelude::{FromValue, Queryable as _};
7use mysql_async::{from_value_opt, Conn as MysqlClient, Opts, OptsBuilder, SslOpts};
8use spin_core::async_trait;
9use spin_world::v2::mysql::{self as v2};
10use spin_world::v2::rdbms_types::{
11 self as v2_types, Column, DbDataType, DbValue, ParameterValue, RowSet,
12};
13use url::Url;
14
15#[async_trait]
16pub trait Client: Send + Sync + 'static {
17 async fn build_client(address: &str) -> Result<Self>
18 where
19 Self: Sized;
20
21 async fn execute(
22 &mut self,
23 statement: String,
24 params: Vec<ParameterValue>,
25 ) -> Result<(), v2::Error>;
26
27 async fn query(
28 &mut self,
29 statement: String,
30 params: Vec<ParameterValue>,
31 max_result_bytes: usize,
32 ) -> Result<RowSet, v2::Error>;
33}
34
35#[async_trait]
36impl Client for MysqlClient {
37 async fn build_client(address: &str) -> Result<Self>
38 where
39 Self: Sized,
40 {
41 tracing::debug!("Build new connection: {}", address);
42
43 let opts = build_opts(address)?;
44
45 let connection_pool = mysql_async::Pool::new(opts);
46
47 connection_pool.get_conn().await.map_err(|e| anyhow!(e))
48 }
49
50 async fn execute(
51 &mut self,
52 statement: String,
53 params: Vec<ParameterValue>,
54 ) -> Result<(), v2::Error> {
55 let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
56 let parameters = mysql_async::Params::Positional(db_params);
57
58 self.exec_batch(&statement, &[parameters])
59 .await
60 .map_err(|e| v2::Error::QueryFailed(format!("{e:?}")))
61 }
62
63 async fn query(
64 &mut self,
65 statement: String,
66 params: Vec<ParameterValue>,
67 max_result_bytes: usize,
68 ) -> Result<RowSet, v2::Error> {
69 let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
70 let parameters = mysql_async::Params::Positional(db_params);
71
72 let mut query_result = self
73 .exec_iter(&statement, parameters)
74 .await
75 .map_err(|e| v2::Error::QueryFailed(format!("{e:?}")))?;
76
77 let columns = convert_columns(query_result.columns());
78
79 let mut query_result = query_result
80 .stream()
81 .await
82 .map_err(|e| v2::Error::Other(e.to_string()))?
83 .ok_or_else(|| v2::Error::Other("unable to stream query result".into()))?;
84
85 let mut rows = Vec::new();
86 let mut byte_count = std::mem::size_of::<RowSet>();
87 while let Some(row) = query_result
88 .try_next()
89 .await
90 .map_err(|e| v2::Error::Other(e.to_string()))?
91 {
92 let row = convert_row(row, &columns)?;
93 byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
94 if byte_count > max_result_bytes {
95 return Err(v2::Error::Other(format!(
96 "query result exceeds limit of {max_result_bytes} bytes"
97 )));
98 }
99 rows.push(row);
100 }
101
102 Ok(v2_types::RowSet { columns, rows })
103 }
104}
105
106fn to_sql_parameter(value: ParameterValue) -> mysql_async::Value {
107 match value {
108 ParameterValue::Boolean(v) => mysql_async::Value::from(v),
109 ParameterValue::Int32(v) => mysql_async::Value::from(v),
110 ParameterValue::Int64(v) => mysql_async::Value::from(v),
111 ParameterValue::Int8(v) => mysql_async::Value::from(v),
112 ParameterValue::Int16(v) => mysql_async::Value::from(v),
113 ParameterValue::Floating32(v) => mysql_async::Value::from(v),
114 ParameterValue::Floating64(v) => mysql_async::Value::from(v),
115 ParameterValue::Uint8(v) => mysql_async::Value::from(v),
116 ParameterValue::Uint16(v) => mysql_async::Value::from(v),
117 ParameterValue::Uint32(v) => mysql_async::Value::from(v),
118 ParameterValue::Uint64(v) => mysql_async::Value::from(v),
119 ParameterValue::Str(v) => mysql_async::Value::from(v),
120 ParameterValue::Binary(v) => mysql_async::Value::from(v),
121 ParameterValue::DbNull => mysql_async::Value::NULL,
122 }
123}
124
125fn convert_columns(columns: Option<Arc<[mysql_async::Column]>>) -> Vec<Column> {
126 match columns {
127 Some(columns) => columns.iter().map(convert_column).collect(),
128 None => vec![],
129 }
130}
131
132fn convert_column(column: &mysql_async::Column) -> Column {
133 let name = column.name_str().into_owned();
134 let data_type = convert_data_type(column);
135
136 Column { name, data_type }
137}
138
139fn convert_data_type(column: &mysql_async::Column) -> DbDataType {
140 let column_type = column.column_type();
141
142 if column_type.is_numeric_type() {
143 convert_numeric_type(column)
144 } else if column_type.is_character_type() {
145 convert_character_type(column)
146 } else {
147 DbDataType::Other
148 }
149}
150
151fn convert_character_type(column: &mysql_async::Column) -> DbDataType {
152 match (column.column_type(), is_binary(column)) {
153 (ColumnType::MYSQL_TYPE_BLOB, false) => DbDataType::Str, (ColumnType::MYSQL_TYPE_BLOB, _) => DbDataType::Binary,
155 (ColumnType::MYSQL_TYPE_LONG_BLOB, _) => DbDataType::Binary,
156 (ColumnType::MYSQL_TYPE_MEDIUM_BLOB, _) => DbDataType::Binary,
157 (ColumnType::MYSQL_TYPE_STRING, true) => DbDataType::Binary, (ColumnType::MYSQL_TYPE_STRING, _) => DbDataType::Str,
159 (ColumnType::MYSQL_TYPE_VAR_STRING, true) => DbDataType::Binary, (ColumnType::MYSQL_TYPE_VAR_STRING, _) => DbDataType::Str,
161 (_, _) => DbDataType::Other,
162 }
163}
164
165fn convert_numeric_type(column: &mysql_async::Column) -> DbDataType {
166 match (column.column_type(), is_signed(column)) {
167 (ColumnType::MYSQL_TYPE_DOUBLE, _) => DbDataType::Floating64,
168 (ColumnType::MYSQL_TYPE_FLOAT, _) => DbDataType::Floating32,
169 (ColumnType::MYSQL_TYPE_INT24, true) => DbDataType::Int32,
170 (ColumnType::MYSQL_TYPE_INT24, false) => DbDataType::Uint32,
171 (ColumnType::MYSQL_TYPE_LONG, true) => DbDataType::Int32,
172 (ColumnType::MYSQL_TYPE_LONG, false) => DbDataType::Uint32,
173 (ColumnType::MYSQL_TYPE_LONGLONG, true) => DbDataType::Int64,
174 (ColumnType::MYSQL_TYPE_LONGLONG, false) => DbDataType::Uint64,
175 (ColumnType::MYSQL_TYPE_SHORT, true) => DbDataType::Int16,
176 (ColumnType::MYSQL_TYPE_SHORT, false) => DbDataType::Uint16,
177 (ColumnType::MYSQL_TYPE_TINY, true) => DbDataType::Int8,
178 (ColumnType::MYSQL_TYPE_TINY, false) => DbDataType::Uint8,
179 (_, _) => DbDataType::Other,
180 }
181}
182
183fn is_signed(column: &mysql_async::Column) -> bool {
184 !column
185 .flags()
186 .contains(mysql_async::consts::ColumnFlags::UNSIGNED_FLAG)
187}
188
189fn is_binary(column: &mysql_async::Column) -> bool {
190 column
191 .flags()
192 .contains(mysql_async::consts::ColumnFlags::BINARY_FLAG)
193}
194
195fn convert_row(mut row: mysql_async::Row, columns: &[Column]) -> Result<Vec<DbValue>, v2::Error> {
196 let mut result = Vec::with_capacity(row.len());
197 for index in 0..row.len() {
198 result.push(convert_entry(&mut row, index, columns)?);
199 }
200 Ok(result)
201}
202
203fn convert_entry(
204 row: &mut mysql_async::Row,
205 index: usize,
206 columns: &[Column],
207) -> Result<DbValue, v2::Error> {
208 match (row.take(index), columns.get(index)) {
209 (None, _) => Ok(DbValue::DbNull), (_, None) => Err(v2::Error::Other(format!(
211 "Can't get column at index {index}"
212 ))),
213 (Some(mysql_async::Value::NULL), _) => Ok(DbValue::DbNull),
214 (Some(value), Some(column)) => convert_value(value, column),
215 }
216}
217
218fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue, v2::Error> {
219 match column.data_type {
220 DbDataType::Binary => convert_value_to::<Vec<u8>>(value).map(DbValue::Binary),
221 DbDataType::Boolean => convert_value_to::<bool>(value).map(DbValue::Boolean),
222 DbDataType::Floating32 => convert_value_to::<f32>(value).map(DbValue::Floating32),
223 DbDataType::Floating64 => convert_value_to::<f64>(value).map(DbValue::Floating64),
224 DbDataType::Int8 => convert_value_to::<i8>(value).map(DbValue::Int8),
225 DbDataType::Int16 => convert_value_to::<i16>(value).map(DbValue::Int16),
226 DbDataType::Int32 => convert_value_to::<i32>(value).map(DbValue::Int32),
227 DbDataType::Int64 => convert_value_to::<i64>(value).map(DbValue::Int64),
228 DbDataType::Str => convert_value_to::<String>(value).map(DbValue::Str),
229 DbDataType::Uint8 => convert_value_to::<u8>(value).map(DbValue::Uint8),
230 DbDataType::Uint16 => convert_value_to::<u16>(value).map(DbValue::Uint16),
231 DbDataType::Uint32 => convert_value_to::<u32>(value).map(DbValue::Uint32),
232 DbDataType::Uint64 => convert_value_to::<u64>(value).map(DbValue::Uint64),
233 DbDataType::Other => Err(v2::Error::ValueConversionFailed(format!(
234 "Cannot convert value {:?} in column {} data type {:?}",
235 value, column.name, column.data_type
236 ))),
237 }
238}
239
240fn is_ssl_param(s: &str) -> bool {
241 ["ssl-mode", "sslmode"].contains(&s.to_lowercase().as_str())
242}
243
244fn build_opts(address: &str) -> Result<Opts, mysql_async::Error> {
252 let url = Url::parse(address)?;
253
254 let use_ssl = url
255 .query_pairs()
256 .any(|(k, v)| is_ssl_param(&k) && v.to_lowercase() != "disabled");
257
258 let query_without_ssl: Vec<(_, _)> = url
259 .query_pairs()
260 .filter(|(k, _v)| !is_ssl_param(k))
261 .collect();
262 let mut cleaned_url = url.clone();
263 cleaned_url.set_query(None);
264 cleaned_url
265 .query_pairs_mut()
266 .extend_pairs(query_without_ssl);
267
268 Ok(OptsBuilder::from_opts(cleaned_url.as_str())
269 .ssl_opts(if use_ssl {
270 Some(SslOpts::default())
271 } else {
272 None
273 })
274 .into())
275}
276
277fn convert_value_to<T: FromValue>(value: mysql_async::Value) -> Result<T, v2::Error> {
278 from_value_opt::<T>(value).map_err(|e| v2::Error::ValueConversionFailed(format!("{e}")))
279}
280
281#[cfg(test)]
282mod test {
283 use super::*;
284
285 #[test]
286 fn test_mysql_address_without_ssl_mode() {
287 assert!(build_opts("mysql://myuser:password@127.0.0.1/db")
288 .unwrap()
289 .ssl_opts()
290 .is_none())
291 }
292
293 #[test]
294 fn test_mysql_address_with_ssl_mode_disabled() {
295 assert!(
296 build_opts("mysql://myuser:password@127.0.0.1/db?ssl-mode=DISABLED")
297 .unwrap()
298 .ssl_opts()
299 .is_none()
300 )
301 }
302
303 #[test]
304 fn test_mysql_address_with_ssl_mode_verify_ca() {
305 assert!(
306 build_opts("mysql://myuser:password@127.0.0.1/db?sslMode=VERIFY_CA")
307 .unwrap()
308 .ssl_opts()
309 .is_some()
310 )
311 }
312
313 #[test]
314 fn test_mysql_address_with_more_to_query() {
315 let address = "mysql://myuser:password@127.0.0.1/db?SsLmOdE=VERIFY_CA&pool_max=10";
316 assert!(build_opts(address).unwrap().ssl_opts().is_some());
317 assert_eq!(
318 build_opts(address).unwrap().pool_opts().constraints().max(),
319 10
320 )
321 }
322}