Skip to main content

spin_factor_outbound_mysql/
client.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use futures::{future::FutureExt as _, stream::TryStreamExt as _};
5use mysql_async::consts::ColumnType;
6use mysql_async::prelude::{FromValue, Queryable as _};
7use mysql_async::{Conn as MysqlClient, Opts, OptsBuilder, SslOpts, from_value_opt};
8use spin_core::async_trait;
9use spin_world::spin::mysql::mysql as v3;
10use spin_world::v2::mysql::{self as v2};
11use spin_world::v2::rdbms_types::{
12    self as v2_types, Column, DbDataType, DbValue, ParameterValue, RowSet,
13};
14use tokio::sync::{Mutex, mpsc, oneshot};
15use url::Url;
16
17#[async_trait]
18pub trait Client: Send + Sync + 'static {
19    async fn build_client(address: &str) -> Result<Self>
20    where
21        Self: Sized;
22
23    async fn execute(
24        &mut self,
25        statement: String,
26        params: Vec<ParameterValue>,
27    ) -> Result<(), v2::Error>;
28
29    async fn query(
30        &mut self,
31        statement: String,
32        params: Vec<ParameterValue>,
33        max_result_bytes: usize,
34    ) -> Result<RowSet, v2::Error>;
35
36    async fn query_async(
37        client: Arc<Mutex<Self>>,
38        statement: String,
39        params: Vec<v3::ParameterValue>,
40        max_result_bytes: usize,
41    ) -> Result<
42        (
43            Vec<v3::Column>,
44            mpsc::Receiver<v3::Row>,
45            oneshot::Receiver<Result<(), v3::Error>>,
46        ),
47        v3::Error,
48    >;
49}
50
51#[async_trait]
52impl Client for MysqlClient {
53    async fn build_client(address: &str) -> Result<Self>
54    where
55        Self: Sized,
56    {
57        tracing::debug!("Build new connection: {}", address);
58
59        let opts = build_opts(address)?;
60
61        let connection_pool = mysql_async::Pool::new(opts);
62
63        connection_pool.get_conn().await.map_err(|e| anyhow!(e))
64    }
65
66    async fn execute(
67        &mut self,
68        statement: String,
69        params: Vec<ParameterValue>,
70    ) -> Result<(), v2::Error> {
71        let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
72        let parameters = mysql_async::Params::Positional(db_params);
73
74        self.exec_batch(&statement, &[parameters])
75            .await
76            .map_err(|e| v2::Error::QueryFailed(format!("{e:?}")))
77    }
78
79    async fn query(
80        &mut self,
81        statement: String,
82        params: Vec<ParameterValue>,
83        max_result_bytes: usize,
84    ) -> Result<RowSet, v2::Error> {
85        let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
86        let parameters = mysql_async::Params::Positional(db_params);
87
88        let mut query_result = self
89            .exec_iter(&statement, parameters)
90            .await
91            .map_err(|e| v2::Error::QueryFailed(format!("{e:?}")))?;
92
93        let columns = convert_columns(query_result.columns());
94
95        let mut query_result = query_result
96            .stream()
97            .await
98            .map_err(|e| v2::Error::Other(e.to_string()))?
99            .ok_or_else(|| v2::Error::Other("unable to stream query result".into()))?;
100
101        let mut rows = Vec::new();
102        let mut byte_count = std::mem::size_of::<RowSet>();
103        while let Some(row) = query_result
104            .try_next()
105            .await
106            .map_err(|e| v2::Error::Other(e.to_string()))?
107        {
108            let row = convert_row(row, &columns)?;
109            byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
110            if byte_count > max_result_bytes {
111                return Err(v2::Error::Other(format!(
112                    "query result exceeds limit of {max_result_bytes} bytes"
113                )));
114            }
115            rows.push(row);
116        }
117
118        Ok(v2_types::RowSet { columns, rows })
119    }
120
121    async fn query_async(
122        client: Arc<Mutex<Self>>,
123        statement: String,
124        params: Vec<v3::ParameterValue>,
125        max_result_bytes: usize,
126    ) -> Result<
127        (
128            Vec<v3::Column>,
129            mpsc::Receiver<v3::Row>,
130            oneshot::Receiver<Result<(), v3::Error>>,
131        ),
132        v3::Error,
133    > {
134        let db_params = params
135            .into_iter()
136            .map(|v| to_sql_parameter(v2_types::ParameterValue::from(v)))
137            .collect::<Vec<_>>();
138        let parameters = mysql_async::Params::Positional(db_params);
139
140        let (rows_tx, rows_rx) = mpsc::channel(4);
141        let (err_tx, err_rx) = oneshot::channel();
142        let (columns_tx, columns_rx) = oneshot::channel();
143
144        let mut byte_count = std::mem::size_of::<(
145            Vec<v3::Column>,
146            mpsc::Receiver<v3::Row>,
147            oneshot::Receiver<Result<(), v3::Error>>,
148        )>();
149
150        tokio::spawn(
151            async move {
152                let mut client = client.lock().await;
153
154                let mut query_result = client
155                    .exec_iter(&statement, parameters)
156                    .await
157                    .map_err(|e| v3::Error::QueryFailed(format!("{e:?}")))?;
158
159                let columns = convert_columns(query_result.columns());
160
161                _ = columns_tx.send(
162                    columns
163                        .iter()
164                        .map(|v| v3::Column::from(v.clone()))
165                        .collect(),
166                );
167
168                let mut query_result = query_result
169                    .stream()
170                    .await
171                    .map_err(|e| v3::Error::Other(e.to_string()))?
172                    .ok_or_else(|| v3::Error::Other("unable to stream query result".into()))?;
173
174                while let Some(row) = query_result
175                    .try_next()
176                    .await
177                    .map_err(|e| v3::Error::Other(e.to_string()))?
178                {
179                    let row = convert_row(row, &columns).map_err(v3::Error::from)?;
180
181                    byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
182                    if byte_count > max_result_bytes {
183                        return Err(v3::Error::Other(format!(
184                            "query result exceeds limit of {max_result_bytes} bytes"
185                        )));
186                    }
187
188                    rows_tx
189                        .send(row.into_iter().map(v3::DbValue::from).collect())
190                        .await
191                        .map_err(|e| v3::Error::Other(format!("async error: {e}")))?;
192                }
193
194                Ok(())
195            }
196            .map(move |result| {
197                _ = err_tx.send(result);
198            }),
199        );
200
201        let columns = columns_rx
202            .await
203            .map_err(|e| v3::Error::Other(format!("async error: {e}")))?;
204
205        Ok((columns, rows_rx, err_rx))
206    }
207}
208
209fn to_sql_parameter(value: ParameterValue) -> mysql_async::Value {
210    match value {
211        ParameterValue::Boolean(v) => mysql_async::Value::from(v),
212        ParameterValue::Int32(v) => mysql_async::Value::from(v),
213        ParameterValue::Int64(v) => mysql_async::Value::from(v),
214        ParameterValue::Int8(v) => mysql_async::Value::from(v),
215        ParameterValue::Int16(v) => mysql_async::Value::from(v),
216        ParameterValue::Floating32(v) => mysql_async::Value::from(v),
217        ParameterValue::Floating64(v) => mysql_async::Value::from(v),
218        ParameterValue::Uint8(v) => mysql_async::Value::from(v),
219        ParameterValue::Uint16(v) => mysql_async::Value::from(v),
220        ParameterValue::Uint32(v) => mysql_async::Value::from(v),
221        ParameterValue::Uint64(v) => mysql_async::Value::from(v),
222        ParameterValue::Str(v) => mysql_async::Value::from(v),
223        ParameterValue::Binary(v) => mysql_async::Value::from(v),
224        ParameterValue::DbNull => mysql_async::Value::NULL,
225    }
226}
227
228fn convert_columns(columns: Option<Arc<[mysql_async::Column]>>) -> Vec<Column> {
229    match columns {
230        Some(columns) => columns.iter().map(convert_column).collect(),
231        None => vec![],
232    }
233}
234
235fn convert_column(column: &mysql_async::Column) -> Column {
236    let name = column.name_str().into_owned();
237    let data_type = convert_data_type(column);
238
239    Column { name, data_type }
240}
241
242fn convert_data_type(column: &mysql_async::Column) -> DbDataType {
243    let column_type = column.column_type();
244
245    if column_type.is_numeric_type() {
246        convert_numeric_type(column)
247    } else if column_type.is_character_type() {
248        convert_character_type(column)
249    } else {
250        DbDataType::Other
251    }
252}
253
254fn convert_character_type(column: &mysql_async::Column) -> DbDataType {
255    match (column.column_type(), is_binary(column)) {
256        (ColumnType::MYSQL_TYPE_BLOB, false) => DbDataType::Str, // TEXT type
257        (ColumnType::MYSQL_TYPE_BLOB, _) => DbDataType::Binary,
258        (ColumnType::MYSQL_TYPE_LONG_BLOB, _) => DbDataType::Binary,
259        (ColumnType::MYSQL_TYPE_MEDIUM_BLOB, _) => DbDataType::Binary,
260        (ColumnType::MYSQL_TYPE_STRING, true) => DbDataType::Binary, // BINARY type
261        (ColumnType::MYSQL_TYPE_STRING, _) => DbDataType::Str,
262        (ColumnType::MYSQL_TYPE_VAR_STRING, true) => DbDataType::Binary, // VARBINARY type
263        (ColumnType::MYSQL_TYPE_VAR_STRING, _) => DbDataType::Str,
264        (_, _) => DbDataType::Other,
265    }
266}
267
268fn convert_numeric_type(column: &mysql_async::Column) -> DbDataType {
269    match (column.column_type(), is_signed(column)) {
270        (ColumnType::MYSQL_TYPE_DOUBLE, _) => DbDataType::Floating64,
271        (ColumnType::MYSQL_TYPE_FLOAT, _) => DbDataType::Floating32,
272        (ColumnType::MYSQL_TYPE_INT24, true) => DbDataType::Int32,
273        (ColumnType::MYSQL_TYPE_INT24, false) => DbDataType::Uint32,
274        (ColumnType::MYSQL_TYPE_LONG, true) => DbDataType::Int32,
275        (ColumnType::MYSQL_TYPE_LONG, false) => DbDataType::Uint32,
276        (ColumnType::MYSQL_TYPE_LONGLONG, true) => DbDataType::Int64,
277        (ColumnType::MYSQL_TYPE_LONGLONG, false) => DbDataType::Uint64,
278        (ColumnType::MYSQL_TYPE_SHORT, true) => DbDataType::Int16,
279        (ColumnType::MYSQL_TYPE_SHORT, false) => DbDataType::Uint16,
280        (ColumnType::MYSQL_TYPE_TINY, true) => DbDataType::Int8,
281        (ColumnType::MYSQL_TYPE_TINY, false) => DbDataType::Uint8,
282        (_, _) => DbDataType::Other,
283    }
284}
285
286fn is_signed(column: &mysql_async::Column) -> bool {
287    !column
288        .flags()
289        .contains(mysql_async::consts::ColumnFlags::UNSIGNED_FLAG)
290}
291
292fn is_binary(column: &mysql_async::Column) -> bool {
293    column
294        .flags()
295        .contains(mysql_async::consts::ColumnFlags::BINARY_FLAG)
296}
297
298fn convert_row(mut row: mysql_async::Row, columns: &[Column]) -> Result<Vec<DbValue>, v2::Error> {
299    let mut result = Vec::with_capacity(row.len());
300    for index in 0..row.len() {
301        result.push(convert_entry(&mut row, index, columns)?);
302    }
303    Ok(result)
304}
305
306fn convert_entry(
307    row: &mut mysql_async::Row,
308    index: usize,
309    columns: &[Column],
310) -> Result<DbValue, v2::Error> {
311    match (row.take(index), columns.get(index)) {
312        (None, _) => Ok(DbValue::DbNull), // TODO: is this right or is this an "index out of range" thing
313        (_, None) => Err(v2::Error::Other(format!(
314            "Can't get column at index {index}"
315        ))),
316        (Some(mysql_async::Value::NULL), _) => Ok(DbValue::DbNull),
317        (Some(value), Some(column)) => convert_value(value, column),
318    }
319}
320
321fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue, v2::Error> {
322    match column.data_type {
323        DbDataType::Binary => convert_value_to::<Vec<u8>>(value).map(DbValue::Binary),
324        DbDataType::Boolean => convert_value_to::<bool>(value).map(DbValue::Boolean),
325        DbDataType::Floating32 => convert_value_to::<f32>(value).map(DbValue::Floating32),
326        DbDataType::Floating64 => convert_value_to::<f64>(value).map(DbValue::Floating64),
327        DbDataType::Int8 => convert_value_to::<i8>(value).map(DbValue::Int8),
328        DbDataType::Int16 => convert_value_to::<i16>(value).map(DbValue::Int16),
329        DbDataType::Int32 => convert_value_to::<i32>(value).map(DbValue::Int32),
330        DbDataType::Int64 => convert_value_to::<i64>(value).map(DbValue::Int64),
331        DbDataType::Str => convert_value_to::<String>(value).map(DbValue::Str),
332        DbDataType::Uint8 => convert_value_to::<u8>(value).map(DbValue::Uint8),
333        DbDataType::Uint16 => convert_value_to::<u16>(value).map(DbValue::Uint16),
334        DbDataType::Uint32 => convert_value_to::<u32>(value).map(DbValue::Uint32),
335        DbDataType::Uint64 => convert_value_to::<u64>(value).map(DbValue::Uint64),
336        DbDataType::Other => Err(v2::Error::ValueConversionFailed(format!(
337            "Cannot convert value {:?} in column {} data type {:?}",
338            value, column.name, column.data_type
339        ))),
340    }
341}
342
343fn is_ssl_param(s: &str) -> bool {
344    ["ssl-mode", "sslmode"].contains(&s.to_lowercase().as_str())
345}
346
347/// The mysql_async crate blows up if you pass it an SSL parameter and doesn't support SSL opts properly. This function
348/// is a workaround to manually set SSL opts if the user requests them.
349///
350/// We only support ssl-mode in the query as per
351/// https://dev.mysql.com/doc/connector-j/8.0/en/connector-j-connp-props-security.html#cj-conn-prop_sslMode.
352///
353/// An issue has been filed in the upstream repository https://github.com/blackbeam/mysql_async/issues/225.
354fn build_opts(address: &str) -> Result<Opts, mysql_async::Error> {
355    let url = Url::parse(address)?;
356
357    let use_ssl = url
358        .query_pairs()
359        .any(|(k, v)| is_ssl_param(&k) && v.to_lowercase() != "disabled");
360
361    let query_without_ssl: Vec<(_, _)> = url
362        .query_pairs()
363        .filter(|(k, _v)| !is_ssl_param(k))
364        .collect();
365    let mut cleaned_url = url.clone();
366    cleaned_url.set_query(None);
367    cleaned_url
368        .query_pairs_mut()
369        .extend_pairs(query_without_ssl);
370
371    Ok(OptsBuilder::from_opts(cleaned_url.as_str())
372        .ssl_opts(if use_ssl {
373            Some(SslOpts::default())
374        } else {
375            None
376        })
377        .into())
378}
379
380fn convert_value_to<T: FromValue>(value: mysql_async::Value) -> Result<T, v2::Error> {
381    from_value_opt::<T>(value).map_err(|e| v2::Error::ValueConversionFailed(format!("{e}")))
382}
383
384#[cfg(test)]
385mod test {
386    use super::*;
387
388    #[test]
389    fn test_mysql_address_without_ssl_mode() {
390        assert!(
391            build_opts("mysql://myuser:password@127.0.0.1/db")
392                .unwrap()
393                .ssl_opts()
394                .is_none()
395        )
396    }
397
398    #[test]
399    fn test_mysql_address_with_ssl_mode_disabled() {
400        assert!(
401            build_opts("mysql://myuser:password@127.0.0.1/db?ssl-mode=DISABLED")
402                .unwrap()
403                .ssl_opts()
404                .is_none()
405        )
406    }
407
408    #[test]
409    fn test_mysql_address_with_ssl_mode_verify_ca() {
410        assert!(
411            build_opts("mysql://myuser:password@127.0.0.1/db?sslMode=VERIFY_CA")
412                .unwrap()
413                .ssl_opts()
414                .is_some()
415        )
416    }
417
418    #[test]
419    fn test_mysql_address_with_more_to_query() {
420        let address = "mysql://myuser:password@127.0.0.1/db?SsLmOdE=VERIFY_CA&pool_max=10";
421        assert!(build_opts(address).unwrap().ssl_opts().is_some());
422        assert_eq!(
423            build_opts(address).unwrap().pool_opts().constraints().max(),
424            10
425        )
426    }
427}