spin_factor_outbound_networking/runtime_config/
spin.rs

1use anyhow::{bail, ensure, Context};
2use ip_network::IpNetwork;
3use rustls_pki_types::pem::PemObject;
4use serde::{Deserialize, Deserializer};
5use spin_factors::runtime_config::toml::GetTomlValue;
6use std::{
7    borrow::Cow,
8    path::{Path, PathBuf},
9};
10
11use super::ClientTlsRuntimeConfig;
12
13/// Spin's default handling of the runtime configuration for outbound networking.
14pub struct SpinRuntimeConfig {
15    runtime_config_dir: PathBuf,
16}
17
18impl SpinRuntimeConfig {
19    /// Creates a new `SpinRuntimeConfig`.
20    ///
21    /// The given `runtime_config_dir` will be used as the root to resolve any
22    /// relative paths.
23    pub fn new(runtime_config_dir: impl Into<PathBuf>) -> Self {
24        Self {
25            runtime_config_dir: runtime_config_dir.into(),
26        }
27    }
28
29    /// Get the runtime configuration for client TLS from a TOML table.
30    ///
31    /// Expects table to be in the format:
32    /// ````toml
33    /// [outbound_networking]
34    /// block_networks = ["1.1.1.1/32", "private"]
35    ///
36    /// [[client_tls]]
37    /// component_ids = ["example-component"]
38    /// hosts = ["example.com"]
39    /// ca_use_webpki_roots = true
40    /// ca_roots_file = "path/to/roots.crt"
41    /// client_cert_file = "path/to/client.crt"
42    /// client_private_key_file = "path/to/client.key"
43    /// ```
44    pub fn config_from_table(
45        &self,
46        table: &impl GetTomlValue,
47    ) -> anyhow::Result<Option<super::RuntimeConfig>> {
48        let maybe_blocked_networks = self
49            .blocked_networks_from_table(table)
50            .context("failed to parse [outbound_networking] table")?;
51        let maybe_tls_configs = self
52            .tls_configs_from_table(table)
53            .context("failed to parse [[client_tls]] table")?;
54
55        if maybe_blocked_networks.is_none() && maybe_tls_configs.is_none() {
56            return Ok(None);
57        }
58
59        let (blocked_ip_networks, block_private_networks) =
60            maybe_blocked_networks.unwrap_or_default();
61
62        let client_tls_configs = maybe_tls_configs.unwrap_or_default();
63
64        let runtime_config = super::RuntimeConfig {
65            blocked_ip_networks,
66            block_private_networks,
67            client_tls_configs,
68        };
69        Ok(Some(runtime_config))
70    }
71
72    /// Attempts to parse (blocked_ip_networks, block_private_networks) from a
73    /// `[outbound_networking]` table.
74    fn blocked_networks_from_table(
75        &self,
76        table: &impl GetTomlValue,
77    ) -> anyhow::Result<Option<(Vec<ip_network::IpNetwork>, bool)>> {
78        let Some(value) = table.get("outbound_networking") else {
79            return Ok(None);
80        };
81        let outbound_networking: OutboundNetworkingToml = value.clone().try_into()?;
82
83        let mut ip_networks = vec![];
84        let mut private_networks = false;
85        for block_network in outbound_networking.block_networks {
86            match block_network {
87                CidrOrPrivate::Cidr(ip_network) => ip_networks.push(ip_network),
88                CidrOrPrivate::Private => {
89                    private_networks = true;
90                }
91            }
92        }
93        Ok(Some((ip_networks, private_networks)))
94    }
95
96    fn tls_configs_from_table<T: GetTomlValue>(
97        &self,
98        table: &T,
99    ) -> anyhow::Result<Option<Vec<ClientTlsRuntimeConfig>>> {
100        let Some(array) = table.get("client_tls") else {
101            return Ok(None);
102        };
103        let toml_configs: Vec<ClientTlsToml> = array.clone().try_into()?;
104
105        let tls_configs = toml_configs
106            .into_iter()
107            .map(|toml_config| self.load_tls_config(toml_config))
108            .collect::<anyhow::Result<Vec<_>>>()
109            .context("failed to parse TLS config")?;
110        Ok(Some(tls_configs))
111    }
112
113    fn load_tls_config(
114        &self,
115        toml_config: ClientTlsToml,
116    ) -> anyhow::Result<ClientTlsRuntimeConfig> {
117        let ClientTlsToml {
118            component_ids,
119            hosts,
120            ca_use_webpki_roots,
121            ca_roots_file,
122            client_cert_file,
123            client_private_key_file,
124        } = toml_config;
125        ensure!(
126            !component_ids.is_empty(),
127            "'component_ids' list may not be empty"
128        );
129        ensure!(!hosts.is_empty(), "'hosts' list may not be empty");
130
131        let components = component_ids.into_iter().map(Into::into).collect();
132
133        let hosts = hosts
134            .iter()
135            .map(|host| {
136                host.parse()
137                    .map_err(|err| anyhow::anyhow!("invalid host {host:?}: {err:?}"))
138            })
139            .collect::<anyhow::Result<Vec<_>>>()?;
140
141        let use_webpki_roots = if let Some(ca_use_webpki_roots) = ca_use_webpki_roots {
142            ca_use_webpki_roots
143        } else {
144            // Use webpki roots by default *unless* explicit roots were given
145            ca_roots_file.is_none()
146        };
147
148        let root_certificates = ca_roots_file
149            .map(|path| self.load_certs(path))
150            .transpose()?
151            .unwrap_or_default();
152
153        let client_cert = match (client_cert_file, client_private_key_file) {
154            (Some(cert_path), Some(key_path)) => Some(super::ClientCertRuntimeConfig {
155                cert_chain: self.load_certs(cert_path)?,
156                key_der: self.load_key(key_path)?,
157            }),
158            (None, None) => None,
159            (Some(_), None) => bail!("client_cert_file specified without client_private_key_file"),
160            (None, Some(_)) => bail!("client_private_key_file specified without client_cert_file"),
161        };
162
163        Ok(ClientTlsRuntimeConfig {
164            components,
165            hosts,
166            root_certificates,
167            use_webpki_roots,
168            client_cert,
169        })
170    }
171
172    // Parse certs from the provided file
173    fn load_certs(
174        &self,
175        path: impl AsRef<Path>,
176    ) -> anyhow::Result<Vec<rustls_pki_types::CertificateDer<'static>>> {
177        let path = self.runtime_config_dir.join(path);
178        rustls_pki_types::CertificateDer::pem_file_iter(&path)
179            .and_then(Iterator::collect)
180            .with_context(|| format!("failed to load certificate(s) from '{}'", path.display()))
181    }
182
183    // Parse a private key from the provided file
184    fn load_key(
185        &self,
186        path: impl AsRef<Path>,
187    ) -> anyhow::Result<rustls_pki_types::PrivateKeyDer<'static>> {
188        let path = self.runtime_config_dir.join(path);
189        rustls_pki_types::PrivateKeyDer::from_pem_file(&path)
190            .with_context(|| format!("failed to load key from '{}'", path.display()))
191    }
192}
193
194#[derive(Debug, Deserialize)]
195#[serde(deny_unknown_fields)]
196struct ClientTlsToml {
197    component_ids: Vec<spin_serde::KebabId>,
198    #[serde(deserialize_with = "deserialize_hosts")]
199    hosts: Vec<String>,
200    ca_use_webpki_roots: Option<bool>,
201    ca_roots_file: Option<PathBuf>,
202    client_cert_file: Option<PathBuf>,
203    client_private_key_file: Option<PathBuf>,
204}
205
206fn deserialize_hosts<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Vec<String>, D::Error> {
207    let hosts = Vec::<String>::deserialize(deserializer)?;
208    for host in &hosts {
209        crate::tls::validate_host(host).map_err(serde::de::Error::custom)?;
210    }
211    Ok(hosts)
212}
213
214#[derive(Debug, Default, Deserialize)]
215#[serde(deny_unknown_fields)]
216struct OutboundNetworkingToml {
217    #[serde(default)]
218    block_networks: Vec<CidrOrPrivate>,
219}
220
221#[derive(Debug)]
222enum CidrOrPrivate {
223    Cidr(ip_network::IpNetwork),
224    Private,
225}
226
227impl<'de> Deserialize<'de> for CidrOrPrivate {
228    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
229    where
230        D: Deserializer<'de>,
231    {
232        let s = Cow::<str>::deserialize(deserializer)?;
233        if s == "private" {
234            return Ok(Self::Private);
235        }
236        if let Ok(net) = IpNetwork::from_str_truncate(&s) {
237            return Ok(Self::Cidr(net));
238        }
239        Err(serde::de::Error::invalid_value(
240            serde::de::Unexpected::Str(&s),
241            &"an IP network in CIDR notation or the keyword 'private'",
242        ))
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use crate::blocked_networks::tests::cidr;
249
250    use super::*;
251
252    const TESTDATA_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata");
253
254    #[test]
255    fn test_no_config() -> anyhow::Result<()> {
256        let maybe_config = SpinRuntimeConfig::new("").config_from_table(&toml::toml! {
257            [some_other_config]
258            relevant = false
259        })?;
260        assert!(maybe_config.is_none(), "{maybe_config:?}");
261        Ok(())
262    }
263
264    #[test]
265    fn test_no_blocked_networks() -> anyhow::Result<()> {
266        for table in &[
267            toml::toml! {
268                [outbound_networking]
269            },
270            toml::toml! {
271                [outbound_networking]
272                block_networks = []
273            },
274        ] {
275            let config = SpinRuntimeConfig::new("")
276                .config_from_table(table)
277                .with_context(|| table.to_string())?
278                .context("expected config, got None")?;
279            assert!(config.blocked_ip_networks.is_empty(), "{config:?}");
280            assert!(!config.block_private_networks);
281        }
282        Ok(())
283    }
284
285    #[test]
286    fn test_some_blocked_networks() -> anyhow::Result<()> {
287        let config = SpinRuntimeConfig::new("")
288            .config_from_table(&toml::toml! {
289                [outbound_networking]
290                block_networks = ["1.1.1.1/32", "8.8.8.8/16", "private"]
291            })
292            .context("config_from_table")?
293            .context("expected config, got None")?;
294        assert!(config.blocked_ip_networks.contains(&cidr("1.1.1.1/32")));
295        // Networks get normalized ("truncated")
296        assert!(config.blocked_ip_networks.contains(&cidr("8.8.0.0/16")));
297        assert!(config.block_private_networks, "{config:?}");
298        Ok(())
299    }
300
301    #[test]
302    fn test_min_tls_config() -> anyhow::Result<()> {
303        let config = SpinRuntimeConfig::new("/doesnt-matter");
304
305        let tls_configs = config
306            .tls_configs_from_table(&toml::toml! {
307                [[client_tls]]
308                component_ids = ["test-component"]
309                hosts = ["test-host"]
310
311            })?
312            .context("missing config section")?;
313        assert_eq!(tls_configs.len(), 1);
314
315        assert_eq!(tls_configs[0].components, ["test-component"]);
316        assert_eq!(tls_configs[0].hosts[0].as_str(), "test-host");
317        assert!(tls_configs[0].use_webpki_roots);
318        Ok(())
319    }
320
321    #[test]
322    fn test_max_tls_config() -> anyhow::Result<()> {
323        let config = SpinRuntimeConfig::new(TESTDATA_DIR);
324
325        let tls_configs = config
326            .tls_configs_from_table(&toml::toml! {
327                [[client_tls]]
328                component_ids = ["test-component"]
329                hosts = ["test-host"]
330                ca_use_webpki_roots = true
331                ca_roots_file = "valid-cert.pem"
332                client_cert_file = "valid-cert.pem"
333                client_private_key_file = "valid-private-key.pem"
334            })?
335            .context("missing config section")?;
336        assert_eq!(tls_configs.len(), 1);
337
338        assert!(tls_configs[0].use_webpki_roots);
339        assert_eq!(tls_configs[0].root_certificates.len(), 2);
340        assert!(tls_configs[0].client_cert.is_some());
341        Ok(())
342    }
343
344    #[test]
345    fn test_use_webpki_roots_default_with_explicit_roots() -> anyhow::Result<()> {
346        let config = SpinRuntimeConfig::new(TESTDATA_DIR);
347
348        let tls_configs = config
349            .tls_configs_from_table(&toml::toml! {
350                [[client_tls]]
351                component_ids = ["test-component"]
352                hosts = ["test-host"]
353                ca_roots_file = "valid-cert.pem"
354            })?
355            .context("missing config section")?;
356
357        assert!(!tls_configs[0].use_webpki_roots);
358        Ok(())
359    }
360
361    #[test]
362    fn test_invalid_cert() {
363        let config = SpinRuntimeConfig::new(TESTDATA_DIR);
364
365        config
366            .tls_configs_from_table(&toml::toml! {
367                [[client_tls]]
368                component_ids = ["test-component"]
369                hosts = ["test-host"]
370                ca_roots_file = "invalid-cert.pem"
371            })
372            .unwrap_err();
373    }
374
375    #[test]
376    fn test_invalid_private_key() {
377        let config = SpinRuntimeConfig::new(TESTDATA_DIR);
378
379        config
380            .tls_configs_from_table(&toml::toml! {
381                [[client_tls]]
382                component_ids = ["test-component"]
383                hosts = ["test-host"]
384                client_cert_file = "valid-cert.pem"
385                client_private_key_file = "invalid-key.pem"
386            })
387            .unwrap_err();
388    }
389}