spin_trigger/cli/
sqlite_statements.rs1use anyhow::Context as _;
2use spin_core::async_trait;
3use spin_factor_sqlite::SqliteFactor;
4use spin_factors::RuntimeFactors;
5use spin_factors_executor::ExecutorHooks;
6use spin_world::MAX_HOST_BUFFERED_BYTES;
7
8const DEFAULT_SQLITE_LABEL: &str = "default";
10
11pub struct SqlStatementExecutorHook {
16 sql_statements: Vec<String>,
17}
18
19impl SqlStatementExecutorHook {
20 pub fn new(sql_statements: Vec<String>) -> Self {
24 Self { sql_statements }
25 }
26
27 pub async fn execute(&self, sqlite: &spin_factor_sqlite::AppState) -> anyhow::Result<()> {
29 if self.sql_statements.is_empty() {
30 return Ok(());
31 }
32 let get_database = |label| async move {
33 sqlite
34 .get_connection(label)
35 .await
36 .transpose()
37 .with_context(|| format!("failed connect to database with label '{label}'"))
38 };
39
40 for statement in &self.sql_statements {
41 if let Some(config) = statement.strip_prefix('@') {
42 let (file, label) = parse_file_and_label(config)?;
43 let database = get_database(label).await?.with_context(|| {
44 format!(
45 "based on the '@{config}' a registered database named '{label}' was expected but not found."
46 )
47 })?;
48 let sql = std::fs::read_to_string(file).with_context(|| {
49 format!("could not read file '{file}' containing sql statements")
50 })?;
51 database.execute_batch(&sql).await.with_context(|| {
52 format!("failed to execute sql against database '{label}' from file '{file}'")
53 })?;
54 } else {
55 let Some(default) = get_database(DEFAULT_SQLITE_LABEL).await? else {
56 debug_assert!(
57 false,
58 "the '{DEFAULT_SQLITE_LABEL}' sqlite database should always be available but for some reason was not"
59 );
60 return Ok(());
61 };
62 default
63 .query(statement, Vec::new(), MAX_HOST_BUFFERED_BYTES)
64 .await
65 .with_context(|| format!("failed to execute following sql statement against default database: '{statement}'"))?;
66 }
67 }
68 Ok(())
69 }
70}
71
72#[async_trait]
73impl<F, U> ExecutorHooks<F, U> for SqlStatementExecutorHook
74where
75 F: RuntimeFactors,
76{
77 async fn configure_app(
78 &self,
79 configured_app: &spin_factors::ConfiguredApp<F>,
80 ) -> anyhow::Result<()> {
81 let Some(sqlite) = configured_app.app_state::<SqliteFactor>().ok() else {
82 return Ok(());
83 };
84 self.execute(sqlite).await?;
85 Ok(())
86 }
87}
88
89fn parse_file_and_label(config: &str) -> anyhow::Result<(&str, &str)> {
91 let config = config.trim();
92 if config.is_empty() {
93 anyhow::bail!("database configuration is empty in the '@{config}' sqlite statement");
94 }
95 let (file, label) = match config.split_once(':') {
96 Some((_, label)) if label.trim().is_empty() => {
97 anyhow::bail!("database label is empty in the '@{config}' sqlite statement")
98 }
99 Some((file, _)) if file.trim().is_empty() => {
100 anyhow::bail!("file path is empty in the '@{config}' sqlite statement")
101 }
102 Some((file, label)) => (file.trim(), label.trim()),
103 None => (config, "default"),
104 };
105 Ok((file, label))
106}
107
108#[cfg(test)]
109mod tests {
110 use std::collections::HashMap;
111 use std::sync::Arc;
112 use std::{collections::VecDeque, sync::mpsc::Sender};
113
114 use spin_core::async_trait;
115 use spin_factor_sqlite::{Connection, ConnectionCreator, QueryAsyncResult};
116 use spin_world::spin::sqlite3_1_0::sqlite as v3;
117 use tempfile::NamedTempFile;
118
119 use super::*;
120
121 #[test]
122 fn test_parse_file_and_label() {
123 assert_eq!(
124 parse_file_and_label("file:label").unwrap(),
125 ("file", "label")
126 );
127 assert!(parse_file_and_label("file:").is_err());
128 assert_eq!(parse_file_and_label("file").unwrap(), ("file", "default"));
129 assert!(parse_file_and_label(":label").is_err());
130 assert!(parse_file_and_label("").is_err());
131 }
132
133 #[tokio::test]
134 async fn test_execute() {
135 let sqlite_file = NamedTempFile::new().unwrap();
136 std::fs::write(&sqlite_file, "select 2;").unwrap();
137
138 let hook = SqlStatementExecutorHook::new(vec![
139 "SELECT 1;".to_string(),
140 format!("@{path}:label", path = sqlite_file.path().display()),
141 ]);
142 let (tx, rx) = std::sync::mpsc::channel();
143 let creator = Arc::new(MockCreator { tx });
144 let mut connection_creators = HashMap::new();
145 connection_creators.insert(
146 "default".into(),
147 creator.clone() as Arc<dyn ConnectionCreator>,
148 );
149 connection_creators.insert("label".into(), creator);
150 let sqlite = spin_factor_sqlite::AppState::new(Default::default(), connection_creators);
151 let result = hook.execute(&sqlite).await;
152 assert!(result.is_ok());
153
154 let mut expected: VecDeque<Action> = vec![
155 Action::CreateConnection("default".to_string()),
156 Action::Query("SELECT 1;".to_string()),
157 Action::CreateConnection("label".to_string()),
158 Action::Execute("select 2;".to_string()),
159 ]
160 .into_iter()
161 .collect();
162 while let Ok(action) = rx.try_recv() {
163 assert_eq!(action, expected.pop_front().unwrap(), "unexpected action");
164 }
165
166 assert!(
167 expected.is_empty(),
168 "Expected actions were never seen: {:?}",
169 expected
170 );
171 }
172
173 struct MockCreator {
174 tx: Sender<Action>,
175 }
176
177 impl MockCreator {
178 fn push(&self, label: &str) {
179 self.tx
180 .send(Action::CreateConnection(label.to_string()))
181 .unwrap();
182 }
183 }
184
185 #[async_trait]
186 impl ConnectionCreator for MockCreator {
187 async fn create_connection(
188 &self,
189 label: &str,
190 ) -> Result<Arc<dyn Connection + 'static>, v3::Error> {
191 self.push(label);
192 Ok(Arc::new(MockConnection {
193 tx: self.tx.clone(),
194 }))
195 }
196 }
197
198 struct MockConnection {
199 tx: Sender<Action>,
200 }
201
202 #[async_trait]
203 impl Connection for MockConnection {
204 async fn query(
205 &self,
206 query: &str,
207 parameters: Vec<v3::Value>,
208 _max_result_bytes: usize,
209 ) -> Result<v3::QueryResult, v3::Error> {
210 self.tx.send(Action::Query(query.to_string())).unwrap();
211 let _ = parameters;
212 Ok(v3::QueryResult {
213 columns: Vec::new(),
214 rows: Vec::new(),
215 })
216 }
217
218 async fn query_async(
219 &self,
220 query: &str,
221 parameters: Vec<v3::Value>,
222 _max_result_bytes: usize,
223 ) -> Result<QueryAsyncResult, v3::Error> {
224 self.tx.send(Action::Query(query.to_string())).unwrap();
225 let _ = parameters;
226 let (_rtx, rrx) = tokio::sync::mpsc::channel(1);
227 let (_etx, erx) = tokio::sync::oneshot::channel();
228 Ok(QueryAsyncResult {
229 columns: Vec::new(),
230 rows: rrx,
231 error: erx,
232 })
233 }
234
235 async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
236 self.tx
237 .send(Action::Execute(statements.to_string()))
238 .unwrap();
239 Ok(())
240 }
241
242 async fn changes(&self) -> Result<u64, v3::Error> {
243 self.tx.send(Action::Changes).unwrap();
244 Ok(123)
245 }
246
247 async fn last_insert_rowid(&self) -> Result<i64, v3::Error> {
248 self.tx.send(Action::LastInsertRowid).unwrap();
249 Ok(456)
250 }
251 }
252
253 #[derive(Debug, PartialEq)]
254 enum Action {
255 CreateConnection(String),
256 Query(String),
257 Execute(String),
258 Changes,
259 LastInsertRowid,
260 }
261}