spin_factor_outbound_networking/
tls.rs1use std::{collections::HashMap, ops::Deref, sync::Arc};
2
3use anyhow::{Context, ensure};
4use rustls::client::danger::ServerCertVerifier;
5
6use crate::runtime_config::{ClientCertRuntimeConfig, ClientTlsRuntimeConfig};
7
8#[derive(Default)]
10pub struct TlsClientConfigs {
11 component_host_tls_client_configs: HashMap<String, HostTlsClientConfigs>,
13 default_tls_client_config: TlsClientConfig,
15}
16
17impl TlsClientConfigs {
18 pub(crate) fn new(
19 client_tls_configs: impl IntoIterator<Item = ClientTlsRuntimeConfig>,
20 ) -> anyhow::Result<Self> {
21 let mut component_host_tls_client_configs = HashMap::<String, HostTlsClientConfigs>::new();
23 for ClientTlsRuntimeConfig {
24 components,
25 hosts,
26 root_certificates,
27 use_platform_roots,
28 use_webpki_roots,
29 client_cert,
30 } in client_tls_configs
31 {
32 ensure!(
33 !components.is_empty(),
34 "client TLS 'components' list may not be empty"
35 );
36 ensure!(
37 !hosts.is_empty(),
38 "client TLS 'hosts' list may not be empty"
39 );
40 let tls_client_config = TlsClientConfig::new(
41 root_certificates,
42 use_webpki_roots,
43 use_platform_roots,
44 client_cert,
45 )
46 .context("error building TLS client config")?;
47 for component in components {
48 let host_configs = component_host_tls_client_configs
49 .entry(component.clone())
50 .or_default();
51 for host in &hosts {
52 validate_host(host)?;
53 Arc::get_mut(host_configs)
55 .unwrap()
56 .entry(host.clone())
57 .or_insert_with(|| tls_client_config.clone());
58 }
59 }
60 }
61
62 Ok(Self {
63 component_host_tls_client_configs,
64 ..Default::default()
65 })
66 }
67
68 pub fn get_component_tls_configs(&self, component_id: &str) -> ComponentTlsClientConfigs {
70 let host_client_configs = self
71 .component_host_tls_client_configs
72 .get(component_id)
73 .cloned();
74 ComponentTlsClientConfigs {
75 host_client_configs,
76 default_client_config: self.default_tls_client_config.clone(),
77 }
78 }
79}
80
81type HostTlsClientConfigs = Arc<HashMap<String, TlsClientConfig>>;
83
84#[derive(Clone)]
86pub struct ComponentTlsClientConfigs {
87 pub(crate) host_client_configs: Option<HostTlsClientConfigs>,
88 pub(crate) default_client_config: TlsClientConfig,
89}
90
91impl ComponentTlsClientConfigs {
92 pub fn get_client_config(&self, host: &str) -> &TlsClientConfig {
94 self.host_client_configs
95 .as_ref()
96 .and_then(|configs| configs.get(host))
97 .unwrap_or(&self.default_client_config)
98 }
99}
100
101#[derive(Clone)]
103pub struct TlsClientConfig(Arc<rustls::ClientConfig>);
104
105impl TlsClientConfig {
106 fn new(
107 root_certificates: Vec<rustls_pki_types::CertificateDer<'static>>,
108 use_webpki_roots: bool,
109 use_platform_roots: bool,
110 client_cert: Option<ClientCertRuntimeConfig>,
111 ) -> anyhow::Result<Self> {
112 let builder = if use_platform_roots {
113 let crypto_provider = rustls::crypto::CryptoProvider::get_default()
114 .cloned()
115 .unwrap_or_else(|| Arc::new(rustls::crypto::ring::default_provider()));
116
117 let verifier: Arc<dyn ServerCertVerifier> = if root_certificates.is_empty() {
118 Arc::new(
119 rustls_platform_verifier::Verifier::new(crypto_provider.clone())
120 .context("failed to initialize platform certificate verifier")?,
121 )
122 } else {
123 Arc::new(
124 rustls_platform_verifier::Verifier::new_with_extra_roots(
125 root_certificates,
126 crypto_provider.clone(),
127 )
128 .context(
129 "failed to initialize platform certificate verifier with extra roots",
130 )?,
131 )
132 };
133
134 rustls::ClientConfig::builder_with_provider(crypto_provider)
135 .with_safe_default_protocol_versions()?
136 .dangerous()
137 .with_custom_certificate_verifier(verifier)
138 } else {
139 let mut root_store = rustls::RootCertStore::empty();
140 if use_webpki_roots {
141 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
142 }
143 for cert in root_certificates {
144 root_store.add(cert)?;
145 }
146 rustls::ClientConfig::builder().with_root_certificates(root_store)
147 };
148
149 let client_config = if let Some(ClientCertRuntimeConfig {
150 cert_chain,
151 key_der,
152 }) = client_cert
153 {
154 builder.with_client_auth_cert(cert_chain, key_der)?
155 } else {
156 builder.with_no_client_auth()
157 };
158 Ok(Self(client_config.into()))
159 }
160
161 pub fn inner(&self) -> Arc<rustls::ClientConfig> {
163 self.0.clone()
164 }
165}
166
167impl Deref for TlsClientConfig {
168 type Target = rustls::ClientConfig;
169
170 fn deref(&self) -> &Self::Target {
171 &self.0
172 }
173}
174
175impl Default for TlsClientConfig {
176 fn default() -> Self {
177 Self::new(vec![], false, true, None).expect("default client config should be valid")
178 }
179}
180
181pub(crate) fn validate_host(host: &str) -> anyhow::Result<()> {
183 let authority: http::uri::Authority = host
184 .parse()
185 .with_context(|| format!("invalid TLS 'host' {host:?}"))?;
186 ensure!(
187 authority.port().is_none(),
188 "invalid TLS 'host' {host:?}; ports not currently supported"
189 );
190 Ok(())
191}
192
193#[cfg(test)]
194mod tests {
195 use std::path::Path;
196
197 use anyhow::Context;
198 use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
199
200 use super::*;
201
202 #[test]
203 fn test_empty_config() -> anyhow::Result<()> {
204 let configs = TlsClientConfigs::default();
206 configs.get_tls_client_config("foo", "bar");
207 Ok(())
208 }
209
210 #[test]
211 fn test_minimal_config() -> anyhow::Result<()> {
212 let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
213 components: vec!["test-component".into()],
214 hosts: vec!["test-host".into()],
215 root_certificates: vec![],
216 use_platform_roots: false,
217 use_webpki_roots: false,
218 client_cert: None,
219 }])?;
220 let config = configs.get_tls_client_config("test-component", "test-host");
221 let default_config = configs.get_tls_client_config("other_component", "test-host");
223 assert!(!Arc::ptr_eq(&config.0, &default_config.0));
224 Ok(())
225 }
226
227 #[test]
228 fn test_maximal_config() -> anyhow::Result<()> {
229 let test_certs = test_certs()?;
230 let test_key = test_key()?;
231 let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
232 components: vec!["test-component".into()],
233 hosts: vec!["test-host".into()],
234 root_certificates: vec![test_certs[0].clone()],
235 use_platform_roots: false,
236 use_webpki_roots: false,
237 client_cert: Some(ClientCertRuntimeConfig {
238 cert_chain: test_certs,
239 key_der: test_key,
240 }),
241 }])?;
242 let config = configs.get_tls_client_config("test-component", "test-host");
243 assert!(config.client_auth_cert_resolver.has_certs());
244 Ok(())
245 }
246
247 #[test]
248 fn test_config_overrides() -> anyhow::Result<()> {
249 let test_certs = test_certs()?;
250 let test_key = test_key()?;
251 let configs = TlsClientConfigs::new([
252 ClientTlsRuntimeConfig {
253 components: vec!["test-component1".into()],
254 hosts: vec!["test-host".into()],
255 client_cert: Some(ClientCertRuntimeConfig {
256 cert_chain: test_certs,
257 key_der: test_key,
258 }),
259 ..Default::default()
260 },
261 ClientTlsRuntimeConfig {
262 components: vec!["test-component1".into(), "test-component2".into()],
263 hosts: vec!["test-host".into()],
264 ..Default::default()
265 },
266 ])?;
267 let config1 = configs.get_tls_client_config("test-component1", "test-host");
269 assert!(config1.client_auth_cert_resolver.has_certs());
270
271 let config2 = configs.get_tls_client_config("test-component-2", "test-host");
273 assert!(!config2.client_auth_cert_resolver.has_certs());
274 Ok(())
275 }
276
277 const TESTDATA_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata");
278
279 fn test_certs() -> anyhow::Result<Vec<CertificateDer<'static>>> {
280 CertificateDer::pem_file_iter(Path::new(TESTDATA_DIR).join("valid-cert.pem"))?
281 .collect::<Result<Vec<_>, _>>()
282 .context("certs")
283 }
284
285 fn test_key() -> anyhow::Result<PrivateKeyDer<'static>> {
286 PrivateKeyDer::from_pem_file(Path::new(TESTDATA_DIR).join("valid-private-key.pem"))
287 .context("key")
288 }
289
290 impl TlsClientConfigs {
291 fn get_tls_client_config(&self, component_id: &str, host: &str) -> TlsClientConfig {
292 let component_config = self.get_component_tls_configs(component_id);
293 component_config.get_client_config(host).clone()
294 }
295 }
296}