spin_factor_outbound_networking/runtime_config/
spin.rs1use anyhow::{bail, ensure, Context};
2use ip_network::IpNetwork;
3use rustls_pki_types::pem::PemObject;
4use serde::{Deserialize, Deserializer};
5use spin_factors::runtime_config::toml::GetTomlValue;
6use std::{
7 borrow::Cow,
8 path::{Path, PathBuf},
9};
10
11use super::ClientTlsRuntimeConfig;
12
13pub struct SpinRuntimeConfig {
15 runtime_config_dir: PathBuf,
16}
17
18impl SpinRuntimeConfig {
19 pub fn new(runtime_config_dir: impl Into<PathBuf>) -> Self {
24 Self {
25 runtime_config_dir: runtime_config_dir.into(),
26 }
27 }
28
29 pub fn config_from_table(
45 &self,
46 table: &impl GetTomlValue,
47 ) -> anyhow::Result<Option<super::RuntimeConfig>> {
48 let maybe_blocked_networks = self
49 .blocked_networks_from_table(table)
50 .context("failed to parse [outbound_networking] table")?;
51 let maybe_tls_configs = self
52 .tls_configs_from_table(table)
53 .context("failed to parse [[client_tls]] table")?;
54
55 if maybe_blocked_networks.is_none() && maybe_tls_configs.is_none() {
56 return Ok(None);
57 }
58
59 let (blocked_ip_networks, block_private_networks) =
60 maybe_blocked_networks.unwrap_or_default();
61
62 let client_tls_configs = maybe_tls_configs.unwrap_or_default();
63
64 let runtime_config = super::RuntimeConfig {
65 blocked_ip_networks,
66 block_private_networks,
67 client_tls_configs,
68 };
69 Ok(Some(runtime_config))
70 }
71
72 fn blocked_networks_from_table(
75 &self,
76 table: &impl GetTomlValue,
77 ) -> anyhow::Result<Option<(Vec<ip_network::IpNetwork>, bool)>> {
78 let Some(value) = table.get("outbound_networking") else {
79 return Ok(None);
80 };
81 let outbound_networking: OutboundNetworkingToml = value.clone().try_into()?;
82
83 let mut ip_networks = vec![];
84 let mut private_networks = false;
85 for block_network in outbound_networking.block_networks {
86 match block_network {
87 CidrOrPrivate::Cidr(ip_network) => ip_networks.push(ip_network),
88 CidrOrPrivate::Private => {
89 private_networks = true;
90 }
91 }
92 }
93 Ok(Some((ip_networks, private_networks)))
94 }
95
96 fn tls_configs_from_table<T: GetTomlValue>(
97 &self,
98 table: &T,
99 ) -> anyhow::Result<Option<Vec<ClientTlsRuntimeConfig>>> {
100 let Some(array) = table.get("client_tls") else {
101 return Ok(None);
102 };
103 let toml_configs: Vec<ClientTlsToml> = array.clone().try_into()?;
104
105 let tls_configs = toml_configs
106 .into_iter()
107 .map(|toml_config| self.load_tls_config(toml_config))
108 .collect::<anyhow::Result<Vec<_>>>()
109 .context("failed to parse TLS config")?;
110 Ok(Some(tls_configs))
111 }
112
113 fn load_tls_config(
114 &self,
115 toml_config: ClientTlsToml,
116 ) -> anyhow::Result<ClientTlsRuntimeConfig> {
117 let ClientTlsToml {
118 component_ids,
119 hosts,
120 ca_use_webpki_roots,
121 ca_roots_file,
122 client_cert_file,
123 client_private_key_file,
124 } = toml_config;
125 ensure!(
126 !component_ids.is_empty(),
127 "'component_ids' list may not be empty"
128 );
129 ensure!(!hosts.is_empty(), "'hosts' list may not be empty");
130
131 let components = component_ids.into_iter().map(Into::into).collect();
132
133 let hosts = hosts
134 .iter()
135 .map(|host| {
136 host.parse()
137 .map_err(|err| anyhow::anyhow!("invalid host {host:?}: {err:?}"))
138 })
139 .collect::<anyhow::Result<Vec<_>>>()?;
140
141 let use_webpki_roots = if let Some(ca_use_webpki_roots) = ca_use_webpki_roots {
142 ca_use_webpki_roots
143 } else {
144 ca_roots_file.is_none()
146 };
147
148 let root_certificates = ca_roots_file
149 .map(|path| self.load_certs(path))
150 .transpose()?
151 .unwrap_or_default();
152
153 let client_cert = match (client_cert_file, client_private_key_file) {
154 (Some(cert_path), Some(key_path)) => Some(super::ClientCertRuntimeConfig {
155 cert_chain: self.load_certs(cert_path)?,
156 key_der: self.load_key(key_path)?,
157 }),
158 (None, None) => None,
159 (Some(_), None) => bail!("client_cert_file specified without client_private_key_file"),
160 (None, Some(_)) => bail!("client_private_key_file specified without client_cert_file"),
161 };
162
163 Ok(ClientTlsRuntimeConfig {
164 components,
165 hosts,
166 root_certificates,
167 use_webpki_roots,
168 client_cert,
169 })
170 }
171
172 fn load_certs(
174 &self,
175 path: impl AsRef<Path>,
176 ) -> anyhow::Result<Vec<rustls_pki_types::CertificateDer<'static>>> {
177 let path = self.runtime_config_dir.join(path);
178 rustls_pki_types::CertificateDer::pem_file_iter(&path)
179 .and_then(Iterator::collect)
180 .with_context(|| format!("failed to load certificate(s) from '{}'", path.display()))
181 }
182
183 fn load_key(
185 &self,
186 path: impl AsRef<Path>,
187 ) -> anyhow::Result<rustls_pki_types::PrivateKeyDer<'static>> {
188 let path = self.runtime_config_dir.join(path);
189 rustls_pki_types::PrivateKeyDer::from_pem_file(&path)
190 .with_context(|| format!("failed to load key from '{}'", path.display()))
191 }
192}
193
194#[derive(Debug, Deserialize)]
195#[serde(deny_unknown_fields)]
196struct ClientTlsToml {
197 component_ids: Vec<spin_serde::KebabId>,
198 #[serde(deserialize_with = "deserialize_hosts")]
199 hosts: Vec<String>,
200 ca_use_webpki_roots: Option<bool>,
201 ca_roots_file: Option<PathBuf>,
202 client_cert_file: Option<PathBuf>,
203 client_private_key_file: Option<PathBuf>,
204}
205
206fn deserialize_hosts<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Vec<String>, D::Error> {
207 let hosts = Vec::<String>::deserialize(deserializer)?;
208 for host in &hosts {
209 crate::tls::validate_host(host).map_err(serde::de::Error::custom)?;
210 }
211 Ok(hosts)
212}
213
214#[derive(Debug, Default, Deserialize)]
215#[serde(deny_unknown_fields)]
216struct OutboundNetworkingToml {
217 #[serde(default)]
218 block_networks: Vec<CidrOrPrivate>,
219}
220
221#[derive(Debug)]
222enum CidrOrPrivate {
223 Cidr(ip_network::IpNetwork),
224 Private,
225}
226
227impl<'de> Deserialize<'de> for CidrOrPrivate {
228 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
229 where
230 D: Deserializer<'de>,
231 {
232 let s = Cow::<str>::deserialize(deserializer)?;
233 if s == "private" {
234 return Ok(Self::Private);
235 }
236 if let Ok(net) = IpNetwork::from_str_truncate(&s) {
237 return Ok(Self::Cidr(net));
238 }
239 Err(serde::de::Error::invalid_value(
240 serde::de::Unexpected::Str(&s),
241 &"an IP network in CIDR notation or the keyword 'private'",
242 ))
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use crate::blocked_networks::tests::cidr;
249
250 use super::*;
251
252 const TESTDATA_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata");
253
254 #[test]
255 fn test_no_config() -> anyhow::Result<()> {
256 let maybe_config = SpinRuntimeConfig::new("").config_from_table(&toml::toml! {
257 [some_other_config]
258 relevant = false
259 })?;
260 assert!(maybe_config.is_none(), "{maybe_config:?}");
261 Ok(())
262 }
263
264 #[test]
265 fn test_no_blocked_networks() -> anyhow::Result<()> {
266 for table in &[
267 toml::toml! {
268 [outbound_networking]
269 },
270 toml::toml! {
271 [outbound_networking]
272 block_networks = []
273 },
274 ] {
275 let config = SpinRuntimeConfig::new("")
276 .config_from_table(table)
277 .with_context(|| table.to_string())?
278 .context("expected config, got None")?;
279 assert!(config.blocked_ip_networks.is_empty(), "{config:?}");
280 assert!(!config.block_private_networks);
281 }
282 Ok(())
283 }
284
285 #[test]
286 fn test_some_blocked_networks() -> anyhow::Result<()> {
287 let config = SpinRuntimeConfig::new("")
288 .config_from_table(&toml::toml! {
289 [outbound_networking]
290 block_networks = ["1.1.1.1/32", "8.8.8.8/16", "private"]
291 })
292 .context("config_from_table")?
293 .context("expected config, got None")?;
294 assert!(config.blocked_ip_networks.contains(&cidr("1.1.1.1/32")));
295 assert!(config.blocked_ip_networks.contains(&cidr("8.8.0.0/16")));
297 assert!(config.block_private_networks, "{config:?}");
298 Ok(())
299 }
300
301 #[test]
302 fn test_min_tls_config() -> anyhow::Result<()> {
303 let config = SpinRuntimeConfig::new("/doesnt-matter");
304
305 let tls_configs = config
306 .tls_configs_from_table(&toml::toml! {
307 [[client_tls]]
308 component_ids = ["test-component"]
309 hosts = ["test-host"]
310
311 })?
312 .context("missing config section")?;
313 assert_eq!(tls_configs.len(), 1);
314
315 assert_eq!(tls_configs[0].components, ["test-component"]);
316 assert_eq!(tls_configs[0].hosts[0].as_str(), "test-host");
317 assert!(tls_configs[0].use_webpki_roots);
318 Ok(())
319 }
320
321 #[test]
322 fn test_max_tls_config() -> anyhow::Result<()> {
323 let config = SpinRuntimeConfig::new(TESTDATA_DIR);
324
325 let tls_configs = config
326 .tls_configs_from_table(&toml::toml! {
327 [[client_tls]]
328 component_ids = ["test-component"]
329 hosts = ["test-host"]
330 ca_use_webpki_roots = true
331 ca_roots_file = "valid-cert.pem"
332 client_cert_file = "valid-cert.pem"
333 client_private_key_file = "valid-private-key.pem"
334 })?
335 .context("missing config section")?;
336 assert_eq!(tls_configs.len(), 1);
337
338 assert!(tls_configs[0].use_webpki_roots);
339 assert_eq!(tls_configs[0].root_certificates.len(), 2);
340 assert!(tls_configs[0].client_cert.is_some());
341 Ok(())
342 }
343
344 #[test]
345 fn test_use_webpki_roots_default_with_explicit_roots() -> anyhow::Result<()> {
346 let config = SpinRuntimeConfig::new(TESTDATA_DIR);
347
348 let tls_configs = config
349 .tls_configs_from_table(&toml::toml! {
350 [[client_tls]]
351 component_ids = ["test-component"]
352 hosts = ["test-host"]
353 ca_roots_file = "valid-cert.pem"
354 })?
355 .context("missing config section")?;
356
357 assert!(!tls_configs[0].use_webpki_roots);
358 Ok(())
359 }
360
361 #[test]
362 fn test_invalid_cert() {
363 let config = SpinRuntimeConfig::new(TESTDATA_DIR);
364
365 config
366 .tls_configs_from_table(&toml::toml! {
367 [[client_tls]]
368 component_ids = ["test-component"]
369 hosts = ["test-host"]
370 ca_roots_file = "invalid-cert.pem"
371 })
372 .unwrap_err();
373 }
374
375 #[test]
376 fn test_invalid_private_key() {
377 let config = SpinRuntimeConfig::new(TESTDATA_DIR);
378
379 config
380 .tls_configs_from_table(&toml::toml! {
381 [[client_tls]]
382 component_ids = ["test-component"]
383 hosts = ["test-host"]
384 client_cert_file = "valid-cert.pem"
385 client_private_key_file = "invalid-key.pem"
386 })
387 .unwrap_err();
388 }
389}