spin_factor_outbound_networking/
lib.rs

1mod allowed_hosts;
2pub mod runtime_config;
3mod tls;
4
5use std::{collections::HashMap, sync::Arc};
6
7use futures_util::FutureExt as _;
8use spin_factor_variables::VariablesFactor;
9use spin_factor_wasi::{SocketAddrUse, WasiFactor};
10use spin_factors::{
11    anyhow::{self, Context},
12    ConfigureAppContext, Error, Factor, FactorInstanceBuilder, PrepareContext, RuntimeFactors,
13};
14use spin_outbound_networking_config::allowed_hosts::{DisallowedHostHandler, OutboundAllowedHosts};
15use url::Url;
16
17use crate::{
18    allowed_hosts::allowed_outbound_hosts, runtime_config::RuntimeConfig, tls::TlsClientConfigs,
19};
20pub use allowed_hosts::validate_service_chaining_for_components;
21
22pub use crate::tls::{ComponentTlsClientConfigs, TlsClientConfig};
23use config::allowed_hosts::AllowedHostsConfig;
24use config::blocked_networks::BlockedNetworks;
25pub use spin_outbound_networking_config as config;
26
27#[derive(Default)]
28pub struct OutboundNetworkingFactor {
29    disallowed_host_handler: Option<Arc<dyn DisallowedHostHandler>>,
30}
31
32impl OutboundNetworkingFactor {
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Sets a handler to be called when a request is disallowed by an
38    /// instance's configured `allowed_outbound_hosts`.
39    pub fn set_disallowed_host_handler(&mut self, handler: impl DisallowedHostHandler + 'static) {
40        self.disallowed_host_handler = Some(Arc::new(handler));
41    }
42}
43
44impl Factor for OutboundNetworkingFactor {
45    type RuntimeConfig = RuntimeConfig;
46    type AppState = AppState;
47    type InstanceBuilder = InstanceBuilder;
48
49    fn configure_app<T: RuntimeFactors>(
50        &self,
51        mut ctx: ConfigureAppContext<T, Self>,
52    ) -> anyhow::Result<Self::AppState> {
53        // Extract allowed_outbound_hosts for all components
54        let component_allowed_hosts = ctx
55            .app()
56            .components()
57            .map(|component| {
58                Ok((
59                    component.id().to_string(),
60                    allowed_outbound_hosts(&component)?
61                        .into_boxed_slice()
62                        .into(),
63                ))
64            })
65            .collect::<anyhow::Result<_>>()?;
66
67        let RuntimeConfig {
68            client_tls_configs,
69            blocked_ip_networks: block_networks,
70            block_private_networks,
71        } = ctx.take_runtime_config().unwrap_or_default();
72
73        let blocked_networks = BlockedNetworks::new(block_networks, block_private_networks);
74        let tls_client_configs = TlsClientConfigs::new(client_tls_configs)?;
75
76        Ok(AppState {
77            component_allowed_hosts,
78            blocked_networks,
79            tls_client_configs,
80        })
81    }
82
83    fn prepare<T: RuntimeFactors>(
84        &self,
85        mut ctx: PrepareContext<T, Self>,
86    ) -> anyhow::Result<Self::InstanceBuilder> {
87        let hosts = ctx
88            .app_state()
89            .component_allowed_hosts
90            .get(ctx.app_component().id())
91            .cloned()
92            .context("missing component allowed hosts")?;
93        let resolver = ctx
94            .instance_builder::<VariablesFactor>()?
95            .expression_resolver()
96            .clone();
97        let allowed_hosts_future = async move {
98            let prepared = resolver.prepare().await.inspect_err(|err| {
99                tracing::error!(
100                    %err, "error.type" = "variable_resolution_failed",
101                    "Error resolving variables when checking request against allowed outbound hosts",
102                );
103            })?;
104            AllowedHostsConfig::parse(&hosts, &prepared).inspect_err(|err| {
105                tracing::error!(
106                    %err, "error.type" = "invalid_allowed_hosts",
107                    "Error parsing allowed outbound hosts",
108                );
109            })
110        }
111        .map(|res| res.map(Arc::new).map_err(Arc::new))
112        .boxed()
113        .shared();
114        let allowed_hosts = OutboundAllowedHosts::new(
115            allowed_hosts_future.clone(),
116            self.disallowed_host_handler.clone(),
117        );
118        let blocked_networks = ctx.app_state().blocked_networks.clone();
119
120        match ctx.instance_builder::<WasiFactor>() {
121            Ok(wasi_builder) => {
122                // Update Wasi socket allowed ports
123                let allowed_hosts = allowed_hosts.clone();
124                wasi_builder.outbound_socket_addr_check(move |addr, addr_use| {
125                    let allowed_hosts = allowed_hosts.clone();
126                    let blocked_networks = blocked_networks.clone();
127                    async move {
128                        let scheme = match addr_use {
129                            SocketAddrUse::TcpBind => return false,
130                            SocketAddrUse::TcpConnect => "tcp",
131                            SocketAddrUse::UdpBind
132                            | SocketAddrUse::UdpConnect
133                            | SocketAddrUse::UdpOutgoingDatagram => "udp",
134                        };
135                        if !allowed_hosts
136                            .check_url(&addr.to_string(), scheme)
137                            .await
138                            .unwrap_or(
139                                // TODO: should this trap (somehow)?
140                                false,
141                            )
142                        {
143                            return false;
144                        }
145                        if blocked_networks.is_blocked(&addr) {
146                            tracing::error!(
147                                "error.type" = "destination_ip_prohibited",
148                                ?addr,
149                                "destination IP prohibited by runtime config"
150                            );
151                            return false;
152                        }
153                        true
154                    }
155                });
156            }
157            Err(Error::NoSuchFactor(_)) => (), // no WasiFactor to configure; that's OK
158            Err(err) => return Err(err.into()),
159        }
160
161        let component_tls_configs = ctx
162            .app_state()
163            .tls_client_configs
164            .get_component_tls_configs(ctx.app_component().id());
165
166        Ok(InstanceBuilder {
167            allowed_hosts,
168            blocked_networks: ctx.app_state().blocked_networks.clone(),
169            component_tls_client_configs: component_tls_configs,
170        })
171    }
172}
173
174pub struct AppState {
175    /// Component ID -> Allowed host list
176    component_allowed_hosts: HashMap<String, Arc<[String]>>,
177    /// Blocked IP networks
178    blocked_networks: BlockedNetworks,
179    /// TLS client configs
180    tls_client_configs: TlsClientConfigs,
181}
182
183pub struct InstanceBuilder {
184    allowed_hosts: OutboundAllowedHosts,
185    blocked_networks: BlockedNetworks,
186    component_tls_client_configs: ComponentTlsClientConfigs,
187}
188
189impl InstanceBuilder {
190    pub fn allowed_hosts(&self) -> OutboundAllowedHosts {
191        self.allowed_hosts.clone()
192    }
193
194    pub fn blocked_networks(&self) -> BlockedNetworks {
195        self.blocked_networks.clone()
196    }
197
198    pub fn component_tls_configs(&self) -> ComponentTlsClientConfigs {
199        self.component_tls_client_configs.clone()
200    }
201}
202
203impl FactorInstanceBuilder for InstanceBuilder {
204    type InstanceState = ();
205
206    fn build(self) -> anyhow::Result<Self::InstanceState> {
207        Ok(())
208    }
209}
210
211/// Records the address host, port, and database as fields on the current tracing span.
212///
213/// This should only be called from within a function that has been instrumented with a span.
214///
215/// The following fields must be pre-declared as empty on the span or they will not show up.
216/// ```
217/// use tracing::field::Empty;
218/// #[tracing::instrument(fields(db.address = Empty, server.port = Empty, db.namespace = Empty))]
219/// fn open() {}
220/// ```
221pub fn record_address_fields(address: &str) {
222    if let Ok(url) = Url::parse(address) {
223        let span = tracing::Span::current();
224        span.record("db.address", url.host_str().unwrap_or_default());
225        span.record("server.port", url.port().unwrap_or_default());
226        span.record("db.namespace", url.path().trim_start_matches('/'));
227    }
228}