spin_factor_outbound_networking/
tls.rs

1use std::{collections::HashMap, ops::Deref, sync::Arc};
2
3use anyhow::{ensure, Context};
4
5use crate::runtime_config::{ClientCertRuntimeConfig, ClientTlsRuntimeConfig};
6
7/// TLS client configs
8#[derive(Default)]
9pub struct TlsClientConfigs {
10    /// Shared map of component ID -> HostTlsClientConfigs
11    component_host_tls_client_configs: HashMap<String, HostTlsClientConfigs>,
12    /// The default [`ClientConfig`] for a host if one is not explicitly configured for it.
13    default_tls_client_config: TlsClientConfig,
14}
15
16impl TlsClientConfigs {
17    pub(crate) fn new(
18        client_tls_configs: impl IntoIterator<Item = ClientTlsRuntimeConfig>,
19    ) -> anyhow::Result<Self> {
20        // Construct nested map of <component ID> -> <host authority> -> TLS client config
21        let mut component_host_tls_client_configs = HashMap::<String, HostTlsClientConfigs>::new();
22        for ClientTlsRuntimeConfig {
23            components,
24            hosts,
25            root_certificates,
26            use_webpki_roots,
27            client_cert,
28        } in client_tls_configs
29        {
30            ensure!(
31                !components.is_empty(),
32                "client TLS 'components' list may not be empty"
33            );
34            ensure!(
35                !hosts.is_empty(),
36                "client TLS 'hosts' list may not be empty"
37            );
38            let tls_client_config =
39                TlsClientConfig::new(root_certificates, use_webpki_roots, client_cert)
40                    .context("error building TLS client config")?;
41            for component in components {
42                let host_configs = component_host_tls_client_configs
43                    .entry(component.clone())
44                    .or_default();
45                for host in &hosts {
46                    validate_host(host)?;
47                    // First matching (component, host) pair wins
48                    Arc::get_mut(host_configs)
49                        .unwrap()
50                        .entry(host.clone())
51                        .or_insert_with(|| tls_client_config.clone());
52                }
53            }
54        }
55
56        Ok(Self {
57            component_host_tls_client_configs,
58            ..Default::default()
59        })
60    }
61
62    /// Returns [`ComponentTlsClientConfigs`] for the given component.
63    pub fn get_component_tls_configs(&self, component_id: &str) -> ComponentTlsClientConfigs {
64        let host_client_configs = self
65            .component_host_tls_client_configs
66            .get(component_id)
67            .cloned();
68        ComponentTlsClientConfigs {
69            host_client_configs,
70            default_client_config: self.default_tls_client_config.clone(),
71        }
72    }
73}
74
75/// Shared maps of host authority -> TlsClientConfig
76type HostTlsClientConfigs = Arc<HashMap<String, TlsClientConfig>>;
77
78/// TLS configurations for a specific component.
79#[derive(Clone)]
80pub struct ComponentTlsClientConfigs {
81    pub(crate) host_client_configs: Option<HostTlsClientConfigs>,
82    pub(crate) default_client_config: TlsClientConfig,
83}
84
85impl ComponentTlsClientConfigs {
86    /// Returns a [`ClientConfig`] for the given host authority.
87    pub fn get_client_config(&self, host: &str) -> &TlsClientConfig {
88        self.host_client_configs
89            .as_ref()
90            .and_then(|configs| configs.get(host))
91            .unwrap_or(&self.default_client_config)
92    }
93}
94
95/// Shared TLS client configuration
96#[derive(Clone)]
97pub struct TlsClientConfig(Arc<rustls::ClientConfig>);
98
99impl TlsClientConfig {
100    fn new(
101        root_certificates: Vec<rustls_pki_types::CertificateDer<'static>>,
102        use_webpki_roots: bool,
103        client_cert: Option<ClientCertRuntimeConfig>,
104    ) -> anyhow::Result<Self> {
105        let mut root_store = rustls::RootCertStore::empty();
106        if use_webpki_roots {
107            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
108        }
109        for cert in root_certificates {
110            root_store.add(cert)?;
111        }
112
113        let builder = rustls::ClientConfig::builder().with_root_certificates(root_store);
114
115        let client_config = if let Some(ClientCertRuntimeConfig {
116            cert_chain,
117            key_der,
118        }) = client_cert
119        {
120            builder.with_client_auth_cert(cert_chain, key_der)?
121        } else {
122            builder.with_no_client_auth()
123        };
124        Ok(Self(client_config.into()))
125    }
126
127    /// Returns the inner [`rustls::ClientConfig`] for consumption by rustls APIs.
128    pub fn inner(&self) -> Arc<rustls::ClientConfig> {
129        self.0.clone()
130    }
131}
132
133impl Deref for TlsClientConfig {
134    type Target = rustls::ClientConfig;
135
136    fn deref(&self) -> &Self::Target {
137        &self.0
138    }
139}
140
141impl Default for TlsClientConfig {
142    fn default() -> Self {
143        Self::new(vec![], true, None).expect("default client config should be valid")
144    }
145}
146
147/// Validate host name (authority without port)
148pub(crate) fn validate_host(host: &str) -> anyhow::Result<()> {
149    let authority: http::uri::Authority = host
150        .parse()
151        .with_context(|| format!("invalid TLS 'host' {host:?}"))?;
152    ensure!(
153        authority.port().is_none(),
154        "invalid TLS 'host' {host:?}; ports not currently supported"
155    );
156    Ok(())
157}
158
159#[cfg(test)]
160mod tests {
161    use std::path::Path;
162
163    use anyhow::Context;
164    use rustls_pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer};
165
166    use super::*;
167
168    #[test]
169    fn test_empty_config() -> anyhow::Result<()> {
170        // Just make sure the default doesn't panic
171        let configs = TlsClientConfigs::default();
172        configs.get_tls_client_config("foo", "bar");
173        Ok(())
174    }
175
176    #[test]
177    fn test_minimal_config() -> anyhow::Result<()> {
178        let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
179            components: vec!["test-component".into()],
180            hosts: vec!["test-host".into()],
181            root_certificates: vec![],
182            use_webpki_roots: false,
183            client_cert: None,
184        }])?;
185        let config = configs.get_tls_client_config("test-component", "test-host");
186        // Check that we didn't just get the default
187        let default_config = configs.get_tls_client_config("other_component", "test-host");
188        assert!(!Arc::ptr_eq(&config.0, &default_config.0));
189        Ok(())
190    }
191
192    #[test]
193    fn test_maximal_config() -> anyhow::Result<()> {
194        let test_certs = test_certs()?;
195        let test_key = test_key()?;
196        let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
197            components: vec!["test-component".into()],
198            hosts: vec!["test-host".into()],
199            root_certificates: vec![test_certs[0].clone()],
200            use_webpki_roots: false,
201            client_cert: Some(ClientCertRuntimeConfig {
202                cert_chain: test_certs,
203                key_der: test_key,
204            }),
205        }])?;
206        let config = configs.get_tls_client_config("test-component", "test-host");
207        assert!(config.client_auth_cert_resolver.has_certs());
208        Ok(())
209    }
210
211    #[test]
212    fn test_config_overrides() -> anyhow::Result<()> {
213        let test_certs = test_certs()?;
214        let test_key = test_key()?;
215        let configs = TlsClientConfigs::new([
216            ClientTlsRuntimeConfig {
217                components: vec!["test-component1".into()],
218                hosts: vec!["test-host".into()],
219                client_cert: Some(ClientCertRuntimeConfig {
220                    cert_chain: test_certs,
221                    key_der: test_key,
222                }),
223                ..Default::default()
224            },
225            ClientTlsRuntimeConfig {
226                components: vec!["test-component1".into(), "test-component2".into()],
227                hosts: vec!["test-host".into()],
228                ..Default::default()
229            },
230        ])?;
231        // First match wins
232        let config1 = configs.get_tls_client_config("test-component1", "test-host");
233        assert!(config1.client_auth_cert_resolver.has_certs());
234
235        // Correctly select by differing component ID
236        let config2 = configs.get_tls_client_config("test-component-2", "test-host");
237        assert!(!config2.client_auth_cert_resolver.has_certs());
238        Ok(())
239    }
240
241    const TESTDATA_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata");
242
243    fn test_certs() -> anyhow::Result<Vec<CertificateDer<'static>>> {
244        CertificateDer::pem_file_iter(Path::new(TESTDATA_DIR).join("valid-cert.pem"))?
245            .collect::<Result<Vec<_>, _>>()
246            .context("certs")
247    }
248
249    fn test_key() -> anyhow::Result<PrivateKeyDer<'static>> {
250        PrivateKeyDer::from_pem_file(Path::new(TESTDATA_DIR).join("valid-private-key.pem"))
251            .context("key")
252    }
253
254    impl TlsClientConfigs {
255        fn get_tls_client_config(&self, component_id: &str, host: &str) -> TlsClientConfig {
256            let component_config = self.get_component_tls_configs(component_id);
257            component_config.get_client_config(host).clone()
258        }
259    }
260}