Skip to main content

spin_factor_outbound_networking/
tls.rs

1use std::{collections::HashMap, ops::Deref, sync::Arc};
2
3use anyhow::{Context, ensure};
4
5use crate::runtime_config::{ClientCertRuntimeConfig, ClientTlsRuntimeConfig};
6
7/// TLS client configs
8#[derive(Default)]
9pub struct TlsClientConfigs {
10    /// Shared map of component ID -> HostTlsClientConfigs
11    component_host_tls_client_configs: HashMap<String, HostTlsClientConfigs>,
12    /// The default [`ClientConfig`] for a host if one is not explicitly configured for it.
13    default_tls_client_config: TlsClientConfig,
14}
15
16impl TlsClientConfigs {
17    pub(crate) fn new(
18        client_tls_configs: impl IntoIterator<Item = ClientTlsRuntimeConfig>,
19    ) -> anyhow::Result<Self> {
20        // Construct nested map of <component ID> -> <host authority> -> TLS client config
21        let mut component_host_tls_client_configs = HashMap::<String, HostTlsClientConfigs>::new();
22        for ClientTlsRuntimeConfig {
23            components,
24            hosts,
25            root_certificates,
26            use_platform_roots,
27            use_webpki_roots,
28            client_cert,
29        } in client_tls_configs
30        {
31            ensure!(
32                !components.is_empty(),
33                "client TLS 'components' list may not be empty"
34            );
35            ensure!(
36                !hosts.is_empty(),
37                "client TLS 'hosts' list may not be empty"
38            );
39            let tls_client_config = TlsClientConfig::new(
40                use_platform_roots,
41                use_webpki_roots,
42                root_certificates,
43                client_cert,
44            )
45            .context("error building TLS client config")?;
46            for component in components {
47                let host_configs = component_host_tls_client_configs
48                    .entry(component.clone())
49                    .or_default();
50                for host in &hosts {
51                    validate_host(host)?;
52                    // First matching (component, host) pair wins
53                    Arc::get_mut(host_configs)
54                        .unwrap()
55                        .entry(host.clone())
56                        .or_insert_with(|| tls_client_config.clone());
57                }
58            }
59        }
60
61        Ok(Self {
62            component_host_tls_client_configs,
63            ..Default::default()
64        })
65    }
66
67    /// Returns [`ComponentTlsClientConfigs`] for the given component.
68    pub fn get_component_tls_configs(&self, component_id: &str) -> ComponentTlsClientConfigs {
69        let host_client_configs = self
70            .component_host_tls_client_configs
71            .get(component_id)
72            .cloned();
73        ComponentTlsClientConfigs {
74            host_client_configs,
75            default_client_config: self.default_tls_client_config.clone(),
76        }
77    }
78}
79
80/// Shared maps of host authority -> TlsClientConfig
81type HostTlsClientConfigs = Arc<HashMap<String, TlsClientConfig>>;
82
83/// TLS configurations for a specific component.
84#[derive(Clone)]
85pub struct ComponentTlsClientConfigs {
86    pub(crate) host_client_configs: Option<HostTlsClientConfigs>,
87    pub(crate) default_client_config: TlsClientConfig,
88}
89
90impl ComponentTlsClientConfigs {
91    /// Returns a [`ClientConfig`] for the given host authority.
92    pub fn get_client_config(&self, host: &str) -> &TlsClientConfig {
93        self.host_client_configs
94            .as_ref()
95            .and_then(|configs| configs.get(host))
96            .unwrap_or(&self.default_client_config)
97    }
98}
99
100/// Shared TLS client configuration
101#[derive(Clone)]
102pub struct TlsClientConfig(Arc<rustls::ClientConfig>);
103
104impl TlsClientConfig {
105    fn new(
106        use_platform_roots: bool,
107        use_webpki_roots: bool,
108        root_certificates: Vec<rustls_pki_types::CertificateDer<'static>>,
109        client_cert: Option<ClientCertRuntimeConfig>,
110    ) -> anyhow::Result<Self> {
111        anyhow::ensure!(
112            use_platform_roots || use_webpki_roots || !root_certificates.is_empty(),
113            "at least one of 'use_platform_roots', 'use_webpki_roots', or 'root_certificates' must be set"
114        );
115
116        let mut extra_roots: Box<dyn Iterator<Item = _>> = Box::new(root_certificates.into_iter());
117        if use_webpki_roots {
118            extra_roots = Box::new(
119                extra_roots.chain(webpki_root_certs::TLS_SERVER_ROOT_CERTS.iter().cloned()),
120            );
121        }
122
123        let builder = rustls::ClientConfig::builder();
124        let builder = if use_platform_roots {
125            let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
126                extra_roots,
127                builder.crypto_provider().clone(),
128            )
129            .context("failed to initialize platform certificate verifier")?;
130            builder
131                .dangerous()
132                .with_custom_certificate_verifier(Arc::new(verifier))
133        } else {
134            let mut root_store = rustls::RootCertStore::empty();
135            extra_roots.try_for_each(|cert| root_store.add(cert))?;
136            builder.with_root_certificates(root_store)
137        };
138
139        let client_config = if let Some(ClientCertRuntimeConfig {
140            cert_chain,
141            key_der,
142        }) = client_cert
143        {
144            builder.with_client_auth_cert(cert_chain, key_der)?
145        } else {
146            builder.with_no_client_auth()
147        };
148        Ok(Self(client_config.into()))
149    }
150
151    /// Returns the inner [`rustls::ClientConfig`] for consumption by rustls APIs.
152    pub fn inner(&self) -> Arc<rustls::ClientConfig> {
153        self.0.clone()
154    }
155}
156
157impl Deref for TlsClientConfig {
158    type Target = rustls::ClientConfig;
159
160    fn deref(&self) -> &Self::Target {
161        &self.0
162    }
163}
164
165impl Default for TlsClientConfig {
166    fn default() -> Self {
167        Self::new(true, true, vec![], None).expect("default client config should be valid")
168    }
169}
170
171/// Validate host name (authority without port)
172pub(crate) fn validate_host(host: &str) -> anyhow::Result<()> {
173    let authority: http::uri::Authority = host
174        .parse()
175        .with_context(|| format!("invalid TLS 'host' {host:?}"))?;
176    ensure!(
177        authority.port().is_none(),
178        "invalid TLS 'host' {host:?}; ports not currently supported"
179    );
180    Ok(())
181}
182
183#[cfg(test)]
184mod tests {
185    use std::path::Path;
186
187    use anyhow::Context;
188    use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
189
190    use super::*;
191
192    #[test]
193    fn test_empty_config() -> anyhow::Result<()> {
194        // Just make sure the default doesn't panic
195        let configs = TlsClientConfigs::default();
196        configs.get_tls_client_config("foo", "bar");
197        Ok(())
198    }
199
200    #[test]
201    fn test_minimal_config() -> anyhow::Result<()> {
202        let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
203            components: vec!["test-component".into()],
204            hosts: vec!["test-host".into()],
205            root_certificates: vec![],
206            use_platform_roots: false,
207            use_webpki_roots: true,
208            client_cert: None,
209        }])?;
210        let config = configs.get_tls_client_config("test-component", "test-host");
211        // Check that we didn't just get the default
212        let default_config = configs.get_tls_client_config("other_component", "test-host");
213        assert!(!Arc::ptr_eq(&config.0, &default_config.0));
214        Ok(())
215    }
216
217    #[test]
218    fn test_maximal_config() -> anyhow::Result<()> {
219        let test_certs = test_certs()?;
220        let test_key = test_key()?;
221        let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
222            components: vec!["test-component".into()],
223            hosts: vec!["test-host".into()],
224            root_certificates: vec![test_certs[0].clone()],
225            use_platform_roots: false,
226            use_webpki_roots: false,
227            client_cert: Some(ClientCertRuntimeConfig {
228                cert_chain: test_certs,
229                key_der: test_key,
230            }),
231        }])?;
232        let config = configs.get_tls_client_config("test-component", "test-host");
233        assert!(config.client_auth_cert_resolver.has_certs());
234        Ok(())
235    }
236
237    #[test]
238    fn test_config_overrides() -> anyhow::Result<()> {
239        let test_certs = test_certs()?;
240        let test_key = test_key()?;
241        let configs = TlsClientConfigs::new([
242            ClientTlsRuntimeConfig {
243                components: vec!["test-component1".into()],
244                hosts: vec!["test-host".into()],
245                client_cert: Some(ClientCertRuntimeConfig {
246                    cert_chain: test_certs,
247                    key_der: test_key,
248                }),
249                ..Default::default()
250            },
251            ClientTlsRuntimeConfig {
252                components: vec!["test-component1".into(), "test-component2".into()],
253                hosts: vec!["test-host".into()],
254                ..Default::default()
255            },
256        ])?;
257        // First match wins
258        let config1 = configs.get_tls_client_config("test-component1", "test-host");
259        assert!(config1.client_auth_cert_resolver.has_certs());
260
261        // Correctly select by differing component ID
262        let config2 = configs.get_tls_client_config("test-component-2", "test-host");
263        assert!(!config2.client_auth_cert_resolver.has_certs());
264        Ok(())
265    }
266
267    const TESTDATA_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata");
268
269    fn test_certs() -> anyhow::Result<Vec<CertificateDer<'static>>> {
270        CertificateDer::pem_file_iter(Path::new(TESTDATA_DIR).join("valid-cert.pem"))?
271            .collect::<Result<Vec<_>, _>>()
272            .context("certs")
273    }
274
275    fn test_key() -> anyhow::Result<PrivateKeyDer<'static>> {
276        PrivateKeyDer::from_pem_file(Path::new(TESTDATA_DIR).join("valid-private-key.pem"))
277            .context("key")
278    }
279
280    impl TlsClientConfigs {
281        fn get_tls_client_config(&self, component_id: &str, host: &str) -> TlsClientConfig {
282            let component_config = self.get_component_tls_configs(component_id);
283            component_config.get_client_config(host).clone()
284        }
285    }
286}