Skip to main content

spin_factor_outbound_mysql/
client.rs

1use std::sync::Arc;
2
3use anyhow::{anyhow, Result};
4use futures::stream::TryStreamExt as _;
5use mysql_async::consts::ColumnType;
6use mysql_async::prelude::{FromValue, Queryable as _};
7use mysql_async::{from_value_opt, Conn as MysqlClient, Opts, OptsBuilder, SslOpts};
8use spin_core::async_trait;
9use spin_world::v2::mysql::{self as v2};
10use spin_world::v2::rdbms_types::{
11    self as v2_types, Column, DbDataType, DbValue, ParameterValue, RowSet,
12};
13use url::Url;
14
15#[async_trait]
16pub trait Client: Send + Sync + 'static {
17    async fn build_client(address: &str) -> Result<Self>
18    where
19        Self: Sized;
20
21    async fn execute(
22        &mut self,
23        statement: String,
24        params: Vec<ParameterValue>,
25    ) -> Result<(), v2::Error>;
26
27    async fn query(
28        &mut self,
29        statement: String,
30        params: Vec<ParameterValue>,
31        max_result_bytes: usize,
32    ) -> Result<RowSet, v2::Error>;
33}
34
35#[async_trait]
36impl Client for MysqlClient {
37    async fn build_client(address: &str) -> Result<Self>
38    where
39        Self: Sized,
40    {
41        tracing::debug!("Build new connection: {}", address);
42
43        let opts = build_opts(address)?;
44
45        let connection_pool = mysql_async::Pool::new(opts);
46
47        connection_pool.get_conn().await.map_err(|e| anyhow!(e))
48    }
49
50    async fn execute(
51        &mut self,
52        statement: String,
53        params: Vec<ParameterValue>,
54    ) -> Result<(), v2::Error> {
55        let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
56        let parameters = mysql_async::Params::Positional(db_params);
57
58        self.exec_batch(&statement, &[parameters])
59            .await
60            .map_err(|e| v2::Error::QueryFailed(format!("{e:?}")))
61    }
62
63    async fn query(
64        &mut self,
65        statement: String,
66        params: Vec<ParameterValue>,
67        max_result_bytes: usize,
68    ) -> Result<RowSet, v2::Error> {
69        let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
70        let parameters = mysql_async::Params::Positional(db_params);
71
72        let mut query_result = self
73            .exec_iter(&statement, parameters)
74            .await
75            .map_err(|e| v2::Error::QueryFailed(format!("{e:?}")))?;
76
77        let columns = convert_columns(query_result.columns());
78
79        let mut query_result = query_result
80            .stream()
81            .await
82            .map_err(|e| v2::Error::Other(e.to_string()))?
83            .ok_or_else(|| v2::Error::Other("unable to stream query result".into()))?;
84
85        let mut rows = Vec::new();
86        let mut byte_count = std::mem::size_of::<RowSet>();
87        while let Some(row) = query_result
88            .try_next()
89            .await
90            .map_err(|e| v2::Error::Other(e.to_string()))?
91        {
92            let row = convert_row(row, &columns)?;
93            byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
94            if byte_count > max_result_bytes {
95                return Err(v2::Error::Other(format!(
96                    "query result exceeds limit of {max_result_bytes} bytes"
97                )));
98            }
99            rows.push(row);
100        }
101
102        Ok(v2_types::RowSet { columns, rows })
103    }
104}
105
106fn to_sql_parameter(value: ParameterValue) -> mysql_async::Value {
107    match value {
108        ParameterValue::Boolean(v) => mysql_async::Value::from(v),
109        ParameterValue::Int32(v) => mysql_async::Value::from(v),
110        ParameterValue::Int64(v) => mysql_async::Value::from(v),
111        ParameterValue::Int8(v) => mysql_async::Value::from(v),
112        ParameterValue::Int16(v) => mysql_async::Value::from(v),
113        ParameterValue::Floating32(v) => mysql_async::Value::from(v),
114        ParameterValue::Floating64(v) => mysql_async::Value::from(v),
115        ParameterValue::Uint8(v) => mysql_async::Value::from(v),
116        ParameterValue::Uint16(v) => mysql_async::Value::from(v),
117        ParameterValue::Uint32(v) => mysql_async::Value::from(v),
118        ParameterValue::Uint64(v) => mysql_async::Value::from(v),
119        ParameterValue::Str(v) => mysql_async::Value::from(v),
120        ParameterValue::Binary(v) => mysql_async::Value::from(v),
121        ParameterValue::DbNull => mysql_async::Value::NULL,
122    }
123}
124
125fn convert_columns(columns: Option<Arc<[mysql_async::Column]>>) -> Vec<Column> {
126    match columns {
127        Some(columns) => columns.iter().map(convert_column).collect(),
128        None => vec![],
129    }
130}
131
132fn convert_column(column: &mysql_async::Column) -> Column {
133    let name = column.name_str().into_owned();
134    let data_type = convert_data_type(column);
135
136    Column { name, data_type }
137}
138
139fn convert_data_type(column: &mysql_async::Column) -> DbDataType {
140    let column_type = column.column_type();
141
142    if column_type.is_numeric_type() {
143        convert_numeric_type(column)
144    } else if column_type.is_character_type() {
145        convert_character_type(column)
146    } else {
147        DbDataType::Other
148    }
149}
150
151fn convert_character_type(column: &mysql_async::Column) -> DbDataType {
152    match (column.column_type(), is_binary(column)) {
153        (ColumnType::MYSQL_TYPE_BLOB, false) => DbDataType::Str, // TEXT type
154        (ColumnType::MYSQL_TYPE_BLOB, _) => DbDataType::Binary,
155        (ColumnType::MYSQL_TYPE_LONG_BLOB, _) => DbDataType::Binary,
156        (ColumnType::MYSQL_TYPE_MEDIUM_BLOB, _) => DbDataType::Binary,
157        (ColumnType::MYSQL_TYPE_STRING, true) => DbDataType::Binary, // BINARY type
158        (ColumnType::MYSQL_TYPE_STRING, _) => DbDataType::Str,
159        (ColumnType::MYSQL_TYPE_VAR_STRING, true) => DbDataType::Binary, // VARBINARY type
160        (ColumnType::MYSQL_TYPE_VAR_STRING, _) => DbDataType::Str,
161        (_, _) => DbDataType::Other,
162    }
163}
164
165fn convert_numeric_type(column: &mysql_async::Column) -> DbDataType {
166    match (column.column_type(), is_signed(column)) {
167        (ColumnType::MYSQL_TYPE_DOUBLE, _) => DbDataType::Floating64,
168        (ColumnType::MYSQL_TYPE_FLOAT, _) => DbDataType::Floating32,
169        (ColumnType::MYSQL_TYPE_INT24, true) => DbDataType::Int32,
170        (ColumnType::MYSQL_TYPE_INT24, false) => DbDataType::Uint32,
171        (ColumnType::MYSQL_TYPE_LONG, true) => DbDataType::Int32,
172        (ColumnType::MYSQL_TYPE_LONG, false) => DbDataType::Uint32,
173        (ColumnType::MYSQL_TYPE_LONGLONG, true) => DbDataType::Int64,
174        (ColumnType::MYSQL_TYPE_LONGLONG, false) => DbDataType::Uint64,
175        (ColumnType::MYSQL_TYPE_SHORT, true) => DbDataType::Int16,
176        (ColumnType::MYSQL_TYPE_SHORT, false) => DbDataType::Uint16,
177        (ColumnType::MYSQL_TYPE_TINY, true) => DbDataType::Int8,
178        (ColumnType::MYSQL_TYPE_TINY, false) => DbDataType::Uint8,
179        (_, _) => DbDataType::Other,
180    }
181}
182
183fn is_signed(column: &mysql_async::Column) -> bool {
184    !column
185        .flags()
186        .contains(mysql_async::consts::ColumnFlags::UNSIGNED_FLAG)
187}
188
189fn is_binary(column: &mysql_async::Column) -> bool {
190    column
191        .flags()
192        .contains(mysql_async::consts::ColumnFlags::BINARY_FLAG)
193}
194
195fn convert_row(mut row: mysql_async::Row, columns: &[Column]) -> Result<Vec<DbValue>, v2::Error> {
196    let mut result = Vec::with_capacity(row.len());
197    for index in 0..row.len() {
198        result.push(convert_entry(&mut row, index, columns)?);
199    }
200    Ok(result)
201}
202
203fn convert_entry(
204    row: &mut mysql_async::Row,
205    index: usize,
206    columns: &[Column],
207) -> Result<DbValue, v2::Error> {
208    match (row.take(index), columns.get(index)) {
209        (None, _) => Ok(DbValue::DbNull), // TODO: is this right or is this an "index out of range" thing
210        (_, None) => Err(v2::Error::Other(format!(
211            "Can't get column at index {index}"
212        ))),
213        (Some(mysql_async::Value::NULL), _) => Ok(DbValue::DbNull),
214        (Some(value), Some(column)) => convert_value(value, column),
215    }
216}
217
218fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue, v2::Error> {
219    match column.data_type {
220        DbDataType::Binary => convert_value_to::<Vec<u8>>(value).map(DbValue::Binary),
221        DbDataType::Boolean => convert_value_to::<bool>(value).map(DbValue::Boolean),
222        DbDataType::Floating32 => convert_value_to::<f32>(value).map(DbValue::Floating32),
223        DbDataType::Floating64 => convert_value_to::<f64>(value).map(DbValue::Floating64),
224        DbDataType::Int8 => convert_value_to::<i8>(value).map(DbValue::Int8),
225        DbDataType::Int16 => convert_value_to::<i16>(value).map(DbValue::Int16),
226        DbDataType::Int32 => convert_value_to::<i32>(value).map(DbValue::Int32),
227        DbDataType::Int64 => convert_value_to::<i64>(value).map(DbValue::Int64),
228        DbDataType::Str => convert_value_to::<String>(value).map(DbValue::Str),
229        DbDataType::Uint8 => convert_value_to::<u8>(value).map(DbValue::Uint8),
230        DbDataType::Uint16 => convert_value_to::<u16>(value).map(DbValue::Uint16),
231        DbDataType::Uint32 => convert_value_to::<u32>(value).map(DbValue::Uint32),
232        DbDataType::Uint64 => convert_value_to::<u64>(value).map(DbValue::Uint64),
233        DbDataType::Other => Err(v2::Error::ValueConversionFailed(format!(
234            "Cannot convert value {:?} in column {} data type {:?}",
235            value, column.name, column.data_type
236        ))),
237    }
238}
239
240fn is_ssl_param(s: &str) -> bool {
241    ["ssl-mode", "sslmode"].contains(&s.to_lowercase().as_str())
242}
243
244/// The mysql_async crate blows up if you pass it an SSL parameter and doesn't support SSL opts properly. This function
245/// is a workaround to manually set SSL opts if the user requests them.
246///
247/// We only support ssl-mode in the query as per
248/// https://dev.mysql.com/doc/connector-j/8.0/en/connector-j-connp-props-security.html#cj-conn-prop_sslMode.
249///
250/// An issue has been filed in the upstream repository https://github.com/blackbeam/mysql_async/issues/225.
251fn build_opts(address: &str) -> Result<Opts, mysql_async::Error> {
252    let url = Url::parse(address)?;
253
254    let use_ssl = url
255        .query_pairs()
256        .any(|(k, v)| is_ssl_param(&k) && v.to_lowercase() != "disabled");
257
258    let query_without_ssl: Vec<(_, _)> = url
259        .query_pairs()
260        .filter(|(k, _v)| !is_ssl_param(k))
261        .collect();
262    let mut cleaned_url = url.clone();
263    cleaned_url.set_query(None);
264    cleaned_url
265        .query_pairs_mut()
266        .extend_pairs(query_without_ssl);
267
268    Ok(OptsBuilder::from_opts(cleaned_url.as_str())
269        .ssl_opts(if use_ssl {
270            Some(SslOpts::default())
271        } else {
272            None
273        })
274        .into())
275}
276
277fn convert_value_to<T: FromValue>(value: mysql_async::Value) -> Result<T, v2::Error> {
278    from_value_opt::<T>(value).map_err(|e| v2::Error::ValueConversionFailed(format!("{e}")))
279}
280
281#[cfg(test)]
282mod test {
283    use super::*;
284
285    #[test]
286    fn test_mysql_address_without_ssl_mode() {
287        assert!(build_opts("mysql://myuser:password@127.0.0.1/db")
288            .unwrap()
289            .ssl_opts()
290            .is_none())
291    }
292
293    #[test]
294    fn test_mysql_address_with_ssl_mode_disabled() {
295        assert!(
296            build_opts("mysql://myuser:password@127.0.0.1/db?ssl-mode=DISABLED")
297                .unwrap()
298                .ssl_opts()
299                .is_none()
300        )
301    }
302
303    #[test]
304    fn test_mysql_address_with_ssl_mode_verify_ca() {
305        assert!(
306            build_opts("mysql://myuser:password@127.0.0.1/db?sslMode=VERIFY_CA")
307                .unwrap()
308                .ssl_opts()
309                .is_some()
310        )
311    }
312
313    #[test]
314    fn test_mysql_address_with_more_to_query() {
315        let address = "mysql://myuser:password@127.0.0.1/db?SsLmOdE=VERIFY_CA&pool_max=10";
316        assert!(build_opts(address).unwrap().ssl_opts().is_some());
317        assert_eq!(
318            build_opts(address).unwrap().pool_opts().constraints().max(),
319            10
320        )
321    }
322}