Skip to main content

spin_sqlite_libsql/
lib.rs

1use std::sync::Arc;
2
3use anyhow::Context;
4use async_trait::async_trait;
5use spin_factor_sqlite::{Connection, QueryAsyncResult};
6use spin_world::spin::sqlite3_1_0::sqlite as v3;
7use spin_world::spin::sqlite3_1_0::sqlite::{self, RowResult};
8use tokio::sync::OnceCell;
9
10/// A lazy wrapper around a [`LibSqlConnection`] that implements the [`Connection`] trait.
11pub struct LazyLibSqlConnection {
12    url: String,
13    token: String,
14    // Since the libSQL client can only be created asynchronously, we wait until
15    // we're in the `Connection` implementation to create. Since we only want to do
16    // this once, we use a `OnceCell` to store it.
17    inner: OnceCell<Arc<LibSqlConnection>>,
18}
19
20impl LazyLibSqlConnection {
21    pub fn new(url: String, token: String) -> Self {
22        Self {
23            url,
24            token,
25            inner: OnceCell::new(),
26        }
27    }
28
29    pub async fn get_or_create_connection(&self) -> Result<&Arc<LibSqlConnection>, v3::Error> {
30        self.inner
31            .get_or_try_init(|| async {
32                LibSqlConnection::create(self.url.clone(), self.token.clone())
33                    .await
34                    .context("failed to create SQLite client")
35                    .map(Arc::new)
36            })
37            .await
38            .map_err(|_| v3::Error::InvalidConnection)
39    }
40}
41
42#[async_trait]
43impl Connection for LazyLibSqlConnection {
44    async fn query(
45        &self,
46        query: &str,
47        parameters: Vec<v3::Value>,
48        max_result_bytes: usize,
49    ) -> Result<v3::QueryResult, v3::Error> {
50        let client = self.get_or_create_connection().await?;
51        client.query(query, parameters, max_result_bytes).await
52    }
53
54    async fn query_async(
55        &self,
56        query: &str,
57        parameters: Vec<v3::Value>,
58        max_result_bytes: usize,
59    ) -> Result<QueryAsyncResult, v3::Error> {
60        let client = self.get_or_create_connection().await?.clone();
61        let query = query.to_string();
62
63        let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
64        let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
65        let (err_tx, err_rx) = tokio::sync::oneshot::channel();
66
67        let the_work = async move {
68            let result = client
69                .inner
70                .query(&query, convert_parameters(&parameters))
71                .await
72                .map_err(|e| v3::Error::Io(e.to_string()));
73
74            let mut rows = match result {
75                Ok(r) => r,
76                Err(e) => {
77                    _ = cols_tx.send(Default::default());
78                    return Err(e);
79                }
80            };
81
82            let columns = columns(&rows);
83            cols_tx
84                .send(columns)
85                .map_err(|_| v3::Error::Io("column send error".into()))?;
86
87            let column_count = rows.column_count();
88
89            loop {
90                let row = match rows.next().await.map_err(io_error_v3)? {
91                    Some(r) => r,
92                    None => break,
93                };
94                let row = convert_row(row, column_count);
95                if row.values.iter().map(|v| v.memory_size()).sum::<usize>() > max_result_bytes {
96                    return Err(v3::Error::Io(format!(
97                        "query result exceeds limit of {max_result_bytes} bytes"
98                    )));
99                }
100                rows_tx
101                    .send(row)
102                    .await
103                    .map_err(|_| v3::Error::Io("column send error".into()))?;
104            }
105
106            Ok(())
107        };
108        tokio::spawn(async move {
109            let res = the_work.await;
110            _ = err_tx.send(res);
111        });
112
113        let columns = cols_rx.await.map_err(|e| v3::Error::Io(e.to_string()))?;
114
115        Ok(QueryAsyncResult {
116            columns,
117            rows: rows_rx,
118            error: err_rx,
119        })
120    }
121
122    async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
123        let client = self.get_or_create_connection().await?;
124        client.execute_batch(statements).await
125    }
126
127    async fn changes(&self) -> Result<u64, sqlite::Error> {
128        let client = self.get_or_create_connection().await?;
129        Ok(client.changes())
130    }
131
132    async fn last_insert_rowid(&self) -> Result<i64, sqlite::Error> {
133        let client = self.get_or_create_connection().await?;
134        Ok(client.last_insert_rowid())
135    }
136
137    fn summary(&self) -> Option<String> {
138        Some(format!("libSQL at {}", self.url))
139    }
140}
141
142/// An open connection to a libSQL server.
143#[derive(Clone)]
144pub struct LibSqlConnection {
145    inner: libsql::Connection,
146}
147
148impl LibSqlConnection {
149    pub async fn create(url: String, token: String) -> anyhow::Result<Self> {
150        let db = libsql::Builder::new_remote(url, token).build().await?;
151        let inner = db.connect()?;
152        Ok(Self { inner })
153    }
154}
155
156impl LibSqlConnection {
157    pub async fn query(
158        &self,
159        query: &str,
160        parameters: Vec<sqlite::Value>,
161        max_result_bytes: usize,
162    ) -> Result<sqlite::QueryResult, sqlite::Error> {
163        let result = self
164            .inner
165            .query(query, convert_parameters(&parameters))
166            .await
167            .map_err(|e| sqlite::Error::Io(e.to_string()))?;
168
169        Ok(sqlite::QueryResult {
170            columns: columns(&result),
171            rows: convert_rows(result, max_result_bytes)
172                .await
173                .map_err(|e| sqlite::Error::Io(e.to_string()))?,
174        })
175    }
176
177    pub async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
178        self.inner.execute_batch(statements).await?;
179
180        Ok(())
181    }
182
183    pub fn changes(&self) -> u64 {
184        self.inner.changes()
185    }
186
187    pub fn last_insert_rowid(&self) -> i64 {
188        self.inner.last_insert_rowid()
189    }
190}
191
192fn columns(rows: &libsql::Rows) -> Vec<String> {
193    (0..rows.column_count())
194        .map(|index| rows.column_name(index).unwrap_or("").to_owned())
195        .collect()
196}
197
198async fn convert_rows(
199    mut rows: libsql::Rows,
200    max_result_bytes: usize,
201) -> anyhow::Result<Vec<RowResult>> {
202    let mut result_rows = vec![];
203
204    let column_count = rows.column_count();
205    let mut byte_count = 0;
206    while let Some(row) = rows.next().await? {
207        let row = convert_row(row, column_count);
208        byte_count += row.values.iter().map(|v| v.memory_size()).sum::<usize>();
209        if byte_count > max_result_bytes {
210            anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
211        }
212        result_rows.push(row);
213    }
214
215    Ok(result_rows)
216}
217
218fn convert_row(row: libsql::Row, column_count: i32) -> RowResult {
219    let values = (0..column_count)
220        .map(|index| convert_value(row.get_value(index).unwrap()))
221        .collect();
222    RowResult { values }
223}
224
225fn convert_value(v: libsql::Value) -> sqlite::Value {
226    use libsql::Value;
227
228    match v {
229        Value::Null => sqlite::Value::Null,
230        Value::Integer(value) => sqlite::Value::Integer(value),
231        Value::Real(value) => sqlite::Value::Real(value),
232        Value::Text(value) => sqlite::Value::Text(value),
233        Value::Blob(value) => sqlite::Value::Blob(value),
234    }
235}
236
237fn convert_parameters(parameters: &[sqlite::Value]) -> Vec<libsql::Value> {
238    use libsql::Value;
239
240    parameters
241        .iter()
242        .map(|v| match v {
243            sqlite::Value::Integer(value) => Value::Integer(*value),
244            sqlite::Value::Real(value) => Value::Real(*value),
245            sqlite::Value::Text(t) => Value::Text(t.clone()),
246            sqlite::Value::Blob(b) => Value::Blob(b.clone()),
247            sqlite::Value::Null => Value::Null,
248        })
249        .collect()
250}
251
252fn io_error_v3(err: libsql::Error) -> v3::Error {
253    v3::Error::Io(err.to_string())
254}