spin_factor_outbound_mysql/
client.rs

1use std::sync::Arc;
2
3use anyhow::{anyhow, Result};
4use mysql_async::consts::ColumnType;
5use mysql_async::prelude::{FromValue, Queryable as _};
6use mysql_async::{from_value_opt, Conn as MysqlClient, Opts, OptsBuilder, SslOpts};
7use spin_core::async_trait;
8use spin_world::v2::mysql::{self as v2};
9use spin_world::v2::rdbms_types::{
10    self as v2_types, Column, DbDataType, DbValue, ParameterValue, RowSet,
11};
12use url::Url;
13
14#[async_trait]
15pub trait Client: Send + Sync + 'static {
16    async fn build_client(address: &str) -> Result<Self>
17    where
18        Self: Sized;
19
20    async fn execute(
21        &mut self,
22        statement: String,
23        params: Vec<ParameterValue>,
24    ) -> Result<(), v2::Error>;
25
26    async fn query(
27        &mut self,
28        statement: String,
29        params: Vec<ParameterValue>,
30    ) -> Result<RowSet, v2::Error>;
31}
32
33#[async_trait]
34impl Client for MysqlClient {
35    async fn build_client(address: &str) -> Result<Self>
36    where
37        Self: Sized,
38    {
39        tracing::debug!("Build new connection: {}", address);
40
41        let opts = build_opts(address)?;
42
43        let connection_pool = mysql_async::Pool::new(opts);
44
45        connection_pool.get_conn().await.map_err(|e| anyhow!(e))
46    }
47
48    async fn execute(
49        &mut self,
50        statement: String,
51        params: Vec<ParameterValue>,
52    ) -> Result<(), v2::Error> {
53        let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
54        let parameters = mysql_async::Params::Positional(db_params);
55
56        self.exec_batch(&statement, &[parameters])
57            .await
58            .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))
59    }
60
61    async fn query(
62        &mut self,
63        statement: String,
64        params: Vec<ParameterValue>,
65    ) -> Result<RowSet, v2::Error> {
66        let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
67        let parameters = mysql_async::Params::Positional(db_params);
68
69        let mut query_result = self
70            .exec_iter(&statement, parameters)
71            .await
72            .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;
73
74        // We have to get these before collect() destroys them
75        let columns = convert_columns(query_result.columns());
76
77        match query_result.collect::<mysql_async::Row>().await {
78            Err(e) => Err(v2::Error::Other(e.to_string())),
79            Ok(result_set) => {
80                let rows = result_set
81                    .into_iter()
82                    .map(|row| convert_row(row, &columns))
83                    .collect::<Result<Vec<_>, _>>()?;
84
85                Ok(v2_types::RowSet { columns, rows })
86            }
87        }
88    }
89}
90
91fn to_sql_parameter(value: ParameterValue) -> mysql_async::Value {
92    match value {
93        ParameterValue::Boolean(v) => mysql_async::Value::from(v),
94        ParameterValue::Int32(v) => mysql_async::Value::from(v),
95        ParameterValue::Int64(v) => mysql_async::Value::from(v),
96        ParameterValue::Int8(v) => mysql_async::Value::from(v),
97        ParameterValue::Int16(v) => mysql_async::Value::from(v),
98        ParameterValue::Floating32(v) => mysql_async::Value::from(v),
99        ParameterValue::Floating64(v) => mysql_async::Value::from(v),
100        ParameterValue::Uint8(v) => mysql_async::Value::from(v),
101        ParameterValue::Uint16(v) => mysql_async::Value::from(v),
102        ParameterValue::Uint32(v) => mysql_async::Value::from(v),
103        ParameterValue::Uint64(v) => mysql_async::Value::from(v),
104        ParameterValue::Str(v) => mysql_async::Value::from(v),
105        ParameterValue::Binary(v) => mysql_async::Value::from(v),
106        ParameterValue::DbNull => mysql_async::Value::NULL,
107    }
108}
109
110fn convert_columns(columns: Option<Arc<[mysql_async::Column]>>) -> Vec<Column> {
111    match columns {
112        Some(columns) => columns.iter().map(convert_column).collect(),
113        None => vec![],
114    }
115}
116
117fn convert_column(column: &mysql_async::Column) -> Column {
118    let name = column.name_str().into_owned();
119    let data_type = convert_data_type(column);
120
121    Column { name, data_type }
122}
123
124fn convert_data_type(column: &mysql_async::Column) -> DbDataType {
125    let column_type = column.column_type();
126
127    if column_type.is_numeric_type() {
128        convert_numeric_type(column)
129    } else if column_type.is_character_type() {
130        convert_character_type(column)
131    } else {
132        DbDataType::Other
133    }
134}
135
136fn convert_character_type(column: &mysql_async::Column) -> DbDataType {
137    match (column.column_type(), is_binary(column)) {
138        (ColumnType::MYSQL_TYPE_BLOB, false) => DbDataType::Str, // TEXT type
139        (ColumnType::MYSQL_TYPE_BLOB, _) => DbDataType::Binary,
140        (ColumnType::MYSQL_TYPE_LONG_BLOB, _) => DbDataType::Binary,
141        (ColumnType::MYSQL_TYPE_MEDIUM_BLOB, _) => DbDataType::Binary,
142        (ColumnType::MYSQL_TYPE_STRING, true) => DbDataType::Binary, // BINARY type
143        (ColumnType::MYSQL_TYPE_STRING, _) => DbDataType::Str,
144        (ColumnType::MYSQL_TYPE_VAR_STRING, true) => DbDataType::Binary, // VARBINARY type
145        (ColumnType::MYSQL_TYPE_VAR_STRING, _) => DbDataType::Str,
146        (_, _) => DbDataType::Other,
147    }
148}
149
150fn convert_numeric_type(column: &mysql_async::Column) -> DbDataType {
151    match (column.column_type(), is_signed(column)) {
152        (ColumnType::MYSQL_TYPE_DOUBLE, _) => DbDataType::Floating64,
153        (ColumnType::MYSQL_TYPE_FLOAT, _) => DbDataType::Floating32,
154        (ColumnType::MYSQL_TYPE_INT24, true) => DbDataType::Int32,
155        (ColumnType::MYSQL_TYPE_INT24, false) => DbDataType::Uint32,
156        (ColumnType::MYSQL_TYPE_LONG, true) => DbDataType::Int32,
157        (ColumnType::MYSQL_TYPE_LONG, false) => DbDataType::Uint32,
158        (ColumnType::MYSQL_TYPE_LONGLONG, true) => DbDataType::Int64,
159        (ColumnType::MYSQL_TYPE_LONGLONG, false) => DbDataType::Uint64,
160        (ColumnType::MYSQL_TYPE_SHORT, true) => DbDataType::Int16,
161        (ColumnType::MYSQL_TYPE_SHORT, false) => DbDataType::Uint16,
162        (ColumnType::MYSQL_TYPE_TINY, true) => DbDataType::Int8,
163        (ColumnType::MYSQL_TYPE_TINY, false) => DbDataType::Uint8,
164        (_, _) => DbDataType::Other,
165    }
166}
167
168fn is_signed(column: &mysql_async::Column) -> bool {
169    !column
170        .flags()
171        .contains(mysql_async::consts::ColumnFlags::UNSIGNED_FLAG)
172}
173
174fn is_binary(column: &mysql_async::Column) -> bool {
175    column
176        .flags()
177        .contains(mysql_async::consts::ColumnFlags::BINARY_FLAG)
178}
179
180fn convert_row(mut row: mysql_async::Row, columns: &[Column]) -> Result<Vec<DbValue>, v2::Error> {
181    let mut result = Vec::with_capacity(row.len());
182    for index in 0..row.len() {
183        result.push(convert_entry(&mut row, index, columns)?);
184    }
185    Ok(result)
186}
187
188fn convert_entry(
189    row: &mut mysql_async::Row,
190    index: usize,
191    columns: &[Column],
192) -> Result<DbValue, v2::Error> {
193    match (row.take(index), columns.get(index)) {
194        (None, _) => Ok(DbValue::DbNull), // TODO: is this right or is this an "index out of range" thing
195        (_, None) => Err(v2::Error::Other(format!(
196            "Can't get column at index {}",
197            index
198        ))),
199        (Some(mysql_async::Value::NULL), _) => Ok(DbValue::DbNull),
200        (Some(value), Some(column)) => convert_value(value, column),
201    }
202}
203
204fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue, v2::Error> {
205    match column.data_type {
206        DbDataType::Binary => convert_value_to::<Vec<u8>>(value).map(DbValue::Binary),
207        DbDataType::Boolean => convert_value_to::<bool>(value).map(DbValue::Boolean),
208        DbDataType::Floating32 => convert_value_to::<f32>(value).map(DbValue::Floating32),
209        DbDataType::Floating64 => convert_value_to::<f64>(value).map(DbValue::Floating64),
210        DbDataType::Int8 => convert_value_to::<i8>(value).map(DbValue::Int8),
211        DbDataType::Int16 => convert_value_to::<i16>(value).map(DbValue::Int16),
212        DbDataType::Int32 => convert_value_to::<i32>(value).map(DbValue::Int32),
213        DbDataType::Int64 => convert_value_to::<i64>(value).map(DbValue::Int64),
214        DbDataType::Str => convert_value_to::<String>(value).map(DbValue::Str),
215        DbDataType::Uint8 => convert_value_to::<u8>(value).map(DbValue::Uint8),
216        DbDataType::Uint16 => convert_value_to::<u16>(value).map(DbValue::Uint16),
217        DbDataType::Uint32 => convert_value_to::<u32>(value).map(DbValue::Uint32),
218        DbDataType::Uint64 => convert_value_to::<u64>(value).map(DbValue::Uint64),
219        DbDataType::Other => Err(v2::Error::ValueConversionFailed(format!(
220            "Cannot convert value {:?} in column {} data type {:?}",
221            value, column.name, column.data_type
222        ))),
223    }
224}
225
226fn is_ssl_param(s: &str) -> bool {
227    ["ssl-mode", "sslmode"].contains(&s.to_lowercase().as_str())
228}
229
230/// The mysql_async crate blows up if you pass it an SSL parameter and doesn't support SSL opts properly. This function
231/// is a workaround to manually set SSL opts if the user requests them.
232///
233/// We only support ssl-mode in the query as per
234/// https://dev.mysql.com/doc/connector-j/8.0/en/connector-j-connp-props-security.html#cj-conn-prop_sslMode.
235///
236/// An issue has been filed in the upstream repository https://github.com/blackbeam/mysql_async/issues/225.
237fn build_opts(address: &str) -> Result<Opts, mysql_async::Error> {
238    let url = Url::parse(address)?;
239
240    let use_ssl = url
241        .query_pairs()
242        .any(|(k, v)| is_ssl_param(&k) && v.to_lowercase() != "disabled");
243
244    let query_without_ssl: Vec<(_, _)> = url
245        .query_pairs()
246        .filter(|(k, _v)| !is_ssl_param(k))
247        .collect();
248    let mut cleaned_url = url.clone();
249    cleaned_url.set_query(None);
250    cleaned_url
251        .query_pairs_mut()
252        .extend_pairs(query_without_ssl);
253
254    Ok(OptsBuilder::from_opts(cleaned_url.as_str())
255        .ssl_opts(if use_ssl {
256            Some(SslOpts::default())
257        } else {
258            None
259        })
260        .into())
261}
262
263fn convert_value_to<T: FromValue>(value: mysql_async::Value) -> Result<T, v2::Error> {
264    from_value_opt::<T>(value).map_err(|e| v2::Error::ValueConversionFailed(format!("{}", e)))
265}
266
267#[cfg(test)]
268mod test {
269    use super::*;
270
271    #[test]
272    fn test_mysql_address_without_ssl_mode() {
273        assert!(build_opts("mysql://myuser:password@127.0.0.1/db")
274            .unwrap()
275            .ssl_opts()
276            .is_none())
277    }
278
279    #[test]
280    fn test_mysql_address_with_ssl_mode_disabled() {
281        assert!(
282            build_opts("mysql://myuser:password@127.0.0.1/db?ssl-mode=DISABLED")
283                .unwrap()
284                .ssl_opts()
285                .is_none()
286        )
287    }
288
289    #[test]
290    fn test_mysql_address_with_ssl_mode_verify_ca() {
291        assert!(
292            build_opts("mysql://myuser:password@127.0.0.1/db?sslMode=VERIFY_CA")
293                .unwrap()
294                .ssl_opts()
295                .is_some()
296        )
297    }
298
299    #[test]
300    fn test_mysql_address_with_more_to_query() {
301        let address = "mysql://myuser:password@127.0.0.1/db?SsLmOdE=VERIFY_CA&pool_max=10";
302        assert!(build_opts(address).unwrap().ssl_opts().is_some());
303        assert_eq!(
304            build_opts(address).unwrap().pool_opts().constraints().max(),
305            10
306        )
307    }
308}