Skip to main content

spin_sqlite_libsql/
lib.rs

1use anyhow::Context;
2use async_trait::async_trait;
3use spin_factor_sqlite::{Connection, QueryAsyncResult};
4use spin_world::spin::sqlite3_1_0::sqlite as v3;
5use spin_world::spin::sqlite3_1_0::sqlite::{self, RowResult};
6use tokio::sync::OnceCell;
7
8/// A lazy wrapper around a [`LibSqlConnection`] that implements the [`Connection`] trait.
9pub struct LazyLibSqlConnection {
10    url: String,
11    token: String,
12    // Since the libSQL client can only be created asynchronously, we wait until
13    // we're in the `Connection` implementation to create. Since we only want to do
14    // this once, we use a `OnceCell` to store it.
15    inner: OnceCell<LibSqlConnection>,
16}
17
18impl LazyLibSqlConnection {
19    pub fn new(url: String, token: String) -> Self {
20        Self {
21            url,
22            token,
23            inner: OnceCell::new(),
24        }
25    }
26
27    pub async fn get_or_create_connection(&self) -> Result<&LibSqlConnection, v3::Error> {
28        self.inner
29            .get_or_try_init(|| async {
30                LibSqlConnection::create(self.url.clone(), self.token.clone())
31                    .await
32                    .context("failed to create SQLite client")
33            })
34            .await
35            .map_err(|_| v3::Error::InvalidConnection)
36    }
37}
38
39#[async_trait]
40impl Connection for LazyLibSqlConnection {
41    async fn query(
42        &self,
43        query: &str,
44        parameters: Vec<v3::Value>,
45        max_result_bytes: usize,
46    ) -> Result<v3::QueryResult, v3::Error> {
47        let client = self.get_or_create_connection().await?;
48        client.query(query, parameters, max_result_bytes).await
49    }
50
51    async fn query_async(
52        &self,
53        query: &str,
54        parameters: Vec<v3::Value>,
55        max_result_bytes: usize,
56    ) -> Result<QueryAsyncResult, v3::Error> {
57        self.get_or_create_connection()
58            .await?
59            .query_async(query, parameters, max_result_bytes)
60            .await
61    }
62
63    async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
64        let client = self.get_or_create_connection().await?;
65        client.execute_batch(statements).await
66    }
67
68    async fn changes(&self) -> Result<u64, sqlite::Error> {
69        let client = self.get_or_create_connection().await?;
70        Ok(client.changes())
71    }
72
73    async fn last_insert_rowid(&self) -> Result<i64, sqlite::Error> {
74        let client = self.get_or_create_connection().await?;
75        Ok(client.last_insert_rowid())
76    }
77
78    fn summary(&self) -> Option<String> {
79        Some(format!("libSQL at {}", self.url))
80    }
81}
82
83/// An open connection to a libSQL server.
84#[derive(Clone)]
85pub struct LibSqlConnection {
86    inner: libsql::Connection,
87}
88
89impl LibSqlConnection {
90    pub async fn create(url: String, token: String) -> anyhow::Result<Self> {
91        let db = libsql::Builder::new_remote(url, token).build().await?;
92        let inner = db.connect()?;
93        Ok(Self { inner })
94    }
95}
96
97impl LibSqlConnection {
98    pub async fn query(
99        &self,
100        query: &str,
101        parameters: Vec<sqlite::Value>,
102        max_result_bytes: usize,
103    ) -> Result<sqlite::QueryResult, sqlite::Error> {
104        let result = self
105            .inner
106            .query(query, convert_parameters(&parameters))
107            .await
108            .map_err(|e| sqlite::Error::Io(e.to_string()))?;
109
110        Ok(sqlite::QueryResult {
111            columns: columns(&result),
112            rows: convert_rows(result, max_result_bytes)
113                .await
114                .map_err(|e| sqlite::Error::Io(e.to_string()))?,
115        })
116    }
117
118    pub async fn query_async(
119        &self,
120        query: &str,
121        parameters: Vec<v3::Value>,
122        max_result_bytes: usize,
123    ) -> Result<QueryAsyncResult, v3::Error> {
124        let client = self.clone();
125        let query = query.to_string();
126
127        let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
128        let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
129        let (err_tx, err_rx) = tokio::sync::oneshot::channel();
130
131        let the_work = async move {
132            let result = client
133                .inner
134                .query(&query, convert_parameters(&parameters))
135                .await
136                .map_err(|e| v3::Error::Io(e.to_string()));
137
138            let mut rows = match result {
139                Ok(r) => r,
140                Err(e) => {
141                    _ = cols_tx.send(Default::default());
142                    return Err(e);
143                }
144            };
145
146            let columns = columns(&rows);
147            cols_tx
148                .send(columns)
149                .map_err(|_| v3::Error::Io("column send error".into()))?;
150
151            let column_count = rows.column_count();
152
153            loop {
154                let row = match rows.next().await.map_err(io_error_v3)? {
155                    Some(r) => r,
156                    None => break,
157                };
158                let row = convert_row(row, column_count);
159                if row.values.iter().map(|v| v.memory_size()).sum::<usize>() > max_result_bytes {
160                    return Err(v3::Error::Io(format!(
161                        "query result exceeds limit of {max_result_bytes} bytes"
162                    )));
163                }
164                rows_tx
165                    .send(row)
166                    .await
167                    .map_err(|_| v3::Error::Io("column send error".into()))?;
168            }
169
170            Ok(())
171        };
172        tokio::spawn(async move {
173            let res = the_work.await;
174            _ = err_tx.send(res);
175        });
176
177        let columns = cols_rx.await.map_err(|e| v3::Error::Io(e.to_string()))?;
178
179        Ok(QueryAsyncResult {
180            columns,
181            rows: rows_rx,
182            error: err_rx,
183        })
184    }
185
186    pub async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
187        self.inner.execute_batch(statements).await?;
188
189        Ok(())
190    }
191
192    pub fn changes(&self) -> u64 {
193        self.inner.changes()
194    }
195
196    pub fn last_insert_rowid(&self) -> i64 {
197        self.inner.last_insert_rowid()
198    }
199}
200
201fn columns(rows: &libsql::Rows) -> Vec<String> {
202    (0..rows.column_count())
203        .map(|index| rows.column_name(index).unwrap_or("").to_owned())
204        .collect()
205}
206
207async fn convert_rows(
208    mut rows: libsql::Rows,
209    max_result_bytes: usize,
210) -> anyhow::Result<Vec<RowResult>> {
211    let mut result_rows = vec![];
212
213    let column_count = rows.column_count();
214    let mut byte_count = 0;
215    while let Some(row) = rows.next().await? {
216        let row = convert_row(row, column_count);
217        byte_count += row.values.iter().map(|v| v.memory_size()).sum::<usize>();
218        if byte_count > max_result_bytes {
219            anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
220        }
221        result_rows.push(row);
222    }
223
224    Ok(result_rows)
225}
226
227fn convert_row(row: libsql::Row, column_count: i32) -> RowResult {
228    let values = (0..column_count)
229        .map(|index| convert_value(row.get_value(index).unwrap()))
230        .collect();
231    RowResult { values }
232}
233
234fn convert_value(v: libsql::Value) -> sqlite::Value {
235    use libsql::Value;
236
237    match v {
238        Value::Null => sqlite::Value::Null,
239        Value::Integer(value) => sqlite::Value::Integer(value),
240        Value::Real(value) => sqlite::Value::Real(value),
241        Value::Text(value) => sqlite::Value::Text(value),
242        Value::Blob(value) => sqlite::Value::Blob(value),
243    }
244}
245
246fn convert_parameters(parameters: &[sqlite::Value]) -> Vec<libsql::Value> {
247    use libsql::Value;
248
249    parameters
250        .iter()
251        .map(|v| match v {
252            sqlite::Value::Integer(value) => Value::Integer(*value),
253            sqlite::Value::Real(value) => Value::Real(*value),
254            sqlite::Value::Text(t) => Value::Text(t.clone()),
255            sqlite::Value::Blob(b) => Value::Blob(b.clone()),
256            sqlite::Value::Null => Value::Null,
257        })
258        .collect()
259}
260
261fn io_error_v3(err: libsql::Error) -> v3::Error {
262    v3::Error::Io(err.to_string())
263}