spin_factor_outbound_networking/
lib.rs1mod allowed_hosts;
2mod blocked_networks;
3pub mod runtime_config;
4mod tls;
5
6use std::{collections::HashMap, sync::Arc};
7
8use futures_util::{
9 future::{BoxFuture, Shared},
10 FutureExt,
11};
12use spin_factor_variables::VariablesFactor;
13use spin_factor_wasi::{SocketAddrUse, WasiFactor};
14use spin_factors::{
15 anyhow::{self, Context},
16 ConfigureAppContext, Error, Factor, FactorInstanceBuilder, PrepareContext, RuntimeFactors,
17};
18use url::Url;
19
20use crate::{runtime_config::RuntimeConfig, tls::TlsClientConfigs};
21
22pub use crate::allowed_hosts::{
23 allowed_outbound_hosts, is_service_chaining_host, parse_service_chaining_target,
24 validate_service_chaining_for_components, AllowedHostConfig, AllowedHostsConfig, HostConfig,
25 OutboundUrl, SERVICE_CHAINING_DOMAIN_SUFFIX,
26};
27pub use crate::blocked_networks::BlockedNetworks;
28pub use crate::tls::{ComponentTlsClientConfigs, TlsClientConfig};
29
30pub type SharedFutureResult<T> = Shared<BoxFuture<'static, Result<Arc<T>, Arc<anyhow::Error>>>>;
31
32#[derive(Default)]
33pub struct OutboundNetworkingFactor {
34 disallowed_host_handler: Option<Arc<dyn DisallowedHostHandler>>,
35}
36
37impl OutboundNetworkingFactor {
38 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn set_disallowed_host_handler(&mut self, handler: impl DisallowedHostHandler + 'static) {
45 self.disallowed_host_handler = Some(Arc::new(handler));
46 }
47}
48
49impl Factor for OutboundNetworkingFactor {
50 type RuntimeConfig = RuntimeConfig;
51 type AppState = AppState;
52 type InstanceBuilder = InstanceBuilder;
53
54 fn configure_app<T: RuntimeFactors>(
55 &self,
56 mut ctx: ConfigureAppContext<T, Self>,
57 ) -> anyhow::Result<Self::AppState> {
58 let component_allowed_hosts = ctx
60 .app()
61 .components()
62 .map(|component| {
63 Ok((
64 component.id().to_string(),
65 allowed_outbound_hosts(&component)?
66 .into_boxed_slice()
67 .into(),
68 ))
69 })
70 .collect::<anyhow::Result<_>>()?;
71
72 let RuntimeConfig {
73 client_tls_configs,
74 blocked_ip_networks: block_networks,
75 block_private_networks,
76 } = ctx.take_runtime_config().unwrap_or_default();
77
78 let blocked_networks = BlockedNetworks::new(block_networks, block_private_networks);
79 let tls_client_configs = TlsClientConfigs::new(client_tls_configs)?;
80
81 Ok(AppState {
82 component_allowed_hosts,
83 blocked_networks,
84 tls_client_configs,
85 })
86 }
87
88 fn prepare<T: RuntimeFactors>(
89 &self,
90 mut ctx: PrepareContext<T, Self>,
91 ) -> anyhow::Result<Self::InstanceBuilder> {
92 let hosts = ctx
93 .app_state()
94 .component_allowed_hosts
95 .get(ctx.app_component().id())
96 .cloned()
97 .context("missing component allowed hosts")?;
98 let resolver = ctx
99 .instance_builder::<VariablesFactor>()?
100 .expression_resolver()
101 .clone();
102 let allowed_hosts_future = async move {
103 let prepared = resolver.prepare().await.inspect_err(|err| {
104 tracing::error!(
105 %err, "error.type" = "variable_resolution_failed",
106 "Error resolving variables when checking request against allowed outbound hosts",
107 );
108 })?;
109 AllowedHostsConfig::parse(&hosts, &prepared).inspect_err(|err| {
110 tracing::error!(
111 %err, "error.type" = "invalid_allowed_hosts",
112 "Error parsing allowed outbound hosts",
113 );
114 })
115 }
116 .map(|res| res.map(Arc::new).map_err(Arc::new))
117 .boxed()
118 .shared();
119 let allowed_hosts = OutboundAllowedHosts {
120 allowed_hosts_future: allowed_hosts_future.clone(),
121 disallowed_host_handler: self.disallowed_host_handler.clone(),
122 };
123 let blocked_networks = ctx.app_state().blocked_networks.clone();
124
125 match ctx.instance_builder::<WasiFactor>() {
126 Ok(wasi_builder) => {
127 let allowed_hosts = allowed_hosts.clone();
129 wasi_builder.outbound_socket_addr_check(move |addr, addr_use| {
130 let allowed_hosts = allowed_hosts.clone();
131 let blocked_networks = blocked_networks.clone();
132 async move {
133 let scheme = match addr_use {
134 SocketAddrUse::TcpBind => return false,
135 SocketAddrUse::TcpConnect => "tcp",
136 SocketAddrUse::UdpBind
137 | SocketAddrUse::UdpConnect
138 | SocketAddrUse::UdpOutgoingDatagram => "udp",
139 };
140 if !allowed_hosts
141 .check_url(&addr.to_string(), scheme)
142 .await
143 .unwrap_or(
144 false,
146 )
147 {
148 return false;
149 }
150 if blocked_networks.is_blocked(&addr) {
151 tracing::error!(
152 "error.type" = "destination_ip_prohibited",
153 ?addr,
154 "destination IP prohibited by runtime config"
155 );
156 return false;
157 }
158 true
159 }
160 });
161 }
162 Err(Error::NoSuchFactor(_)) => (), Err(err) => return Err(err.into()),
164 }
165
166 let component_tls_configs = ctx
167 .app_state()
168 .tls_client_configs
169 .get_component_tls_configs(ctx.app_component().id());
170
171 Ok(InstanceBuilder {
172 allowed_hosts,
173 blocked_networks: ctx.app_state().blocked_networks.clone(),
174 component_tls_client_configs: component_tls_configs,
175 })
176 }
177}
178
179pub struct AppState {
180 component_allowed_hosts: HashMap<String, Arc<[String]>>,
182 blocked_networks: BlockedNetworks,
184 tls_client_configs: TlsClientConfigs,
186}
187
188pub struct InstanceBuilder {
189 allowed_hosts: OutboundAllowedHosts,
190 blocked_networks: BlockedNetworks,
191 component_tls_client_configs: ComponentTlsClientConfigs,
192}
193
194impl InstanceBuilder {
195 pub fn allowed_hosts(&self) -> OutboundAllowedHosts {
196 self.allowed_hosts.clone()
197 }
198
199 pub fn blocked_networks(&self) -> BlockedNetworks {
200 self.blocked_networks.clone()
201 }
202
203 pub fn component_tls_configs(&self) -> ComponentTlsClientConfigs {
204 self.component_tls_client_configs.clone()
205 }
206}
207
208impl FactorInstanceBuilder for InstanceBuilder {
209 type InstanceState = ();
210
211 fn build(self) -> anyhow::Result<Self::InstanceState> {
212 Ok(())
213 }
214}
215
216#[derive(Clone)]
218pub struct OutboundAllowedHosts {
219 allowed_hosts_future: SharedFutureResult<AllowedHostsConfig>,
220 disallowed_host_handler: Option<Arc<dyn DisallowedHostHandler>>,
221}
222
223impl OutboundAllowedHosts {
224 pub async fn check_url(&self, url: &str, scheme: &str) -> anyhow::Result<bool> {
229 tracing::debug!("Checking outbound networking request to '{url}'");
230 let url = match OutboundUrl::parse(url, scheme) {
231 Ok(url) => url,
232 Err(err) => {
233 tracing::warn!(%err,
234 "A component tried to make a request to a url that could not be parsed: {url}",
235 );
236 return Ok(false);
237 }
238 };
239
240 let allowed_hosts = self.resolve().await?;
241 let is_allowed = allowed_hosts.allows(&url);
242 if !is_allowed {
243 tracing::debug!("Disallowed outbound networking request to '{url}'");
244 self.report_disallowed_host(url.scheme(), &url.authority());
245 }
246 Ok(is_allowed)
247 }
248
249 pub async fn check_relative_url(&self, schemes: &[&str]) -> anyhow::Result<bool> {
254 tracing::debug!("Checking relative outbound networking request with schemes {schemes:?}");
255 let allowed_hosts = self.resolve().await?;
256 let is_allowed = allowed_hosts.allows_relative_url(schemes);
257 if !is_allowed {
258 tracing::debug!(
259 "Disallowed relative outbound networking request with schemes {schemes:?}"
260 );
261 let scheme = schemes.first().unwrap_or(&"");
262 self.report_disallowed_host(scheme, "self");
263 }
264 Ok(is_allowed)
265 }
266
267 async fn resolve(&self) -> anyhow::Result<Arc<AllowedHostsConfig>> {
268 self.allowed_hosts_future
269 .clone()
270 .await
271 .map_err(anyhow::Error::msg)
272 }
273
274 fn report_disallowed_host(&self, scheme: &str, authority: &str) {
275 if let Some(handler) = &self.disallowed_host_handler {
276 handler.handle_disallowed_host(scheme, authority);
277 }
278 }
279}
280
281pub trait DisallowedHostHandler: Send + Sync {
282 fn handle_disallowed_host(&self, scheme: &str, authority: &str);
283}
284
285impl<F: Fn(&str, &str) + Send + Sync> DisallowedHostHandler for F {
286 fn handle_disallowed_host(&self, scheme: &str, authority: &str) {
287 self(scheme, authority);
288 }
289}
290
291pub fn record_address_fields(address: &str) {
302 if let Ok(url) = Url::parse(address) {
303 let span = tracing::Span::current();
304 span.record("db.address", url.host_str().unwrap_or_default());
305 span.record("server.port", url.port().unwrap_or_default());
306 span.record("db.namespace", url.path().trim_start_matches('/'));
307 }
308}