spin_factor_outbound_networking/
tls.rs1use std::{collections::HashMap, ops::Deref, sync::Arc};
2
3use anyhow::{ensure, Context};
4
5use crate::runtime_config::{ClientCertRuntimeConfig, ClientTlsRuntimeConfig};
6
7#[derive(Default)]
9pub struct TlsClientConfigs {
10 component_host_tls_client_configs: HashMap<String, HostTlsClientConfigs>,
12 default_tls_client_config: TlsClientConfig,
14}
15
16impl TlsClientConfigs {
17 pub(crate) fn new(
18 client_tls_configs: impl IntoIterator<Item = ClientTlsRuntimeConfig>,
19 ) -> anyhow::Result<Self> {
20 let mut component_host_tls_client_configs = HashMap::<String, HostTlsClientConfigs>::new();
22 for ClientTlsRuntimeConfig {
23 components,
24 hosts,
25 root_certificates,
26 use_webpki_roots,
27 client_cert,
28 } in client_tls_configs
29 {
30 ensure!(
31 !components.is_empty(),
32 "client TLS 'components' list may not be empty"
33 );
34 ensure!(
35 !hosts.is_empty(),
36 "client TLS 'hosts' list may not be empty"
37 );
38 let tls_client_config =
39 TlsClientConfig::new(root_certificates, use_webpki_roots, client_cert)
40 .context("error building TLS client config")?;
41 for component in components {
42 let host_configs = component_host_tls_client_configs
43 .entry(component.clone())
44 .or_default();
45 for host in &hosts {
46 validate_host(host)?;
47 Arc::get_mut(host_configs)
49 .unwrap()
50 .entry(host.clone())
51 .or_insert_with(|| tls_client_config.clone());
52 }
53 }
54 }
55
56 Ok(Self {
57 component_host_tls_client_configs,
58 ..Default::default()
59 })
60 }
61
62 pub fn get_component_tls_configs(&self, component_id: &str) -> ComponentTlsClientConfigs {
64 let host_client_configs = self
65 .component_host_tls_client_configs
66 .get(component_id)
67 .cloned();
68 ComponentTlsClientConfigs {
69 host_client_configs,
70 default_client_config: self.default_tls_client_config.clone(),
71 }
72 }
73}
74
75type HostTlsClientConfigs = Arc<HashMap<String, TlsClientConfig>>;
77
78#[derive(Clone)]
80pub struct ComponentTlsClientConfigs {
81 pub(crate) host_client_configs: Option<HostTlsClientConfigs>,
82 pub(crate) default_client_config: TlsClientConfig,
83}
84
85impl ComponentTlsClientConfigs {
86 pub fn get_client_config(&self, host: &str) -> &TlsClientConfig {
88 self.host_client_configs
89 .as_ref()
90 .and_then(|configs| configs.get(host))
91 .unwrap_or(&self.default_client_config)
92 }
93}
94
95#[derive(Clone)]
97pub struct TlsClientConfig(Arc<rustls::ClientConfig>);
98
99impl TlsClientConfig {
100 fn new(
101 root_certificates: Vec<rustls_pki_types::CertificateDer<'static>>,
102 use_webpki_roots: bool,
103 client_cert: Option<ClientCertRuntimeConfig>,
104 ) -> anyhow::Result<Self> {
105 let mut root_store = rustls::RootCertStore::empty();
106 if use_webpki_roots {
107 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
108 }
109 for cert in root_certificates {
110 root_store.add(cert)?;
111 }
112
113 let builder = rustls::ClientConfig::builder().with_root_certificates(root_store);
114
115 let client_config = if let Some(ClientCertRuntimeConfig {
116 cert_chain,
117 key_der,
118 }) = client_cert
119 {
120 builder.with_client_auth_cert(cert_chain, key_der)?
121 } else {
122 builder.with_no_client_auth()
123 };
124 Ok(Self(client_config.into()))
125 }
126
127 pub fn inner(&self) -> Arc<rustls::ClientConfig> {
129 self.0.clone()
130 }
131}
132
133impl Deref for TlsClientConfig {
134 type Target = rustls::ClientConfig;
135
136 fn deref(&self) -> &Self::Target {
137 &self.0
138 }
139}
140
141impl Default for TlsClientConfig {
142 fn default() -> Self {
143 Self::new(vec![], true, None).expect("default client config should be valid")
144 }
145}
146
147pub(crate) fn validate_host(host: &str) -> anyhow::Result<()> {
149 let authority: http::uri::Authority = host
150 .parse()
151 .with_context(|| format!("invalid TLS 'host' {host:?}"))?;
152 ensure!(
153 authority.port().is_none(),
154 "invalid TLS 'host' {host:?}; ports not currently supported"
155 );
156 Ok(())
157}
158
159#[cfg(test)]
160mod tests {
161 use std::path::Path;
162
163 use anyhow::Context;
164 use rustls_pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer};
165
166 use super::*;
167
168 #[test]
169 fn test_empty_config() -> anyhow::Result<()> {
170 let configs = TlsClientConfigs::default();
172 configs.get_tls_client_config("foo", "bar");
173 Ok(())
174 }
175
176 #[test]
177 fn test_minimal_config() -> anyhow::Result<()> {
178 let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
179 components: vec!["test-component".into()],
180 hosts: vec!["test-host".into()],
181 root_certificates: vec![],
182 use_webpki_roots: false,
183 client_cert: None,
184 }])?;
185 let config = configs.get_tls_client_config("test-component", "test-host");
186 let default_config = configs.get_tls_client_config("other_component", "test-host");
188 assert!(!Arc::ptr_eq(&config.0, &default_config.0));
189 Ok(())
190 }
191
192 #[test]
193 fn test_maximal_config() -> anyhow::Result<()> {
194 let test_certs = test_certs()?;
195 let test_key = test_key()?;
196 let configs = TlsClientConfigs::new([ClientTlsRuntimeConfig {
197 components: vec!["test-component".into()],
198 hosts: vec!["test-host".into()],
199 root_certificates: vec![test_certs[0].clone()],
200 use_webpki_roots: false,
201 client_cert: Some(ClientCertRuntimeConfig {
202 cert_chain: test_certs,
203 key_der: test_key,
204 }),
205 }])?;
206 let config = configs.get_tls_client_config("test-component", "test-host");
207 assert!(config.client_auth_cert_resolver.has_certs());
208 Ok(())
209 }
210
211 #[test]
212 fn test_config_overrides() -> anyhow::Result<()> {
213 let test_certs = test_certs()?;
214 let test_key = test_key()?;
215 let configs = TlsClientConfigs::new([
216 ClientTlsRuntimeConfig {
217 components: vec!["test-component1".into()],
218 hosts: vec!["test-host".into()],
219 client_cert: Some(ClientCertRuntimeConfig {
220 cert_chain: test_certs,
221 key_der: test_key,
222 }),
223 ..Default::default()
224 },
225 ClientTlsRuntimeConfig {
226 components: vec!["test-component1".into(), "test-component2".into()],
227 hosts: vec!["test-host".into()],
228 ..Default::default()
229 },
230 ])?;
231 let config1 = configs.get_tls_client_config("test-component1", "test-host");
233 assert!(config1.client_auth_cert_resolver.has_certs());
234
235 let config2 = configs.get_tls_client_config("test-component-2", "test-host");
237 assert!(!config2.client_auth_cert_resolver.has_certs());
238 Ok(())
239 }
240
241 const TESTDATA_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata");
242
243 fn test_certs() -> anyhow::Result<Vec<CertificateDer<'static>>> {
244 CertificateDer::pem_file_iter(Path::new(TESTDATA_DIR).join("valid-cert.pem"))?
245 .collect::<Result<Vec<_>, _>>()
246 .context("certs")
247 }
248
249 fn test_key() -> anyhow::Result<PrivateKeyDer<'static>> {
250 PrivateKeyDer::from_pem_file(Path::new(TESTDATA_DIR).join("valid-private-key.pem"))
251 .context("key")
252 }
253
254 impl TlsClientConfigs {
255 fn get_tls_client_config(&self, component_id: &str, host: &str) -> TlsClientConfig {
256 let component_config = self.get_component_tls_configs(component_id);
257 component_config.get_client_config(host).clone()
258 }
259 }
260}