spin_sqlite_libsql/
lib.rs1use std::sync::Arc;
2
3use anyhow::Context;
4use async_trait::async_trait;
5use spin_factor_sqlite::{Connection, QueryAsyncResult};
6use spin_world::spin::sqlite3_1_0::sqlite as v3;
7use spin_world::spin::sqlite3_1_0::sqlite::{self, RowResult};
8use tokio::sync::OnceCell;
9
10pub struct LazyLibSqlConnection {
12 url: String,
13 token: String,
14 inner: OnceCell<Arc<LibSqlConnection>>,
18}
19
20impl LazyLibSqlConnection {
21 pub fn new(url: String, token: String) -> Self {
22 Self {
23 url,
24 token,
25 inner: OnceCell::new(),
26 }
27 }
28
29 pub async fn get_or_create_connection(&self) -> Result<&Arc<LibSqlConnection>, v3::Error> {
30 self.inner
31 .get_or_try_init(|| async {
32 LibSqlConnection::create(self.url.clone(), self.token.clone())
33 .await
34 .context("failed to create SQLite client")
35 .map(Arc::new)
36 })
37 .await
38 .map_err(|_| v3::Error::InvalidConnection)
39 }
40}
41
42#[async_trait]
43impl Connection for LazyLibSqlConnection {
44 async fn query(
45 &self,
46 query: &str,
47 parameters: Vec<v3::Value>,
48 max_result_bytes: usize,
49 ) -> Result<v3::QueryResult, v3::Error> {
50 let client = self.get_or_create_connection().await?;
51 client.query(query, parameters, max_result_bytes).await
52 }
53
54 async fn query_async(
55 &self,
56 query: &str,
57 parameters: Vec<v3::Value>,
58 max_result_bytes: usize,
59 ) -> Result<QueryAsyncResult, v3::Error> {
60 let client = self.get_or_create_connection().await?.clone();
61 let query = query.to_string();
62
63 let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
64 let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
65 let (err_tx, err_rx) = tokio::sync::oneshot::channel();
66
67 let the_work = async move {
68 let result = client
69 .inner
70 .query(&query, convert_parameters(¶meters))
71 .await
72 .map_err(|e| v3::Error::Io(e.to_string()));
73
74 let mut rows = match result {
75 Ok(r) => r,
76 Err(e) => {
77 _ = cols_tx.send(Default::default());
78 return Err(e);
79 }
80 };
81
82 let columns = columns(&rows);
83 cols_tx
84 .send(columns)
85 .map_err(|_| v3::Error::Io("column send error".into()))?;
86
87 let column_count = rows.column_count();
88
89 loop {
90 let row = match rows.next().await.map_err(io_error_v3)? {
91 Some(r) => r,
92 None => break,
93 };
94 let row = convert_row(row, column_count);
95 if row.values.iter().map(|v| v.memory_size()).sum::<usize>() > max_result_bytes {
96 return Err(v3::Error::Io(format!(
97 "query result exceeds limit of {max_result_bytes} bytes"
98 )));
99 }
100 rows_tx
101 .send(row)
102 .await
103 .map_err(|_| v3::Error::Io("column send error".into()))?;
104 }
105
106 Ok(())
107 };
108 tokio::spawn(async move {
109 let res = the_work.await;
110 _ = err_tx.send(res);
111 });
112
113 let columns = cols_rx.await.map_err(|e| v3::Error::Io(e.to_string()))?;
114
115 Ok(QueryAsyncResult {
116 columns,
117 rows: rows_rx,
118 error: err_rx,
119 })
120 }
121
122 async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
123 let client = self.get_or_create_connection().await?;
124 client.execute_batch(statements).await
125 }
126
127 async fn changes(&self) -> Result<u64, sqlite::Error> {
128 let client = self.get_or_create_connection().await?;
129 Ok(client.changes())
130 }
131
132 async fn last_insert_rowid(&self) -> Result<i64, sqlite::Error> {
133 let client = self.get_or_create_connection().await?;
134 Ok(client.last_insert_rowid())
135 }
136
137 fn summary(&self) -> Option<String> {
138 Some(format!("libSQL at {}", self.url))
139 }
140}
141
142#[derive(Clone)]
144pub struct LibSqlConnection {
145 inner: libsql::Connection,
146}
147
148impl LibSqlConnection {
149 pub async fn create(url: String, token: String) -> anyhow::Result<Self> {
150 let db = libsql::Builder::new_remote(url, token).build().await?;
151 let inner = db.connect()?;
152 Ok(Self { inner })
153 }
154}
155
156impl LibSqlConnection {
157 pub async fn query(
158 &self,
159 query: &str,
160 parameters: Vec<sqlite::Value>,
161 max_result_bytes: usize,
162 ) -> Result<sqlite::QueryResult, sqlite::Error> {
163 let result = self
164 .inner
165 .query(query, convert_parameters(¶meters))
166 .await
167 .map_err(|e| sqlite::Error::Io(e.to_string()))?;
168
169 Ok(sqlite::QueryResult {
170 columns: columns(&result),
171 rows: convert_rows(result, max_result_bytes)
172 .await
173 .map_err(|e| sqlite::Error::Io(e.to_string()))?,
174 })
175 }
176
177 pub async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
178 self.inner.execute_batch(statements).await?;
179
180 Ok(())
181 }
182
183 pub fn changes(&self) -> u64 {
184 self.inner.changes()
185 }
186
187 pub fn last_insert_rowid(&self) -> i64 {
188 self.inner.last_insert_rowid()
189 }
190}
191
192fn columns(rows: &libsql::Rows) -> Vec<String> {
193 (0..rows.column_count())
194 .map(|index| rows.column_name(index).unwrap_or("").to_owned())
195 .collect()
196}
197
198async fn convert_rows(
199 mut rows: libsql::Rows,
200 max_result_bytes: usize,
201) -> anyhow::Result<Vec<RowResult>> {
202 let mut result_rows = vec![];
203
204 let column_count = rows.column_count();
205 let mut byte_count = 0;
206 while let Some(row) = rows.next().await? {
207 let row = convert_row(row, column_count);
208 byte_count += row.values.iter().map(|v| v.memory_size()).sum::<usize>();
209 if byte_count > max_result_bytes {
210 anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
211 }
212 result_rows.push(row);
213 }
214
215 Ok(result_rows)
216}
217
218fn convert_row(row: libsql::Row, column_count: i32) -> RowResult {
219 let values = (0..column_count)
220 .map(|index| convert_value(row.get_value(index).unwrap()))
221 .collect();
222 RowResult { values }
223}
224
225fn convert_value(v: libsql::Value) -> sqlite::Value {
226 use libsql::Value;
227
228 match v {
229 Value::Null => sqlite::Value::Null,
230 Value::Integer(value) => sqlite::Value::Integer(value),
231 Value::Real(value) => sqlite::Value::Real(value),
232 Value::Text(value) => sqlite::Value::Text(value),
233 Value::Blob(value) => sqlite::Value::Blob(value),
234 }
235}
236
237fn convert_parameters(parameters: &[sqlite::Value]) -> Vec<libsql::Value> {
238 use libsql::Value;
239
240 parameters
241 .iter()
242 .map(|v| match v {
243 sqlite::Value::Integer(value) => Value::Integer(*value),
244 sqlite::Value::Real(value) => Value::Real(*value),
245 sqlite::Value::Text(t) => Value::Text(t.clone()),
246 sqlite::Value::Blob(b) => Value::Blob(b.clone()),
247 sqlite::Value::Null => Value::Null,
248 })
249 .collect()
250}
251
252fn io_error_v3(err: libsql::Error) -> v3::Error {
253 v3::Error::Io(err.to_string())
254}