1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use spin_core::wasmtime::component::{Accessor, FutureReader, StreamReader};
5use spin_factor_otel::OtelFactorState;
6use spin_factors::wasmtime::component::Resource;
7use spin_factors::{anyhow, SelfInstanceBuilder};
8use spin_world::spin::sqlite3_1_0::sqlite as v3;
9use spin_world::v1::sqlite as v1;
10use spin_world::v2::sqlite as v2;
11use spin_world::MAX_HOST_BUFFERED_BYTES;
12use tracing::field::Empty;
13use tracing::{instrument, Level};
14
15use crate::{Connection, ConnectionCreator, QueryAsyncResult};
16
17pub struct InstanceState {
18 allowed_databases: Arc<HashSet<String>>,
19 connections: spin_resource_table::Table<Arc<dyn Connection>>,
21 connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
23 otel: OtelFactorState,
24}
25
26impl InstanceState {
27 pub fn new(
31 allowed_databases: Arc<HashSet<String>>,
32 connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
33 otel: OtelFactorState,
34 ) -> Self {
35 Self {
36 allowed_databases,
37 connections: spin_resource_table::Table::new(256),
38 connection_creators,
39 otel,
40 }
41 }
42
43 fn get_connection<T: 'static>(
45 &self,
46 connection: Resource<T>,
47 ) -> Result<Arc<dyn Connection>, v3::Error> {
48 self.connections
49 .get(connection.rep())
50 .cloned()
51 .ok_or(v3::Error::InvalidConnection)
52 }
53
54 async fn open_impl<T: 'static>(&mut self, database: String) -> Result<Resource<T>, v3::Error> {
55 if !self.allowed_databases.contains(&database) {
56 return Err(v3::Error::AccessDenied);
57 }
58 let conn = self
59 .connection_creators
60 .get(&database)
61 .ok_or(v3::Error::NoSuchDatabase)?
62 .create_connection(&database)
63 .await?;
64 tracing::Span::current().record(
65 "sqlite.backend",
66 conn.summary().as_deref().unwrap_or("unknown"),
67 );
68 self.connections
69 .push(conn)
70 .map_err(|()| v3::Error::Io("too many connections opened".to_string()))
71 .map(Resource::new_own)
72 }
73
74 async fn execute_impl<T: 'static>(
75 &mut self,
76 connection: Resource<T>,
77 query: String,
78 parameters: Vec<v3::Value>,
79 ) -> Result<v3::QueryResult, v3::Error> {
80 let conn = self.get_connection(connection)?;
81 tracing::Span::current().record(
82 "sqlite.backend",
83 conn.summary().as_deref().unwrap_or("unknown"),
84 );
85 conn.query(&query, parameters, MAX_HOST_BUFFERED_BYTES)
86 .await
87 }
88
89 pub fn allowed_databases(&self) -> &HashSet<String> {
91 &self.allowed_databases
92 }
93}
94
95impl SelfInstanceBuilder for InstanceState {}
96
97impl v3::Host for InstanceState {
98 fn convert_error(&mut self, error: v3::Error) -> anyhow::Result<v3::Error> {
99 Ok(error)
100 }
101}
102
103impl v3::HostConnection for InstanceState {
104 #[instrument(name = "spin_sqlite.open", skip(self), err(level = Level::INFO), fields(otel.kind = "client", db.system = "sqlite", sqlite.backend = Empty))]
105 async fn open(&mut self, database: String) -> Result<Resource<v3::Connection>, v3::Error> {
106 self.open_impl(database).await
107 }
108
109 #[instrument(name = "spin_sqlite.execute", skip(self, connection, parameters), err(level = Level::INFO), fields(otel.kind = "client", db.system = "sqlite", otel.name = query, sqlite.backend = Empty))]
110 async fn execute(
111 &mut self,
112 connection: Resource<v3::Connection>,
113 query: String,
114 parameters: Vec<v3::Value>,
115 ) -> Result<v3::QueryResult, v3::Error> {
116 self.execute_impl(connection, query, parameters).await
117 }
118
119 async fn changes(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<u64> {
120 let conn = match self.get_connection(connection) {
121 Ok(c) => c,
122 Err(err) => return Err(err.into()),
123 };
124 tracing::Span::current().record(
125 "sqlite.backend",
126 conn.summary().as_deref().unwrap_or("unknown"),
127 );
128 conn.changes().await.map_err(|e| e.into())
129 }
130
131 async fn last_insert_rowid(
132 &mut self,
133 connection: Resource<v3::Connection>,
134 ) -> anyhow::Result<i64> {
135 let conn = match self.get_connection(connection) {
136 Ok(c) => c,
137 Err(err) => return Err(err.into()),
138 };
139 tracing::Span::current().record(
140 "sqlite.backend",
141 conn.summary().as_deref().unwrap_or("unknown"),
142 );
143 conn.last_insert_rowid().await.map_err(|e| e.into())
144 }
145
146 async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
147 let _ = self.connections.remove(connection.rep());
148 Ok(())
149 }
150}
151
152impl v3::HostConnectionWithStore for crate::SqliteFactorData {
153 async fn open_async<T>(
154 accessor: &Accessor<T, Self>,
155 database: String,
156 ) -> Result<Resource<v3::Connection>, v3::Error> {
157 let conn_creator = accessor.with(|mut access| {
160 let host = access.get();
161 if !host.allowed_databases.contains(&database) {
162 return Err(v3::Error::AccessDenied);
163 }
164 host.connection_creators
165 .get(&database)
166 .ok_or(v3::Error::NoSuchDatabase)
167 .cloned()
168 })?;
169
170 let conn = conn_creator.create_connection(&database).await?;
171
172 tracing::Span::current().record(
173 "sqlite.backend",
174 conn.summary().as_deref().unwrap_or("unknown"),
175 );
176
177 let resource = accessor.with(|mut access| {
178 let host = access.get();
179 host.connections
180 .push(conn)
181 .map_err(|()| v3::Error::Io("too many connections opened".to_string()))
182 .map(Resource::new_own)
183 });
184
185 resource
186 }
187
188 async fn execute_async<T>(
189 accessor: &Accessor<T, Self>,
190 connection: Resource<v3::Connection>,
191 query: String,
192 parameters: Vec<v3::Value>,
193 ) -> Result<
194 (
195 Vec<String>,
196 StreamReader<v3::RowResult>,
197 FutureReader<Result<(), v3::Error>>,
198 ),
199 v3::Error,
200 > {
201 let conn = accessor.with(|mut access| {
202 let host = access.get();
203 host.get_connection(connection)
204 })?;
205
206 tracing::Span::current().record(
207 "sqlite.backend",
208 conn.summary().as_deref().unwrap_or("unknown"),
209 );
210
211 let QueryAsyncResult {
212 columns,
213 rows,
214 error,
215 } = conn
216 .query_async(&query, parameters, MAX_HOST_BUFFERED_BYTES)
217 .await?;
218 let row_producer = spin_wasi_async::stream::producer(rows);
219
220 let (sr, efr) = accessor.with(|mut access| {
221 let sr = StreamReader::new(&mut access, row_producer);
222 let efr = FutureReader::new(&mut access, error);
223 (sr, efr)
224 });
225
226 Ok((columns, sr, efr))
227 }
228
229 async fn changes_async<T>(
230 accessor: &Accessor<T, Self>,
231 connection: Resource<v3::Connection>,
232 ) -> anyhow::Result<u64> {
233 let conn = accessor.with(|mut access| {
234 let host = access.get();
235 host.get_connection(connection)
236 });
237
238 let conn = match conn {
239 Ok(c) => c,
240 Err(err) => return Err(err.into()),
241 };
242 tracing::Span::current().record(
243 "sqlite.backend",
244 conn.summary().as_deref().unwrap_or("unknown"),
245 );
246 conn.changes().await.map_err(|e| e.into())
247 }
248
249 async fn last_insert_rowid_async<T>(
250 accessor: &Accessor<T, Self>,
251 connection: Resource<v3::Connection>,
252 ) -> anyhow::Result<i64> {
253 let conn = accessor.with(|mut access| {
254 let host = access.get();
255 host.get_connection(connection)
256 });
257
258 let conn = match conn {
259 Ok(c) => c,
260 Err(err) => return Err(err.into()),
261 };
262 tracing::Span::current().record(
263 "sqlite.backend",
264 conn.summary().as_deref().unwrap_or("unknown"),
265 );
266 conn.last_insert_rowid().await.map_err(|e| e.into())
267 }
268}
269
270impl v2::Host for InstanceState {
271 fn convert_error(&mut self, error: v2::Error) -> anyhow::Result<v2::Error> {
272 Ok(error)
273 }
274}
275
276impl v2::HostConnection for InstanceState {
277 #[instrument(name = "spin_sqlite.open", skip(self), err(level = Level::INFO), fields(otel.kind = "client", db.system = "sqlite", sqlite.backend = Empty))]
278 async fn open(&mut self, database: String) -> Result<Resource<v2::Connection>, v2::Error> {
279 self.otel.reparent_tracing_span();
280 self.open_impl(database).await.map_err(to_v2_error)
281 }
282
283 #[instrument(name = "spin_sqlite.execute", skip(self, connection, parameters), err(level = Level::INFO), fields(otel.kind = "client", db.system = "sqlite", otel.name = query, sqlite.backend = Empty))]
284 async fn execute(
285 &mut self,
286 connection: Resource<v2::Connection>,
287 query: String,
288 parameters: Vec<v2::Value>,
289 ) -> Result<v2::QueryResult, v2::Error> {
290 self.otel.reparent_tracing_span();
291 self.execute_impl(
292 connection,
293 query,
294 parameters.into_iter().map(from_v2_value).collect(),
295 )
296 .await
297 .map(to_v2_query_result)
298 .map_err(to_v2_error)
299 }
300
301 async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
302 let _ = self.connections.remove(connection.rep());
303 Ok(())
304 }
305}
306
307impl v1::Host for InstanceState {
308 async fn open(&mut self, database: String) -> Result<u32, v1::Error> {
309 let result = <Self as v3::HostConnection>::open(self, database).await;
310 result.map_err(to_legacy_error).map(|s| s.rep())
311 }
312
313 async fn execute(
314 &mut self,
315 connection: u32,
316 query: String,
317 parameters: Vec<spin_world::v1::sqlite::Value>,
318 ) -> Result<spin_world::v1::sqlite::QueryResult, v1::Error> {
319 let this = Resource::new_borrow(connection);
320 let result = <Self as v3::HostConnection>::execute(
321 self,
322 this,
323 query,
324 parameters.into_iter().map(from_legacy_value).collect(),
325 )
326 .await;
327 result.map_err(to_legacy_error).map(to_legacy_query_result)
328 }
329
330 async fn close(&mut self, connection: u32) -> anyhow::Result<()> {
331 <Self as v2::HostConnection>::drop(self, Resource::new_own(connection)).await
332 }
333
334 fn convert_error(&mut self, error: v1::Error) -> anyhow::Result<v1::Error> {
335 Ok(error)
336 }
337}
338
339fn to_v2_error(error: v3::Error) -> v2::Error {
340 match error {
341 v3::Error::NoSuchDatabase => v2::Error::NoSuchDatabase,
342 v3::Error::AccessDenied => v2::Error::AccessDenied,
343 v3::Error::InvalidConnection => v2::Error::InvalidConnection,
344 v3::Error::DatabaseFull => v2::Error::DatabaseFull,
345 v3::Error::Io(s) => v2::Error::Io(s),
346 }
347}
348
349fn to_legacy_error(error: v3::Error) -> v1::Error {
350 match error {
351 v3::Error::NoSuchDatabase => v1::Error::NoSuchDatabase,
352 v3::Error::AccessDenied => v1::Error::AccessDenied,
353 v3::Error::InvalidConnection => v1::Error::InvalidConnection,
354 v3::Error::DatabaseFull => v1::Error::DatabaseFull,
355 v3::Error::Io(s) => v1::Error::Io(s),
356 }
357}
358
359fn to_v2_query_result(result: v3::QueryResult) -> v2::QueryResult {
360 v2::QueryResult {
361 columns: result.columns,
362 rows: result.rows.into_iter().map(to_v2_row_result).collect(),
363 }
364}
365
366fn to_legacy_query_result(result: v3::QueryResult) -> v1::QueryResult {
367 v1::QueryResult {
368 columns: result.columns,
369 rows: result.rows.into_iter().map(to_legacy_row_result).collect(),
370 }
371}
372
373fn to_v2_row_result(result: v3::RowResult) -> v2::RowResult {
374 v2::RowResult {
375 values: result.values.into_iter().map(to_v2_value).collect(),
376 }
377}
378
379fn to_legacy_row_result(result: v3::RowResult) -> v1::RowResult {
380 v1::RowResult {
381 values: result.values.into_iter().map(to_legacy_value).collect(),
382 }
383}
384
385fn to_v2_value(value: v3::Value) -> v2::Value {
386 match value {
387 v3::Value::Integer(i) => v2::Value::Integer(i),
388 v3::Value::Real(r) => v2::Value::Real(r),
389 v3::Value::Text(t) => v2::Value::Text(t),
390 v3::Value::Blob(b) => v2::Value::Blob(b),
391 v3::Value::Null => v2::Value::Null,
392 }
393}
394
395fn to_legacy_value(value: v3::Value) -> v1::Value {
396 match value {
397 v3::Value::Integer(i) => v1::Value::Integer(i),
398 v3::Value::Real(r) => v1::Value::Real(r),
399 v3::Value::Text(t) => v1::Value::Text(t),
400 v3::Value::Blob(b) => v1::Value::Blob(b),
401 v3::Value::Null => v1::Value::Null,
402 }
403}
404
405fn from_v2_value(value: v2::Value) -> v3::Value {
406 match value {
407 v2::Value::Integer(i) => v3::Value::Integer(i),
408 v2::Value::Real(r) => v3::Value::Real(r),
409 v2::Value::Text(t) => v3::Value::Text(t),
410 v2::Value::Blob(b) => v3::Value::Blob(b),
411 v2::Value::Null => v3::Value::Null,
412 }
413}
414
415fn from_legacy_value(value: v1::Value) -> v3::Value {
416 match value {
417 v1::Value::Integer(i) => v3::Value::Integer(i),
418 v1::Value::Real(r) => v3::Value::Real(r),
419 v1::Value::Text(t) => v3::Value::Text(t),
420 v1::Value::Blob(b) => v3::Value::Blob(b),
421 v1::Value::Null => v3::Value::Null,
422 }
423}