spin_sqlite_inproc/
lib.rs1use std::{
2 path::PathBuf,
3 sync::OnceLock,
4 sync::{Arc, Mutex},
5};
6
7use anyhow::Context as _;
8use async_trait::async_trait;
9use spin_factor_sqlite::Connection;
10use spin_world::spin::sqlite::sqlite;
11
12#[derive(Debug, Clone)]
14pub enum InProcDatabaseLocation {
15 InMemory,
17 Path(PathBuf),
19}
20
21impl InProcDatabaseLocation {
22 pub fn from_path(path: Option<PathBuf>) -> anyhow::Result<Self> {
27 match path {
28 Some(path) => {
29 if let Some(parent) = path.parent() {
31 std::fs::create_dir_all(parent).with_context(|| {
32 format!(
33 "failed to create sqlite database directory '{}'",
34 parent.display()
35 )
36 })?;
37 }
38 Ok(Self::Path(path))
39 }
40 None => Ok(Self::InMemory),
41 }
42 }
43}
44
45pub struct InProcConnection {
47 location: InProcDatabaseLocation,
48 connection: OnceLock<Arc<Mutex<rusqlite::Connection>>>,
49}
50
51impl InProcConnection {
52 pub fn new(location: InProcDatabaseLocation) -> Result<Self, sqlite::Error> {
53 let connection = OnceLock::new();
54 Ok(Self {
55 location,
56 connection,
57 })
58 }
59
60 pub fn db_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
61 if let Some(c) = self.connection.get() {
62 return Ok(c.clone());
63 }
64 let new = self.create_connection()?;
67 Ok(self.connection.get_or_init(|| new)).cloned()
68 }
69
70 fn create_connection(&self) -> Result<Arc<Mutex<rusqlite::Connection>>, sqlite::Error> {
71 let connection = match &self.location {
72 InProcDatabaseLocation::InMemory => rusqlite::Connection::open_in_memory(),
73 InProcDatabaseLocation::Path(path) => rusqlite::Connection::open(path),
74 }
75 .map_err(|e| sqlite::Error::Io(e.to_string()))?;
76 Ok(Arc::new(Mutex::new(connection)))
77 }
78}
79
80#[async_trait]
81impl Connection for InProcConnection {
82 async fn query(
83 &self,
84 query: &str,
85 parameters: Vec<sqlite::Value>,
86 ) -> Result<sqlite::QueryResult, sqlite::Error> {
87 let connection = self.db_connection()?;
88 let query = query.to_owned();
89 tokio::task::spawn_blocking(move || execute_query(&connection, &query, parameters))
91 .await
92 .context("internal runtime error")
93 .map_err(|e| sqlite::Error::Io(e.to_string()))?
94 }
95
96 async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
97 let connection = self.db_connection()?;
98 let statements = statements.to_owned();
99 tokio::task::spawn_blocking(move || {
100 let conn = connection.lock().unwrap();
101 conn.execute_batch(&statements)
102 .context("failed to execute batch statements")
103 })
104 .await?
105 .context("failed to spawn blocking task")?;
106 Ok(())
107 }
108
109 async fn changes(&self) -> Result<u64, sqlite::Error> {
110 let connection = self.db_connection()?;
111 let conn = connection.lock().unwrap();
112 Ok(conn.changes())
113 }
114
115 async fn last_insert_rowid(&self) -> Result<i64, sqlite::Error> {
116 let connection = self.db_connection()?;
117 let conn = connection.lock().unwrap();
118 Ok(conn.last_insert_rowid())
119 }
120
121 fn summary(&self) -> Option<String> {
122 Some(match &self.location {
123 InProcDatabaseLocation::InMemory => "a temporary in-memory database".to_string(),
124 InProcDatabaseLocation::Path(path) => format!("\"{}\"", path.display()),
125 })
126 }
127}
128
129fn execute_query(
131 connection: &Mutex<rusqlite::Connection>,
132 query: &str,
133 parameters: Vec<sqlite::Value>,
134) -> Result<sqlite::QueryResult, sqlite::Error> {
135 let conn = connection.lock().unwrap();
136 let mut statement = conn
137 .prepare_cached(query)
138 .map_err(|e| sqlite::Error::Io(e.to_string()))?;
139 let columns = statement
140 .column_names()
141 .into_iter()
142 .map(ToOwned::to_owned)
143 .collect();
144 let rows = statement
145 .query_map(
146 rusqlite::params_from_iter(convert_data(parameters.into_iter())),
147 |row| {
148 let mut values = vec![];
149 for column in 0.. {
150 let value = row.get::<usize, ValueWrapper>(column);
151 if let Err(rusqlite::Error::InvalidColumnIndex(_)) = value {
152 break;
153 }
154 let value = value?.0;
155 values.push(value);
156 }
157 Ok(sqlite::RowResult { values })
158 },
159 )
160 .map_err(|e| sqlite::Error::Io(e.to_string()))?;
161 let rows = rows
162 .into_iter()
163 .map(|r| r.map_err(|e| sqlite::Error::Io(e.to_string())))
164 .collect::<Result<_, sqlite::Error>>()?;
165 Ok(sqlite::QueryResult { columns, rows })
166}
167
168fn convert_data(
169 arguments: impl Iterator<Item = sqlite::Value>,
170) -> impl Iterator<Item = rusqlite::types::Value> {
171 arguments.map(|a| match a {
172 sqlite::Value::Null => rusqlite::types::Value::Null,
173 sqlite::Value::Integer(i) => rusqlite::types::Value::Integer(i),
174 sqlite::Value::Real(r) => rusqlite::types::Value::Real(r),
175 sqlite::Value::Text(t) => rusqlite::types::Value::Text(t),
176 sqlite::Value::Blob(b) => rusqlite::types::Value::Blob(b),
177 })
178}
179
180struct ValueWrapper(sqlite::Value);
182
183impl rusqlite::types::FromSql for ValueWrapper {
184 fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
185 let value = match value {
186 rusqlite::types::ValueRef::Null => sqlite::Value::Null,
187 rusqlite::types::ValueRef::Integer(i) => sqlite::Value::Integer(i),
188 rusqlite::types::ValueRef::Real(f) => sqlite::Value::Real(f),
189 rusqlite::types::ValueRef::Text(t) => {
190 sqlite::Value::Text(String::from_utf8(t.to_vec()).unwrap())
191 }
192 rusqlite::types::ValueRef::Blob(b) => sqlite::Value::Blob(b.to_vec()),
193 };
194 Ok(ValueWrapper(value))
195 }
196}