spin_sqlite_inproc/
lib.rs1use std::{
2 path::PathBuf,
3 sync::OnceLock,
4 sync::{Arc, Mutex},
5};
6
7use anyhow::Context as _;
8use async_trait::async_trait;
9use spin_factor_sqlite::{Connection, QueryAsyncResult};
10use spin_world::spin::sqlite3_1_0::sqlite;
11use spin_world::spin::sqlite3_1_0::sqlite::{self as v3};
12
13#[derive(Debug, Clone)]
15pub enum InProcDatabaseLocation {
16 InMemory,
18 Path(PathBuf),
20}
21
22impl InProcDatabaseLocation {
23 pub fn from_path(path: Option<PathBuf>) -> anyhow::Result<Self> {
28 match path {
29 Some(path) => {
30 if let Some(parent) = path.parent() {
32 std::fs::create_dir_all(parent).with_context(|| {
33 format!(
34 "failed to create sqlite database directory '{}'",
35 parent.display()
36 )
37 })?;
38 }
39 Ok(Self::Path(path))
40 }
41 None => Ok(Self::InMemory),
42 }
43 }
44}
45
46pub struct InProcConnection {
48 location: InProcDatabaseLocation,
49 allow_attach_file: bool,
50 connection: OnceLock<Arc<Mutex<rusqlite::Connection>>>,
51}
52
53impl InProcConnection {
54 pub fn new(
55 location: InProcDatabaseLocation,
56 allow_attach_file: bool,
57 ) -> Result<Self, sqlite::Error> {
58 let connection = OnceLock::new();
59 Ok(Self {
60 location,
61 allow_attach_file,
62 connection,
63 })
64 }
65
66 pub fn db_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
67 if let Some(c) = self.connection.get() {
68 return Ok(c.clone());
69 }
70 let new = self.create_connection()?;
73 Ok(self.connection.get_or_init(|| new)).cloned()
74 }
75
76 fn create_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
77 let connection = match &self.location {
78 InProcDatabaseLocation::InMemory => rusqlite::Connection::open_in_memory(),
79 InProcDatabaseLocation::Path(path) => rusqlite::Connection::open(path),
80 }
81 .map_err(|e| sqlite::Error::Io(e.to_string()))?;
82 if !self.allow_attach_file {
83 connection.authorizer(Some(|ctx: rusqlite::hooks::AuthContext<'_>| {
84 use rusqlite::hooks::{AuthAction, Authorization};
85 match ctx.action {
86 AuthAction::Attach { filename } if !matches!(filename, "" | ":memory:") => {
88 Authorization::Deny
89 }
90 _ => Authorization::Allow,
91 }
92 }));
93 }
94 Ok(Arc::new(Mutex::new(connection)))
95 }
96}
97
98#[async_trait]
99impl Connection for InProcConnection {
100 async fn query(
101 &self,
102 query: &str,
103 parameters: Vec<sqlite::Value>,
104 max_result_bytes: usize,
105 ) -> Result<sqlite::QueryResult, sqlite::Error> {
106 let connection = self.db_connection()?;
107 let query = query.to_owned();
108 tokio::task::spawn_blocking(move || {
110 execute_query(&connection, &query, parameters, max_result_bytes)
111 })
112 .await
113 .context("internal runtime error")
114 .map_err(|e| sqlite::Error::Io(e.to_string()))?
115 }
116
117 async fn query_async(
118 &self,
119 query: &str,
120 parameters: Vec<v3::Value>,
121 max_result_bytes: usize,
122 ) -> Result<QueryAsyncResult, v3::Error> {
123 let connection = self.db_connection()?;
124 let query = query.to_owned();
125
126 let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
127 let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
128 let (err_tx, err_rx) = tokio::sync::oneshot::channel();
129
130 let the_work = move || {
131 let conn = connection.lock().unwrap();
132 let mut statement = match conn.prepare_cached(&query) {
133 Ok(s) => s,
134 Err(e) => {
135 _ = cols_tx.send(Default::default());
136 return Err(io_error_v3(e));
137 }
138 };
139 let columns: Vec<_> = statement
140 .column_names()
141 .into_iter()
142 .map(ToOwned::to_owned)
143 .collect();
144 cols_tx
145 .send(columns)
146 .map_err(|_| v3::Error::Io("column send error".into()))?;
147
148 let mut rows = statement
149 .query(rusqlite::params_from_iter(convert_data(
150 parameters.into_iter(),
151 )))
152 .map_err(io_error_v3)?;
153
154 loop {
155 let row = match rows.next().map_err(io_error_v3)? {
156 None => break,
157 Some(r) => r,
158 };
159
160 let row = convert_row(row).map_err(io_error_v3)?;
161
162 if row.values.iter().map(|v| v.memory_size()).sum::<usize>() > max_result_bytes {
163 return Err(sqlite::Error::Io(format!(
164 "query result exceeds limit of {max_result_bytes} bytes"
165 )));
166 }
167
168 rows_tx
169 .blocking_send(row)
170 .map_err(|_| v3::Error::Io("row send error".into()))?;
171 }
172
173 Ok(())
174 };
175 tokio::task::spawn_blocking(move || {
176 let res = the_work();
177 _ = err_tx.send(res);
178 });
179
180 let columns = cols_rx.await.map_err(|e| v3::Error::Io(e.to_string()))?;
181
182 Ok(QueryAsyncResult {
183 columns,
184 rows: rows_rx,
185 error: err_rx,
186 })
187 }
188
189 async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
190 let connection = self.db_connection()?;
191 let statements = statements.to_owned();
192 tokio::task::spawn_blocking(move || {
193 let conn = connection.lock().unwrap();
194 conn.execute_batch(&statements)
195 .context("failed to execute batch statements")
196 })
197 .await?
198 .context("failed to spawn blocking task")?;
199 Ok(())
200 }
201
202 async fn changes(&self) -> Result<u64, sqlite::Error> {
203 let connection = self.db_connection()?;
204 let conn = connection.lock().unwrap();
205 Ok(conn.changes())
206 }
207
208 async fn last_insert_rowid(&self) -> Result<i64, sqlite::Error> {
209 let connection = self.db_connection()?;
210 let conn = connection.lock().unwrap();
211 Ok(conn.last_insert_rowid())
212 }
213
214 fn summary(&self) -> Option<String> {
215 Some(match &self.location {
216 InProcDatabaseLocation::InMemory => "a temporary in-memory database".to_string(),
217 InProcDatabaseLocation::Path(path) => format!("\"{}\"", path.display()),
218 })
219 }
220}
221
222fn io_error_v3(err: rusqlite::Error) -> v3::Error {
223 v3::Error::Io(err.to_string())
224}
225
226fn convert_row(row: &rusqlite::Row) -> Result<sqlite::RowResult, rusqlite::Error> {
227 let mut values = vec![];
228 for column in 0.. {
229 let value = row.get::<usize, ValueWrapper>(column);
230 if let Err(rusqlite::Error::InvalidColumnIndex(_)) = value {
231 break;
232 }
233 let value = value?.0;
234 values.push(value);
235 }
236 Ok(sqlite::RowResult { values })
237}
238
239fn execute_query(
241 connection: &Mutex<rusqlite::Connection>,
242 query: &str,
243 parameters: Vec<sqlite::Value>,
244 max_result_bytes: usize,
245) -> Result<sqlite::QueryResult, sqlite::Error> {
246 let conn = connection.lock().unwrap();
247 let mut statement = conn
248 .prepare_cached(query)
249 .map_err(|e| sqlite::Error::Io(e.to_string()))?;
250 let columns = statement
251 .column_names()
252 .into_iter()
253 .map(ToOwned::to_owned)
254 .collect();
255 let mut byte_count = std::mem::size_of::<sqlite::QueryResult>();
256 let rows = statement
257 .query_map(
258 rusqlite::params_from_iter(convert_data(parameters.into_iter())),
259 convert_row,
260 )
261 .map_err(|e| sqlite::Error::Io(e.to_string()))?;
262 let rows = rows
263 .into_iter()
264 .map(|r| match r {
265 Ok(r) => {
266 byte_count += r.values.iter().map(|v| v.memory_size()).sum::<usize>();
267 if byte_count > max_result_bytes {
268 Err(sqlite::Error::Io(format!(
269 "query result exceeds limit of {max_result_bytes} bytes"
270 )))
271 } else {
272 Ok(r)
273 }
274 }
275 Err(e) => Err(sqlite::Error::Io(e.to_string())),
276 })
277 .collect::<Result<_, sqlite::Error>>()?;
278 Ok(sqlite::QueryResult { columns, rows })
279}
280
281fn convert_data(
282 arguments: impl Iterator<Item = sqlite::Value>,
283) -> impl Iterator<Item = rusqlite::types::Value> {
284 arguments.map(|a| match a {
285 sqlite::Value::Null => rusqlite::types::Value::Null,
286 sqlite::Value::Integer(i) => rusqlite::types::Value::Integer(i),
287 sqlite::Value::Real(r) => rusqlite::types::Value::Real(r),
288 sqlite::Value::Text(t) => rusqlite::types::Value::Text(t),
289 sqlite::Value::Blob(b) => rusqlite::types::Value::Blob(b),
290 })
291}
292
293struct ValueWrapper(sqlite::Value);
295
296impl rusqlite::types::FromSql for ValueWrapper {
297 fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
298 let value = match value {
299 rusqlite::types::ValueRef::Null => sqlite::Value::Null,
300 rusqlite::types::ValueRef::Integer(i) => sqlite::Value::Integer(i),
301 rusqlite::types::ValueRef::Real(f) => sqlite::Value::Real(f),
302 rusqlite::types::ValueRef::Text(t) => {
303 sqlite::Value::Text(String::from_utf8(t.to_vec()).unwrap())
304 }
305 rusqlite::types::ValueRef::Blob(b) => sqlite::Value::Blob(b.to_vec()),
306 };
307 Ok(ValueWrapper(value))
308 }
309}