Skip to main content

spin_sqlite_inproc/
lib.rs

1use std::{
2    path::PathBuf,
3    sync::OnceLock,
4    sync::{Arc, Mutex},
5};
6
7use anyhow::Context as _;
8use async_trait::async_trait;
9use spin_factor_sqlite::{Connection, QueryAsyncResult};
10use spin_world::spin::sqlite3_1_0::sqlite;
11use spin_world::spin::sqlite3_1_0::sqlite::{self as v3};
12
13/// The location of an in-process sqlite database.
14#[derive(Debug, Clone)]
15pub enum InProcDatabaseLocation {
16    /// An in-memory sqlite database.
17    InMemory,
18    /// The path to the sqlite database.
19    Path(PathBuf),
20}
21
22impl InProcDatabaseLocation {
23    /// Convert an optional path to a database location.
24    ///
25    /// Ensures that the parent directory of the database exists. If path is None, then an in memory
26    /// database will be used.
27    pub fn from_path(path: Option<PathBuf>) -> anyhow::Result<Self> {
28        match path {
29            Some(path) => {
30                // Create the store's parent directory if necessary
31                if let Some(parent) = path.parent() {
32                    std::fs::create_dir_all(parent).with_context(|| {
33                        format!(
34                            "failed to create sqlite database directory '{}'",
35                            parent.display()
36                        )
37                    })?;
38                }
39                Ok(Self::Path(path))
40            }
41            None => Ok(Self::InMemory),
42        }
43    }
44}
45
46/// A connection to a sqlite database
47pub struct InProcConnection {
48    location: InProcDatabaseLocation,
49    allow_attach_file: bool,
50    connection: OnceLock<Arc<Mutex<rusqlite::Connection>>>,
51}
52
53impl InProcConnection {
54    pub fn new(
55        location: InProcDatabaseLocation,
56        allow_attach_file: bool,
57    ) -> Result<Self, sqlite::Error> {
58        let connection = OnceLock::new();
59        Ok(Self {
60            location,
61            allow_attach_file,
62            connection,
63        })
64    }
65
66    pub fn db_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
67        if let Some(c) = self.connection.get() {
68            return Ok(c.clone());
69        }
70        // Only create the connection if we failed to get it.
71        // We might do duplicate work here if there's a race, but that's fine.
72        let new = self.create_connection()?;
73        Ok(self.connection.get_or_init(|| new)).cloned()
74    }
75
76    fn create_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
77        let connection = match &self.location {
78            InProcDatabaseLocation::InMemory => rusqlite::Connection::open_in_memory(),
79            InProcDatabaseLocation::Path(path) => rusqlite::Connection::open(path),
80        }
81        .map_err(|e| sqlite::Error::Io(e.to_string()))?;
82        if !self.allow_attach_file {
83            connection.authorizer(Some(|ctx: rusqlite::hooks::AuthContext<'_>| {
84                use rusqlite::hooks::{AuthAction, Authorization};
85                match ctx.action {
86                    // Deny attaching files except tempfile ("") and in-memory (":memory:") databases
87                    AuthAction::Attach { filename } if !matches!(filename, "" | ":memory:") => {
88                        Authorization::Deny
89                    }
90                    _ => Authorization::Allow,
91                }
92            }));
93        }
94        Ok(Arc::new(Mutex::new(connection)))
95    }
96}
97
98#[async_trait]
99impl Connection for InProcConnection {
100    async fn query(
101        &self,
102        query: &str,
103        parameters: Vec<sqlite::Value>,
104        max_result_bytes: usize,
105    ) -> Result<sqlite::QueryResult, sqlite::Error> {
106        let connection = self.db_connection()?;
107        let query = query.to_owned();
108        // Tell the tokio runtime that we're going to block while making the query
109        tokio::task::spawn_blocking(move || {
110            execute_query(&connection, &query, parameters, max_result_bytes)
111        })
112        .await
113        .context("internal runtime error")
114        .map_err(|e| sqlite::Error::Io(e.to_string()))?
115    }
116
117    async fn query_async(
118        &self,
119        query: &str,
120        parameters: Vec<v3::Value>,
121        max_result_bytes: usize,
122    ) -> Result<QueryAsyncResult, v3::Error> {
123        let connection = self.db_connection()?;
124        let query = query.to_owned();
125
126        let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
127        let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
128        let (err_tx, err_rx) = tokio::sync::oneshot::channel();
129
130        let the_work = move || {
131            let conn = connection.lock().unwrap();
132            let mut statement = match conn.prepare_cached(&query) {
133                Ok(s) => s,
134                Err(e) => {
135                    _ = cols_tx.send(Default::default());
136                    return Err(io_error_v3(e));
137                }
138            };
139            let columns: Vec<_> = statement
140                .column_names()
141                .into_iter()
142                .map(ToOwned::to_owned)
143                .collect();
144            cols_tx
145                .send(columns)
146                .map_err(|_| v3::Error::Io("column send error".into()))?;
147
148            let mut rows = statement
149                .query(rusqlite::params_from_iter(convert_data(
150                    parameters.into_iter(),
151                )))
152                .map_err(io_error_v3)?;
153
154            loop {
155                let row = match rows.next().map_err(io_error_v3)? {
156                    None => break,
157                    Some(r) => r,
158                };
159
160                let row = convert_row(row).map_err(io_error_v3)?;
161
162                if row.values.iter().map(|v| v.memory_size()).sum::<usize>() > max_result_bytes {
163                    return Err(sqlite::Error::Io(format!(
164                        "query result exceeds limit of {max_result_bytes} bytes"
165                    )));
166                }
167
168                rows_tx
169                    .blocking_send(row)
170                    .map_err(|_| v3::Error::Io("row send error".into()))?;
171            }
172
173            Ok(())
174        };
175        tokio::task::spawn_blocking(move || {
176            let res = the_work();
177            _ = err_tx.send(res);
178        });
179
180        let columns = cols_rx.await.map_err(|e| v3::Error::Io(e.to_string()))?;
181
182        Ok(QueryAsyncResult {
183            columns,
184            rows: rows_rx,
185            error: err_rx,
186        })
187    }
188
189    async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
190        let connection = self.db_connection()?;
191        let statements = statements.to_owned();
192        tokio::task::spawn_blocking(move || {
193            let conn = connection.lock().unwrap();
194            conn.execute_batch(&statements)
195                .context("failed to execute batch statements")
196        })
197        .await?
198        .context("failed to spawn blocking task")?;
199        Ok(())
200    }
201
202    async fn changes(&self) -> Result<u64, sqlite::Error> {
203        let connection = self.db_connection()?;
204        let conn = connection.lock().unwrap();
205        Ok(conn.changes())
206    }
207
208    async fn last_insert_rowid(&self) -> Result<i64, sqlite::Error> {
209        let connection = self.db_connection()?;
210        let conn = connection.lock().unwrap();
211        Ok(conn.last_insert_rowid())
212    }
213
214    fn summary(&self) -> Option<String> {
215        Some(match &self.location {
216            InProcDatabaseLocation::InMemory => "a temporary in-memory database".to_string(),
217            InProcDatabaseLocation::Path(path) => format!("\"{}\"", path.display()),
218        })
219    }
220}
221
222fn io_error_v3(err: rusqlite::Error) -> v3::Error {
223    v3::Error::Io(err.to_string())
224}
225
226fn convert_row(row: &rusqlite::Row) -> Result<sqlite::RowResult, rusqlite::Error> {
227    let mut values = vec![];
228    for column in 0.. {
229        let value = row.get::<usize, ValueWrapper>(column);
230        if let Err(rusqlite::Error::InvalidColumnIndex(_)) = value {
231            break;
232        }
233        let value = value?.0;
234        values.push(value);
235    }
236    Ok(sqlite::RowResult { values })
237}
238
239// This function lives outside the query function to make it more readable.
240fn execute_query(
241    connection: &Mutex<rusqlite::Connection>,
242    query: &str,
243    parameters: Vec<sqlite::Value>,
244    max_result_bytes: usize,
245) -> Result<sqlite::QueryResult, sqlite::Error> {
246    let conn = connection.lock().unwrap();
247    let mut statement = conn
248        .prepare_cached(query)
249        .map_err(|e| sqlite::Error::Io(e.to_string()))?;
250    let columns = statement
251        .column_names()
252        .into_iter()
253        .map(ToOwned::to_owned)
254        .collect();
255    let mut byte_count = std::mem::size_of::<sqlite::QueryResult>();
256    let rows = statement
257        .query_map(
258            rusqlite::params_from_iter(convert_data(parameters.into_iter())),
259            convert_row,
260        )
261        .map_err(|e| sqlite::Error::Io(e.to_string()))?;
262    let rows = rows
263        .into_iter()
264        .map(|r| match r {
265            Ok(r) => {
266                byte_count += r.values.iter().map(|v| v.memory_size()).sum::<usize>();
267                if byte_count > max_result_bytes {
268                    Err(sqlite::Error::Io(format!(
269                        "query result exceeds limit of {max_result_bytes} bytes"
270                    )))
271                } else {
272                    Ok(r)
273                }
274            }
275            Err(e) => Err(sqlite::Error::Io(e.to_string())),
276        })
277        .collect::<Result<_, sqlite::Error>>()?;
278    Ok(sqlite::QueryResult { columns, rows })
279}
280
281fn convert_data(
282    arguments: impl Iterator<Item = sqlite::Value>,
283) -> impl Iterator<Item = rusqlite::types::Value> {
284    arguments.map(|a| match a {
285        sqlite::Value::Null => rusqlite::types::Value::Null,
286        sqlite::Value::Integer(i) => rusqlite::types::Value::Integer(i),
287        sqlite::Value::Real(r) => rusqlite::types::Value::Real(r),
288        sqlite::Value::Text(t) => rusqlite::types::Value::Text(t),
289        sqlite::Value::Blob(b) => rusqlite::types::Value::Blob(b),
290    })
291}
292
293// A wrapper around sqlite::Value so that we can convert from rusqlite ValueRef
294struct ValueWrapper(sqlite::Value);
295
296impl rusqlite::types::FromSql for ValueWrapper {
297    fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
298        let value = match value {
299            rusqlite::types::ValueRef::Null => sqlite::Value::Null,
300            rusqlite::types::ValueRef::Integer(i) => sqlite::Value::Integer(i),
301            rusqlite::types::ValueRef::Real(f) => sqlite::Value::Real(f),
302            rusqlite::types::ValueRef::Text(t) => {
303                sqlite::Value::Text(String::from_utf8(t.to_vec()).unwrap())
304            }
305            rusqlite::types::ValueRef::Blob(b) => sqlite::Value::Blob(b.to_vec()),
306        };
307        Ok(ValueWrapper(value))
308    }
309}