Skip to main content

spin_factor_outbound_networking/
lib.rs

1mod allowed_hosts;
2pub mod runtime_config;
3mod tls;
4
5use std::{collections::HashMap, sync::Arc};
6
7use futures_util::FutureExt as _;
8use opentelemetry_semantic_conventions::attribute::SERVER_PORT;
9use spin_factor_variables::VariablesFactor;
10use spin_factor_wasi::{SocketAddrUse, SocketPermitState, WasiFactor};
11use spin_factors::{
12    ConfigureAppContext, Error, Factor, FactorInstanceBuilder, PrepareContext, RuntimeFactors,
13    anyhow::{self, Context},
14};
15use spin_outbound_networking_config::allowed_hosts::{DisallowedHostHandler, OutboundAllowedHosts};
16use tokio::sync::Semaphore;
17use url::Url;
18
19use crate::{
20    allowed_hosts::allowed_outbound_hosts, runtime_config::RuntimeConfig, tls::TlsClientConfigs,
21};
22pub use allowed_hosts::validate_service_chaining_for_components;
23pub use spin_connection_semaphore::{ConnectionPermit, ConnectionSemaphore};
24
25pub use crate::tls::{ComponentTlsClientConfigs, TlsClientConfig};
26use config::allowed_hosts::AllowedHostsConfig;
27use config::blocked_networks::BlockedNetworks;
28pub use spin_outbound_networking_config as config;
29
30#[derive(Default)]
31pub struct OutboundNetworkingFactor {
32    disallowed_host_handler: Option<Arc<dyn DisallowedHostHandler>>,
33}
34
35impl OutboundNetworkingFactor {
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    /// Sets a handler to be called when a request is disallowed by an
41    /// instance's configured `allowed_outbound_hosts`.
42    pub fn set_disallowed_host_handler(&mut self, handler: impl DisallowedHostHandler + 'static) {
43        self.disallowed_host_handler = Some(Arc::new(handler));
44    }
45}
46
47impl Factor for OutboundNetworkingFactor {
48    type RuntimeConfig = RuntimeConfig;
49    type AppState = AppState;
50    type InstanceBuilder = InstanceBuilder;
51
52    fn configure_app<T: RuntimeFactors>(
53        &self,
54        mut ctx: ConfigureAppContext<T, Self>,
55    ) -> anyhow::Result<Self::AppState> {
56        // Extract allowed_outbound_hosts for all components
57        let component_allowed_hosts = ctx
58            .app()
59            .components()
60            .map(|component| {
61                Ok((
62                    component.id().to_string(),
63                    allowed_outbound_hosts(&component)?
64                        .into_boxed_slice()
65                        .into(),
66                ))
67            })
68            .collect::<anyhow::Result<_>>()?;
69
70        let RuntimeConfig {
71            client_tls_configs,
72            blocked_ip_networks: block_networks,
73            block_private_networks,
74            max_socket_connections,
75            max_total_connections,
76            wait_timeout,
77        } = ctx.take_runtime_config().unwrap_or_default();
78
79        let blocked_networks = BlockedNetworks::new(block_networks, block_private_networks);
80        let tls_client_configs = TlsClientConfigs::new(client_tls_configs)?;
81        let global_connection_semaphore =
82            max_total_connections.map(|n| Arc::new(Semaphore::new(n)));
83
84        if let (Some(socket_cap), Some(global_cap)) =
85            (max_socket_connections, max_total_connections)
86            && socket_cap > global_cap
87        {
88            tracing::warn!(
89                "outbound_networking max_socket_connections ({socket_cap}) exceeds \
90                 max_total_connections ({global_cap}); the global limit will be the effective \
91                 cap for TCP/UDP sockets"
92            );
93        }
94
95        let socket_connection_semaphore =
96            if max_socket_connections.is_some() || global_connection_semaphore.is_some() {
97                Some(ConnectionSemaphore::new(
98                    global_connection_semaphore.clone(),
99                    max_socket_connections,
100                    "wasi-sockets",
101                    wait_timeout,
102                ))
103            } else {
104                None
105            };
106
107        Ok(AppState {
108            component_allowed_hosts,
109            blocked_networks,
110            tls_client_configs,
111            socket_connection_semaphore,
112            global_connection_semaphore,
113            max_total_connections,
114        })
115    }
116
117    fn prepare<T: RuntimeFactors>(
118        &self,
119        mut ctx: PrepareContext<T, Self>,
120    ) -> anyhow::Result<Self::InstanceBuilder> {
121        let hosts = ctx
122            .app_state()
123            .component_allowed_hosts
124            .get(ctx.app_component().id())
125            .cloned()
126            .context("missing component allowed hosts")?;
127        let resolver = ctx
128            .instance_builder::<VariablesFactor>()?
129            .expression_resolver()
130            .clone();
131        let component_ids = ctx
132            .app_component()
133            .app
134            .components()
135            .map(|c| c.id().to_string())
136            .collect::<Vec<_>>();
137        let allowed_hosts_future = async move {
138            let prepared = resolver.prepare().await.inspect_err(|err| {
139                tracing::error!(
140                    %err, "error.type" = "variable_resolution_failed",
141                    "Error resolving variables when checking request against allowed outbound hosts",
142                );
143            })?;
144            AllowedHostsConfig::parse(&hosts, &prepared, &component_ids).inspect_err(|err| {
145                tracing::error!(
146                    %err, "error.type" = "invalid_allowed_hosts",
147                    "Error parsing allowed outbound hosts",
148                );
149            })
150        }
151        .map(|res| res.map(Arc::new).map_err(Arc::new))
152        .boxed()
153        .shared();
154        let allowed_hosts = OutboundAllowedHosts::new(
155            allowed_hosts_future.clone(),
156            self.disallowed_host_handler.clone(),
157        );
158        let blocked_networks = ctx.app_state().blocked_networks.clone();
159        let permit_state = ctx
160            .app_state()
161            .socket_connection_semaphore
162            .clone()
163            .map(SocketPermitState::new);
164
165        match ctx.instance_builder::<WasiFactor>() {
166            Ok(wasi_builder) => {
167                if let Some(state) = permit_state {
168                    wasi_builder.set_socket_permit_state(state);
169                }
170
171                let allowed_hosts = allowed_hosts.clone();
172                wasi_builder.outbound_socket_addr_check(move |addr, addr_use| {
173                    let allowed_hosts = allowed_hosts.clone();
174                    let blocked_networks = blocked_networks.clone();
175                    async move {
176                        let scheme = match addr_use {
177                            SocketAddrUse::TcpBind => return false,
178                            SocketAddrUse::TcpConnect => "tcp",
179                            SocketAddrUse::UdpBind
180                            | SocketAddrUse::UdpConnect
181                            | SocketAddrUse::UdpOutgoingDatagram => "udp",
182                        };
183                        if !allowed_hosts
184                            .check_url(&addr.to_string(), scheme)
185                            .await
186                            .unwrap_or(
187                                // TODO: should this trap (somehow)?
188                                false,
189                            )
190                        {
191                            return false;
192                        }
193                        if blocked_networks.is_blocked(&addr) {
194                            tracing::error!(
195                                "error.type" = "destination_ip_prohibited",
196                                ?addr,
197                                "destination IP prohibited by runtime config"
198                            );
199                            return false;
200                        }
201                        true
202                    }
203                });
204            }
205            Err(Error::NoSuchFactor(_)) => (), // no WasiFactor to configure; that's OK
206            Err(err) => return Err(err.into()),
207        }
208
209        let component_tls_configs = ctx
210            .app_state()
211            .tls_client_configs
212            .get_component_tls_configs(ctx.app_component().id());
213
214        Ok(InstanceBuilder {
215            allowed_hosts,
216            blocked_networks: ctx.app_state().blocked_networks.clone(),
217            component_tls_client_configs: component_tls_configs,
218        })
219    }
220}
221
222pub struct AppState {
223    /// Component ID -> Allowed host list
224    component_allowed_hosts: HashMap<String, Arc<[String]>>,
225    /// Blocked IP networks
226    blocked_networks: BlockedNetworks,
227    /// TLS client configs
228    tls_client_configs: TlsClientConfigs,
229    /// Pre-built semaphore for TCP/UDP socket quota enforcement (global + socket-specific).
230    /// `None` means no limits are configured.
231    socket_connection_semaphore: Option<ConnectionSemaphore>,
232    /// App-wide semaphore capping total concurrent outbound connections across ALL types.
233    /// `None` means unlimited.
234    global_connection_semaphore: Option<Arc<Semaphore>>,
235    /// The configured global connection limit (for warning comparisons in other factors).
236    max_total_connections: Option<usize>,
237}
238
239/// Builds a [`ConnectionSemaphore`] for an outbound factor, incorporating the optional global
240/// connection limit from the networking factor's app state.
241///
242/// Emits a warning when the per-factor limit exceeds the global cap (the global limit would
243/// be the effective ceiling in that case).
244pub fn build_connection_semaphore(
245    networking: Option<&AppState>,
246    factor: &'static str,
247    factor_limit: Option<usize>,
248    wait_timeout: Option<std::time::Duration>,
249) -> ConnectionSemaphore {
250    if let (Some(per_factor), Some(global_limit)) = (
251        factor_limit,
252        networking.and_then(|n| n.max_total_connections),
253    ) && per_factor > global_limit
254    {
255        tracing::warn!(
256            "outbound_{factor} max_connections ({per_factor}) exceeds global \
257             max_total_connections ({global_limit}); the global limit will be the \
258             effective cap"
259        );
260    }
261    ConnectionSemaphore::new(
262        networking.and_then(|n| n.global_connection_semaphore.clone()),
263        factor_limit,
264        factor,
265        wait_timeout,
266    )
267}
268
269pub struct InstanceBuilder {
270    allowed_hosts: OutboundAllowedHosts,
271    blocked_networks: BlockedNetworks,
272    component_tls_client_configs: ComponentTlsClientConfigs,
273}
274
275impl InstanceBuilder {
276    pub fn allowed_hosts(&self) -> OutboundAllowedHosts {
277        self.allowed_hosts.clone()
278    }
279
280    pub fn blocked_networks(&self) -> BlockedNetworks {
281        self.blocked_networks.clone()
282    }
283
284    pub fn component_tls_configs(&self) -> ComponentTlsClientConfigs {
285        self.component_tls_client_configs.clone()
286    }
287}
288
289impl FactorInstanceBuilder for InstanceBuilder {
290    type InstanceState = ();
291
292    fn build(self) -> anyhow::Result<Self::InstanceState> {
293        Ok(())
294    }
295}
296
297/// Records the address host, port, and database as fields on the current tracing span.
298///
299/// This should only be called from within a function that has been instrumented with a span.
300///
301/// The following fields must be pre-declared as empty on the span or they will not show up.
302/// ```
303/// use tracing::field::Empty;
304/// #[tracing::instrument(fields(db.address = Empty, server.port = Empty, db.namespace = Empty))]
305/// fn open() {}
306/// ```
307pub fn record_address_fields(address: &str) {
308    if let Ok(url) = Url::parse(address) {
309        let span = tracing::Span::current();
310        // `db.address` and `db.namespace` are incubating in opentelemetry-semantic-conventions 0.28.
311        // Leaving as string literals to avoid enabling the semconv_experimental feature.
312        span.record("db.address", url.host_str().unwrap_or_default());
313        span.record(SERVER_PORT, url.port().unwrap_or_default());
314        span.record("db.namespace", url.path().trim_start_matches('/'));
315    }
316}