spin_factor_outbound_mysql/
client.rs1use std::sync::Arc;
2
3use anyhow::{anyhow, Result};
4use mysql_async::consts::ColumnType;
5use mysql_async::prelude::{FromValue, Queryable as _};
6use mysql_async::{from_value_opt, Conn as MysqlClient, Opts, OptsBuilder, SslOpts};
7use spin_core::async_trait;
8use spin_world::v2::mysql::{self as v2};
9use spin_world::v2::rdbms_types::{
10 self as v2_types, Column, DbDataType, DbValue, ParameterValue, RowSet,
11};
12use url::Url;
13
14#[async_trait]
15pub trait Client: Send + Sync + 'static {
16 async fn build_client(address: &str) -> Result<Self>
17 where
18 Self: Sized;
19
20 async fn execute(
21 &mut self,
22 statement: String,
23 params: Vec<ParameterValue>,
24 ) -> Result<(), v2::Error>;
25
26 async fn query(
27 &mut self,
28 statement: String,
29 params: Vec<ParameterValue>,
30 ) -> Result<RowSet, v2::Error>;
31}
32
33#[async_trait]
34impl Client for MysqlClient {
35 async fn build_client(address: &str) -> Result<Self>
36 where
37 Self: Sized,
38 {
39 tracing::debug!("Build new connection: {}", address);
40
41 let opts = build_opts(address)?;
42
43 let connection_pool = mysql_async::Pool::new(opts);
44
45 connection_pool.get_conn().await.map_err(|e| anyhow!(e))
46 }
47
48 async fn execute(
49 &mut self,
50 statement: String,
51 params: Vec<ParameterValue>,
52 ) -> Result<(), v2::Error> {
53 let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
54 let parameters = mysql_async::Params::Positional(db_params);
55
56 self.exec_batch(&statement, &[parameters])
57 .await
58 .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))
59 }
60
61 async fn query(
62 &mut self,
63 statement: String,
64 params: Vec<ParameterValue>,
65 ) -> Result<RowSet, v2::Error> {
66 let db_params = params.into_iter().map(to_sql_parameter).collect::<Vec<_>>();
67 let parameters = mysql_async::Params::Positional(db_params);
68
69 let mut query_result = self
70 .exec_iter(&statement, parameters)
71 .await
72 .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;
73
74 let columns = convert_columns(query_result.columns());
76
77 match query_result.collect::<mysql_async::Row>().await {
78 Err(e) => Err(v2::Error::Other(e.to_string())),
79 Ok(result_set) => {
80 let rows = result_set
81 .into_iter()
82 .map(|row| convert_row(row, &columns))
83 .collect::<Result<Vec<_>, _>>()?;
84
85 Ok(v2_types::RowSet { columns, rows })
86 }
87 }
88 }
89}
90
91fn to_sql_parameter(value: ParameterValue) -> mysql_async::Value {
92 match value {
93 ParameterValue::Boolean(v) => mysql_async::Value::from(v),
94 ParameterValue::Int32(v) => mysql_async::Value::from(v),
95 ParameterValue::Int64(v) => mysql_async::Value::from(v),
96 ParameterValue::Int8(v) => mysql_async::Value::from(v),
97 ParameterValue::Int16(v) => mysql_async::Value::from(v),
98 ParameterValue::Floating32(v) => mysql_async::Value::from(v),
99 ParameterValue::Floating64(v) => mysql_async::Value::from(v),
100 ParameterValue::Uint8(v) => mysql_async::Value::from(v),
101 ParameterValue::Uint16(v) => mysql_async::Value::from(v),
102 ParameterValue::Uint32(v) => mysql_async::Value::from(v),
103 ParameterValue::Uint64(v) => mysql_async::Value::from(v),
104 ParameterValue::Str(v) => mysql_async::Value::from(v),
105 ParameterValue::Binary(v) => mysql_async::Value::from(v),
106 ParameterValue::DbNull => mysql_async::Value::NULL,
107 }
108}
109
110fn convert_columns(columns: Option<Arc<[mysql_async::Column]>>) -> Vec<Column> {
111 match columns {
112 Some(columns) => columns.iter().map(convert_column).collect(),
113 None => vec![],
114 }
115}
116
117fn convert_column(column: &mysql_async::Column) -> Column {
118 let name = column.name_str().into_owned();
119 let data_type = convert_data_type(column);
120
121 Column { name, data_type }
122}
123
124fn convert_data_type(column: &mysql_async::Column) -> DbDataType {
125 let column_type = column.column_type();
126
127 if column_type.is_numeric_type() {
128 convert_numeric_type(column)
129 } else if column_type.is_character_type() {
130 convert_character_type(column)
131 } else {
132 DbDataType::Other
133 }
134}
135
136fn convert_character_type(column: &mysql_async::Column) -> DbDataType {
137 match (column.column_type(), is_binary(column)) {
138 (ColumnType::MYSQL_TYPE_BLOB, false) => DbDataType::Str, (ColumnType::MYSQL_TYPE_BLOB, _) => DbDataType::Binary,
140 (ColumnType::MYSQL_TYPE_LONG_BLOB, _) => DbDataType::Binary,
141 (ColumnType::MYSQL_TYPE_MEDIUM_BLOB, _) => DbDataType::Binary,
142 (ColumnType::MYSQL_TYPE_STRING, true) => DbDataType::Binary, (ColumnType::MYSQL_TYPE_STRING, _) => DbDataType::Str,
144 (ColumnType::MYSQL_TYPE_VAR_STRING, true) => DbDataType::Binary, (ColumnType::MYSQL_TYPE_VAR_STRING, _) => DbDataType::Str,
146 (_, _) => DbDataType::Other,
147 }
148}
149
150fn convert_numeric_type(column: &mysql_async::Column) -> DbDataType {
151 match (column.column_type(), is_signed(column)) {
152 (ColumnType::MYSQL_TYPE_DOUBLE, _) => DbDataType::Floating64,
153 (ColumnType::MYSQL_TYPE_FLOAT, _) => DbDataType::Floating32,
154 (ColumnType::MYSQL_TYPE_INT24, true) => DbDataType::Int32,
155 (ColumnType::MYSQL_TYPE_INT24, false) => DbDataType::Uint32,
156 (ColumnType::MYSQL_TYPE_LONG, true) => DbDataType::Int32,
157 (ColumnType::MYSQL_TYPE_LONG, false) => DbDataType::Uint32,
158 (ColumnType::MYSQL_TYPE_LONGLONG, true) => DbDataType::Int64,
159 (ColumnType::MYSQL_TYPE_LONGLONG, false) => DbDataType::Uint64,
160 (ColumnType::MYSQL_TYPE_SHORT, true) => DbDataType::Int16,
161 (ColumnType::MYSQL_TYPE_SHORT, false) => DbDataType::Uint16,
162 (ColumnType::MYSQL_TYPE_TINY, true) => DbDataType::Int8,
163 (ColumnType::MYSQL_TYPE_TINY, false) => DbDataType::Uint8,
164 (_, _) => DbDataType::Other,
165 }
166}
167
168fn is_signed(column: &mysql_async::Column) -> bool {
169 !column
170 .flags()
171 .contains(mysql_async::consts::ColumnFlags::UNSIGNED_FLAG)
172}
173
174fn is_binary(column: &mysql_async::Column) -> bool {
175 column
176 .flags()
177 .contains(mysql_async::consts::ColumnFlags::BINARY_FLAG)
178}
179
180fn convert_row(mut row: mysql_async::Row, columns: &[Column]) -> Result<Vec<DbValue>, v2::Error> {
181 let mut result = Vec::with_capacity(row.len());
182 for index in 0..row.len() {
183 result.push(convert_entry(&mut row, index, columns)?);
184 }
185 Ok(result)
186}
187
188fn convert_entry(
189 row: &mut mysql_async::Row,
190 index: usize,
191 columns: &[Column],
192) -> Result<DbValue, v2::Error> {
193 match (row.take(index), columns.get(index)) {
194 (None, _) => Ok(DbValue::DbNull), (_, None) => Err(v2::Error::Other(format!(
196 "Can't get column at index {}",
197 index
198 ))),
199 (Some(mysql_async::Value::NULL), _) => Ok(DbValue::DbNull),
200 (Some(value), Some(column)) => convert_value(value, column),
201 }
202}
203
204fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue, v2::Error> {
205 match column.data_type {
206 DbDataType::Binary => convert_value_to::<Vec<u8>>(value).map(DbValue::Binary),
207 DbDataType::Boolean => convert_value_to::<bool>(value).map(DbValue::Boolean),
208 DbDataType::Floating32 => convert_value_to::<f32>(value).map(DbValue::Floating32),
209 DbDataType::Floating64 => convert_value_to::<f64>(value).map(DbValue::Floating64),
210 DbDataType::Int8 => convert_value_to::<i8>(value).map(DbValue::Int8),
211 DbDataType::Int16 => convert_value_to::<i16>(value).map(DbValue::Int16),
212 DbDataType::Int32 => convert_value_to::<i32>(value).map(DbValue::Int32),
213 DbDataType::Int64 => convert_value_to::<i64>(value).map(DbValue::Int64),
214 DbDataType::Str => convert_value_to::<String>(value).map(DbValue::Str),
215 DbDataType::Uint8 => convert_value_to::<u8>(value).map(DbValue::Uint8),
216 DbDataType::Uint16 => convert_value_to::<u16>(value).map(DbValue::Uint16),
217 DbDataType::Uint32 => convert_value_to::<u32>(value).map(DbValue::Uint32),
218 DbDataType::Uint64 => convert_value_to::<u64>(value).map(DbValue::Uint64),
219 DbDataType::Other => Err(v2::Error::ValueConversionFailed(format!(
220 "Cannot convert value {:?} in column {} data type {:?}",
221 value, column.name, column.data_type
222 ))),
223 }
224}
225
226fn is_ssl_param(s: &str) -> bool {
227 ["ssl-mode", "sslmode"].contains(&s.to_lowercase().as_str())
228}
229
230fn build_opts(address: &str) -> Result<Opts, mysql_async::Error> {
238 let url = Url::parse(address)?;
239
240 let use_ssl = url
241 .query_pairs()
242 .any(|(k, v)| is_ssl_param(&k) && v.to_lowercase() != "disabled");
243
244 let query_without_ssl: Vec<(_, _)> = url
245 .query_pairs()
246 .filter(|(k, _v)| !is_ssl_param(k))
247 .collect();
248 let mut cleaned_url = url.clone();
249 cleaned_url.set_query(None);
250 cleaned_url
251 .query_pairs_mut()
252 .extend_pairs(query_without_ssl);
253
254 Ok(OptsBuilder::from_opts(cleaned_url.as_str())
255 .ssl_opts(if use_ssl {
256 Some(SslOpts::default())
257 } else {
258 None
259 })
260 .into())
261}
262
263fn convert_value_to<T: FromValue>(value: mysql_async::Value) -> Result<T, v2::Error> {
264 from_value_opt::<T>(value).map_err(|e| v2::Error::ValueConversionFailed(format!("{}", e)))
265}
266
267#[cfg(test)]
268mod test {
269 use super::*;
270
271 #[test]
272 fn test_mysql_address_without_ssl_mode() {
273 assert!(build_opts("mysql://myuser:password@127.0.0.1/db")
274 .unwrap()
275 .ssl_opts()
276 .is_none())
277 }
278
279 #[test]
280 fn test_mysql_address_with_ssl_mode_disabled() {
281 assert!(
282 build_opts("mysql://myuser:password@127.0.0.1/db?ssl-mode=DISABLED")
283 .unwrap()
284 .ssl_opts()
285 .is_none()
286 )
287 }
288
289 #[test]
290 fn test_mysql_address_with_ssl_mode_verify_ca() {
291 assert!(
292 build_opts("mysql://myuser:password@127.0.0.1/db?sslMode=VERIFY_CA")
293 .unwrap()
294 .ssl_opts()
295 .is_some()
296 )
297 }
298
299 #[test]
300 fn test_mysql_address_with_more_to_query() {
301 let address = "mysql://myuser:password@127.0.0.1/db?SsLmOdE=VERIFY_CA&pool_max=10";
302 assert!(build_opts(address).unwrap().ssl_opts().is_some());
303 assert_eq!(
304 build_opts(address).unwrap().pool_opts().constraints().max(),
305 10
306 )
307 }
308}