Skip to main content

spin_trigger/cli/
sqlite_statements.rs

1use anyhow::Context as _;
2use spin_core::async_trait;
3use spin_factor_sqlite::SqliteFactor;
4use spin_factors::RuntimeFactors;
5use spin_factors_executor::ExecutorHooks;
6use spin_world::MAX_HOST_BUFFERED_BYTES;
7
8/// The default sqlite label
9const DEFAULT_SQLITE_LABEL: &str = "default";
10
11/// ExecutorHook for executing sqlite statements.
12///
13/// This executor assumes that the configured app has access to `SqliteFactor`.
14/// It will silently ignore the hook if the app does not have access to `SqliteFactor`.
15pub struct SqlStatementExecutorHook {
16    sql_statements: Vec<String>,
17}
18
19impl SqlStatementExecutorHook {
20    /// Creates a new SqlStatementExecutorHook
21    ///
22    /// The statements can be either a list of raw SQL statements or a list of `@{file:label}` statements.
23    pub fn new(sql_statements: Vec<String>) -> Self {
24        Self { sql_statements }
25    }
26
27    /// Executes the sql statements.
28    pub async fn execute(&self, sqlite: &spin_factor_sqlite::AppState) -> anyhow::Result<()> {
29        if self.sql_statements.is_empty() {
30            return Ok(());
31        }
32        let get_database = |label| async move {
33            sqlite
34                .get_connection(label)
35                .await
36                .transpose()
37                .with_context(|| format!("failed connect to database with label '{label}'"))
38        };
39
40        for statement in &self.sql_statements {
41            if let Some(config) = statement.strip_prefix('@') {
42                let (file, label) = parse_file_and_label(config)?;
43                let database = get_database(label).await?.with_context(|| {
44                    format!(
45                        "based on the '@{config}' a registered database named '{label}' was expected but not found."
46                    )
47                })?;
48                let sql = std::fs::read_to_string(file).with_context(|| {
49                    format!("could not read file '{file}' containing sql statements")
50                })?;
51                database.execute_batch(&sql).await.with_context(|| {
52                    format!("failed to execute sql against database '{label}' from file '{file}'")
53                })?;
54            } else {
55                let Some(default) = get_database(DEFAULT_SQLITE_LABEL).await? else {
56                    debug_assert!(
57                        false,
58                        "the '{DEFAULT_SQLITE_LABEL}' sqlite database should always be available but for some reason was not"
59                    );
60                    return Ok(());
61                };
62                default
63                    .query(statement, Vec::new(), MAX_HOST_BUFFERED_BYTES)
64                    .await
65                    .with_context(|| format!("failed to execute following sql statement against default database: '{statement}'"))?;
66            }
67        }
68        Ok(())
69    }
70}
71
72#[async_trait]
73impl<F, U> ExecutorHooks<F, U> for SqlStatementExecutorHook
74where
75    F: RuntimeFactors,
76{
77    async fn configure_app(
78        &self,
79        configured_app: &spin_factors::ConfiguredApp<F>,
80    ) -> anyhow::Result<()> {
81        let Some(sqlite) = configured_app.app_state::<SqliteFactor>().ok() else {
82            return Ok(());
83        };
84        self.execute(sqlite).await?;
85        Ok(())
86    }
87}
88
89/// Parses a @{file:label} sqlite statement
90fn parse_file_and_label(config: &str) -> anyhow::Result<(&str, &str)> {
91    let config = config.trim();
92    if config.is_empty() {
93        anyhow::bail!("database configuration is empty in the '@{config}' sqlite statement");
94    }
95    let (file, label) = match config.split_once(':') {
96        Some((_, label)) if label.trim().is_empty() => {
97            anyhow::bail!("database label is empty in the '@{config}' sqlite statement")
98        }
99        Some((file, _)) if file.trim().is_empty() => {
100            anyhow::bail!("file path is empty in the '@{config}' sqlite statement")
101        }
102        Some((file, label)) => (file.trim(), label.trim()),
103        None => (config, "default"),
104    };
105    Ok((file, label))
106}
107
108#[cfg(test)]
109mod tests {
110    use std::collections::HashMap;
111    use std::sync::Arc;
112    use std::{collections::VecDeque, sync::mpsc::Sender};
113
114    use spin_core::async_trait;
115    use spin_factor_sqlite::{Connection, ConnectionCreator, QueryAsyncResult};
116    use spin_world::spin::sqlite3_1_0::sqlite as v3;
117    use tempfile::NamedTempFile;
118
119    use super::*;
120
121    #[test]
122    fn test_parse_file_and_label() {
123        assert_eq!(
124            parse_file_and_label("file:label").unwrap(),
125            ("file", "label")
126        );
127        assert!(parse_file_and_label("file:").is_err());
128        assert_eq!(parse_file_and_label("file").unwrap(), ("file", "default"));
129        assert!(parse_file_and_label(":label").is_err());
130        assert!(parse_file_and_label("").is_err());
131    }
132
133    #[tokio::test]
134    async fn test_execute() {
135        let sqlite_file = NamedTempFile::new().unwrap();
136        std::fs::write(&sqlite_file, "select 2;").unwrap();
137
138        let hook = SqlStatementExecutorHook::new(vec![
139            "SELECT 1;".to_string(),
140            format!("@{path}:label", path = sqlite_file.path().display()),
141        ]);
142        let (tx, rx) = std::sync::mpsc::channel();
143        let creator = Arc::new(MockCreator { tx });
144        let mut connection_creators = HashMap::new();
145        connection_creators.insert(
146            "default".into(),
147            creator.clone() as Arc<dyn ConnectionCreator>,
148        );
149        connection_creators.insert("label".into(), creator);
150        let sqlite = spin_factor_sqlite::AppState::new(Default::default(), connection_creators);
151        let result = hook.execute(&sqlite).await;
152        assert!(result.is_ok());
153
154        let mut expected: VecDeque<Action> = vec![
155            Action::CreateConnection("default".to_string()),
156            Action::Query("SELECT 1;".to_string()),
157            Action::CreateConnection("label".to_string()),
158            Action::Execute("select 2;".to_string()),
159        ]
160        .into_iter()
161        .collect();
162        while let Ok(action) = rx.try_recv() {
163            assert_eq!(action, expected.pop_front().unwrap(), "unexpected action");
164        }
165
166        assert!(
167            expected.is_empty(),
168            "Expected actions were never seen: {:?}",
169            expected
170        );
171    }
172
173    struct MockCreator {
174        tx: Sender<Action>,
175    }
176
177    impl MockCreator {
178        fn push(&self, label: &str) {
179            self.tx
180                .send(Action::CreateConnection(label.to_string()))
181                .unwrap();
182        }
183    }
184
185    #[async_trait]
186    impl ConnectionCreator for MockCreator {
187        async fn create_connection(
188            &self,
189            label: &str,
190        ) -> Result<Arc<dyn Connection + 'static>, v3::Error> {
191            self.push(label);
192            Ok(Arc::new(MockConnection {
193                tx: self.tx.clone(),
194            }))
195        }
196    }
197
198    struct MockConnection {
199        tx: Sender<Action>,
200    }
201
202    #[async_trait]
203    impl Connection for MockConnection {
204        async fn query(
205            &self,
206            query: &str,
207            parameters: Vec<v3::Value>,
208            _max_result_bytes: usize,
209        ) -> Result<v3::QueryResult, v3::Error> {
210            self.tx.send(Action::Query(query.to_string())).unwrap();
211            let _ = parameters;
212            Ok(v3::QueryResult {
213                columns: Vec::new(),
214                rows: Vec::new(),
215            })
216        }
217
218        async fn query_async(
219            &self,
220            query: &str,
221            parameters: Vec<v3::Value>,
222            _max_result_bytes: usize,
223        ) -> Result<QueryAsyncResult, v3::Error> {
224            self.tx.send(Action::Query(query.to_string())).unwrap();
225            let _ = parameters;
226            let (_rtx, rrx) = tokio::sync::mpsc::channel(1);
227            let (_etx, erx) = tokio::sync::oneshot::channel();
228            Ok(QueryAsyncResult {
229                columns: Vec::new(),
230                rows: rrx,
231                error: erx,
232            })
233        }
234
235        async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
236            self.tx
237                .send(Action::Execute(statements.to_string()))
238                .unwrap();
239            Ok(())
240        }
241
242        async fn changes(&self) -> Result<u64, v3::Error> {
243            self.tx.send(Action::Changes).unwrap();
244            Ok(123)
245        }
246
247        async fn last_insert_rowid(&self) -> Result<i64, v3::Error> {
248            self.tx.send(Action::LastInsertRowid).unwrap();
249            Ok(456)
250        }
251    }
252
253    #[derive(Debug, PartialEq)]
254    enum Action {
255        CreateConnection(String),
256        Query(String),
257        Execute(String),
258        Changes,
259        LastInsertRowid,
260    }
261}