Skip to main content

spin_sqlite_inproc/
lib.rs

1use std::{
2    path::PathBuf,
3    sync::OnceLock,
4    sync::{Arc, Mutex},
5};
6
7use anyhow::Context as _;
8use async_trait::async_trait;
9use spin_factor_sqlite::{Connection, QueryAsyncResult};
10use spin_world::spin::sqlite3_1_0::sqlite;
11use spin_world::spin::sqlite3_1_0::sqlite::{self as v3};
12
13/// The location of an in-process sqlite database.
14#[derive(Debug, Clone)]
15pub enum InProcDatabaseLocation {
16    /// An in-memory sqlite database.
17    InMemory,
18    /// The path to the sqlite database.
19    Path(PathBuf),
20}
21
22impl InProcDatabaseLocation {
23    /// Convert an optional path to a database location.
24    ///
25    /// Ensures that the parent directory of the database exists. If path is None, then an in memory
26    /// database will be used.
27    pub fn from_path(path: Option<PathBuf>) -> anyhow::Result<Self> {
28        match path {
29            Some(path) => {
30                // Create the store's parent directory if necessary
31                if let Some(parent) = path.parent() {
32                    std::fs::create_dir_all(parent).with_context(|| {
33                        format!(
34                            "failed to create sqlite database directory '{}'",
35                            parent.display()
36                        )
37                    })?;
38                }
39                Ok(Self::Path(path))
40            }
41            None => Ok(Self::InMemory),
42        }
43    }
44}
45
46/// A connection to a sqlite database
47pub struct InProcConnection {
48    location: InProcDatabaseLocation,
49    connection: OnceLock<Arc<Mutex<rusqlite::Connection>>>,
50}
51
52impl InProcConnection {
53    pub fn new(location: InProcDatabaseLocation) -> Result<Self, sqlite::Error> {
54        let connection = OnceLock::new();
55        Ok(Self {
56            location,
57            connection,
58        })
59    }
60
61    pub fn db_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
62        if let Some(c) = self.connection.get() {
63            return Ok(c.clone());
64        }
65        // Only create the connection if we failed to get it.
66        // We might do duplicate work here if there's a race, but that's fine.
67        let new = self.create_connection()?;
68        Ok(self.connection.get_or_init(|| new)).cloned()
69    }
70
71    fn create_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
72        let connection = match &self.location {
73            InProcDatabaseLocation::InMemory => rusqlite::Connection::open_in_memory(),
74            InProcDatabaseLocation::Path(path) => rusqlite::Connection::open(path),
75        }
76        .map_err(|e| sqlite::Error::Io(e.to_string()))?;
77        Ok(Arc::new(Mutex::new(connection)))
78    }
79}
80
81#[async_trait]
82impl Connection for InProcConnection {
83    async fn query(
84        &self,
85        query: &str,
86        parameters: Vec<sqlite::Value>,
87        max_result_bytes: usize,
88    ) -> Result<sqlite::QueryResult, sqlite::Error> {
89        let connection = self.db_connection()?;
90        let query = query.to_owned();
91        // Tell the tokio runtime that we're going to block while making the query
92        tokio::task::spawn_blocking(move || {
93            execute_query(&connection, &query, parameters, max_result_bytes)
94        })
95        .await
96        .context("internal runtime error")
97        .map_err(|e| sqlite::Error::Io(e.to_string()))?
98    }
99
100    async fn query_async(
101        &self,
102        query: &str,
103        parameters: Vec<v3::Value>,
104        max_result_bytes: usize,
105    ) -> Result<QueryAsyncResult, v3::Error> {
106        let connection = self.db_connection()?;
107        let query = query.to_owned();
108
109        let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
110        let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
111        let (err_tx, err_rx) = tokio::sync::oneshot::channel();
112
113        let the_work = move || {
114            let conn = connection.lock().unwrap();
115            let mut statement = match conn.prepare_cached(&query) {
116                Ok(s) => s,
117                Err(e) => {
118                    _ = cols_tx.send(Default::default());
119                    return Err(io_error_v3(e));
120                }
121            };
122            let columns: Vec<_> = statement
123                .column_names()
124                .into_iter()
125                .map(ToOwned::to_owned)
126                .collect();
127            cols_tx
128                .send(columns)
129                .map_err(|_| v3::Error::Io("column send error".into()))?;
130
131            let mut rows = statement
132                .query(rusqlite::params_from_iter(convert_data(
133                    parameters.into_iter(),
134                )))
135                .map_err(io_error_v3)?;
136
137            loop {
138                let row = match rows.next().map_err(io_error_v3)? {
139                    None => break,
140                    Some(r) => r,
141                };
142
143                let row = convert_row(row).map_err(io_error_v3)?;
144
145                if row.values.iter().map(|v| v.memory_size()).sum::<usize>() > max_result_bytes {
146                    return Err(sqlite::Error::Io(format!(
147                        "query result exceeds limit of {max_result_bytes} bytes"
148                    )));
149                }
150
151                println!("it sends the row");
152                rows_tx
153                    .blocking_send(row)
154                    .map_err(|_| v3::Error::Io("row send error".into()))?;
155            }
156
157            Ok(())
158        };
159        tokio::task::spawn_blocking(move || {
160            let res = the_work();
161            _ = err_tx.send(res);
162        });
163
164        let columns = cols_rx.await.map_err(|e| v3::Error::Io(e.to_string()))?;
165
166        Ok(QueryAsyncResult {
167            columns,
168            rows: rows_rx,
169            error: err_rx,
170        })
171    }
172
173    async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
174        let connection = self.db_connection()?;
175        let statements = statements.to_owned();
176        tokio::task::spawn_blocking(move || {
177            let conn = connection.lock().unwrap();
178            conn.execute_batch(&statements)
179                .context("failed to execute batch statements")
180        })
181        .await?
182        .context("failed to spawn blocking task")?;
183        Ok(())
184    }
185
186    async fn changes(&self) -> Result<u64, sqlite::Error> {
187        let connection = self.db_connection()?;
188        let conn = connection.lock().unwrap();
189        Ok(conn.changes())
190    }
191
192    async fn last_insert_rowid(&self) -> Result<i64, sqlite::Error> {
193        let connection = self.db_connection()?;
194        let conn = connection.lock().unwrap();
195        Ok(conn.last_insert_rowid())
196    }
197
198    fn summary(&self) -> Option<String> {
199        Some(match &self.location {
200            InProcDatabaseLocation::InMemory => "a temporary in-memory database".to_string(),
201            InProcDatabaseLocation::Path(path) => format!("\"{}\"", path.display()),
202        })
203    }
204}
205
206fn io_error_v3(err: rusqlite::Error) -> v3::Error {
207    v3::Error::Io(err.to_string())
208}
209
210fn convert_row(row: &rusqlite::Row) -> Result<sqlite::RowResult, rusqlite::Error> {
211    let mut values = vec![];
212    for column in 0.. {
213        let value = row.get::<usize, ValueWrapper>(column);
214        if let Err(rusqlite::Error::InvalidColumnIndex(_)) = value {
215            break;
216        }
217        let value = value?.0;
218        values.push(value);
219    }
220    Ok(sqlite::RowResult { values })
221}
222
223// This function lives outside the query function to make it more readable.
224fn execute_query(
225    connection: &Mutex<rusqlite::Connection>,
226    query: &str,
227    parameters: Vec<sqlite::Value>,
228    max_result_bytes: usize,
229) -> Result<sqlite::QueryResult, sqlite::Error> {
230    let conn = connection.lock().unwrap();
231    let mut statement = conn
232        .prepare_cached(query)
233        .map_err(|e| sqlite::Error::Io(e.to_string()))?;
234    let columns = statement
235        .column_names()
236        .into_iter()
237        .map(ToOwned::to_owned)
238        .collect();
239    let mut byte_count = std::mem::size_of::<sqlite::QueryResult>();
240    let rows = statement
241        .query_map(
242            rusqlite::params_from_iter(convert_data(parameters.into_iter())),
243            convert_row,
244        )
245        .map_err(|e| sqlite::Error::Io(e.to_string()))?;
246    let rows = rows
247        .into_iter()
248        .map(|r| match r {
249            Ok(r) => {
250                byte_count += r.values.iter().map(|v| v.memory_size()).sum::<usize>();
251                if byte_count > max_result_bytes {
252                    Err(sqlite::Error::Io(format!(
253                        "query result exceeds limit of {max_result_bytes} bytes"
254                    )))
255                } else {
256                    Ok(r)
257                }
258            }
259            Err(e) => Err(sqlite::Error::Io(e.to_string())),
260        })
261        .collect::<Result<_, sqlite::Error>>()?;
262    Ok(sqlite::QueryResult { columns, rows })
263}
264
265fn convert_data(
266    arguments: impl Iterator<Item = sqlite::Value>,
267) -> impl Iterator<Item = rusqlite::types::Value> {
268    arguments.map(|a| match a {
269        sqlite::Value::Null => rusqlite::types::Value::Null,
270        sqlite::Value::Integer(i) => rusqlite::types::Value::Integer(i),
271        sqlite::Value::Real(r) => rusqlite::types::Value::Real(r),
272        sqlite::Value::Text(t) => rusqlite::types::Value::Text(t),
273        sqlite::Value::Blob(b) => rusqlite::types::Value::Blob(b),
274    })
275}
276
277// A wrapper around sqlite::Value so that we can convert from rusqlite ValueRef
278struct ValueWrapper(sqlite::Value);
279
280impl rusqlite::types::FromSql for ValueWrapper {
281    fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
282        let value = match value {
283            rusqlite::types::ValueRef::Null => sqlite::Value::Null,
284            rusqlite::types::ValueRef::Integer(i) => sqlite::Value::Integer(i),
285            rusqlite::types::ValueRef::Real(f) => sqlite::Value::Real(f),
286            rusqlite::types::ValueRef::Text(t) => {
287                sqlite::Value::Text(String::from_utf8(t.to_vec()).unwrap())
288            }
289            rusqlite::types::ValueRef::Blob(b) => sqlite::Value::Blob(b.to_vec()),
290        };
291        Ok(ValueWrapper(value))
292    }
293}