spin_trigger/cli/
sqlite_statements.rs

1use anyhow::Context as _;
2use spin_core::async_trait;
3use spin_factor_sqlite::SqliteFactor;
4use spin_factors::RuntimeFactors;
5use spin_factors_executor::ExecutorHooks;
6
7/// The default sqlite label
8const DEFAULT_SQLITE_LABEL: &str = "default";
9
10/// ExecutorHook for executing sqlite statements.
11///
12/// This executor assumes that the configured app has access to `SqliteFactor`.
13/// It will silently ignore the hook if the app does not have access to `SqliteFactor`.
14pub struct SqlStatementExecutorHook {
15    sql_statements: Vec<String>,
16}
17
18impl SqlStatementExecutorHook {
19    /// Creates a new SqlStatementExecutorHook
20    ///
21    /// The statements can be either a list of raw SQL statements or a list of `@{file:label}` statements.
22    pub fn new(sql_statements: Vec<String>) -> Self {
23        Self { sql_statements }
24    }
25
26    /// Executes the sql statements.
27    pub async fn execute(&self, sqlite: &spin_factor_sqlite::AppState) -> anyhow::Result<()> {
28        if self.sql_statements.is_empty() {
29            return Ok(());
30        }
31        let get_database = |label| async move {
32            sqlite
33                .get_connection(label)
34                .await
35                .transpose()
36                .with_context(|| format!("failed connect to database with label '{label}'"))
37        };
38
39        for statement in &self.sql_statements {
40            if let Some(config) = statement.strip_prefix('@') {
41                let (file, label) = parse_file_and_label(config)?;
42                let database = get_database(label).await?.with_context(|| {
43                    format!(
44                        "based on the '@{config}' a registered database named '{label}' was expected but not found."
45                    )
46                })?;
47                let sql = std::fs::read_to_string(file).with_context(|| {
48                    format!("could not read file '{file}' containing sql statements")
49                })?;
50                database.execute_batch(&sql).await.with_context(|| {
51                    format!("failed to execute sql against database '{label}' from file '{file}'")
52                })?;
53            } else {
54                let Some(default) = get_database(DEFAULT_SQLITE_LABEL).await? else {
55                    debug_assert!(false, "the '{DEFAULT_SQLITE_LABEL}' sqlite database should always be available but for some reason was not");
56                    return Ok(());
57                };
58                default
59                    .query(statement, Vec::new())
60                    .await
61                    .with_context(|| format!("failed to execute following sql statement against default database: '{statement}'"))?;
62            }
63        }
64        Ok(())
65    }
66}
67
68#[async_trait]
69impl<F, U> ExecutorHooks<F, U> for SqlStatementExecutorHook
70where
71    F: RuntimeFactors,
72{
73    async fn configure_app(
74        &self,
75        configured_app: &spin_factors::ConfiguredApp<F>,
76    ) -> anyhow::Result<()> {
77        let Some(sqlite) = configured_app.app_state::<SqliteFactor>().ok() else {
78            return Ok(());
79        };
80        self.execute(sqlite).await?;
81        Ok(())
82    }
83}
84
85/// Parses a @{file:label} sqlite statement
86fn parse_file_and_label(config: &str) -> anyhow::Result<(&str, &str)> {
87    let config = config.trim();
88    if config.is_empty() {
89        anyhow::bail!("database configuration is empty in the '@{config}' sqlite statement");
90    }
91    let (file, label) = match config.split_once(':') {
92        Some((_, label)) if label.trim().is_empty() => {
93            anyhow::bail!("database label is empty in the '@{config}' sqlite statement")
94        }
95        Some((file, _)) if file.trim().is_empty() => {
96            anyhow::bail!("file path is empty in the '@{config}' sqlite statement")
97        }
98        Some((file, label)) => (file.trim(), label.trim()),
99        None => (config, "default"),
100    };
101    Ok((file, label))
102}
103
104#[cfg(test)]
105mod tests {
106    use std::collections::HashMap;
107    use std::sync::Arc;
108    use std::{collections::VecDeque, sync::mpsc::Sender};
109
110    use spin_core::async_trait;
111    use spin_factor_sqlite::{Connection, ConnectionCreator};
112    use spin_world::spin::sqlite::sqlite as v3;
113    use tempfile::NamedTempFile;
114
115    use super::*;
116
117    #[test]
118    fn test_parse_file_and_label() {
119        assert_eq!(
120            parse_file_and_label("file:label").unwrap(),
121            ("file", "label")
122        );
123        assert!(parse_file_and_label("file:").is_err());
124        assert_eq!(parse_file_and_label("file").unwrap(), ("file", "default"));
125        assert!(parse_file_and_label(":label").is_err());
126        assert!(parse_file_and_label("").is_err());
127    }
128
129    #[tokio::test]
130    async fn test_execute() {
131        let sqlite_file = NamedTempFile::new().unwrap();
132        std::fs::write(&sqlite_file, "select 2;").unwrap();
133
134        let hook = SqlStatementExecutorHook::new(vec![
135            "SELECT 1;".to_string(),
136            format!("@{path}:label", path = sqlite_file.path().display()),
137        ]);
138        let (tx, rx) = std::sync::mpsc::channel();
139        let creator = Arc::new(MockCreator { tx });
140        let mut connection_creators = HashMap::new();
141        connection_creators.insert(
142            "default".into(),
143            creator.clone() as Arc<dyn ConnectionCreator>,
144        );
145        connection_creators.insert("label".into(), creator);
146        let sqlite = spin_factor_sqlite::AppState::new(Default::default(), connection_creators);
147        let result = hook.execute(&sqlite).await;
148        assert!(result.is_ok());
149
150        let mut expected: VecDeque<Action> = vec![
151            Action::CreateConnection("default".to_string()),
152            Action::Query("SELECT 1;".to_string()),
153            Action::CreateConnection("label".to_string()),
154            Action::Execute("select 2;".to_string()),
155        ]
156        .into_iter()
157        .collect();
158        while let Ok(action) = rx.try_recv() {
159            assert_eq!(action, expected.pop_front().unwrap(), "unexpected action");
160        }
161
162        assert!(
163            expected.is_empty(),
164            "Expected actions were never seen: {:?}",
165            expected
166        );
167    }
168
169    struct MockCreator {
170        tx: Sender<Action>,
171    }
172
173    impl MockCreator {
174        fn push(&self, label: &str) {
175            self.tx
176                .send(Action::CreateConnection(label.to_string()))
177                .unwrap();
178        }
179    }
180
181    #[async_trait]
182    impl ConnectionCreator for MockCreator {
183        async fn create_connection(
184            &self,
185            label: &str,
186        ) -> Result<Box<dyn Connection + 'static>, v3::Error> {
187            self.push(label);
188            Ok(Box::new(MockConnection {
189                tx: self.tx.clone(),
190            }))
191        }
192    }
193
194    struct MockConnection {
195        tx: Sender<Action>,
196    }
197
198    #[async_trait]
199    impl Connection for MockConnection {
200        async fn query(
201            &self,
202            query: &str,
203            parameters: Vec<v3::Value>,
204        ) -> Result<v3::QueryResult, v3::Error> {
205            self.tx.send(Action::Query(query.to_string())).unwrap();
206            let _ = parameters;
207            Ok(v3::QueryResult {
208                columns: Vec::new(),
209                rows: Vec::new(),
210            })
211        }
212
213        async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
214            self.tx
215                .send(Action::Execute(statements.to_string()))
216                .unwrap();
217            Ok(())
218        }
219
220        async fn changes(&self) -> Result<u64, v3::Error> {
221            self.tx.send(Action::Changes).unwrap();
222            Ok(123)
223        }
224
225        async fn last_insert_rowid(&self) -> Result<i64, v3::Error> {
226            self.tx.send(Action::LastInsertRowid).unwrap();
227            Ok(456)
228        }
229    }
230
231    #[derive(Debug, PartialEq)]
232    enum Action {
233        CreateConnection(String),
234        Query(String),
235        Execute(String),
236        Changes,
237        LastInsertRowid,
238    }
239}