spin_factor_outbound_http/
wasi.rs

1use std::{error::Error, sync::Arc};
2
3use anyhow::Context;
4use http::{header::HOST, Request};
5use http_body_util::BodyExt;
6use spin_factor_outbound_networking::{
7    BlockedNetworks, ComponentTlsClientConfigs, OutboundAllowedHosts, TlsClientConfig,
8};
9use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceState};
10use tokio::{net::TcpStream, time::timeout};
11use tracing::{field::Empty, instrument, Instrument};
12use wasmtime_wasi::p2::{IoImpl, IoView};
13use wasmtime_wasi_http::{
14    bindings::http::types::ErrorCode,
15    body::HyperOutgoingBody,
16    io::TokioIo,
17    types::{HostFutureIncomingResponse, IncomingResponse},
18    WasiHttpCtx, WasiHttpImpl, WasiHttpView,
19};
20
21use crate::{
22    intercept::{InterceptOutcome, OutboundHttpInterceptor},
23    wasi_2023_10_18, wasi_2023_11_10, InstanceState, OutboundHttpFactor, SelfRequestOrigin,
24};
25
26pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
27where
28    C: spin_factors::InitContext<OutboundHttpFactor>,
29{
30    fn get_http<C>(store: &mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>
31    where
32        C: spin_factors::InitContext<OutboundHttpFactor>,
33    {
34        let (state, table) = C::get_data_with_table(store);
35        WasiHttpImpl(IoImpl(WasiHttpImplInner { state, table }))
36    }
37    let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>;
38    let linker = ctx.linker();
39    wasmtime_wasi_http::bindings::http::outgoing_handler::add_to_linker_get_host(linker, get_http)?;
40    wasmtime_wasi_http::bindings::http::types::add_to_linker_get_host(linker, get_http)?;
41
42    wasi_2023_10_18::add_to_linker(linker, get_http)?;
43    wasi_2023_11_10::add_to_linker(linker, get_http)?;
44
45    Ok(())
46}
47
48impl OutboundHttpFactor {
49    pub fn get_wasi_http_impl(
50        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
51    ) -> Option<WasiHttpImpl<impl WasiHttpView + '_>> {
52        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
53        Some(WasiHttpImpl(IoImpl(WasiHttpImplInner { state, table })))
54    }
55}
56
57pub(crate) struct WasiHttpImplInner<'a> {
58    state: &'a mut InstanceState,
59    table: &'a mut ResourceTable,
60}
61
62impl IoView for WasiHttpImplInner<'_> {
63    fn table(&mut self) -> &mut ResourceTable {
64        self.table
65    }
66}
67
68impl WasiHttpView for WasiHttpImplInner<'_> {
69    fn ctx(&mut self) -> &mut WasiHttpCtx {
70        &mut self.state.wasi_http_ctx
71    }
72
73    #[instrument(
74        name = "spin_outbound_http.send_request",
75        skip_all,
76        fields(
77            otel.kind = "client",
78            url.full = Empty,
79            http.request.method = %request.method(),
80            otel.name = %request.method(),
81            http.response.status_code = Empty,
82            server.address = Empty,
83            server.port = Empty,
84        ),
85    )]
86    fn send_request(
87        &mut self,
88        request: Request<wasmtime_wasi_http::body::HyperOutgoingBody>,
89        config: wasmtime_wasi_http::types::OutgoingRequestConfig,
90    ) -> wasmtime_wasi_http::HttpResult<wasmtime_wasi_http::types::HostFutureIncomingResponse> {
91        Ok(HostFutureIncomingResponse::Pending(
92            wasmtime_wasi::runtime::spawn(
93                send_request_impl(
94                    request,
95                    config,
96                    self.state.allowed_hosts.clone(),
97                    self.state.component_tls_configs.clone(),
98                    self.state.request_interceptor.clone(),
99                    self.state.self_request_origin.clone(),
100                    self.state.blocked_networks.clone(),
101                )
102                .in_current_span(),
103            ),
104        ))
105    }
106}
107
108async fn send_request_impl(
109    mut request: Request<wasmtime_wasi_http::body::HyperOutgoingBody>,
110    mut config: wasmtime_wasi_http::types::OutgoingRequestConfig,
111    outbound_allowed_hosts: OutboundAllowedHosts,
112    component_tls_configs: ComponentTlsClientConfigs,
113    request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
114    self_request_origin: Option<SelfRequestOrigin>,
115    blocked_networks: BlockedNetworks,
116) -> anyhow::Result<Result<IncomingResponse, ErrorCode>> {
117    // wasmtime-wasi-http fills in scheme and authority for relative URLs
118    // (e.g. https://:443/<path>), which makes them hard to reason about.
119    // Undo that here.
120    let uri = request.uri_mut();
121    if uri
122        .authority()
123        .is_some_and(|authority| authority.host().is_empty())
124    {
125        let mut builder = http::uri::Builder::new();
126        if let Some(paq) = uri.path_and_query() {
127            builder = builder.path_and_query(paq.clone());
128        }
129        *uri = builder.build().unwrap();
130    }
131    let span = tracing::Span::current();
132    span.record("url.full", uri.to_string());
133
134    spin_telemetry::inject_trace_context(&mut request);
135
136    let host = request.uri().host().unwrap_or_default();
137    let tls_client_config = component_tls_configs.get_client_config(host).clone();
138
139    let is_self_request = request
140        .uri()
141        .authority()
142        .is_some_and(|a| a.host() == "self.alt");
143
144    if request.uri().authority().is_some() && !is_self_request {
145        // Absolute URI
146        let is_allowed = outbound_allowed_hosts
147            .check_url(&request.uri().to_string(), "https")
148            .await
149            .unwrap_or(false);
150        if !is_allowed {
151            return Ok(Err(ErrorCode::HttpRequestDenied));
152        }
153    } else {
154        // Relative URI ("self" request)
155        let is_allowed = outbound_allowed_hosts
156            .check_relative_url(&["http", "https"])
157            .await
158            .unwrap_or(false);
159        if !is_allowed {
160            return Ok(Err(ErrorCode::HttpRequestDenied));
161        }
162
163        let Some(origin) = self_request_origin else {
164            tracing::error!("Couldn't handle outbound HTTP request to relative URI; no origin set");
165            return Ok(Err(ErrorCode::HttpRequestUriInvalid));
166        };
167
168        config.use_tls = origin.use_tls();
169
170        request.headers_mut().insert(HOST, origin.host_header());
171
172        let path_and_query = request.uri().path_and_query().cloned();
173        *request.uri_mut() = origin.into_uri(path_and_query);
174    }
175
176    if let Some(interceptor) = request_interceptor {
177        let intercept_request = std::mem::take(&mut request).into();
178        match interceptor.intercept(intercept_request).await? {
179            InterceptOutcome::Continue(req) => {
180                request = req.into_hyper_request();
181            }
182            InterceptOutcome::Complete(resp) => {
183                let resp = IncomingResponse {
184                    resp,
185                    worker: None,
186                    between_bytes_timeout: config.between_bytes_timeout,
187                };
188                return Ok(Ok(resp));
189            }
190        }
191    }
192
193    let authority = request.uri().authority().context("authority not set")?;
194    span.record("server.address", authority.host());
195    if let Some(port) = authority.port() {
196        span.record("server.port", port.as_u16());
197    }
198
199    Ok(send_request_handler(request, config, tls_client_config, blocked_networks).await)
200}
201
202/// This is a fork of wasmtime_wasi_http::default_send_request_handler function
203/// forked from bytecodealliance/wasmtime commit-sha 29a76b68200fcfa69c8fb18ce6c850754279a05b
204/// This fork provides the ability to configure client cert auth for mTLS
205async fn send_request_handler(
206    mut request: http::Request<HyperOutgoingBody>,
207    wasmtime_wasi_http::types::OutgoingRequestConfig {
208        use_tls,
209        connect_timeout,
210        first_byte_timeout,
211        between_bytes_timeout,
212    }: wasmtime_wasi_http::types::OutgoingRequestConfig,
213    tls_client_config: TlsClientConfig,
214    blocked_networks: BlockedNetworks,
215) -> Result<wasmtime_wasi_http::types::IncomingResponse, ErrorCode> {
216    let authority_str = if let Some(authority) = request.uri().authority() {
217        if authority.port().is_some() {
218            authority.to_string()
219        } else {
220            let port = if use_tls { 443 } else { 80 };
221            format!("{}:{port}", authority)
222        }
223    } else {
224        return Err(ErrorCode::HttpRequestUriInvalid);
225    };
226
227    // Resolve the authority to IP addresses
228    let mut socket_addrs = tokio::net::lookup_host(&authority_str)
229        .await
230        .map_err(|_| dns_error("address not available".into(), 0))?
231        .collect::<Vec<_>>();
232
233    // Remove blocked IPs
234    let blocked_addrs = blocked_networks.remove_blocked(&mut socket_addrs);
235    if socket_addrs.is_empty() && !blocked_addrs.is_empty() {
236        tracing::error!(
237            "error.type" = "destination_ip_prohibited",
238            ?blocked_addrs,
239            "all destination IP(s) prohibited by runtime config"
240        );
241        return Err(ErrorCode::DestinationIpProhibited);
242    }
243
244    let tcp_stream = timeout(connect_timeout, TcpStream::connect(socket_addrs.as_slice()))
245        .await
246        .map_err(|_| ErrorCode::ConnectionTimeout)?
247        .map_err(|err| match err.kind() {
248            std::io::ErrorKind::AddrNotAvailable => dns_error("address not available".into(), 0),
249            _ => ErrorCode::ConnectionRefused,
250        })?;
251
252    let (mut sender, worker) = if use_tls {
253        #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))]
254        {
255            return Err(ErrorCode::InternalError(Some(
256                "unsupported architecture for SSL".to_string(),
257            )));
258        }
259
260        #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
261        {
262            use rustls::pki_types::ServerName;
263            let connector = tokio_rustls::TlsConnector::from(tls_client_config.inner());
264            let mut parts = authority_str.split(':');
265            let host = parts.next().unwrap_or(&authority_str);
266            let domain = ServerName::try_from(host)
267                .map_err(|e| {
268                    tracing::warn!("dns lookup error: {e:?}");
269                    dns_error("invalid dns name".to_string(), 0)
270                })?
271                .to_owned();
272            let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
273                tracing::warn!("tls protocol error: {e:?}");
274                ErrorCode::TlsProtocolError
275            })?;
276            let stream = TokioIo::new(stream);
277
278            let (sender, conn) = timeout(
279                connect_timeout,
280                hyper::client::conn::http1::handshake(stream),
281            )
282            .await
283            .map_err(|_| ErrorCode::ConnectionTimeout)?
284            .map_err(hyper_request_error)?;
285
286            let worker = wasmtime_wasi::runtime::spawn(async move {
287                match conn.await {
288                    Ok(()) => {}
289                    // TODO: shouldn't throw away this error and ideally should
290                    // surface somewhere.
291                    Err(e) => tracing::warn!("dropping error {e}"),
292                }
293            });
294
295            (sender, worker)
296        }
297    } else {
298        let tcp_stream = TokioIo::new(tcp_stream);
299        let (sender, conn) = timeout(
300            connect_timeout,
301            // TODO: we should plumb the builder through the http context, and use it here
302            hyper::client::conn::http1::handshake(tcp_stream),
303        )
304        .await
305        .map_err(|_| ErrorCode::ConnectionTimeout)?
306        .map_err(hyper_request_error)?;
307
308        let worker = wasmtime_wasi::runtime::spawn(async move {
309            match conn.await {
310                Ok(()) => {}
311                // TODO: same as above, shouldn't throw this error away.
312                Err(e) => tracing::warn!("dropping error {e}"),
313            }
314        });
315
316        (sender, worker)
317    };
318
319    // at this point, the request contains the scheme and the authority, but
320    // the http packet should only include those if addressing a proxy, so
321    // remove them here, since SendRequest::send_request does not do it for us
322    *request.uri_mut() = http::Uri::builder()
323        .path_and_query(
324            request
325                .uri()
326                .path_and_query()
327                .map(|p| p.as_str())
328                .unwrap_or("/"),
329        )
330        .build()
331        .expect("comes from valid request");
332
333    let resp = timeout(first_byte_timeout, sender.send_request(request))
334        .await
335        .map_err(|_| ErrorCode::ConnectionReadTimeout)?
336        .map_err(hyper_request_error)?
337        .map(|body| body.map_err(hyper_request_error).boxed());
338
339    tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
340
341    Ok(wasmtime_wasi_http::types::IncomingResponse {
342        resp,
343        worker: Some(worker),
344        between_bytes_timeout,
345    })
346}
347
348/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
349fn hyper_request_error(err: hyper::Error) -> ErrorCode {
350    // If there's a source, we might be able to extract a wasi-http error from it.
351    if let Some(cause) = err.source() {
352        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
353            return err.clone();
354        }
355    }
356
357    tracing::warn!("hyper request error: {err:?}");
358
359    ErrorCode::HttpProtocolError
360}
361
362fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
363    ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
364        rcode: Some(rcode),
365        info_code: Some(info_code),
366    })
367}