1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use futures::{future::FutureExt as _, stream::TryStreamExt as _};
5use mysql_async::consts::ColumnType;
6use mysql_async::prelude::{FromValue, Queryable as _};
7use mysql_async::{Conn as MysqlClient, Opts, OptsBuilder, SslOpts, from_value_opt};
8use spin_core::async_trait;
9use spin_world::spin::mysql::mysql as v3;
10use spin_world::v2::mysql::{self as v2};
11use spin_world::v2::rdbms_types::{
12 self as v2_types, Column, DbDataType, DbValue, ParameterValue, RowSet,
13};
14use tokio::sync::{Mutex, mpsc, oneshot};
15use url::Url;
16
17#[async_trait]
18pub trait Client: Send + Sync + 'static {
19 async fn build_client(address: &str) -> Result<Self>
20 where
21 Self: Sized;
22
23 async fn execute(
24 &mut self,
25 statement: String,
26 params: Vec<ParameterValue>,
27 ) -> Result<(), v2::Error>;
28
29 async fn query(
30 &mut self,
31 statement: String,
32 params: Vec<ParameterValue>,
33 max_result_bytes: usize,
34 ) -> Result<RowSet, v2::Error>;
35
36 async fn query_async(
37 client: Arc<Mutex<Self>>,
38 statement: String,
39 params: Vec<v3::ParameterValue>,
40 max_result_bytes: usize,
41 ) -> Result<
42 (
43 Vec<v3::Column>,
44 mpsc::Receiver<v3::Row>,
45 oneshot::Receiver<Result<(), v3::Error>>,
46 ),
47 v3::Error,
48 >;
49}
50
51#[async_trait]
52impl Client for MysqlClient {
53 async fn build_client(address: &str) -> Result<Self>
54 where
55 Self: Sized,
56 {
57 tracing::debug!("Build new connection: {}", address);
58
59 let opts = build_opts(address)?;
60
61 let connection_pool = mysql_async::Pool::new(opts);
62
63 connection_pool.get_conn().await.map_err(|e| anyhow!(e))
64 }
65
66 async fn execute(
67 &mut self,
68 statement: String,
69 params: Vec<ParameterValue>,
70 ) -> Result<(), v2::Error> {
71 let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
72 let parameters = mysql_async::Params::Positional(db_params);
73
74 self.exec_batch(&statement, &[parameters])
75 .await
76 .map_err(|e| v2::Error::QueryFailed(format!("{e:?}")))
77 }
78
79 async fn query(
80 &mut self,
81 statement: String,
82 params: Vec<ParameterValue>,
83 max_result_bytes: usize,
84 ) -> Result<RowSet, v2::Error> {
85 let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
86 let parameters = mysql_async::Params::Positional(db_params);
87
88 let mut query_result = self
89 .exec_iter(&statement, parameters)
90 .await
91 .map_err(|e| v2::Error::QueryFailed(format!("{e:?}")))?;
92
93 let columns = convert_columns(query_result.columns());
94
95 let mut query_result = query_result
96 .stream()
97 .await
98 .map_err(|e| v2::Error::Other(e.to_string()))?
99 .ok_or_else(|| v2::Error::Other("unable to stream query result".into()))?;
100
101 let mut rows = Vec::new();
102 let mut byte_count = std::mem::size_of::<RowSet>();
103 while let Some(row) = query_result
104 .try_next()
105 .await
106 .map_err(|e| v2::Error::Other(e.to_string()))?
107 {
108 let row = convert_row(row, &columns)?;
109 byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
110 if byte_count > max_result_bytes {
111 return Err(v2::Error::Other(format!(
112 "query result exceeds limit of {max_result_bytes} bytes"
113 )));
114 }
115 rows.push(row);
116 }
117
118 Ok(v2_types::RowSet { columns, rows })
119 }
120
121 async fn query_async(
122 client: Arc<Mutex<Self>>,
123 statement: String,
124 params: Vec<v3::ParameterValue>,
125 max_result_bytes: usize,
126 ) -> Result<
127 (
128 Vec<v3::Column>,
129 mpsc::Receiver<v3::Row>,
130 oneshot::Receiver<Result<(), v3::Error>>,
131 ),
132 v3::Error,
133 > {
134 let db_params = params
135 .into_iter()
136 .map(|v| to_sql_parameter(v2_types::ParameterValue::from(v)))
137 .collect::<Vec<_>>();
138 let parameters = mysql_async::Params::Positional(db_params);
139
140 let (rows_tx, rows_rx) = mpsc::channel(4);
141 let (err_tx, err_rx) = oneshot::channel();
142 let (columns_tx, columns_rx) = oneshot::channel();
143
144 let mut byte_count = std::mem::size_of::<(
145 Vec<v3::Column>,
146 mpsc::Receiver<v3::Row>,
147 oneshot::Receiver<Result<(), v3::Error>>,
148 )>();
149
150 tokio::spawn(
151 async move {
152 let mut client = client.lock().await;
153
154 let mut query_result = client
155 .exec_iter(&statement, parameters)
156 .await
157 .map_err(|e| v3::Error::QueryFailed(format!("{e:?}")))?;
158
159 let columns = convert_columns(query_result.columns());
160
161 _ = columns_tx.send(
162 columns
163 .iter()
164 .map(|v| v3::Column::from(v.clone()))
165 .collect(),
166 );
167
168 let mut query_result = query_result
169 .stream()
170 .await
171 .map_err(|e| v3::Error::Other(e.to_string()))?
172 .ok_or_else(|| v3::Error::Other("unable to stream query result".into()))?;
173
174 while let Some(row) = query_result
175 .try_next()
176 .await
177 .map_err(|e| v3::Error::Other(e.to_string()))?
178 {
179 let row = convert_row(row, &columns).map_err(v3::Error::from)?;
180
181 byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
182 if byte_count > max_result_bytes {
183 return Err(v3::Error::Other(format!(
184 "query result exceeds limit of {max_result_bytes} bytes"
185 )));
186 }
187
188 rows_tx
189 .send(row.into_iter().map(v3::DbValue::from).collect())
190 .await
191 .map_err(|e| v3::Error::Other(format!("async error: {e}")))?;
192 }
193
194 Ok(())
195 }
196 .map(move |result| {
197 _ = err_tx.send(result);
198 }),
199 );
200
201 let columns = columns_rx
202 .await
203 .map_err(|e| v3::Error::Other(format!("async error: {e}")))?;
204
205 Ok((columns, rows_rx, err_rx))
206 }
207}
208
209fn to_sql_parameter(value: ParameterValue) -> mysql_async::Value {
210 match value {
211 ParameterValue::Boolean(v) => mysql_async::Value::from(v),
212 ParameterValue::Int32(v) => mysql_async::Value::from(v),
213 ParameterValue::Int64(v) => mysql_async::Value::from(v),
214 ParameterValue::Int8(v) => mysql_async::Value::from(v),
215 ParameterValue::Int16(v) => mysql_async::Value::from(v),
216 ParameterValue::Floating32(v) => mysql_async::Value::from(v),
217 ParameterValue::Floating64(v) => mysql_async::Value::from(v),
218 ParameterValue::Uint8(v) => mysql_async::Value::from(v),
219 ParameterValue::Uint16(v) => mysql_async::Value::from(v),
220 ParameterValue::Uint32(v) => mysql_async::Value::from(v),
221 ParameterValue::Uint64(v) => mysql_async::Value::from(v),
222 ParameterValue::Str(v) => mysql_async::Value::from(v),
223 ParameterValue::Binary(v) => mysql_async::Value::from(v),
224 ParameterValue::DbNull => mysql_async::Value::NULL,
225 }
226}
227
228fn convert_columns(columns: Option<Arc<[mysql_async::Column]>>) -> Vec<Column> {
229 match columns {
230 Some(columns) => columns.iter().map(convert_column).collect(),
231 None => vec![],
232 }
233}
234
235fn convert_column(column: &mysql_async::Column) -> Column {
236 let name = column.name_str().into_owned();
237 let data_type = convert_data_type(column);
238
239 Column { name, data_type }
240}
241
242fn convert_data_type(column: &mysql_async::Column) -> DbDataType {
243 let column_type = column.column_type();
244
245 if column_type.is_numeric_type() {
246 convert_numeric_type(column)
247 } else if column_type.is_character_type() {
248 convert_character_type(column)
249 } else {
250 DbDataType::Other
251 }
252}
253
254fn convert_character_type(column: &mysql_async::Column) -> DbDataType {
255 match (column.column_type(), is_binary(column)) {
256 (ColumnType::MYSQL_TYPE_BLOB, false) => DbDataType::Str, (ColumnType::MYSQL_TYPE_BLOB, _) => DbDataType::Binary,
258 (ColumnType::MYSQL_TYPE_LONG_BLOB, _) => DbDataType::Binary,
259 (ColumnType::MYSQL_TYPE_MEDIUM_BLOB, _) => DbDataType::Binary,
260 (ColumnType::MYSQL_TYPE_STRING, true) => DbDataType::Binary, (ColumnType::MYSQL_TYPE_STRING, _) => DbDataType::Str,
262 (ColumnType::MYSQL_TYPE_VAR_STRING, true) => DbDataType::Binary, (ColumnType::MYSQL_TYPE_VAR_STRING, _) => DbDataType::Str,
264 (_, _) => DbDataType::Other,
265 }
266}
267
268fn convert_numeric_type(column: &mysql_async::Column) -> DbDataType {
269 match (column.column_type(), is_signed(column)) {
270 (ColumnType::MYSQL_TYPE_DOUBLE, _) => DbDataType::Floating64,
271 (ColumnType::MYSQL_TYPE_FLOAT, _) => DbDataType::Floating32,
272 (ColumnType::MYSQL_TYPE_INT24, true) => DbDataType::Int32,
273 (ColumnType::MYSQL_TYPE_INT24, false) => DbDataType::Uint32,
274 (ColumnType::MYSQL_TYPE_LONG, true) => DbDataType::Int32,
275 (ColumnType::MYSQL_TYPE_LONG, false) => DbDataType::Uint32,
276 (ColumnType::MYSQL_TYPE_LONGLONG, true) => DbDataType::Int64,
277 (ColumnType::MYSQL_TYPE_LONGLONG, false) => DbDataType::Uint64,
278 (ColumnType::MYSQL_TYPE_SHORT, true) => DbDataType::Int16,
279 (ColumnType::MYSQL_TYPE_SHORT, false) => DbDataType::Uint16,
280 (ColumnType::MYSQL_TYPE_TINY, true) => DbDataType::Int8,
281 (ColumnType::MYSQL_TYPE_TINY, false) => DbDataType::Uint8,
282 (_, _) => DbDataType::Other,
283 }
284}
285
286fn is_signed(column: &mysql_async::Column) -> bool {
287 !column
288 .flags()
289 .contains(mysql_async::consts::ColumnFlags::UNSIGNED_FLAG)
290}
291
292fn is_binary(column: &mysql_async::Column) -> bool {
293 column
294 .flags()
295 .contains(mysql_async::consts::ColumnFlags::BINARY_FLAG)
296}
297
298fn convert_row(mut row: mysql_async::Row, columns: &[Column]) -> Result<Vec<DbValue>, v2::Error> {
299 let mut result = Vec::with_capacity(row.len());
300 for index in 0..row.len() {
301 result.push(convert_entry(&mut row, index, columns)?);
302 }
303 Ok(result)
304}
305
306fn convert_entry(
307 row: &mut mysql_async::Row,
308 index: usize,
309 columns: &[Column],
310) -> Result<DbValue, v2::Error> {
311 match (row.take(index), columns.get(index)) {
312 (None, _) => Ok(DbValue::DbNull), (_, None) => Err(v2::Error::Other(format!(
314 "Can't get column at index {index}"
315 ))),
316 (Some(mysql_async::Value::NULL), _) => Ok(DbValue::DbNull),
317 (Some(value), Some(column)) => convert_value(value, column),
318 }
319}
320
321fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue, v2::Error> {
322 match column.data_type {
323 DbDataType::Binary => convert_value_to::<Vec<u8>>(value).map(DbValue::Binary),
324 DbDataType::Boolean => convert_value_to::<bool>(value).map(DbValue::Boolean),
325 DbDataType::Floating32 => convert_value_to::<f32>(value).map(DbValue::Floating32),
326 DbDataType::Floating64 => convert_value_to::<f64>(value).map(DbValue::Floating64),
327 DbDataType::Int8 => convert_value_to::<i8>(value).map(DbValue::Int8),
328 DbDataType::Int16 => convert_value_to::<i16>(value).map(DbValue::Int16),
329 DbDataType::Int32 => convert_value_to::<i32>(value).map(DbValue::Int32),
330 DbDataType::Int64 => convert_value_to::<i64>(value).map(DbValue::Int64),
331 DbDataType::Str => convert_value_to::<String>(value).map(DbValue::Str),
332 DbDataType::Uint8 => convert_value_to::<u8>(value).map(DbValue::Uint8),
333 DbDataType::Uint16 => convert_value_to::<u16>(value).map(DbValue::Uint16),
334 DbDataType::Uint32 => convert_value_to::<u32>(value).map(DbValue::Uint32),
335 DbDataType::Uint64 => convert_value_to::<u64>(value).map(DbValue::Uint64),
336 DbDataType::Other => Err(v2::Error::ValueConversionFailed(format!(
337 "Cannot convert value {:?} in column {} data type {:?}",
338 value, column.name, column.data_type
339 ))),
340 }
341}
342
343fn is_ssl_param(s: &str) -> bool {
344 ["ssl-mode", "sslmode"].contains(&s.to_lowercase().as_str())
345}
346
347fn build_opts(address: &str) -> Result<Opts, mysql_async::Error> {
355 let url = Url::parse(address)?;
356
357 let use_ssl = url
358 .query_pairs()
359 .any(|(k, v)| is_ssl_param(&k) && v.to_lowercase() != "disabled");
360
361 let query_without_ssl: Vec<(_, _)> = url
362 .query_pairs()
363 .filter(|(k, _v)| !is_ssl_param(k))
364 .collect();
365 let mut cleaned_url = url.clone();
366 cleaned_url.set_query(None);
367 cleaned_url
368 .query_pairs_mut()
369 .extend_pairs(query_without_ssl);
370
371 Ok(OptsBuilder::from_opts(cleaned_url.as_str())
372 .ssl_opts(if use_ssl {
373 Some(SslOpts::default())
374 } else {
375 None
376 })
377 .into())
378}
379
380fn convert_value_to<T: FromValue>(value: mysql_async::Value) -> Result<T, v2::Error> {
381 from_value_opt::<T>(value).map_err(|e| v2::Error::ValueConversionFailed(format!("{e}")))
382}
383
384#[cfg(test)]
385mod test {
386 use super::*;
387
388 #[test]
389 fn test_mysql_address_without_ssl_mode() {
390 assert!(
391 build_opts("mysql://myuser:password@127.0.0.1/db")
392 .unwrap()
393 .ssl_opts()
394 .is_none()
395 )
396 }
397
398 #[test]
399 fn test_mysql_address_with_ssl_mode_disabled() {
400 assert!(
401 build_opts("mysql://myuser:password@127.0.0.1/db?ssl-mode=DISABLED")
402 .unwrap()
403 .ssl_opts()
404 .is_none()
405 )
406 }
407
408 #[test]
409 fn test_mysql_address_with_ssl_mode_verify_ca() {
410 assert!(
411 build_opts("mysql://myuser:password@127.0.0.1/db?sslMode=VERIFY_CA")
412 .unwrap()
413 .ssl_opts()
414 .is_some()
415 )
416 }
417
418 #[test]
419 fn test_mysql_address_with_more_to_query() {
420 let address = "mysql://myuser:password@127.0.0.1/db?SsLmOdE=VERIFY_CA&pool_max=10";
421 assert!(build_opts(address).unwrap().ssl_opts().is_some());
422 assert_eq!(
423 build_opts(address).unwrap().pool_opts().constraints().max(),
424 10
425 )
426 }
427}