spin_factor_outbound_networking/runtime_config/
spin.rs

1use anyhow::{bail, ensure, Context};
2use serde::{Deserialize, Deserializer};
3use spin_factors::runtime_config::toml::GetTomlValue;
4use std::io;
5use std::{
6    fs,
7    path::{Path, PathBuf},
8};
9
10use super::{validate_host, TlsConfig};
11
12/// Spin's default handling of the runtime configuration for outbound TLS.
13pub struct SpinTlsRuntimeConfig {
14    runtime_config_dir: PathBuf,
15}
16
17impl SpinTlsRuntimeConfig {
18    /// Creates a new `SpinTlsRuntimeConfig`.
19    ///
20    /// The given `runtime_config_dir` will be used as the root to resolve any
21    /// relative paths.
22    pub fn new(runtime_config_dir: impl Into<PathBuf>) -> Self {
23        Self {
24            runtime_config_dir: runtime_config_dir.into(),
25        }
26    }
27
28    /// Get the runtime configuration for client TLS from a TOML table.
29    ///
30    /// Expects table to be in the format:
31    /// ````toml
32    /// [[client_tls]]
33    /// component_ids = ["example-component"]
34    /// hosts = ["example.com"]
35    /// ca_use_webpki_roots = true
36    /// ca_roots_file = "path/to/roots.crt"
37    /// client_cert_file = "path/to/client.crt"
38    /// client_private_key_file = "path/to/client.key"
39    /// ```
40    pub fn config_from_table(
41        &self,
42        table: &impl GetTomlValue,
43    ) -> anyhow::Result<Option<super::RuntimeConfig>> {
44        let Some(tls_configs) = self.tls_configs_from_table(table)? else {
45            return Ok(None);
46        };
47        let runtime_config = super::RuntimeConfig::new(tls_configs)?;
48        Ok(Some(runtime_config))
49    }
50
51    fn tls_configs_from_table<T: GetTomlValue>(
52        &self,
53        table: &T,
54    ) -> anyhow::Result<Option<Vec<TlsConfig>>> {
55        let Some(array) = table.get("client_tls") else {
56            return Ok(None);
57        };
58        let toml_configs: Vec<RuntimeConfigToml> = array.clone().try_into()?;
59
60        let tls_configs = toml_configs
61            .into_iter()
62            .map(|toml_config| self.load_tls_config(toml_config))
63            .collect::<anyhow::Result<Vec<_>>>()
64            .context("failed to parse TLS configs from TOML")?;
65        Ok(Some(tls_configs))
66    }
67
68    fn load_tls_config(&self, toml_config: RuntimeConfigToml) -> anyhow::Result<TlsConfig> {
69        let RuntimeConfigToml {
70            component_ids,
71            hosts,
72            ca_use_webpki_roots,
73            ca_roots_file,
74            client_cert_file,
75            client_private_key_file,
76        } = toml_config;
77        ensure!(
78            !component_ids.is_empty(),
79            "[[client_tls]] 'component_ids' list may not be empty"
80        );
81        ensure!(
82            !hosts.is_empty(),
83            "[[client_tls]] 'hosts' list may not be empty"
84        );
85
86        let components = component_ids.into_iter().map(Into::into).collect();
87
88        let hosts = hosts
89            .iter()
90            .map(|host| {
91                host.parse()
92                    .map_err(|err| anyhow::anyhow!("invalid host {host:?}: {err:?}"))
93            })
94            .collect::<anyhow::Result<Vec<_>>>()?;
95
96        let use_webpki_roots = if let Some(ca_use_webpki_roots) = ca_use_webpki_roots {
97            ca_use_webpki_roots
98        } else {
99            // Use webpki roots by default *unless* explicit roots were given
100            ca_roots_file.is_none()
101        };
102
103        let root_certificates = ca_roots_file
104            .map(|path| self.load_certs(path))
105            .transpose()?
106            .unwrap_or_default();
107
108        let client_cert = match (client_cert_file, client_private_key_file) {
109            (Some(cert_path), Some(key_path)) => Some(super::ClientCertConfig {
110                cert_chain: self.load_certs(cert_path)?,
111                key_der: self.load_key(key_path)?,
112            }),
113            (None, None) => None,
114            (Some(_), None) => bail!("client_cert_file specified without client_private_key_file"),
115            (None, Some(_)) => bail!("client_private_key_file specified without client_cert_file"),
116        };
117
118        Ok(TlsConfig {
119            components,
120            hosts,
121            root_certificates,
122            use_webpki_roots,
123            client_cert,
124        })
125    }
126
127    // Parse certs from the provided file
128    fn load_certs(
129        &self,
130        path: impl AsRef<Path>,
131    ) -> io::Result<Vec<rustls_pki_types::CertificateDer<'static>>> {
132        let path = self.runtime_config_dir.join(path);
133        rustls_pemfile::certs(&mut io::BufReader::new(fs::File::open(path).map_err(
134            |err| {
135                io::Error::new(
136                    io::ErrorKind::InvalidInput,
137                    format!("failed to read cert file {:?}", err),
138                )
139            },
140        )?))
141        .collect()
142    }
143
144    // Parse a private key from the provided file
145    fn load_key(
146        &self,
147        path: impl AsRef<Path>,
148    ) -> anyhow::Result<rustls_pki_types::PrivateKeyDer<'static>> {
149        let path = self.runtime_config_dir.join(path);
150        let file = fs::File::open(&path)
151            .with_context(|| format!("failed to read private key from '{}'", path.display()))?;
152        Ok(rustls_pemfile::private_key(&mut io::BufReader::new(file))
153            .with_context(|| format!("failed to parse private key from '{}'", path.display()))?
154            .ok_or_else(|| {
155                io::Error::new(
156                    io::ErrorKind::InvalidInput,
157                    format!(
158                        "private key file '{}' contains no private keys",
159                        path.display()
160                    ),
161                )
162            })?)
163    }
164}
165
166#[derive(Debug, Deserialize)]
167#[serde(deny_unknown_fields)]
168pub struct RuntimeConfigToml {
169    component_ids: Vec<spin_serde::KebabId>,
170    #[serde(deserialize_with = "deserialize_hosts")]
171    hosts: Vec<String>,
172    ca_use_webpki_roots: Option<bool>,
173    ca_roots_file: Option<PathBuf>,
174    client_cert_file: Option<PathBuf>,
175    client_private_key_file: Option<PathBuf>,
176}
177
178fn deserialize_hosts<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Vec<String>, D::Error> {
179    let hosts = Vec::<String>::deserialize(deserializer)?;
180    for host in &hosts {
181        validate_host(host).map_err(serde::de::Error::custom)?;
182    }
183    Ok(hosts)
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    const TESTDATA_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata");
191
192    #[test]
193    fn test_min_config() -> anyhow::Result<()> {
194        let config = SpinTlsRuntimeConfig::new("/doesnt-matter");
195
196        let tls_configs = config
197            .tls_configs_from_table(&toml::toml! {
198                [[client_tls]]
199                component_ids = ["test-component"]
200                hosts = ["test-host"]
201
202            })?
203            .context("missing config section")?;
204        assert_eq!(tls_configs.len(), 1);
205
206        assert_eq!(tls_configs[0].components, ["test-component"]);
207        assert_eq!(tls_configs[0].hosts[0].as_str(), "test-host");
208        assert!(tls_configs[0].use_webpki_roots);
209        Ok(())
210    }
211
212    #[test]
213    fn test_max_config() -> anyhow::Result<()> {
214        let config = SpinTlsRuntimeConfig::new(TESTDATA_DIR);
215
216        let tls_configs = config
217            .tls_configs_from_table(&toml::toml! {
218                [[client_tls]]
219                component_ids = ["test-component"]
220                hosts = ["test-host"]
221                ca_use_webpki_roots = true
222                ca_roots_file = "valid-cert.pem"
223                client_cert_file = "valid-cert.pem"
224                client_private_key_file = "valid-private-key.pem"
225            })?
226            .context("missing config section")?;
227        assert_eq!(tls_configs.len(), 1);
228
229        assert!(tls_configs[0].use_webpki_roots);
230        assert_eq!(tls_configs[0].root_certificates.len(), 2);
231        assert!(tls_configs[0].client_cert.is_some());
232        Ok(())
233    }
234
235    #[test]
236    fn test_use_webpki_roots_default_with_explicit_roots() -> anyhow::Result<()> {
237        let config = SpinTlsRuntimeConfig::new(TESTDATA_DIR);
238
239        let tls_configs = config
240            .tls_configs_from_table(&toml::toml! {
241                [[client_tls]]
242                component_ids = ["test-component"]
243                hosts = ["test-host"]
244                ca_roots_file = "valid-cert.pem"
245            })?
246            .context("missing config section")?;
247
248        assert!(!tls_configs[0].use_webpki_roots);
249        Ok(())
250    }
251
252    #[test]
253    fn test_invalid_cert() {
254        let config = SpinTlsRuntimeConfig::new(TESTDATA_DIR);
255
256        config
257            .tls_configs_from_table(&toml::toml! {
258                [[client_tls]]
259                component_ids = ["test-component"]
260                hosts = ["test-host"]
261                ca_roots_file = "invalid-cert.pem"
262            })
263            .unwrap_err();
264    }
265
266    #[test]
267    fn test_invalid_private_key() {
268        let config = SpinTlsRuntimeConfig::new(TESTDATA_DIR);
269
270        config
271            .tls_configs_from_table(&toml::toml! {
272                [[client_tls]]
273                component_ids = ["test-component"]
274                hosts = ["test-host"]
275                client_cert_file = "valid-cert.pem"
276                client_private_key_file = "invalid-key.pem"
277            })
278            .unwrap_err();
279    }
280}