Skip to main content

spin_factor_sqlite/
lib.rs

1mod host;
2pub mod runtime_config;
3
4use std::collections::{HashMap, HashSet};
5use std::sync::Arc;
6
7use host::InstanceState;
8
9use async_trait::async_trait;
10use spin_factor_otel::OtelFactorState;
11use spin_factors::{anyhow, Factor};
12use spin_locked_app::MetadataKey;
13use spin_world::spin::sqlite3_1_0::sqlite as v3;
14use spin_world::v1::sqlite as v1;
15use spin_world::v2::sqlite as v2;
16
17pub use runtime_config::RuntimeConfig;
18
19#[derive(Default)]
20pub struct SqliteFactor {
21    _priv: (),
22}
23
24impl SqliteFactor {
25    /// Create a new `SqliteFactor`
26    pub fn new() -> Self {
27        Self { _priv: () }
28    }
29}
30
31impl Factor for SqliteFactor {
32    type RuntimeConfig = RuntimeConfig;
33    type AppState = AppState;
34    type InstanceBuilder = InstanceState;
35
36    fn init(&mut self, ctx: &mut impl spin_factors::InitContext<Self>) -> anyhow::Result<()> {
37        ctx.link_bindings(v1::add_to_linker::<_, SqliteFactorData>)?;
38        ctx.link_bindings(v2::add_to_linker::<_, SqliteFactorData>)?;
39        ctx.link_bindings(v3::add_to_linker::<_, SqliteFactorData>)?;
40        Ok(())
41    }
42
43    fn configure_app<T: spin_factors::RuntimeFactors>(
44        &self,
45        mut ctx: spin_factors::ConfigureAppContext<T, Self>,
46    ) -> anyhow::Result<Self::AppState> {
47        let connection_creators = ctx
48            .take_runtime_config()
49            .unwrap_or_default()
50            .connection_creators;
51
52        let allowed_databases = ctx
53            .app()
54            .components()
55            .map(|component| {
56                Ok((
57                    component.id().to_string(),
58                    Arc::new(
59                        component
60                            .get_metadata(ALLOWED_DATABASES_KEY)?
61                            .unwrap_or_default()
62                            .into_iter()
63                            .collect::<HashSet<_>>(),
64                    ),
65                ))
66            })
67            .collect::<anyhow::Result<HashMap<_, _>>>()?;
68
69        ensure_allowed_databases_are_configured(&allowed_databases, |label| {
70            connection_creators.contains_key(label)
71        })?;
72
73        Ok(AppState::new(allowed_databases, connection_creators))
74    }
75
76    fn prepare<T: spin_factors::RuntimeFactors>(
77        &self,
78        mut ctx: spin_factors::PrepareContext<T, Self>,
79    ) -> spin_factors::anyhow::Result<Self::InstanceBuilder> {
80        let allowed_databases = ctx
81            .app_state()
82            .allowed_databases
83            .get(ctx.app_component().id())
84            .cloned()
85            .unwrap_or_default();
86        let otel = OtelFactorState::from_prepare_context(&mut ctx)?;
87        Ok(InstanceState::new(
88            allowed_databases,
89            ctx.app_state().connection_creators.clone(),
90            otel,
91        ))
92    }
93}
94
95/// Ensure that all the databases in the allowed databases list for each component are configured
96fn ensure_allowed_databases_are_configured(
97    allowed_databases: &HashMap<String, Arc<HashSet<String>>>,
98    is_configured: impl Fn(&str) -> bool,
99) -> anyhow::Result<()> {
100    let mut errors = Vec::new();
101    for (component_id, allowed_dbs) in allowed_databases {
102        for allowed in allowed_dbs.iter() {
103            if !is_configured(allowed) {
104                errors.push(format!(
105                    "- Component {component_id} uses database '{allowed}'"
106                ));
107            }
108        }
109    }
110
111    if !errors.is_empty() {
112        let prologue = vec![
113            "One or more components use SQLite databases which are not defined.",
114            "Check the spelling, or pass a runtime configuration file that defines these stores.",
115            "See https://spinframework.dev/dynamic-configuration#sqlite-storage-runtime-configuration",
116            "Details:",
117        ];
118        let lines: Vec<_> = prologue
119            .into_iter()
120            .map(|s| s.to_owned())
121            .chain(errors)
122            .collect();
123        return Err(anyhow::anyhow!(lines.join("\n")));
124    }
125    Ok(())
126}
127
128/// Metadata key for a list of allowed databases for a component.
129pub const ALLOWED_DATABASES_KEY: MetadataKey<Vec<String>> = MetadataKey::new("databases");
130
131#[derive(Clone)]
132pub struct AppState {
133    /// A map from component id to a set of allowed database labels.
134    allowed_databases: HashMap<String, Arc<HashSet<String>>>,
135    /// A mapping from database label to a connection creator.
136    connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
137}
138
139impl AppState {
140    /// Create a new `AppState`
141    pub fn new(
142        allowed_databases: HashMap<String, Arc<HashSet<String>>>,
143        connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
144    ) -> Self {
145        Self {
146            allowed_databases,
147            connection_creators,
148        }
149    }
150
151    /// Get a connection for a given database label.
152    ///
153    /// Returns `None` if there is no connection creator for the given label.
154    pub async fn get_connection(
155        &self,
156        label: &str,
157    ) -> Option<Result<Arc<dyn Connection>, v3::Error>> {
158        let connection = self
159            .connection_creators
160            .get(label)?
161            .create_connection(label)
162            .await;
163        Some(connection)
164    }
165
166    /// Returns true if the given database label is used by any component.
167    pub fn database_is_used(&self, label: &str) -> bool {
168        self.allowed_databases
169            .values()
170            .any(|stores| stores.contains(label))
171    }
172}
173
174/// A creator of a connections for a particular SQLite database.
175#[async_trait]
176pub trait ConnectionCreator: Send + Sync {
177    /// Get a *new* [`Connection`]
178    ///
179    /// The connection should be a new connection, not a reused one.
180    async fn create_connection(
181        &self,
182        label: &str,
183    ) -> Result<Arc<dyn Connection + 'static>, v3::Error>;
184}
185
186#[async_trait]
187impl<F> ConnectionCreator for F
188where
189    F: Fn() -> anyhow::Result<Arc<dyn Connection + 'static>> + Send + Sync + 'static,
190{
191    async fn create_connection(
192        &self,
193        label: &str,
194    ) -> Result<Arc<dyn Connection + 'static>, v3::Error> {
195        let _ = label;
196        (self)().map_err(|_| v3::Error::InvalidConnection)
197    }
198}
199
200/// A trait abstracting over operations to a SQLite database
201#[async_trait]
202pub trait Connection: Send + Sync {
203    async fn query(
204        &self,
205        query: &str,
206        parameters: Vec<v3::Value>,
207        max_result_bytes: usize,
208    ) -> Result<v3::QueryResult, v3::Error>;
209
210    async fn query_async(
211        &self,
212        query: &str,
213        parameters: Vec<v3::Value>,
214        max_result_bytes: usize,
215    ) -> Result<QueryAsyncResult, v3::Error>;
216
217    async fn execute_batch(&self, statements: &str) -> anyhow::Result<()>;
218
219    async fn changes(&self) -> Result<u64, v3::Error>;
220
221    async fn last_insert_rowid(&self) -> Result<i64, v3::Error>;
222
223    /// A human-readable summary of the connection's configuration
224    ///
225    /// Example: "libSQL at libsql://example.com"
226    fn summary(&self) -> Option<String> {
227        None
228    }
229}
230
231pub struct QueryAsyncResult {
232    pub columns: Vec<String>,
233    pub rows: tokio::sync::mpsc::Receiver<v3::RowResult>,
234    pub error: tokio::sync::oneshot::Receiver<Result<(), v3::Error>>,
235}
236
237pub struct SqliteFactorData(SqliteFactor);
238
239impl spin_core::wasmtime::component::HasData for SqliteFactorData {
240    type Data<'a> = &'a mut InstanceState;
241}