spin_sqlite_inproc/
lib.rs

1use std::{
2    path::PathBuf,
3    sync::OnceLock,
4    sync::{Arc, Mutex},
5};
6
7use anyhow::Context as _;
8use async_trait::async_trait;
9use spin_factor_sqlite::Connection;
10use spin_world::spin::sqlite::sqlite;
11
12/// The location of an in-process sqlite database.
13#[derive(Debug, Clone)]
14pub enum InProcDatabaseLocation {
15    /// An in-memory sqlite database.
16    InMemory,
17    /// The path to the sqlite database.
18    Path(PathBuf),
19}
20
21impl InProcDatabaseLocation {
22    /// Convert an optional path to a database location.
23    ///
24    /// Ensures that the parent directory of the database exists. If path is None, then an in memory
25    /// database will be used.
26    pub fn from_path(path: Option<PathBuf>) -> anyhow::Result<Self> {
27        match path {
28            Some(path) => {
29                // Create the store's parent directory if necessary
30                if let Some(parent) = path.parent() {
31                    std::fs::create_dir_all(parent).with_context(|| {
32                        format!(
33                            "failed to create sqlite database directory '{}'",
34                            parent.display()
35                        )
36                    })?;
37                }
38                Ok(Self::Path(path))
39            }
40            None => Ok(Self::InMemory),
41        }
42    }
43}
44
45/// A connection to a sqlite database
46pub struct InProcConnection {
47    location: InProcDatabaseLocation,
48    connection: OnceLock<Arc<Mutex<rusqlite::Connection>>>,
49}
50
51impl InProcConnection {
52    pub fn new(location: InProcDatabaseLocation) -> Result<Self, sqlite::Error> {
53        let connection = OnceLock::new();
54        Ok(Self {
55            location,
56            connection,
57        })
58    }
59
60    pub fn db_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
61        if let Some(c) = self.connection.get() {
62            return Ok(c.clone());
63        }
64        // Only create the connection if we failed to get it.
65        // We might do duplicate work here if there's a race, but that's fine.
66        let new = self.create_connection()?;
67        Ok(self.connection.get_or_init(|| new)).cloned()
68    }
69
70    fn create_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
71        let connection = match &self.location {
72            InProcDatabaseLocation::InMemory => rusqlite::Connection::open_in_memory(),
73            InProcDatabaseLocation::Path(path) => rusqlite::Connection::open(path),
74        }
75        .map_err(|e| sqlite::Error::Io(e.to_string()))?;
76        Ok(Arc::new(Mutex::new(connection)))
77    }
78}
79
80#[async_trait]
81impl Connection for InProcConnection {
82    async fn query(
83        &self,
84        query: &str,
85        parameters: Vec<sqlite::Value>,
86    ) -> Result<sqlite::QueryResult, sqlite::Error> {
87        let connection = self.db_connection()?;
88        let query = query.to_owned();
89        // Tell the tokio runtime that we're going to block while making the query
90        tokio::task::spawn_blocking(move || execute_query(&connection, &query, parameters))
91            .await
92            .context("internal runtime error")
93            .map_err(|e| sqlite::Error::Io(e.to_string()))?
94    }
95
96    async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
97        let connection = self.db_connection()?;
98        let statements = statements.to_owned();
99        tokio::task::spawn_blocking(move || {
100            let conn = connection.lock().unwrap();
101            conn.execute_batch(&statements)
102                .context("failed to execute batch statements")
103        })
104        .await?
105        .context("failed to spawn blocking task")?;
106        Ok(())
107    }
108
109    async fn changes(&self) -> Result<u64, sqlite::Error> {
110        let connection = self.db_connection()?;
111        let conn = connection.lock().unwrap();
112        Ok(conn.changes())
113    }
114
115    async fn last_insert_rowid(&self) -> Result<i64, sqlite::Error> {
116        let connection = self.db_connection()?;
117        let conn = connection.lock().unwrap();
118        Ok(conn.last_insert_rowid())
119    }
120
121    fn summary(&self) -> Option<String> {
122        Some(match &self.location {
123            InProcDatabaseLocation::InMemory => "a temporary in-memory database".to_string(),
124            InProcDatabaseLocation::Path(path) => format!("\"{}\"", path.display()),
125        })
126    }
127}
128
129// This function lives outside the query function to make it more readable.
130fn execute_query(
131    connection: &Mutex<rusqlite::Connection>,
132    query: &str,
133    parameters: Vec<sqlite::Value>,
134) -> Result<sqlite::QueryResult, sqlite::Error> {
135    let conn = connection.lock().unwrap();
136    let mut statement = conn
137        .prepare_cached(query)
138        .map_err(|e| sqlite::Error::Io(e.to_string()))?;
139    let columns = statement
140        .column_names()
141        .into_iter()
142        .map(ToOwned::to_owned)
143        .collect();
144    let rows = statement
145        .query_map(
146            rusqlite::params_from_iter(convert_data(parameters.into_iter())),
147            |row| {
148                let mut values = vec![];
149                for column in 0.. {
150                    let value = row.get::<usize, ValueWrapper>(column);
151                    if let Err(rusqlite::Error::InvalidColumnIndex(_)) = value {
152                        break;
153                    }
154                    let value = value?.0;
155                    values.push(value);
156                }
157                Ok(sqlite::RowResult { values })
158            },
159        )
160        .map_err(|e| sqlite::Error::Io(e.to_string()))?;
161    let rows = rows
162        .into_iter()
163        .map(|r| r.map_err(|e| sqlite::Error::Io(e.to_string())))
164        .collect::<Result<_, sqlite::Error>>()?;
165    Ok(sqlite::QueryResult { columns, rows })
166}
167
168fn convert_data(
169    arguments: impl Iterator<Item = sqlite::Value>,
170) -> impl Iterator<Item = rusqlite::types::Value> {
171    arguments.map(|a| match a {
172        sqlite::Value::Null => rusqlite::types::Value::Null,
173        sqlite::Value::Integer(i) => rusqlite::types::Value::Integer(i),
174        sqlite::Value::Real(r) => rusqlite::types::Value::Real(r),
175        sqlite::Value::Text(t) => rusqlite::types::Value::Text(t),
176        sqlite::Value::Blob(b) => rusqlite::types::Value::Blob(b),
177    })
178}
179
180// A wrapper around sqlite::Value so that we can convert from rusqlite ValueRef
181struct ValueWrapper(sqlite::Value);
182
183impl rusqlite::types::FromSql for ValueWrapper {
184    fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
185        let value = match value {
186            rusqlite::types::ValueRef::Null => sqlite::Value::Null,
187            rusqlite::types::ValueRef::Integer(i) => sqlite::Value::Integer(i),
188            rusqlite::types::ValueRef::Real(f) => sqlite::Value::Real(f),
189            rusqlite::types::ValueRef::Text(t) => {
190                sqlite::Value::Text(String::from_utf8(t.to_vec()).unwrap())
191            }
192            rusqlite::types::ValueRef::Blob(b) => sqlite::Value::Blob(b.to_vec()),
193        };
194        Ok(ValueWrapper(value))
195    }
196}