spin_factor_sqlite/
lib.rs

1mod host;
2pub mod runtime_config;
3
4use std::collections::{HashMap, HashSet};
5use std::sync::Arc;
6
7use host::InstanceState;
8
9use async_trait::async_trait;
10use spin_factors::{anyhow, Factor};
11use spin_locked_app::MetadataKey;
12use spin_world::spin::sqlite::sqlite as v3;
13use spin_world::v1::sqlite as v1;
14use spin_world::v2::sqlite as v2;
15
16pub use runtime_config::RuntimeConfig;
17
18#[derive(Default)]
19pub struct SqliteFactor {
20    _priv: (),
21}
22
23impl SqliteFactor {
24    /// Create a new `SqliteFactor`
25    pub fn new() -> Self {
26        Self { _priv: () }
27    }
28}
29
30impl Factor for SqliteFactor {
31    type RuntimeConfig = RuntimeConfig;
32    type AppState = AppState;
33    type InstanceBuilder = InstanceState;
34
35    fn init(&mut self, ctx: &mut impl spin_factors::InitContext<Self>) -> anyhow::Result<()> {
36        ctx.link_bindings(v1::add_to_linker)?;
37        ctx.link_bindings(v2::add_to_linker)?;
38        ctx.link_bindings(v3::add_to_linker)?;
39        Ok(())
40    }
41
42    fn configure_app<T: spin_factors::RuntimeFactors>(
43        &self,
44        mut ctx: spin_factors::ConfigureAppContext<T, Self>,
45    ) -> anyhow::Result<Self::AppState> {
46        let connection_creators = ctx
47            .take_runtime_config()
48            .unwrap_or_default()
49            .connection_creators;
50
51        let allowed_databases = ctx
52            .app()
53            .components()
54            .map(|component| {
55                Ok((
56                    component.id().to_string(),
57                    Arc::new(
58                        component
59                            .get_metadata(ALLOWED_DATABASES_KEY)?
60                            .unwrap_or_default()
61                            .into_iter()
62                            .collect::<HashSet<_>>(),
63                    ),
64                ))
65            })
66            .collect::<anyhow::Result<HashMap<_, _>>>()?;
67
68        ensure_allowed_databases_are_configured(&allowed_databases, |label| {
69            connection_creators.contains_key(label)
70        })?;
71
72        Ok(AppState::new(allowed_databases, connection_creators))
73    }
74
75    fn prepare<T: spin_factors::RuntimeFactors>(
76        &self,
77        ctx: spin_factors::PrepareContext<T, Self>,
78    ) -> spin_factors::anyhow::Result<Self::InstanceBuilder> {
79        let allowed_databases = ctx
80            .app_state()
81            .allowed_databases
82            .get(ctx.app_component().id())
83            .cloned()
84            .unwrap_or_default();
85        Ok(InstanceState::new(
86            allowed_databases,
87            ctx.app_state().connection_creators.clone(),
88        ))
89    }
90}
91
92/// Ensure that all the databases in the allowed databases list for each component are configured
93fn ensure_allowed_databases_are_configured(
94    allowed_databases: &HashMap<String, Arc<HashSet<String>>>,
95    is_configured: impl Fn(&str) -> bool,
96) -> anyhow::Result<()> {
97    let mut errors = Vec::new();
98    for (component_id, allowed_dbs) in allowed_databases {
99        for allowed in allowed_dbs.iter() {
100            if !is_configured(allowed) {
101                errors.push(format!(
102                    "- Component {component_id} uses database '{allowed}'"
103                ));
104            }
105        }
106    }
107
108    if !errors.is_empty() {
109        let prologue = vec![
110            "One or more components use SQLite databases which are not defined.",
111            "Check the spelling, or pass a runtime configuration file that defines these stores.",
112            "See https://spinframework.dev/dynamic-configuration#sqlite-storage-runtime-configuration",
113            "Details:",
114        ];
115        let lines: Vec<_> = prologue
116            .into_iter()
117            .map(|s| s.to_owned())
118            .chain(errors)
119            .collect();
120        return Err(anyhow::anyhow!(lines.join("\n")));
121    }
122    Ok(())
123}
124
125/// Metadata key for a list of allowed databases for a component.
126pub const ALLOWED_DATABASES_KEY: MetadataKey<Vec<String>> = MetadataKey::new("databases");
127
128#[derive(Clone)]
129pub struct AppState {
130    /// A map from component id to a set of allowed database labels.
131    allowed_databases: HashMap<String, Arc<HashSet<String>>>,
132    /// A mapping from database label to a connection creator.
133    connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
134}
135
136impl AppState {
137    /// Create a new `AppState`
138    pub fn new(
139        allowed_databases: HashMap<String, Arc<HashSet<String>>>,
140        connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
141    ) -> Self {
142        Self {
143            allowed_databases,
144            connection_creators,
145        }
146    }
147
148    /// Get a connection for a given database label.
149    ///
150    /// Returns `None` if there is no connection creator for the given label.
151    pub async fn get_connection(
152        &self,
153        label: &str,
154    ) -> Option<Result<Box<dyn Connection>, v3::Error>> {
155        let connection = self
156            .connection_creators
157            .get(label)?
158            .create_connection(label)
159            .await;
160        Some(connection)
161    }
162
163    /// Returns true if the given database label is used by any component.
164    pub fn database_is_used(&self, label: &str) -> bool {
165        self.allowed_databases
166            .values()
167            .any(|stores| stores.contains(label))
168    }
169}
170
171/// A creator of a connections for a particular SQLite database.
172#[async_trait]
173pub trait ConnectionCreator: Send + Sync {
174    /// Get a *new* [`Connection`]
175    ///
176    /// The connection should be a new connection, not a reused one.
177    async fn create_connection(
178        &self,
179        label: &str,
180    ) -> Result<Box<dyn Connection + 'static>, v3::Error>;
181}
182
183#[async_trait]
184impl<F> ConnectionCreator for F
185where
186    F: Fn() -> anyhow::Result<Box<dyn Connection + 'static>> + Send + Sync + 'static,
187{
188    async fn create_connection(
189        &self,
190        label: &str,
191    ) -> Result<Box<dyn Connection + 'static>, v3::Error> {
192        let _ = label;
193        (self)().map_err(|_| v3::Error::InvalidConnection)
194    }
195}
196
197/// A trait abstracting over operations to a SQLite database
198#[async_trait]
199pub trait Connection: Send + Sync {
200    async fn query(
201        &self,
202        query: &str,
203        parameters: Vec<v3::Value>,
204    ) -> Result<v3::QueryResult, v3::Error>;
205
206    async fn execute_batch(&self, statements: &str) -> anyhow::Result<()>;
207
208    async fn changes(&self) -> Result<u64, v3::Error>;
209
210    async fn last_insert_rowid(&self) -> Result<i64, v3::Error>;
211
212    /// A human-readable summary of the connection's configuration
213    ///
214    /// Example: "libSQL at libsql://example.com"
215    fn summary(&self) -> Option<String> {
216        None
217    }
218}