spin_trigger/cli/
sqlite_statements.rs1use anyhow::Context as _;
2use spin_core::async_trait;
3use spin_factor_sqlite::SqliteFactor;
4use spin_factors::RuntimeFactors;
5use spin_factors_executor::ExecutorHooks;
6
7const DEFAULT_SQLITE_LABEL: &str = "default";
9
10pub struct SqlStatementExecutorHook {
15 sql_statements: Vec<String>,
16}
17
18impl SqlStatementExecutorHook {
19 pub fn new(sql_statements: Vec<String>) -> Self {
23 Self { sql_statements }
24 }
25
26 pub async fn execute(&self, sqlite: &spin_factor_sqlite::AppState) -> anyhow::Result<()> {
28 if self.sql_statements.is_empty() {
29 return Ok(());
30 }
31 let get_database = |label| async move {
32 sqlite
33 .get_connection(label)
34 .await
35 .transpose()
36 .with_context(|| format!("failed connect to database with label '{label}'"))
37 };
38
39 for statement in &self.sql_statements {
40 if let Some(config) = statement.strip_prefix('@') {
41 let (file, label) = parse_file_and_label(config)?;
42 let database = get_database(label).await?.with_context(|| {
43 format!(
44 "based on the '@{config}' a registered database named '{label}' was expected but not found."
45 )
46 })?;
47 let sql = std::fs::read_to_string(file).with_context(|| {
48 format!("could not read file '{file}' containing sql statements")
49 })?;
50 database.execute_batch(&sql).await.with_context(|| {
51 format!("failed to execute sql against database '{label}' from file '{file}'")
52 })?;
53 } else {
54 let Some(default) = get_database(DEFAULT_SQLITE_LABEL).await? else {
55 debug_assert!(false, "the '{DEFAULT_SQLITE_LABEL}' sqlite database should always be available but for some reason was not");
56 return Ok(());
57 };
58 default
59 .query(statement, Vec::new())
60 .await
61 .with_context(|| format!("failed to execute following sql statement against default database: '{statement}'"))?;
62 }
63 }
64 Ok(())
65 }
66}
67
68#[async_trait]
69impl<F, U> ExecutorHooks<F, U> for SqlStatementExecutorHook
70where
71 F: RuntimeFactors,
72{
73 async fn configure_app(
74 &self,
75 configured_app: &spin_factors::ConfiguredApp<F>,
76 ) -> anyhow::Result<()> {
77 let Some(sqlite) = configured_app.app_state::<SqliteFactor>().ok() else {
78 return Ok(());
79 };
80 self.execute(sqlite).await?;
81 Ok(())
82 }
83}
84
85fn parse_file_and_label(config: &str) -> anyhow::Result<(&str, &str)> {
87 let config = config.trim();
88 if config.is_empty() {
89 anyhow::bail!("database configuration is empty in the '@{config}' sqlite statement");
90 }
91 let (file, label) = match config.split_once(':') {
92 Some((_, label)) if label.trim().is_empty() => {
93 anyhow::bail!("database label is empty in the '@{config}' sqlite statement")
94 }
95 Some((file, _)) if file.trim().is_empty() => {
96 anyhow::bail!("file path is empty in the '@{config}' sqlite statement")
97 }
98 Some((file, label)) => (file.trim(), label.trim()),
99 None => (config, "default"),
100 };
101 Ok((file, label))
102}
103
104#[cfg(test)]
105mod tests {
106 use std::collections::HashMap;
107 use std::sync::Arc;
108 use std::{collections::VecDeque, sync::mpsc::Sender};
109
110 use spin_core::async_trait;
111 use spin_factor_sqlite::{Connection, ConnectionCreator};
112 use spin_world::spin::sqlite::sqlite as v3;
113 use tempfile::NamedTempFile;
114
115 use super::*;
116
117 #[test]
118 fn test_parse_file_and_label() {
119 assert_eq!(
120 parse_file_and_label("file:label").unwrap(),
121 ("file", "label")
122 );
123 assert!(parse_file_and_label("file:").is_err());
124 assert_eq!(parse_file_and_label("file").unwrap(), ("file", "default"));
125 assert!(parse_file_and_label(":label").is_err());
126 assert!(parse_file_and_label("").is_err());
127 }
128
129 #[tokio::test]
130 async fn test_execute() {
131 let sqlite_file = NamedTempFile::new().unwrap();
132 std::fs::write(&sqlite_file, "select 2;").unwrap();
133
134 let hook = SqlStatementExecutorHook::new(vec![
135 "SELECT 1;".to_string(),
136 format!("@{path}:label", path = sqlite_file.path().display()),
137 ]);
138 let (tx, rx) = std::sync::mpsc::channel();
139 let creator = Arc::new(MockCreator { tx });
140 let mut connection_creators = HashMap::new();
141 connection_creators.insert(
142 "default".into(),
143 creator.clone() as Arc<dyn ConnectionCreator>,
144 );
145 connection_creators.insert("label".into(), creator);
146 let sqlite = spin_factor_sqlite::AppState::new(Default::default(), connection_creators);
147 let result = hook.execute(&sqlite).await;
148 assert!(result.is_ok());
149
150 let mut expected: VecDeque<Action> = vec![
151 Action::CreateConnection("default".to_string()),
152 Action::Query("SELECT 1;".to_string()),
153 Action::CreateConnection("label".to_string()),
154 Action::Execute("select 2;".to_string()),
155 ]
156 .into_iter()
157 .collect();
158 while let Ok(action) = rx.try_recv() {
159 assert_eq!(action, expected.pop_front().unwrap(), "unexpected action");
160 }
161
162 assert!(
163 expected.is_empty(),
164 "Expected actions were never seen: {:?}",
165 expected
166 );
167 }
168
169 struct MockCreator {
170 tx: Sender<Action>,
171 }
172
173 impl MockCreator {
174 fn push(&self, label: &str) {
175 self.tx
176 .send(Action::CreateConnection(label.to_string()))
177 .unwrap();
178 }
179 }
180
181 #[async_trait]
182 impl ConnectionCreator for MockCreator {
183 async fn create_connection(
184 &self,
185 label: &str,
186 ) -> Result<Box<dyn Connection + 'static>, v3::Error> {
187 self.push(label);
188 Ok(Box::new(MockConnection {
189 tx: self.tx.clone(),
190 }))
191 }
192 }
193
194 struct MockConnection {
195 tx: Sender<Action>,
196 }
197
198 #[async_trait]
199 impl Connection for MockConnection {
200 async fn query(
201 &self,
202 query: &str,
203 parameters: Vec<v3::Value>,
204 ) -> Result<v3::QueryResult, v3::Error> {
205 self.tx.send(Action::Query(query.to_string())).unwrap();
206 let _ = parameters;
207 Ok(v3::QueryResult {
208 columns: Vec::new(),
209 rows: Vec::new(),
210 })
211 }
212
213 async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
214 self.tx
215 .send(Action::Execute(statements.to_string()))
216 .unwrap();
217 Ok(())
218 }
219
220 async fn changes(&self) -> Result<u64, v3::Error> {
221 self.tx.send(Action::Changes).unwrap();
222 Ok(123)
223 }
224
225 async fn last_insert_rowid(&self) -> Result<i64, v3::Error> {
226 self.tx.send(Action::LastInsertRowid).unwrap();
227 Ok(456)
228 }
229 }
230
231 #[derive(Debug, PartialEq)]
232 enum Action {
233 CreateConnection(String),
234 Query(String),
235 Execute(String),
236 Changes,
237 LastInsertRowid,
238 }
239}