spin_factor_outbound_networking/
lib.rs

1mod allowed_hosts;
2pub mod runtime_config;
3mod tls;
4
5use std::{collections::HashMap, sync::Arc};
6
7use futures_util::FutureExt as _;
8use spin_factor_variables::VariablesFactor;
9use spin_factor_wasi::{SocketAddrUse, WasiFactor};
10use spin_factors::{
11    anyhow::{self, Context},
12    ConfigureAppContext, Error, Factor, FactorInstanceBuilder, PrepareContext, RuntimeFactors,
13};
14use spin_outbound_networking_config::allowed_hosts::{DisallowedHostHandler, OutboundAllowedHosts};
15use url::Url;
16
17use crate::{
18    allowed_hosts::allowed_outbound_hosts, runtime_config::RuntimeConfig, tls::TlsClientConfigs,
19};
20pub use allowed_hosts::validate_service_chaining_for_components;
21
22pub use crate::tls::{ComponentTlsClientConfigs, TlsClientConfig};
23use config::allowed_hosts::AllowedHostsConfig;
24use config::blocked_networks::BlockedNetworks;
25pub use spin_outbound_networking_config as config;
26
27#[derive(Default)]
28pub struct OutboundNetworkingFactor {
29    disallowed_host_handler: Option<Arc<dyn DisallowedHostHandler>>,
30}
31
32impl OutboundNetworkingFactor {
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Sets a handler to be called when a request is disallowed by an
38    /// instance's configured `allowed_outbound_hosts`.
39    pub fn set_disallowed_host_handler(&mut self, handler: impl DisallowedHostHandler + 'static) {
40        self.disallowed_host_handler = Some(Arc::new(handler));
41    }
42}
43
44impl Factor for OutboundNetworkingFactor {
45    type RuntimeConfig = RuntimeConfig;
46    type AppState = AppState;
47    type InstanceBuilder = InstanceBuilder;
48
49    fn configure_app<T: RuntimeFactors>(
50        &self,
51        mut ctx: ConfigureAppContext<T, Self>,
52    ) -> anyhow::Result<Self::AppState> {
53        // Extract allowed_outbound_hosts for all components
54        let component_allowed_hosts = ctx
55            .app()
56            .components()
57            .map(|component| {
58                Ok((
59                    component.id().to_string(),
60                    allowed_outbound_hosts(&component)?
61                        .into_boxed_slice()
62                        .into(),
63                ))
64            })
65            .collect::<anyhow::Result<_>>()?;
66
67        let RuntimeConfig {
68            client_tls_configs,
69            blocked_ip_networks: block_networks,
70            block_private_networks,
71        } = ctx.take_runtime_config().unwrap_or_default();
72
73        let blocked_networks = BlockedNetworks::new(block_networks, block_private_networks);
74        let tls_client_configs = TlsClientConfigs::new(client_tls_configs)?;
75
76        Ok(AppState {
77            component_allowed_hosts,
78            blocked_networks,
79            tls_client_configs,
80        })
81    }
82
83    fn prepare<T: RuntimeFactors>(
84        &self,
85        mut ctx: PrepareContext<T, Self>,
86    ) -> anyhow::Result<Self::InstanceBuilder> {
87        let hosts = ctx
88            .app_state()
89            .component_allowed_hosts
90            .get(ctx.app_component().id())
91            .cloned()
92            .context("missing component allowed hosts")?;
93        let resolver = ctx
94            .instance_builder::<VariablesFactor>()?
95            .expression_resolver()
96            .clone();
97        let component_ids = ctx
98            .app_component()
99            .app
100            .components()
101            .map(|c| c.id().to_string())
102            .collect::<Vec<_>>();
103        let allowed_hosts_future = async move {
104            let prepared = resolver.prepare().await.inspect_err(|err| {
105                tracing::error!(
106                    %err, "error.type" = "variable_resolution_failed",
107                    "Error resolving variables when checking request against allowed outbound hosts",
108                );
109            })?;
110            AllowedHostsConfig::parse(&hosts, &prepared, &component_ids).inspect_err(|err| {
111                tracing::error!(
112                    %err, "error.type" = "invalid_allowed_hosts",
113                    "Error parsing allowed outbound hosts",
114                );
115            })
116        }
117        .map(|res| res.map(Arc::new).map_err(Arc::new))
118        .boxed()
119        .shared();
120        let allowed_hosts = OutboundAllowedHosts::new(
121            allowed_hosts_future.clone(),
122            self.disallowed_host_handler.clone(),
123        );
124        let blocked_networks = ctx.app_state().blocked_networks.clone();
125
126        match ctx.instance_builder::<WasiFactor>() {
127            Ok(wasi_builder) => {
128                // Update Wasi socket allowed ports
129                let allowed_hosts = allowed_hosts.clone();
130                wasi_builder.outbound_socket_addr_check(move |addr, addr_use| {
131                    let allowed_hosts = allowed_hosts.clone();
132                    let blocked_networks = blocked_networks.clone();
133                    async move {
134                        let scheme = match addr_use {
135                            SocketAddrUse::TcpBind => return false,
136                            SocketAddrUse::TcpConnect => "tcp",
137                            SocketAddrUse::UdpBind
138                            | SocketAddrUse::UdpConnect
139                            | SocketAddrUse::UdpOutgoingDatagram => "udp",
140                        };
141                        if !allowed_hosts
142                            .check_url(&addr.to_string(), scheme)
143                            .await
144                            .unwrap_or(
145                                // TODO: should this trap (somehow)?
146                                false,
147                            )
148                        {
149                            return false;
150                        }
151                        if blocked_networks.is_blocked(&addr) {
152                            tracing::error!(
153                                "error.type" = "destination_ip_prohibited",
154                                ?addr,
155                                "destination IP prohibited by runtime config"
156                            );
157                            return false;
158                        }
159                        true
160                    }
161                });
162            }
163            Err(Error::NoSuchFactor(_)) => (), // no WasiFactor to configure; that's OK
164            Err(err) => return Err(err.into()),
165        }
166
167        let component_tls_configs = ctx
168            .app_state()
169            .tls_client_configs
170            .get_component_tls_configs(ctx.app_component().id());
171
172        Ok(InstanceBuilder {
173            allowed_hosts,
174            blocked_networks: ctx.app_state().blocked_networks.clone(),
175            component_tls_client_configs: component_tls_configs,
176        })
177    }
178}
179
180pub struct AppState {
181    /// Component ID -> Allowed host list
182    component_allowed_hosts: HashMap<String, Arc<[String]>>,
183    /// Blocked IP networks
184    blocked_networks: BlockedNetworks,
185    /// TLS client configs
186    tls_client_configs: TlsClientConfigs,
187}
188
189pub struct InstanceBuilder {
190    allowed_hosts: OutboundAllowedHosts,
191    blocked_networks: BlockedNetworks,
192    component_tls_client_configs: ComponentTlsClientConfigs,
193}
194
195impl InstanceBuilder {
196    pub fn allowed_hosts(&self) -> OutboundAllowedHosts {
197        self.allowed_hosts.clone()
198    }
199
200    pub fn blocked_networks(&self) -> BlockedNetworks {
201        self.blocked_networks.clone()
202    }
203
204    pub fn component_tls_configs(&self) -> ComponentTlsClientConfigs {
205        self.component_tls_client_configs.clone()
206    }
207}
208
209impl FactorInstanceBuilder for InstanceBuilder {
210    type InstanceState = ();
211
212    fn build(self) -> anyhow::Result<Self::InstanceState> {
213        Ok(())
214    }
215}
216
217/// Records the address host, port, and database as fields on the current tracing span.
218///
219/// This should only be called from within a function that has been instrumented with a span.
220///
221/// The following fields must be pre-declared as empty on the span or they will not show up.
222/// ```
223/// use tracing::field::Empty;
224/// #[tracing::instrument(fields(db.address = Empty, server.port = Empty, db.namespace = Empty))]
225/// fn open() {}
226/// ```
227pub fn record_address_fields(address: &str) {
228    if let Ok(url) = Url::parse(address) {
229        let span = tracing::Span::current();
230        span.record("db.address", url.host_str().unwrap_or_default());
231        span.record("server.port", url.port().unwrap_or_default());
232        span.record("db.namespace", url.path().trim_start_matches('/'));
233    }
234}