1use std::{
2 collections::HashMap,
3 env::VarError,
4 path::{Path, PathBuf},
5 sync::OnceLock,
6};
7
8use serde::Deserialize;
9use spin_expressions::{Key, Provider};
10use spin_factors::anyhow::{self, Context as _};
11use spin_world::async_trait;
12use tracing::{instrument, Level};
13
14#[derive(Debug, Default, Deserialize)]
16#[serde(deny_unknown_fields)]
17pub struct EnvVariablesConfig {
18 #[serde(default)]
22 pub prefix: Option<String>,
23 #[serde(default)]
25 pub dotenv_path: Option<PathBuf>,
26}
27
28const DEFAULT_ENV_PREFIX: &str = "SPIN_VARIABLE";
29
30type EnvFetcherFn = Box<dyn Fn(&str) -> Result<String, VarError> + Send + Sync>;
31
32pub struct EnvVariablesProvider {
34 prefix: Option<String>,
35 env_fetcher: EnvFetcherFn,
36 dotenv_path: Option<PathBuf>,
37 dotenv_cache: OnceLock<HashMap<String, String>>,
38}
39
40impl Default for EnvVariablesProvider {
41 fn default() -> Self {
42 Self {
43 prefix: None,
44 env_fetcher: Box::new(|s| std::env::var(s)),
45 dotenv_path: Some(".env".into()),
46 dotenv_cache: Default::default(),
47 }
48 }
49}
50
51impl EnvVariablesProvider {
52 pub fn new(
60 prefix: Option<impl Into<String>>,
61 env_fetcher: impl Fn(&str) -> Result<String, VarError> + Send + Sync + 'static,
62 dotenv_path: Option<PathBuf>,
63 ) -> Self {
64 Self {
65 prefix: prefix.map(Into::into),
66 dotenv_path,
67 env_fetcher: Box::new(env_fetcher),
68 dotenv_cache: Default::default(),
69 }
70 }
71
72 fn get_sync(&self, key: &Key) -> anyhow::Result<Option<String>> {
74 let prefix = self
75 .prefix
76 .clone()
77 .unwrap_or_else(|| DEFAULT_ENV_PREFIX.to_string());
78
79 let upper_key = key.as_ref().to_ascii_uppercase();
80 let env_key = format!("{prefix}_{upper_key}");
81
82 self.query_env(&env_key)
83 }
84
85 fn query_env(&self, env_key: &str) -> anyhow::Result<Option<String>> {
87 match (self.env_fetcher)(env_key) {
88 Err(std::env::VarError::NotPresent) => self.get_dotenv(env_key),
89 other => other
90 .map(Some)
91 .with_context(|| format!("failed to resolve env var {env_key}")),
92 }
93 }
94
95 fn get_dotenv(&self, key: &str) -> anyhow::Result<Option<String>> {
96 let Some(dotenv_path) = self.dotenv_path.as_deref() else {
97 return Ok(None);
98 };
99 let cache = match self.dotenv_cache.get() {
100 Some(cache) => cache,
101 None => {
102 let cache = load_dotenv(dotenv_path)?;
103 let _ = self.dotenv_cache.set(cache);
104 self.dotenv_cache.get().unwrap()
107 }
108 };
109 Ok(cache.get(key).cloned())
110 }
111}
112
113impl std::fmt::Debug for EnvVariablesProvider {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("EnvProvider")
116 .field("prefix", &self.prefix)
117 .field("dotenv_path", &self.dotenv_path)
118 .finish()
119 }
120}
121
122fn load_dotenv(dotenv_path: &Path) -> anyhow::Result<HashMap<String, String>> {
123 Ok(dotenvy::from_path_iter(dotenv_path)
124 .into_iter()
125 .flatten()
126 .collect::<Result<HashMap<String, String>, _>>()?)
127}
128
129#[async_trait]
130impl Provider for EnvVariablesProvider {
131 #[instrument(name = "spin_variables.get_from_env", level = Level::DEBUG, skip(self), err(level = Level::INFO))]
132 async fn get(&self, key: &Key) -> anyhow::Result<Option<String>> {
133 tokio::task::block_in_place(|| self.get_sync(key))
134 }
135}
136
137#[cfg(test)]
138mod test {
139 use std::env::temp_dir;
140
141 use super::*;
142
143 struct TestEnv {
144 map: HashMap<String, String>,
145 }
146
147 impl TestEnv {
148 fn new() -> Self {
149 Self {
150 map: Default::default(),
151 }
152 }
153
154 fn insert(&mut self, key: &str, value: &str) {
155 self.map.insert(key.to_string(), value.to_string());
156 }
157
158 fn get(&self, key: &str) -> Result<String, VarError> {
159 self.map.get(key).cloned().ok_or(VarError::NotPresent)
160 }
161 }
162
163 #[test]
164 fn provider_get() {
165 let mut env = TestEnv::new();
166 env.insert("TESTING_SPIN_ENV_KEY1", "val");
167 let key1 = Key::new("env_key1").unwrap();
168 assert_eq!(
169 EnvVariablesProvider::new(Some("TESTING_SPIN"), move |key| env.get(key), None)
170 .get_sync(&key1)
171 .unwrap(),
172 Some("val".to_string())
173 );
174 }
175
176 #[test]
177 fn provider_get_dotenv() {
178 let dotenv_path = temp_dir().join("spin-env-provider-test");
179 std::fs::write(&dotenv_path, b"TESTING_SPIN_ENV_KEY2=dotenv_val").unwrap();
180
181 let key = Key::new("env_key2").unwrap();
182 assert_eq!(
183 EnvVariablesProvider::new(
184 Some("TESTING_SPIN"),
185 |_| Err(VarError::NotPresent),
186 Some(dotenv_path)
187 )
188 .get_sync(&key)
189 .unwrap(),
190 Some("dotenv_val".to_string())
191 );
192 }
193
194 #[test]
195 fn provider_get_missing() {
196 let key = Key::new("definitely_not_set").unwrap();
197 assert_eq!(
198 EnvVariablesProvider::new(
199 Some("TESTING_SPIN"),
200 |_| Err(VarError::NotPresent),
201 Default::default()
202 )
203 .get_sync(&key)
204 .unwrap(),
205 None
206 );
207 }
208}