spin_factor_outbound_networking/
lib.rs

1mod allowed_hosts;
2mod blocked_networks;
3pub mod runtime_config;
4mod tls;
5
6use std::{collections::HashMap, sync::Arc};
7
8use futures_util::{
9    future::{BoxFuture, Shared},
10    FutureExt,
11};
12use spin_factor_variables::VariablesFactor;
13use spin_factor_wasi::{SocketAddrUse, WasiFactor};
14use spin_factors::{
15    anyhow::{self, Context},
16    ConfigureAppContext, Error, Factor, FactorInstanceBuilder, PrepareContext, RuntimeFactors,
17};
18use url::Url;
19
20use crate::{runtime_config::RuntimeConfig, tls::TlsClientConfigs};
21
22pub use crate::allowed_hosts::{
23    allowed_outbound_hosts, is_service_chaining_host, parse_service_chaining_target,
24    validate_service_chaining_for_components, AllowedHostConfig, AllowedHostsConfig, HostConfig,
25    OutboundUrl, SERVICE_CHAINING_DOMAIN_SUFFIX,
26};
27pub use crate::blocked_networks::BlockedNetworks;
28pub use crate::tls::{ComponentTlsClientConfigs, TlsClientConfig};
29
30pub type SharedFutureResult<T> = Shared<BoxFuture<'static, Result<Arc<T>, Arc<anyhow::Error>>>>;
31
32#[derive(Default)]
33pub struct OutboundNetworkingFactor {
34    disallowed_host_handler: Option<Arc<dyn DisallowedHostHandler>>,
35}
36
37impl OutboundNetworkingFactor {
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Sets a handler to be called when a request is disallowed by an
43    /// instance's configured `allowed_outbound_hosts`.
44    pub fn set_disallowed_host_handler(&mut self, handler: impl DisallowedHostHandler + 'static) {
45        self.disallowed_host_handler = Some(Arc::new(handler));
46    }
47}
48
49impl Factor for OutboundNetworkingFactor {
50    type RuntimeConfig = RuntimeConfig;
51    type AppState = AppState;
52    type InstanceBuilder = InstanceBuilder;
53
54    fn configure_app<T: RuntimeFactors>(
55        &self,
56        mut ctx: ConfigureAppContext<T, Self>,
57    ) -> anyhow::Result<Self::AppState> {
58        // Extract allowed_outbound_hosts for all components
59        let component_allowed_hosts = ctx
60            .app()
61            .components()
62            .map(|component| {
63                Ok((
64                    component.id().to_string(),
65                    allowed_outbound_hosts(&component)?
66                        .into_boxed_slice()
67                        .into(),
68                ))
69            })
70            .collect::<anyhow::Result<_>>()?;
71
72        let RuntimeConfig {
73            client_tls_configs,
74            blocked_ip_networks: block_networks,
75            block_private_networks,
76        } = ctx.take_runtime_config().unwrap_or_default();
77
78        let blocked_networks = BlockedNetworks::new(block_networks, block_private_networks);
79        let tls_client_configs = TlsClientConfigs::new(client_tls_configs)?;
80
81        Ok(AppState {
82            component_allowed_hosts,
83            blocked_networks,
84            tls_client_configs,
85        })
86    }
87
88    fn prepare<T: RuntimeFactors>(
89        &self,
90        mut ctx: PrepareContext<T, Self>,
91    ) -> anyhow::Result<Self::InstanceBuilder> {
92        let hosts = ctx
93            .app_state()
94            .component_allowed_hosts
95            .get(ctx.app_component().id())
96            .cloned()
97            .context("missing component allowed hosts")?;
98        let resolver = ctx
99            .instance_builder::<VariablesFactor>()?
100            .expression_resolver()
101            .clone();
102        let allowed_hosts_future = async move {
103            let prepared = resolver.prepare().await.inspect_err(|err| {
104                tracing::error!(
105                    %err, "error.type" = "variable_resolution_failed",
106                    "Error resolving variables when checking request against allowed outbound hosts",
107                );
108            })?;
109            AllowedHostsConfig::parse(&hosts, &prepared).inspect_err(|err| {
110                tracing::error!(
111                    %err, "error.type" = "invalid_allowed_hosts",
112                    "Error parsing allowed outbound hosts",
113                );
114            })
115        }
116        .map(|res| res.map(Arc::new).map_err(Arc::new))
117        .boxed()
118        .shared();
119        let allowed_hosts = OutboundAllowedHosts {
120            allowed_hosts_future: allowed_hosts_future.clone(),
121            disallowed_host_handler: self.disallowed_host_handler.clone(),
122        };
123        let blocked_networks = ctx.app_state().blocked_networks.clone();
124
125        match ctx.instance_builder::<WasiFactor>() {
126            Ok(wasi_builder) => {
127                // Update Wasi socket allowed ports
128                let allowed_hosts = allowed_hosts.clone();
129                wasi_builder.outbound_socket_addr_check(move |addr, addr_use| {
130                    let allowed_hosts = allowed_hosts.clone();
131                    let blocked_networks = blocked_networks.clone();
132                    async move {
133                        let scheme = match addr_use {
134                            SocketAddrUse::TcpBind => return false,
135                            SocketAddrUse::TcpConnect => "tcp",
136                            SocketAddrUse::UdpBind
137                            | SocketAddrUse::UdpConnect
138                            | SocketAddrUse::UdpOutgoingDatagram => "udp",
139                        };
140                        if !allowed_hosts
141                            .check_url(&addr.to_string(), scheme)
142                            .await
143                            .unwrap_or(
144                                // TODO: should this trap (somehow)?
145                                false,
146                            )
147                        {
148                            return false;
149                        }
150                        if blocked_networks.is_blocked(&addr) {
151                            tracing::error!(
152                                "error.type" = "destination_ip_prohibited",
153                                ?addr,
154                                "destination IP prohibited by runtime config"
155                            );
156                            return false;
157                        }
158                        true
159                    }
160                });
161            }
162            Err(Error::NoSuchFactor(_)) => (), // no WasiFactor to configure; that's OK
163            Err(err) => return Err(err.into()),
164        }
165
166        let component_tls_configs = ctx
167            .app_state()
168            .tls_client_configs
169            .get_component_tls_configs(ctx.app_component().id());
170
171        Ok(InstanceBuilder {
172            allowed_hosts,
173            blocked_networks: ctx.app_state().blocked_networks.clone(),
174            component_tls_client_configs: component_tls_configs,
175        })
176    }
177}
178
179pub struct AppState {
180    /// Component ID -> Allowed host list
181    component_allowed_hosts: HashMap<String, Arc<[String]>>,
182    /// Blocked IP networks
183    blocked_networks: BlockedNetworks,
184    /// TLS client configs
185    tls_client_configs: TlsClientConfigs,
186}
187
188pub struct InstanceBuilder {
189    allowed_hosts: OutboundAllowedHosts,
190    blocked_networks: BlockedNetworks,
191    component_tls_client_configs: ComponentTlsClientConfigs,
192}
193
194impl InstanceBuilder {
195    pub fn allowed_hosts(&self) -> OutboundAllowedHosts {
196        self.allowed_hosts.clone()
197    }
198
199    pub fn blocked_networks(&self) -> BlockedNetworks {
200        self.blocked_networks.clone()
201    }
202
203    pub fn component_tls_configs(&self) -> ComponentTlsClientConfigs {
204        self.component_tls_client_configs.clone()
205    }
206}
207
208impl FactorInstanceBuilder for InstanceBuilder {
209    type InstanceState = ();
210
211    fn build(self) -> anyhow::Result<Self::InstanceState> {
212        Ok(())
213    }
214}
215
216/// A check for whether a URL is allowed by the outbound networking configuration.
217#[derive(Clone)]
218pub struct OutboundAllowedHosts {
219    allowed_hosts_future: SharedFutureResult<AllowedHostsConfig>,
220    disallowed_host_handler: Option<Arc<dyn DisallowedHostHandler>>,
221}
222
223impl OutboundAllowedHosts {
224    /// Checks address against allowed hosts
225    ///
226    /// Calls the [`DisallowedHostHandler`] if set and URL is disallowed.
227    /// If `url` cannot be parsed, `{scheme}://` is prepended to `url` and retried.
228    pub async fn check_url(&self, url: &str, scheme: &str) -> anyhow::Result<bool> {
229        tracing::debug!("Checking outbound networking request to '{url}'");
230        let url = match OutboundUrl::parse(url, scheme) {
231            Ok(url) => url,
232            Err(err) => {
233                tracing::warn!(%err,
234                    "A component tried to make a request to a url that could not be parsed: {url}",
235                );
236                return Ok(false);
237            }
238        };
239
240        let allowed_hosts = self.resolve().await?;
241        let is_allowed = allowed_hosts.allows(&url);
242        if !is_allowed {
243            tracing::debug!("Disallowed outbound networking request to '{url}'");
244            self.report_disallowed_host(url.scheme(), &url.authority());
245        }
246        Ok(is_allowed)
247    }
248
249    /// Checks if allowed hosts permit relative requests
250    ///
251    /// Calls the [`DisallowedHostHandler`] if set and relative requests are
252    /// disallowed.
253    pub async fn check_relative_url(&self, schemes: &[&str]) -> anyhow::Result<bool> {
254        tracing::debug!("Checking relative outbound networking request with schemes {schemes:?}");
255        let allowed_hosts = self.resolve().await?;
256        let is_allowed = allowed_hosts.allows_relative_url(schemes);
257        if !is_allowed {
258            tracing::debug!(
259                "Disallowed relative outbound networking request with schemes {schemes:?}"
260            );
261            let scheme = schemes.first().unwrap_or(&"");
262            self.report_disallowed_host(scheme, "self");
263        }
264        Ok(is_allowed)
265    }
266
267    async fn resolve(&self) -> anyhow::Result<Arc<AllowedHostsConfig>> {
268        self.allowed_hosts_future
269            .clone()
270            .await
271            .map_err(anyhow::Error::msg)
272    }
273
274    fn report_disallowed_host(&self, scheme: &str, authority: &str) {
275        if let Some(handler) = &self.disallowed_host_handler {
276            handler.handle_disallowed_host(scheme, authority);
277        }
278    }
279}
280
281pub trait DisallowedHostHandler: Send + Sync {
282    fn handle_disallowed_host(&self, scheme: &str, authority: &str);
283}
284
285impl<F: Fn(&str, &str) + Send + Sync> DisallowedHostHandler for F {
286    fn handle_disallowed_host(&self, scheme: &str, authority: &str) {
287        self(scheme, authority);
288    }
289}
290
291/// Records the address host, port, and database as fields on the current tracing span.
292///
293/// This should only be called from within a function that has been instrumented with a span.
294///
295/// The following fields must be pre-declared as empty on the span or they will not show up.
296/// ```
297/// use tracing::field::Empty;
298/// #[tracing::instrument(fields(db.address = Empty, server.port = Empty, db.namespace = Empty))]
299/// fn open() {}
300/// ```
301pub fn record_address_fields(address: &str) {
302    if let Ok(url) = Url::parse(address) {
303        let span = tracing::Span::current();
304        span.record("db.address", url.host_str().unwrap_or_default());
305        span.record("server.port", url.port().unwrap_or_default());
306        span.record("db.namespace", url.path().trim_start_matches('/'));
307    }
308}