spin_factor_outbound_networking/
tls.rs1use std::{collections::HashMap, ops::Deref, sync::Arc};
2
3use anyhow::{Context, ensure};
4
5use crate::runtime_config::{ClientCertRuntimeConfig, ClientTlsRuntimeConfig};
6
7#[derive(Default)]
9pub struct TlsClientConfigs {
10 component_host_tls_client_configs: HashMap<String, HostTlsClientConfigs>,
12 default_tls_client_config: TlsClientConfig,
14}
15
16impl TlsClientConfigs {
17 pub(crate) fn new(
18 client_tls_configs: impl IntoIterator<Item = ClientTlsRuntimeConfig>,
19 ) -> anyhow::Result<Self> {
20 let mut component_host_tls_client_configs = HashMap::<String, HostTlsClientConfigs>::new();
22 for ClientTlsRuntimeConfig {
23 components,
24 hosts,
25 root_certificates,
26 use_platform_roots,
27 use_webpki_roots,
28 client_cert,
29 } in client_tls_configs
30 {
31 ensure!(
32 !components.is_empty(),
33 "client TLS 'components' list may not be empty"
34 );
35 ensure!(
36 !hosts.is_empty(),
37 "client TLS 'hosts' list may not be empty"
38 );
39 let tls_client_config = TlsClientConfig::new(
40 use_platform_roots,
41 use_webpki_roots,
42 root_certificates,
43 client_cert,
44 )
45 .context("error building TLS client config")?;
46 for component in components {
47 let host_configs = component_host_tls_client_configs
48 .entry(component.clone())
49 .or_default();
50 for host in &hosts {
51 validate_host(host)?;
52 Arc::get_mut(host_configs)
54 .unwrap()
55 .entry(host.clone())
56 .or_insert_with(|| tls_client_config.clone());
57 }
58 }
59 }
60
61 Ok(Self {
62 component_host_tls_client_configs,
63 ..Default::default()
64 })
65 }
66
67 pub fn get_component_tls_configs(&self, component_id: &str) -> ComponentTlsClientConfigs {
69 let host_client_configs = self
70 .component_host_tls_client_configs
71 .get(component_id)
72 .cloned();
73 ComponentTlsClientConfigs {
74 host_client_configs,
75 default_client_config: self.default_tls_client_config.clone(),
76 }
77 }
78}
79
80type HostTlsClientConfigs = Arc<HashMap<String, TlsClientConfig>>;
82
83#[derive(Clone)]
85pub struct ComponentTlsClientConfigs {
86 pub(crate) host_client_configs: Option<HostTlsClientConfigs>,
87 pub(crate) default_client_config: TlsClientConfig,
88}
89
90impl ComponentTlsClientConfigs {
91 pub fn get_client_config(&self, host: &str) -> &TlsClientConfig {
93 self.host_client_configs
94 .as_ref()
95 .and_then(|configs| configs.get(host))
96 .unwrap_or(&self.default_client_config)
97 }
98}
99
100#[derive(Clone)]
102pub struct TlsClientConfig(Arc<rustls::ClientConfig>);
103
104impl TlsClientConfig {
105 fn new(
106 use_platform_roots: bool,
107 use_webpki_roots: bool,
108 root_certificates: Vec<rustls_pki_types::CertificateDer<'static>>,
109 client_cert: Option<ClientCertRuntimeConfig>,
110 ) -> anyhow::Result<Self> {
111 anyhow::ensure!(
112 use_platform_roots || use_webpki_roots || !root_certificates.is_empty(),
113 "at least one of 'use_platform_roots', 'use_webpki_roots', or 'root_certificates' must be set"
114 );
115
116 let mut extra_roots: Box<dyn Iterator<Item = _>> = Box::new(root_certificates.into_iter());
117 if use_webpki_roots {
118 extra_roots = Box::new(
119 extra_roots.chain(webpki_root_certs::TLS_SERVER_ROOT_CERTS.iter().cloned()),
120 );
121 }
122
123 let builder = rustls::ClientConfig::builder();
124 let builder = if use_platform_roots {
125 let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
126 extra_roots,
127 builder.crypto_provider().clone(),
128 )
129 .context("failed to initialize platform certificate verifier")?;
130 builder
131 .dangerous()
132 .with_custom_certificate_verifier(Arc::new(verifier))
133 } else {
134 let mut root_store = rustls::RootCertStore::empty();
135 extra_roots.try_for_each(|cert| root_store.add(cert))?;
136 builder.with_root_certificates(root_store)
137 };
138
139 let client_config = if let Some(ClientCertRuntimeConfig {
140 cert_chain,
141 key_der,
142 }) = client_cert
143 {
144 builder.with_client_auth_cert(cert_chain, key_der)?
145 } else {
146 builder.with_no_client_auth()
147 };
148 Ok(Self(client_config.into()))
149 }
150
151 pub fn inner(&self) -> Arc<rustls::ClientConfig> {
153 self.0.clone()
154 }
155}
156
157impl Deref for TlsClientConfig {
158 type Target = rustls::ClientConfig;
159
160 fn deref(&self) -> &Self::Target {
161 &self.0
162 }
163}
164
165impl Default for TlsClientConfig {
166 fn default() -> Self {
167 Self::new(true, true, vec![], None).expect("default client config should be valid")
168 }
169}
170
171pub(crate) fn validate_host(host: &str) -> anyhow::Result<()> {
173 let authority: http::uri::Authority = host
174 .parse()
175 .with_context(|| format!("invalid TLS 'host' {host:?}"))?;
176 ensure!(
177 authority.port().is_none(),
178 "invalid TLS 'host' {host:?}; ports not currently supported"
179 );
180 Ok(())
181}
182
183#[cfg(test)]
184mod tests {
185 use std::path::Path;
186
187 use anyhow::Context;
188 use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
189
190 use super::*;
191
192 #[test]
193 fn test_empty_config() -> anyhow::Result<()> {
194 let configs = TlsClientConfigs::default();
196 configs.get_tls_client_config("foo", "bar");
197 Ok(())
198 }
199
200 #[test]
201 fn test_minimal_config() -> anyhow::Result<()> {
202 let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
203 components: vec!["test-component".into()],
204 hosts: vec!["test-host".into()],
205 root_certificates: vec![],
206 use_platform_roots: false,
207 use_webpki_roots: true,
208 client_cert: None,
209 }])?;
210 let config = configs.get_tls_client_config("test-component", "test-host");
211 let default_config = configs.get_tls_client_config("other_component", "test-host");
213 assert!(!Arc::ptr_eq(&config.0, &default_config.0));
214 Ok(())
215 }
216
217 #[test]
218 fn test_maximal_config() -> anyhow::Result<()> {
219 let test_certs = test_certs()?;
220 let test_key = test_key()?;
221 let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
222 components: vec!["test-component".into()],
223 hosts: vec!["test-host".into()],
224 root_certificates: vec![test_certs[0].clone()],
225 use_platform_roots: false,
226 use_webpki_roots: false,
227 client_cert: Some(ClientCertRuntimeConfig {
228 cert_chain: test_certs,
229 key_der: test_key,
230 }),
231 }])?;
232 let config = configs.get_tls_client_config("test-component", "test-host");
233 assert!(config.client_auth_cert_resolver.has_certs());
234 Ok(())
235 }
236
237 #[test]
238 fn test_config_overrides() -> anyhow::Result<()> {
239 let test_certs = test_certs()?;
240 let test_key = test_key()?;
241 let configs = TlsClientConfigs::new([
242 ClientTlsRuntimeConfig {
243 components: vec!["test-component1".into()],
244 hosts: vec!["test-host".into()],
245 client_cert: Some(ClientCertRuntimeConfig {
246 cert_chain: test_certs,
247 key_der: test_key,
248 }),
249 ..Default::default()
250 },
251 ClientTlsRuntimeConfig {
252 components: vec!["test-component1".into(), "test-component2".into()],
253 hosts: vec!["test-host".into()],
254 ..Default::default()
255 },
256 ])?;
257 let config1 = configs.get_tls_client_config("test-component1", "test-host");
259 assert!(config1.client_auth_cert_resolver.has_certs());
260
261 let config2 = configs.get_tls_client_config("test-component-2", "test-host");
263 assert!(!config2.client_auth_cert_resolver.has_certs());
264 Ok(())
265 }
266
267 const TESTDATA_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata");
268
269 fn test_certs() -> anyhow::Result<Vec<CertificateDer<'static>>> {
270 CertificateDer::pem_file_iter(Path::new(TESTDATA_DIR).join("valid-cert.pem"))?
271 .collect::<Result<Vec<_>, _>>()
272 .context("certs")
273 }
274
275 fn test_key() -> anyhow::Result<PrivateKeyDer<'static>> {
276 PrivateKeyDer::from_pem_file(Path::new(TESTDATA_DIR).join("valid-private-key.pem"))
277 .context("key")
278 }
279
280 impl TlsClientConfigs {
281 fn get_tls_client_config(&self, component_id: &str, host: &str) -> TlsClientConfig {
282 let component_config = self.get_component_tls_configs(component_id);
283 component_config.get_client_config(host).clone()
284 }
285 }
286}