spin_factor_outbound_networking/
lib.rs1mod allowed_hosts;
2pub mod runtime_config;
3mod tls;
4
5use std::{collections::HashMap, sync::Arc};
6
7use futures_util::FutureExt as _;
8use opentelemetry_semantic_conventions::attribute::SERVER_PORT;
9use spin_factor_variables::VariablesFactor;
10use spin_factor_wasi::{SocketAddrUse, SocketPermitState, WasiFactor};
11use spin_factors::{
12 ConfigureAppContext, Error, Factor, FactorInstanceBuilder, PrepareContext, RuntimeFactors,
13 anyhow::{self, Context},
14};
15use spin_outbound_networking_config::allowed_hosts::{DisallowedHostHandler, OutboundAllowedHosts};
16use tokio::sync::Semaphore;
17use url::Url;
18
19use crate::{
20 allowed_hosts::allowed_outbound_hosts, runtime_config::RuntimeConfig, tls::TlsClientConfigs,
21};
22pub use allowed_hosts::validate_service_chaining_for_components;
23pub use spin_connection_semaphore::{ConnectionPermit, ConnectionSemaphore};
24
25pub use crate::tls::{ComponentTlsClientConfigs, TlsClientConfig};
26use config::allowed_hosts::AllowedHostsConfig;
27use config::blocked_networks::BlockedNetworks;
28pub use spin_outbound_networking_config as config;
29
30#[derive(Default)]
31pub struct OutboundNetworkingFactor {
32 disallowed_host_handler: Option<Arc<dyn DisallowedHostHandler>>,
33}
34
35impl OutboundNetworkingFactor {
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn set_disallowed_host_handler(&mut self, handler: impl DisallowedHostHandler + 'static) {
43 self.disallowed_host_handler = Some(Arc::new(handler));
44 }
45}
46
47impl Factor for OutboundNetworkingFactor {
48 type RuntimeConfig = RuntimeConfig;
49 type AppState = AppState;
50 type InstanceBuilder = InstanceBuilder;
51
52 fn configure_app<T: RuntimeFactors>(
53 &self,
54 mut ctx: ConfigureAppContext<T, Self>,
55 ) -> anyhow::Result<Self::AppState> {
56 let component_allowed_hosts = ctx
58 .app()
59 .components()
60 .map(|component| {
61 Ok((
62 component.id().to_string(),
63 allowed_outbound_hosts(&component)?
64 .into_boxed_slice()
65 .into(),
66 ))
67 })
68 .collect::<anyhow::Result<_>>()?;
69
70 let RuntimeConfig {
71 client_tls_configs,
72 blocked_ip_networks: block_networks,
73 block_private_networks,
74 max_socket_connections,
75 max_total_connections,
76 wait_timeout,
77 } = ctx.take_runtime_config().unwrap_or_default();
78
79 let blocked_networks = BlockedNetworks::new(block_networks, block_private_networks);
80 let tls_client_configs = TlsClientConfigs::new(client_tls_configs)?;
81 let global_connection_semaphore =
82 max_total_connections.map(|n| Arc::new(Semaphore::new(n)));
83
84 if let (Some(socket_cap), Some(global_cap)) =
85 (max_socket_connections, max_total_connections)
86 && socket_cap > global_cap
87 {
88 tracing::warn!(
89 "outbound_networking max_socket_connections ({socket_cap}) exceeds \
90 max_total_connections ({global_cap}); the global limit will be the effective \
91 cap for TCP/UDP sockets"
92 );
93 }
94
95 let socket_connection_semaphore =
96 if max_socket_connections.is_some() || global_connection_semaphore.is_some() {
97 Some(ConnectionSemaphore::new(
98 global_connection_semaphore.clone(),
99 max_socket_connections,
100 "wasi-sockets",
101 wait_timeout,
102 ))
103 } else {
104 None
105 };
106
107 Ok(AppState {
108 component_allowed_hosts,
109 blocked_networks,
110 tls_client_configs,
111 socket_connection_semaphore,
112 global_connection_semaphore,
113 max_total_connections,
114 })
115 }
116
117 fn prepare<T: RuntimeFactors>(
118 &self,
119 mut ctx: PrepareContext<T, Self>,
120 ) -> anyhow::Result<Self::InstanceBuilder> {
121 let hosts = ctx
122 .app_state()
123 .component_allowed_hosts
124 .get(ctx.app_component().id())
125 .cloned()
126 .context("missing component allowed hosts")?;
127 let resolver = ctx
128 .instance_builder::<VariablesFactor>()?
129 .expression_resolver()
130 .clone();
131 let component_ids = ctx
132 .app_component()
133 .app
134 .components()
135 .map(|c| c.id().to_string())
136 .collect::<Vec<_>>();
137 let allowed_hosts_future = async move {
138 let prepared = resolver.prepare().await.inspect_err(|err| {
139 tracing::error!(
140 %err, "error.type" = "variable_resolution_failed",
141 "Error resolving variables when checking request against allowed outbound hosts",
142 );
143 })?;
144 AllowedHostsConfig::parse(&hosts, &prepared, &component_ids).inspect_err(|err| {
145 tracing::error!(
146 %err, "error.type" = "invalid_allowed_hosts",
147 "Error parsing allowed outbound hosts",
148 );
149 })
150 }
151 .map(|res| res.map(Arc::new).map_err(Arc::new))
152 .boxed()
153 .shared();
154 let allowed_hosts = OutboundAllowedHosts::new(
155 allowed_hosts_future.clone(),
156 self.disallowed_host_handler.clone(),
157 );
158 let blocked_networks = ctx.app_state().blocked_networks.clone();
159 let permit_state = ctx
160 .app_state()
161 .socket_connection_semaphore
162 .clone()
163 .map(SocketPermitState::new);
164
165 match ctx.instance_builder::<WasiFactor>() {
166 Ok(wasi_builder) => {
167 if let Some(state) = permit_state {
168 wasi_builder.set_socket_permit_state(state);
169 }
170
171 let allowed_hosts = allowed_hosts.clone();
172 wasi_builder.outbound_socket_addr_check(move |addr, addr_use| {
173 let allowed_hosts = allowed_hosts.clone();
174 let blocked_networks = blocked_networks.clone();
175 async move {
176 let scheme = match addr_use {
177 SocketAddrUse::TcpBind => return false,
178 SocketAddrUse::TcpConnect => "tcp",
179 SocketAddrUse::UdpBind
180 | SocketAddrUse::UdpConnect
181 | SocketAddrUse::UdpOutgoingDatagram => "udp",
182 };
183 if !allowed_hosts
184 .check_url(&addr.to_string(), scheme)
185 .await
186 .unwrap_or(
187 false,
189 )
190 {
191 return false;
192 }
193 if blocked_networks.is_blocked(&addr) {
194 tracing::error!(
195 "error.type" = "destination_ip_prohibited",
196 ?addr,
197 "destination IP prohibited by runtime config"
198 );
199 return false;
200 }
201 true
202 }
203 });
204 }
205 Err(Error::NoSuchFactor(_)) => (), Err(err) => return Err(err.into()),
207 }
208
209 let component_tls_configs = ctx
210 .app_state()
211 .tls_client_configs
212 .get_component_tls_configs(ctx.app_component().id());
213
214 Ok(InstanceBuilder {
215 allowed_hosts,
216 blocked_networks: ctx.app_state().blocked_networks.clone(),
217 component_tls_client_configs: component_tls_configs,
218 })
219 }
220}
221
222pub struct AppState {
223 component_allowed_hosts: HashMap<String, Arc<[String]>>,
225 blocked_networks: BlockedNetworks,
227 tls_client_configs: TlsClientConfigs,
229 socket_connection_semaphore: Option<ConnectionSemaphore>,
232 global_connection_semaphore: Option<Arc<Semaphore>>,
235 max_total_connections: Option<usize>,
237}
238
239pub fn build_connection_semaphore(
245 networking: Option<&AppState>,
246 factor: &'static str,
247 factor_limit: Option<usize>,
248 wait_timeout: Option<std::time::Duration>,
249) -> ConnectionSemaphore {
250 if let (Some(per_factor), Some(global_limit)) = (
251 factor_limit,
252 networking.and_then(|n| n.max_total_connections),
253 ) && per_factor > global_limit
254 {
255 tracing::warn!(
256 "outbound_{factor} max_connections ({per_factor}) exceeds global \
257 max_total_connections ({global_limit}); the global limit will be the \
258 effective cap"
259 );
260 }
261 ConnectionSemaphore::new(
262 networking.and_then(|n| n.global_connection_semaphore.clone()),
263 factor_limit,
264 factor,
265 wait_timeout,
266 )
267}
268
269pub struct InstanceBuilder {
270 allowed_hosts: OutboundAllowedHosts,
271 blocked_networks: BlockedNetworks,
272 component_tls_client_configs: ComponentTlsClientConfigs,
273}
274
275impl InstanceBuilder {
276 pub fn allowed_hosts(&self) -> OutboundAllowedHosts {
277 self.allowed_hosts.clone()
278 }
279
280 pub fn blocked_networks(&self) -> BlockedNetworks {
281 self.blocked_networks.clone()
282 }
283
284 pub fn component_tls_configs(&self) -> ComponentTlsClientConfigs {
285 self.component_tls_client_configs.clone()
286 }
287}
288
289impl FactorInstanceBuilder for InstanceBuilder {
290 type InstanceState = ();
291
292 fn build(self) -> anyhow::Result<Self::InstanceState> {
293 Ok(())
294 }
295}
296
297pub fn record_address_fields(address: &str) {
308 if let Ok(url) = Url::parse(address) {
309 let span = tracing::Span::current();
310 span.record("db.address", url.host_str().unwrap_or_default());
313 span.record(SERVER_PORT, url.port().unwrap_or_default());
314 span.record("db.namespace", url.path().trim_start_matches('/'));
315 }
316}