spin_sqlite_inproc/
lib.rs1use std::{
2 path::PathBuf,
3 sync::OnceLock,
4 sync::{Arc, Mutex},
5};
6
7use anyhow::Context as _;
8use async_trait::async_trait;
9use spin_factor_sqlite::{Connection, QueryAsyncResult};
10use spin_world::spin::sqlite3_1_0::sqlite;
11use spin_world::spin::sqlite3_1_0::sqlite::{self as v3};
12
13#[derive(Debug, Clone)]
15pub enum InProcDatabaseLocation {
16 InMemory,
18 Path(PathBuf),
20}
21
22impl InProcDatabaseLocation {
23 pub fn from_path(path: Option<PathBuf>) -> anyhow::Result<Self> {
28 match path {
29 Some(path) => {
30 if let Some(parent) = path.parent() {
32 std::fs::create_dir_all(parent).with_context(|| {
33 format!(
34 "failed to create sqlite database directory '{}'",
35 parent.display()
36 )
37 })?;
38 }
39 Ok(Self::Path(path))
40 }
41 None => Ok(Self::InMemory),
42 }
43 }
44}
45
46pub struct InProcConnection {
48 location: InProcDatabaseLocation,
49 connection: OnceLock<Arc<Mutex<rusqlite::Connection>>>,
50}
51
52impl InProcConnection {
53 pub fn new(location: InProcDatabaseLocation) -> Result<Self, sqlite::Error> {
54 let connection = OnceLock::new();
55 Ok(Self {
56 location,
57 connection,
58 })
59 }
60
61 pub fn db_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
62 if let Some(c) = self.connection.get() {
63 return Ok(c.clone());
64 }
65 let new = self.create_connection()?;
68 Ok(self.connection.get_or_init(|| new)).cloned()
69 }
70
71 fn create_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
72 let connection = match &self.location {
73 InProcDatabaseLocation::InMemory => rusqlite::Connection::open_in_memory(),
74 InProcDatabaseLocation::Path(path) => rusqlite::Connection::open(path),
75 }
76 .map_err(|e| sqlite::Error::Io(e.to_string()))?;
77 Ok(Arc::new(Mutex::new(connection)))
78 }
79}
80
81#[async_trait]
82impl Connection for InProcConnection {
83 async fn query(
84 &self,
85 query: &str,
86 parameters: Vec<sqlite::Value>,
87 max_result_bytes: usize,
88 ) -> Result<sqlite::QueryResult, sqlite::Error> {
89 let connection = self.db_connection()?;
90 let query = query.to_owned();
91 tokio::task::spawn_blocking(move || {
93 execute_query(&connection, &query, parameters, max_result_bytes)
94 })
95 .await
96 .context("internal runtime error")
97 .map_err(|e| sqlite::Error::Io(e.to_string()))?
98 }
99
100 async fn query_async(
101 &self,
102 query: &str,
103 parameters: Vec<v3::Value>,
104 max_result_bytes: usize,
105 ) -> Result<QueryAsyncResult, v3::Error> {
106 let connection = self.db_connection()?;
107 let query = query.to_owned();
108
109 let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
110 let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
111 let (err_tx, err_rx) = tokio::sync::oneshot::channel();
112
113 let the_work = move || {
114 let conn = connection.lock().unwrap();
115 let mut statement = match conn.prepare_cached(&query) {
116 Ok(s) => s,
117 Err(e) => {
118 _ = cols_tx.send(Default::default());
119 return Err(io_error_v3(e));
120 }
121 };
122 let columns: Vec<_> = statement
123 .column_names()
124 .into_iter()
125 .map(ToOwned::to_owned)
126 .collect();
127 cols_tx
128 .send(columns)
129 .map_err(|_| v3::Error::Io("column send error".into()))?;
130
131 let mut rows = statement
132 .query(rusqlite::params_from_iter(convert_data(
133 parameters.into_iter(),
134 )))
135 .map_err(io_error_v3)?;
136
137 loop {
138 let row = match rows.next().map_err(io_error_v3)? {
139 None => break,
140 Some(r) => r,
141 };
142
143 let row = convert_row(row).map_err(io_error_v3)?;
144
145 if row.values.iter().map(|v| v.memory_size()).sum::<usize>() > max_result_bytes {
146 return Err(sqlite::Error::Io(format!(
147 "query result exceeds limit of {max_result_bytes} bytes"
148 )));
149 }
150
151 println!("it sends the row");
152 rows_tx
153 .blocking_send(row)
154 .map_err(|_| v3::Error::Io("row send error".into()))?;
155 }
156
157 Ok(())
158 };
159 tokio::task::spawn_blocking(move || {
160 let res = the_work();
161 _ = err_tx.send(res);
162 });
163
164 let columns = cols_rx.await.map_err(|e| v3::Error::Io(e.to_string()))?;
165
166 Ok(QueryAsyncResult {
167 columns,
168 rows: rows_rx,
169 error: err_rx,
170 })
171 }
172
173 async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
174 let connection = self.db_connection()?;
175 let statements = statements.to_owned();
176 tokio::task::spawn_blocking(move || {
177 let conn = connection.lock().unwrap();
178 conn.execute_batch(&statements)
179 .context("failed to execute batch statements")
180 })
181 .await?
182 .context("failed to spawn blocking task")?;
183 Ok(())
184 }
185
186 async fn changes(&self) -> Result<u64, sqlite::Error> {
187 let connection = self.db_connection()?;
188 let conn = connection.lock().unwrap();
189 Ok(conn.changes())
190 }
191
192 async fn last_insert_rowid(&self) -> Result<i64, sqlite::Error> {
193 let connection = self.db_connection()?;
194 let conn = connection.lock().unwrap();
195 Ok(conn.last_insert_rowid())
196 }
197
198 fn summary(&self) -> Option<String> {
199 Some(match &self.location {
200 InProcDatabaseLocation::InMemory => "a temporary in-memory database".to_string(),
201 InProcDatabaseLocation::Path(path) => format!("\"{}\"", path.display()),
202 })
203 }
204}
205
206fn io_error_v3(err: rusqlite::Error) -> v3::Error {
207 v3::Error::Io(err.to_string())
208}
209
210fn convert_row(row: &rusqlite::Row) -> Result<sqlite::RowResult, rusqlite::Error> {
211 let mut values = vec![];
212 for column in 0.. {
213 let value = row.get::<usize, ValueWrapper>(column);
214 if let Err(rusqlite::Error::InvalidColumnIndex(_)) = value {
215 break;
216 }
217 let value = value?.0;
218 values.push(value);
219 }
220 Ok(sqlite::RowResult { values })
221}
222
223fn execute_query(
225 connection: &Mutex<rusqlite::Connection>,
226 query: &str,
227 parameters: Vec<sqlite::Value>,
228 max_result_bytes: usize,
229) -> Result<sqlite::QueryResult, sqlite::Error> {
230 let conn = connection.lock().unwrap();
231 let mut statement = conn
232 .prepare_cached(query)
233 .map_err(|e| sqlite::Error::Io(e.to_string()))?;
234 let columns = statement
235 .column_names()
236 .into_iter()
237 .map(ToOwned::to_owned)
238 .collect();
239 let mut byte_count = std::mem::size_of::<sqlite::QueryResult>();
240 let rows = statement
241 .query_map(
242 rusqlite::params_from_iter(convert_data(parameters.into_iter())),
243 convert_row,
244 )
245 .map_err(|e| sqlite::Error::Io(e.to_string()))?;
246 let rows = rows
247 .into_iter()
248 .map(|r| match r {
249 Ok(r) => {
250 byte_count += r.values.iter().map(|v| v.memory_size()).sum::<usize>();
251 if byte_count > max_result_bytes {
252 Err(sqlite::Error::Io(format!(
253 "query result exceeds limit of {max_result_bytes} bytes"
254 )))
255 } else {
256 Ok(r)
257 }
258 }
259 Err(e) => Err(sqlite::Error::Io(e.to_string())),
260 })
261 .collect::<Result<_, sqlite::Error>>()?;
262 Ok(sqlite::QueryResult { columns, rows })
263}
264
265fn convert_data(
266 arguments: impl Iterator<Item = sqlite::Value>,
267) -> impl Iterator<Item = rusqlite::types::Value> {
268 arguments.map(|a| match a {
269 sqlite::Value::Null => rusqlite::types::Value::Null,
270 sqlite::Value::Integer(i) => rusqlite::types::Value::Integer(i),
271 sqlite::Value::Real(r) => rusqlite::types::Value::Real(r),
272 sqlite::Value::Text(t) => rusqlite::types::Value::Text(t),
273 sqlite::Value::Blob(b) => rusqlite::types::Value::Blob(b),
274 })
275}
276
277struct ValueWrapper(sqlite::Value);
279
280impl rusqlite::types::FromSql for ValueWrapper {
281 fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
282 let value = match value {
283 rusqlite::types::ValueRef::Null => sqlite::Value::Null,
284 rusqlite::types::ValueRef::Integer(i) => sqlite::Value::Integer(i),
285 rusqlite::types::ValueRef::Real(f) => sqlite::Value::Real(f),
286 rusqlite::types::ValueRef::Text(t) => {
287 sqlite::Value::Text(String::from_utf8(t.to_vec()).unwrap())
288 }
289 rusqlite::types::ValueRef::Blob(b) => sqlite::Value::Blob(b.to_vec()),
290 };
291 Ok(ValueWrapper(value))
292 }
293}