Skip to main content

spin_factor_outbound_networking/
tls.rs

1use std::{collections::HashMap, ops::Deref, sync::Arc};
2
3use anyhow::{Context, ensure};
4use rustls::client::danger::ServerCertVerifier;
5
6use crate::runtime_config::{ClientCertRuntimeConfig, ClientTlsRuntimeConfig};
7
8/// TLS client configs
9#[derive(Default)]
10pub struct TlsClientConfigs {
11    /// Shared map of component ID -> HostTlsClientConfigs
12    component_host_tls_client_configs: HashMap<String, HostTlsClientConfigs>,
13    /// The default [`ClientConfig`] for a host if one is not explicitly configured for it.
14    default_tls_client_config: TlsClientConfig,
15}
16
17impl TlsClientConfigs {
18    pub(crate) fn new(
19        client_tls_configs: impl IntoIterator<Item = ClientTlsRuntimeConfig>,
20    ) -> anyhow::Result<Self> {
21        // Construct nested map of <component ID> -> <host authority> -> TLS client config
22        let mut component_host_tls_client_configs = HashMap::<String, HostTlsClientConfigs>::new();
23        for ClientTlsRuntimeConfig {
24            components,
25            hosts,
26            root_certificates,
27            use_platform_roots,
28            use_webpki_roots,
29            client_cert,
30        } in client_tls_configs
31        {
32            ensure!(
33                !components.is_empty(),
34                "client TLS 'components' list may not be empty"
35            );
36            ensure!(
37                !hosts.is_empty(),
38                "client TLS 'hosts' list may not be empty"
39            );
40            let tls_client_config = TlsClientConfig::new(
41                root_certificates,
42                use_webpki_roots,
43                use_platform_roots,
44                client_cert,
45            )
46            .context("error building TLS client config")?;
47            for component in components {
48                let host_configs = component_host_tls_client_configs
49                    .entry(component.clone())
50                    .or_default();
51                for host in &hosts {
52                    validate_host(host)?;
53                    // First matching (component, host) pair wins
54                    Arc::get_mut(host_configs)
55                        .unwrap()
56                        .entry(host.clone())
57                        .or_insert_with(|| tls_client_config.clone());
58                }
59            }
60        }
61
62        Ok(Self {
63            component_host_tls_client_configs,
64            ..Default::default()
65        })
66    }
67
68    /// Returns [`ComponentTlsClientConfigs`] for the given component.
69    pub fn get_component_tls_configs(&self, component_id: &str) -> ComponentTlsClientConfigs {
70        let host_client_configs = self
71            .component_host_tls_client_configs
72            .get(component_id)
73            .cloned();
74        ComponentTlsClientConfigs {
75            host_client_configs,
76            default_client_config: self.default_tls_client_config.clone(),
77        }
78    }
79}
80
81/// Shared maps of host authority -> TlsClientConfig
82type HostTlsClientConfigs = Arc<HashMap<String, TlsClientConfig>>;
83
84/// TLS configurations for a specific component.
85#[derive(Clone)]
86pub struct ComponentTlsClientConfigs {
87    pub(crate) host_client_configs: Option<HostTlsClientConfigs>,
88    pub(crate) default_client_config: TlsClientConfig,
89}
90
91impl ComponentTlsClientConfigs {
92    /// Returns a [`ClientConfig`] for the given host authority.
93    pub fn get_client_config(&self, host: &str) -> &TlsClientConfig {
94        self.host_client_configs
95            .as_ref()
96            .and_then(|configs| configs.get(host))
97            .unwrap_or(&self.default_client_config)
98    }
99}
100
101/// Shared TLS client configuration
102#[derive(Clone)]
103pub struct TlsClientConfig(Arc<rustls::ClientConfig>);
104
105impl TlsClientConfig {
106    fn new(
107        root_certificates: Vec<rustls_pki_types::CertificateDer<'static>>,
108        use_webpki_roots: bool,
109        use_platform_roots: bool,
110        client_cert: Option<ClientCertRuntimeConfig>,
111    ) -> anyhow::Result<Self> {
112        let builder = if use_platform_roots {
113            let crypto_provider = rustls::crypto::CryptoProvider::get_default()
114                .cloned()
115                .unwrap_or_else(|| Arc::new(rustls::crypto::ring::default_provider()));
116
117            let verifier: Arc<dyn ServerCertVerifier> = if root_certificates.is_empty() {
118                Arc::new(
119                    rustls_platform_verifier::Verifier::new(crypto_provider.clone())
120                        .context("failed to initialize platform certificate verifier")?,
121                )
122            } else {
123                Arc::new(
124                    rustls_platform_verifier::Verifier::new_with_extra_roots(
125                        root_certificates,
126                        crypto_provider.clone(),
127                    )
128                    .context(
129                        "failed to initialize platform certificate verifier with extra roots",
130                    )?,
131                )
132            };
133
134            rustls::ClientConfig::builder_with_provider(crypto_provider)
135                .with_safe_default_protocol_versions()?
136                .dangerous()
137                .with_custom_certificate_verifier(verifier)
138        } else {
139            let mut root_store = rustls::RootCertStore::empty();
140            if use_webpki_roots {
141                root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
142            }
143            for cert in root_certificates {
144                root_store.add(cert)?;
145            }
146            rustls::ClientConfig::builder().with_root_certificates(root_store)
147        };
148
149        let client_config = if let Some(ClientCertRuntimeConfig {
150            cert_chain,
151            key_der,
152        }) = client_cert
153        {
154            builder.with_client_auth_cert(cert_chain, key_der)?
155        } else {
156            builder.with_no_client_auth()
157        };
158        Ok(Self(client_config.into()))
159    }
160
161    /// Returns the inner [`rustls::ClientConfig`] for consumption by rustls APIs.
162    pub fn inner(&self) -> Arc<rustls::ClientConfig> {
163        self.0.clone()
164    }
165}
166
167impl Deref for TlsClientConfig {
168    type Target = rustls::ClientConfig;
169
170    fn deref(&self) -> &Self::Target {
171        &self.0
172    }
173}
174
175impl Default for TlsClientConfig {
176    fn default() -> Self {
177        Self::new(vec![], false, true, None).expect("default client config should be valid")
178    }
179}
180
181/// Validate host name (authority without port)
182pub(crate) fn validate_host(host: &str) -> anyhow::Result<()> {
183    let authority: http::uri::Authority = host
184        .parse()
185        .with_context(|| format!("invalid TLS 'host' {host:?}"))?;
186    ensure!(
187        authority.port().is_none(),
188        "invalid TLS 'host' {host:?}; ports not currently supported"
189    );
190    Ok(())
191}
192
193#[cfg(test)]
194mod tests {
195    use std::path::Path;
196
197    use anyhow::Context;
198    use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
199
200    use super::*;
201
202    #[test]
203    fn test_empty_config() -> anyhow::Result<()> {
204        // Just make sure the default doesn't panic
205        let configs = TlsClientConfigs::default();
206        configs.get_tls_client_config("foo", "bar");
207        Ok(())
208    }
209
210    #[test]
211    fn test_minimal_config() -> anyhow::Result<()> {
212        let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
213            components: vec!["test-component".into()],
214            hosts: vec!["test-host".into()],
215            root_certificates: vec![],
216            use_platform_roots: false,
217            use_webpki_roots: false,
218            client_cert: None,
219        }])?;
220        let config = configs.get_tls_client_config("test-component", "test-host");
221        // Check that we didn't just get the default
222        let default_config = configs.get_tls_client_config("other_component", "test-host");
223        assert!(!Arc::ptr_eq(&config.0, &default_config.0));
224        Ok(())
225    }
226
227    #[test]
228    fn test_maximal_config() -> anyhow::Result<()> {
229        let test_certs = test_certs()?;
230        let test_key = test_key()?;
231        let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
232            components: vec!["test-component".into()],
233            hosts: vec!["test-host".into()],
234            root_certificates: vec![test_certs[0].clone()],
235            use_platform_roots: false,
236            use_webpki_roots: false,
237            client_cert: Some(ClientCertRuntimeConfig {
238                cert_chain: test_certs,
239                key_der: test_key,
240            }),
241        }])?;
242        let config = configs.get_tls_client_config("test-component", "test-host");
243        assert!(config.client_auth_cert_resolver.has_certs());
244        Ok(())
245    }
246
247    #[test]
248    fn test_config_overrides() -> anyhow::Result<()> {
249        let test_certs = test_certs()?;
250        let test_key = test_key()?;
251        let configs = TlsClientConfigs::new([
252            ClientTlsRuntimeConfig {
253                components: vec!["test-component1".into()],
254                hosts: vec!["test-host".into()],
255                client_cert: Some(ClientCertRuntimeConfig {
256                    cert_chain: test_certs,
257                    key_der: test_key,
258                }),
259                ..Default::default()
260            },
261            ClientTlsRuntimeConfig {
262                components: vec!["test-component1".into(), "test-component2".into()],
263                hosts: vec!["test-host".into()],
264                ..Default::default()
265            },
266        ])?;
267        // First match wins
268        let config1 = configs.get_tls_client_config("test-component1", "test-host");
269        assert!(config1.client_auth_cert_resolver.has_certs());
270
271        // Correctly select by differing component ID
272        let config2 = configs.get_tls_client_config("test-component-2", "test-host");
273        assert!(!config2.client_auth_cert_resolver.has_certs());
274        Ok(())
275    }
276
277    const TESTDATA_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata");
278
279    fn test_certs() -> anyhow::Result<Vec<CertificateDer<'static>>> {
280        CertificateDer::pem_file_iter(Path::new(TESTDATA_DIR).join("valid-cert.pem"))?
281            .collect::<Result<Vec<_>, _>>()
282            .context("certs")
283    }
284
285    fn test_key() -> anyhow::Result<PrivateKeyDer<'static>> {
286        PrivateKeyDer::from_pem_file(Path::new(TESTDATA_DIR).join("valid-private-key.pem"))
287            .context("key")
288    }
289
290    impl TlsClientConfigs {
291        fn get_tls_client_config(&self, component_id: &str, host: &str) -> TlsClientConfig {
292            let component_config = self.get_component_tls_configs(component_id);
293            component_config.get_client_config(host).clone()
294        }
295    }
296}