Skip to main content

spin_factor_sqlite/
host.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use spin_core::wasmtime::component::{Accessor, FutureReader, StreamReader};
5use spin_factor_otel::OtelFactorState;
6use spin_factors::wasmtime::component::Resource;
7use spin_factors::{anyhow, SelfInstanceBuilder};
8use spin_world::spin::sqlite3_1_0::sqlite as v3;
9use spin_world::v1::sqlite as v1;
10use spin_world::v2::sqlite as v2;
11use spin_world::MAX_HOST_BUFFERED_BYTES;
12use tracing::field::Empty;
13use tracing::{instrument, Level};
14
15use crate::{Connection, ConnectionCreator, QueryAsyncResult};
16
17pub struct InstanceState {
18    allowed_databases: Arc<HashSet<String>>,
19    /// A resource table of connections.
20    connections: spin_resource_table::Table<Arc<dyn Connection>>,
21    /// A map from database label to connection creators.
22    connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
23    otel: OtelFactorState,
24}
25
26impl InstanceState {
27    /// Create a new `InstanceState`
28    ///
29    /// Takes the list of allowed databases, and a function for getting a connection creator given a database label.
30    pub fn new(
31        allowed_databases: Arc<HashSet<String>>,
32        connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
33        otel: OtelFactorState,
34    ) -> Self {
35        Self {
36            allowed_databases,
37            connections: spin_resource_table::Table::new(256),
38            connection_creators,
39            otel,
40        }
41    }
42
43    /// Get a connection for a given database label.
44    fn get_connection<T: 'static>(
45        &self,
46        connection: Resource<T>,
47    ) -> Result<Arc<dyn Connection>, v3::Error> {
48        self.connections
49            .get(connection.rep())
50            .cloned()
51            .ok_or(v3::Error::InvalidConnection)
52    }
53
54    async fn open_impl<T: 'static>(&mut self, database: String) -> Result<Resource<T>, v3::Error> {
55        if !self.allowed_databases.contains(&database) {
56            return Err(v3::Error::AccessDenied);
57        }
58        let conn = self
59            .connection_creators
60            .get(&database)
61            .ok_or(v3::Error::NoSuchDatabase)?
62            .create_connection(&database)
63            .await?;
64        tracing::Span::current().record(
65            "sqlite.backend",
66            conn.summary().as_deref().unwrap_or("unknown"),
67        );
68        self.connections
69            .push(conn)
70            .map_err(|()| v3::Error::Io("too many connections opened".to_string()))
71            .map(Resource::new_own)
72    }
73
74    async fn execute_impl<T: 'static>(
75        &mut self,
76        connection: Resource<T>,
77        query: String,
78        parameters: Vec<v3::Value>,
79    ) -> Result<v3::QueryResult, v3::Error> {
80        let conn = self.get_connection(connection)?;
81        tracing::Span::current().record(
82            "sqlite.backend",
83            conn.summary().as_deref().unwrap_or("unknown"),
84        );
85        conn.query(&query, parameters, MAX_HOST_BUFFERED_BYTES)
86            .await
87    }
88
89    /// Get the set of allowed databases.
90    pub fn allowed_databases(&self) -> &HashSet<String> {
91        &self.allowed_databases
92    }
93}
94
95impl SelfInstanceBuilder for InstanceState {}
96
97impl v3::Host for InstanceState {
98    fn convert_error(&mut self, error: v3::Error) -> anyhow::Result<v3::Error> {
99        Ok(error)
100    }
101}
102
103impl v3::HostConnection for InstanceState {
104    #[instrument(name = "spin_sqlite.open", skip(self), err(level = Level::INFO), fields(otel.kind = "client", db.system = "sqlite", sqlite.backend = Empty))]
105    async fn open(&mut self, database: String) -> Result<Resource<v3::Connection>, v3::Error> {
106        self.open_impl(database).await
107    }
108
109    #[instrument(name = "spin_sqlite.execute", skip(self, connection, parameters), err(level = Level::INFO), fields(otel.kind = "client", db.system = "sqlite", otel.name = query, sqlite.backend = Empty))]
110    async fn execute(
111        &mut self,
112        connection: Resource<v3::Connection>,
113        query: String,
114        parameters: Vec<v3::Value>,
115    ) -> Result<v3::QueryResult, v3::Error> {
116        self.execute_impl(connection, query, parameters).await
117    }
118
119    async fn changes(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<u64> {
120        let conn = match self.get_connection(connection) {
121            Ok(c) => c,
122            Err(err) => return Err(err.into()),
123        };
124        tracing::Span::current().record(
125            "sqlite.backend",
126            conn.summary().as_deref().unwrap_or("unknown"),
127        );
128        conn.changes().await.map_err(|e| e.into())
129    }
130
131    async fn last_insert_rowid(
132        &mut self,
133        connection: Resource<v3::Connection>,
134    ) -> anyhow::Result<i64> {
135        let conn = match self.get_connection(connection) {
136            Ok(c) => c,
137            Err(err) => return Err(err.into()),
138        };
139        tracing::Span::current().record(
140            "sqlite.backend",
141            conn.summary().as_deref().unwrap_or("unknown"),
142        );
143        conn.last_insert_rowid().await.map_err(|e| e.into())
144    }
145
146    async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
147        let _ = self.connections.remove(connection.rep());
148        Ok(())
149    }
150}
151
152impl v3::HostConnectionWithStore for crate::SqliteFactorData {
153    async fn open_async<T>(
154        accessor: &Accessor<T, Self>,
155        database: String,
156    ) -> Result<Resource<v3::Connection>, v3::Error> {
157        // TODO: this duplicates `open_impl` logic but split up to move
158        // in and out of the Accessor. How to dedupe?
159        let conn_creator = accessor.with(|mut access| {
160            let host = access.get();
161            if !host.allowed_databases.contains(&database) {
162                return Err(v3::Error::AccessDenied);
163            }
164            host.connection_creators
165                .get(&database)
166                .ok_or(v3::Error::NoSuchDatabase)
167                .cloned()
168        })?;
169
170        let conn = conn_creator.create_connection(&database).await?;
171
172        tracing::Span::current().record(
173            "sqlite.backend",
174            conn.summary().as_deref().unwrap_or("unknown"),
175        );
176
177        let resource = accessor.with(|mut access| {
178            let host = access.get();
179            host.connections
180                .push(conn)
181                .map_err(|()| v3::Error::Io("too many connections opened".to_string()))
182                .map(Resource::new_own)
183        });
184
185        resource
186    }
187
188    async fn execute_async<T>(
189        accessor: &Accessor<T, Self>,
190        connection: Resource<v3::Connection>,
191        query: String,
192        parameters: Vec<v3::Value>,
193    ) -> Result<
194        (
195            Vec<String>,
196            StreamReader<v3::RowResult>,
197            FutureReader<Result<(), v3::Error>>,
198        ),
199        v3::Error,
200    > {
201        let conn = accessor.with(|mut access| {
202            let host = access.get();
203            host.get_connection(connection)
204        })?;
205
206        tracing::Span::current().record(
207            "sqlite.backend",
208            conn.summary().as_deref().unwrap_or("unknown"),
209        );
210
211        let QueryAsyncResult {
212            columns,
213            rows,
214            error,
215        } = conn
216            .query_async(&query, parameters, MAX_HOST_BUFFERED_BYTES)
217            .await?;
218        let row_producer = spin_wasi_async::stream::producer(rows);
219
220        let (sr, efr) = accessor.with(|mut access| {
221            let sr = StreamReader::new(&mut access, row_producer);
222            let efr = FutureReader::new(&mut access, error);
223            (sr, efr)
224        });
225
226        Ok((columns, sr, efr))
227    }
228
229    async fn changes_async<T>(
230        accessor: &Accessor<T, Self>,
231        connection: Resource<v3::Connection>,
232    ) -> anyhow::Result<u64> {
233        let conn = accessor.with(|mut access| {
234            let host = access.get();
235            host.get_connection(connection)
236        });
237
238        let conn = match conn {
239            Ok(c) => c,
240            Err(err) => return Err(err.into()),
241        };
242        tracing::Span::current().record(
243            "sqlite.backend",
244            conn.summary().as_deref().unwrap_or("unknown"),
245        );
246        conn.changes().await.map_err(|e| e.into())
247    }
248
249    async fn last_insert_rowid_async<T>(
250        accessor: &Accessor<T, Self>,
251        connection: Resource<v3::Connection>,
252    ) -> anyhow::Result<i64> {
253        let conn = accessor.with(|mut access| {
254            let host = access.get();
255            host.get_connection(connection)
256        });
257
258        let conn = match conn {
259            Ok(c) => c,
260            Err(err) => return Err(err.into()),
261        };
262        tracing::Span::current().record(
263            "sqlite.backend",
264            conn.summary().as_deref().unwrap_or("unknown"),
265        );
266        conn.last_insert_rowid().await.map_err(|e| e.into())
267    }
268}
269
270impl v2::Host for InstanceState {
271    fn convert_error(&mut self, error: v2::Error) -> anyhow::Result<v2::Error> {
272        Ok(error)
273    }
274}
275
276impl v2::HostConnection for InstanceState {
277    #[instrument(name = "spin_sqlite.open", skip(self), err(level = Level::INFO), fields(otel.kind = "client", db.system = "sqlite", sqlite.backend = Empty))]
278    async fn open(&mut self, database: String) -> Result<Resource<v2::Connection>, v2::Error> {
279        self.otel.reparent_tracing_span();
280        self.open_impl(database).await.map_err(to_v2_error)
281    }
282
283    #[instrument(name = "spin_sqlite.execute", skip(self, connection, parameters), err(level = Level::INFO), fields(otel.kind = "client", db.system = "sqlite", otel.name = query, sqlite.backend = Empty))]
284    async fn execute(
285        &mut self,
286        connection: Resource<v2::Connection>,
287        query: String,
288        parameters: Vec<v2::Value>,
289    ) -> Result<v2::QueryResult, v2::Error> {
290        self.otel.reparent_tracing_span();
291        self.execute_impl(
292            connection,
293            query,
294            parameters.into_iter().map(from_v2_value).collect(),
295        )
296        .await
297        .map(to_v2_query_result)
298        .map_err(to_v2_error)
299    }
300
301    async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
302        let _ = self.connections.remove(connection.rep());
303        Ok(())
304    }
305}
306
307impl v1::Host for InstanceState {
308    async fn open(&mut self, database: String) -> Result<u32, v1::Error> {
309        let result = <Self as v3::HostConnection>::open(self, database).await;
310        result.map_err(to_legacy_error).map(|s| s.rep())
311    }
312
313    async fn execute(
314        &mut self,
315        connection: u32,
316        query: String,
317        parameters: Vec<spin_world::v1::sqlite::Value>,
318    ) -> Result<spin_world::v1::sqlite::QueryResult, v1::Error> {
319        let this = Resource::new_borrow(connection);
320        let result = <Self as v3::HostConnection>::execute(
321            self,
322            this,
323            query,
324            parameters.into_iter().map(from_legacy_value).collect(),
325        )
326        .await;
327        result.map_err(to_legacy_error).map(to_legacy_query_result)
328    }
329
330    async fn close(&mut self, connection: u32) -> anyhow::Result<()> {
331        <Self as v2::HostConnection>::drop(self, Resource::new_own(connection)).await
332    }
333
334    fn convert_error(&mut self, error: v1::Error) -> anyhow::Result<v1::Error> {
335        Ok(error)
336    }
337}
338
339fn to_v2_error(error: v3::Error) -> v2::Error {
340    match error {
341        v3::Error::NoSuchDatabase => v2::Error::NoSuchDatabase,
342        v3::Error::AccessDenied => v2::Error::AccessDenied,
343        v3::Error::InvalidConnection => v2::Error::InvalidConnection,
344        v3::Error::DatabaseFull => v2::Error::DatabaseFull,
345        v3::Error::Io(s) => v2::Error::Io(s),
346    }
347}
348
349fn to_legacy_error(error: v3::Error) -> v1::Error {
350    match error {
351        v3::Error::NoSuchDatabase => v1::Error::NoSuchDatabase,
352        v3::Error::AccessDenied => v1::Error::AccessDenied,
353        v3::Error::InvalidConnection => v1::Error::InvalidConnection,
354        v3::Error::DatabaseFull => v1::Error::DatabaseFull,
355        v3::Error::Io(s) => v1::Error::Io(s),
356    }
357}
358
359fn to_v2_query_result(result: v3::QueryResult) -> v2::QueryResult {
360    v2::QueryResult {
361        columns: result.columns,
362        rows: result.rows.into_iter().map(to_v2_row_result).collect(),
363    }
364}
365
366fn to_legacy_query_result(result: v3::QueryResult) -> v1::QueryResult {
367    v1::QueryResult {
368        columns: result.columns,
369        rows: result.rows.into_iter().map(to_legacy_row_result).collect(),
370    }
371}
372
373fn to_v2_row_result(result: v3::RowResult) -> v2::RowResult {
374    v2::RowResult {
375        values: result.values.into_iter().map(to_v2_value).collect(),
376    }
377}
378
379fn to_legacy_row_result(result: v3::RowResult) -> v1::RowResult {
380    v1::RowResult {
381        values: result.values.into_iter().map(to_legacy_value).collect(),
382    }
383}
384
385fn to_v2_value(value: v3::Value) -> v2::Value {
386    match value {
387        v3::Value::Integer(i) => v2::Value::Integer(i),
388        v3::Value::Real(r) => v2::Value::Real(r),
389        v3::Value::Text(t) => v2::Value::Text(t),
390        v3::Value::Blob(b) => v2::Value::Blob(b),
391        v3::Value::Null => v2::Value::Null,
392    }
393}
394
395fn to_legacy_value(value: v3::Value) -> v1::Value {
396    match value {
397        v3::Value::Integer(i) => v1::Value::Integer(i),
398        v3::Value::Real(r) => v1::Value::Real(r),
399        v3::Value::Text(t) => v1::Value::Text(t),
400        v3::Value::Blob(b) => v1::Value::Blob(b),
401        v3::Value::Null => v1::Value::Null,
402    }
403}
404
405fn from_v2_value(value: v2::Value) -> v3::Value {
406    match value {
407        v2::Value::Integer(i) => v3::Value::Integer(i),
408        v2::Value::Real(r) => v3::Value::Real(r),
409        v2::Value::Text(t) => v3::Value::Text(t),
410        v2::Value::Blob(b) => v3::Value::Blob(b),
411        v2::Value::Null => v3::Value::Null,
412    }
413}
414
415fn from_legacy_value(value: v1::Value) -> v3::Value {
416    match value {
417        v1::Value::Integer(i) => v3::Value::Integer(i),
418        v1::Value::Real(r) => v3::Value::Real(r),
419        v1::Value::Text(t) => v3::Value::Text(t),
420        v1::Value::Blob(b) => v3::Value::Blob(b),
421        v1::Value::Null => v3::Value::Null,
422    }
423}