Skip to main content

spin_factor_outbound_http/
wasi.rs

1use std::{
2    error::Error,
3    future::Future,
4    io::IoSlice,
5    net::SocketAddr,
6    ops::DerefMut,
7    pin::Pin,
8    sync::{Arc, Mutex},
9    task::{self, Context, Poll},
10    time::Duration,
11};
12
13use bytes::Bytes;
14use futures::channel::oneshot;
15use http::{
16    HeaderMap, Uri,
17    header::{CONTENT_LENGTH, HOST},
18    uri::Scheme,
19};
20use http_body::{Body, Frame, SizeHint};
21use http_body_util::{BodyExt, combinators::UnsyncBoxBody};
22use hyper_util::{
23    client::legacy::{
24        Client,
25        connect::{Connected, Connection},
26    },
27    rt::{TokioExecutor, TokioIo},
28};
29use spin_factor_outbound_networking::{
30    ComponentTlsClientConfigs, TlsClientConfig,
31    config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
32};
33use spin_factors::RuntimeFactorsInstanceState;
34use tokio::{
35    io::{AsyncRead, AsyncWrite, ReadBuf},
36    net::TcpStream,
37    sync::{OwnedSemaphorePermit, Semaphore},
38    time::timeout,
39};
40use tokio_rustls::client::TlsStream;
41use tower_service::Service;
42use tracing::{Instrument, Span, field::Empty, instrument};
43use wasmtime::component::HasData;
44use wasmtime_wasi::TrappableError;
45use wasmtime_wasi_http::{
46    p2::{
47        self, HttpError, WasiHttpCtxView,
48        bindings::http::types::{self as p2_types, ErrorCode},
49        body::HyperOutgoingBody,
50        types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
51    },
52    p3::{self, bindings::http::types as p3_types},
53};
54
55use crate::{
56    InstanceHttpHooks, OutboundHttpFactor, SelfRequestOrigin,
57    intercept::{InterceptOutcome, OutboundHttpInterceptor},
58    wasi_2023_10_18, wasi_2023_11_10,
59};
60
61use tracing_opentelemetry::OpenTelemetrySpanExt as _;
62
63const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
64
65pub struct MutexBody<T>(Mutex<T>);
66
67impl<T> MutexBody<T> {
68    pub fn new(body: T) -> Self {
69        Self(Mutex::new(body))
70    }
71}
72
73impl<T: Body + Unpin> Body for MutexBody<T> {
74    type Data = T::Data;
75    type Error = T::Error;
76
77    fn poll_frame(
78        self: Pin<&mut Self>,
79        cx: &mut Context<'_>,
80    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
81        Pin::new(self.0.lock().unwrap().deref_mut()).poll_frame(cx)
82    }
83
84    fn is_end_stream(&self) -> bool {
85        self.0.lock().unwrap().is_end_stream()
86    }
87
88    fn size_hint(&self) -> SizeHint {
89        self.0.lock().unwrap().size_hint()
90    }
91}
92
93/// Body which, when dropped, will notify the receiver corresponding to the
94/// specified sender.
95pub struct NotifyOnDropBody<B> {
96    body: B,
97    _tx: oneshot::Sender<()>,
98}
99
100impl<B> NotifyOnDropBody<B> {
101    pub fn new(body: B, tx: oneshot::Sender<()>) -> Self {
102        Self { body, _tx: tx }
103    }
104}
105
106impl<B: Body + Unpin> Body for NotifyOnDropBody<B> {
107    type Data = B::Data;
108    type Error = B::Error;
109
110    fn poll_frame(
111        mut self: Pin<&mut Self>,
112        cx: &mut Context<'_>,
113    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
114        Pin::new(&mut self.body).poll_frame(cx)
115    }
116
117    fn is_end_stream(&self) -> bool {
118        self.body.is_end_stream()
119    }
120
121    fn size_hint(&self) -> SizeHint {
122        self.body.size_hint()
123    }
124}
125
126pub(crate) struct HasHttp;
127
128impl HasData for HasHttp {
129    type Data<'a> = WasiHttpCtxView<'a>;
130}
131
132impl p3::WasiHttpHooks for InstanceHttpHooks {
133    #[instrument(
134        name = "spin_outbound_http.send_request",
135        skip_all,
136        fields(
137            otel.kind = "client",
138            url.full = Empty,
139            http.request.method = %request.method(),
140            otel.name = %request.method(),
141            http.response.status_code = Empty,
142            server.address = Empty,
143            server.port = Empty,
144        )
145    )]
146    #[allow(clippy::type_complexity)]
147    fn send_request(
148        &mut self,
149        request: http::Request<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
150        options: Option<p3::RequestOptions>,
151        fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
152    ) -> Box<
153        dyn Future<
154                Output = Result<
155                    (
156                        http::Response<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
157                        Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
158                    ),
159                    TrappableError<p3_types::ErrorCode>,
160                >,
161            > + Send,
162    > {
163        self.otel.reparent_tracing_span();
164
165        // If the caller (i.e. the guest) has trouble consuming the response
166        // (e.g. encountering a network error while forwarding it on to some
167        // other place), it can report that error to us via `fut`.  However,
168        // there's nothing we'll be able to do with it here, so we ignore it.
169        // Presumably the guest will also drop the body stream and trailers
170        // future if it encounters such an error while those things are still
171        // arriving, which Hyper will deal with as appropriate (e.g. closing the
172        // connection).
173        _ = fut;
174
175        let request = request.map(|body| MutexBody::new(body).boxed());
176
177        let request_sender = RequestSender {
178            allowed_hosts: self.allowed_hosts.clone(),
179            component_tls_configs: self.component_tls_configs.clone(),
180            request_interceptor: self.request_interceptor.clone(),
181            self_request_origin: self.self_request_origin.clone(),
182            blocked_networks: self.blocked_networks.clone(),
183            http_clients: self.wasi_http_clients.clone(),
184            concurrent_outbound_connections_semaphore: self
185                .concurrent_outbound_connections_semaphore
186                .clone(),
187        };
188        let config = OutgoingRequestConfig {
189            use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
190            connect_timeout: options
191                .and_then(|v| v.connect_timeout)
192                .unwrap_or(DEFAULT_TIMEOUT),
193            first_byte_timeout: options
194                .and_then(|v| v.first_byte_timeout)
195                .unwrap_or(DEFAULT_TIMEOUT),
196            between_bytes_timeout: options
197                .and_then(|v| v.between_bytes_timeout)
198                .unwrap_or(DEFAULT_TIMEOUT),
199        };
200        Box::new(async {
201            match request_sender
202                .send(
203                    request.map(|body| body.map_err(p3_to_p2_error_code).boxed_unsync()),
204                    config,
205                )
206                .await
207            {
208                Ok(IncomingResponse {
209                    resp,
210                    between_bytes_timeout,
211                    ..
212                }) => Ok((
213                    resp.map(|body| {
214                        BetweenBytesTimeoutBody {
215                            body,
216                            sleep: None,
217                            timeout: between_bytes_timeout,
218                        }
219                        .boxed_unsync()
220                    }),
221                    Box::new(async {
222                        // TODO: Can we plumb connection errors through to here, or
223                        // will `hyper_util::client::legacy::Client` pass them all
224                        // via the response body?
225                        Ok(())
226                    }) as Box<dyn Future<Output = _> + Send>,
227                )),
228                Err(http_error) => match http_error.downcast() {
229                    Ok(error_code) => Err(TrappableError::from(p2_to_p3_error_code(error_code))),
230                    Err(trap) => Err(TrappableError::trap(trap)),
231                },
232            }
233        })
234    }
235}
236
237pin_project_lite::pin_project! {
238    struct BetweenBytesTimeoutBody<B> {
239        #[pin]
240        body: B,
241        #[pin]
242        sleep: Option<tokio::time::Sleep>,
243        timeout: Duration,
244    }
245}
246
247impl<B: Body<Error = p2_types::ErrorCode>> Body for BetweenBytesTimeoutBody<B> {
248    type Data = B::Data;
249    type Error = p3_types::ErrorCode;
250
251    fn poll_frame(
252        self: Pin<&mut Self>,
253        cx: &mut Context<'_>,
254    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
255        let mut me = self.project();
256        match me.body.poll_frame(cx) {
257            Poll::Ready(value) => {
258                me.sleep.as_mut().set(None);
259                Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
260            }
261            Poll::Pending => {
262                if me.sleep.is_none() {
263                    me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
264                }
265                task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
266                Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
267            }
268        }
269    }
270
271    fn is_end_stream(&self) -> bool {
272        self.body.is_end_stream()
273    }
274
275    fn size_hint(&self) -> SizeHint {
276        self.body.size_hint()
277    }
278}
279
280pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
281where
282    C: spin_factors::InitContext<OutboundHttpFactor>,
283{
284    let linker = ctx.linker();
285
286    fn get_http<C>(store: &mut C::StoreData) -> WasiHttpCtxView<'_>
287    where
288        C: spin_factors::InitContext<OutboundHttpFactor>,
289    {
290        let (state, table) = C::get_data_with_table(store);
291        let ctx = &mut state.wasi_http_ctx;
292        WasiHttpCtxView {
293            ctx,
294            table,
295            hooks: &mut state.hooks,
296        }
297    }
298
299    let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpCtxView<'_>;
300    wasmtime_wasi_http::p2::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
301        linker, get_http,
302    )?;
303    wasmtime_wasi_http::p2::bindings::http::types::add_to_linker::<_, HasHttp>(
304        linker,
305        &Default::default(),
306        get_http,
307    )?;
308
309    fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
310    where
311        C: spin_factors::InitContext<OutboundHttpFactor>,
312    {
313        let (state, table) = C::get_data_with_table(store);
314        let ctx = &mut state.wasi_http_ctx;
315        p3::WasiHttpCtxView {
316            ctx,
317            table,
318            hooks: &mut state.hooks,
319        }
320    }
321
322    let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
323    p3::bindings::http::client::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
324    p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
325
326    wasi_2023_10_18::add_to_linker(linker, get_http)?;
327    wasi_2023_11_10::add_to_linker(linker, get_http)?;
328
329    Ok(())
330}
331
332impl OutboundHttpFactor {
333    pub fn get_wasi_http_impl(
334        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
335    ) -> Option<WasiHttpCtxView<'_>> {
336        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
337        let ctx = &mut state.wasi_http_ctx;
338        Some(WasiHttpCtxView {
339            ctx,
340            table,
341            hooks: &mut state.hooks,
342        })
343    }
344
345    pub fn get_wasi_p3_http_impl(
346        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
347    ) -> Option<p3::WasiHttpCtxView<'_>> {
348        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
349        let ctx = &mut state.wasi_http_ctx;
350        Some(p3::WasiHttpCtxView {
351            ctx,
352            table,
353            hooks: &mut state.hooks,
354        })
355    }
356}
357
358type OutgoingRequest = http::Request<HyperOutgoingBody>;
359
360impl p2::WasiHttpHooks for InstanceHttpHooks {
361    #[instrument(
362        name = "spin_outbound_http.send_request",
363        skip_all,
364        fields(
365            otel.kind = "client",
366            url.full = Empty,
367            http.request.method = %request.method(),
368            otel.name = %request.method(),
369            http.response.status_code = Empty,
370            server.address = Empty,
371            server.port = Empty,
372        )
373    )]
374    fn send_request(
375        &mut self,
376        request: OutgoingRequest,
377        config: OutgoingRequestConfig,
378    ) -> Result<wasmtime_wasi_http::p2::types::HostFutureIncomingResponse, HttpError> {
379        self.otel.reparent_tracing_span();
380
381        let request_sender = RequestSender {
382            allowed_hosts: self.allowed_hosts.clone(),
383            component_tls_configs: self.component_tls_configs.clone(),
384            request_interceptor: self.request_interceptor.clone(),
385            self_request_origin: self.self_request_origin.clone(),
386            blocked_networks: self.blocked_networks.clone(),
387            http_clients: self.wasi_http_clients.clone(),
388            concurrent_outbound_connections_semaphore: self
389                .concurrent_outbound_connections_semaphore
390                .clone(),
391        };
392        Ok(HostFutureIncomingResponse::Pending(
393            wasmtime_wasi::runtime::spawn(
394                async {
395                    match request_sender.send(request, config).await {
396                        Ok(resp) => Ok(Ok(resp)),
397                        Err(http_error) => match http_error.downcast() {
398                            Ok(error_code) => Ok(Err(error_code)),
399                            Err(trap) => Err(trap),
400                        },
401                    }
402                }
403                .in_current_span(),
404            ),
405        ))
406    }
407}
408
409struct RequestSender {
410    allowed_hosts: OutboundAllowedHosts,
411    blocked_networks: BlockedNetworks,
412    component_tls_configs: ComponentTlsClientConfigs,
413    self_request_origin: Option<SelfRequestOrigin>,
414    request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
415    http_clients: HttpClients,
416    concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
417}
418
419impl RequestSender {
420    async fn send(
421        self,
422        mut request: OutgoingRequest,
423        mut config: OutgoingRequestConfig,
424    ) -> Result<IncomingResponse, HttpError> {
425        self.prepare_request(&mut request, &mut config).await?;
426
427        // If the current span has opentelemetry trace context, inject it into the request
428        spin_telemetry::inject_trace_context(&mut request);
429
430        // Run any configured request interceptor
431        let mut override_connect_addr = None;
432        if let Some(interceptor) = &self.request_interceptor {
433            let intercept_request = std::mem::take(&mut request).into();
434            match interceptor.intercept(intercept_request).await? {
435                InterceptOutcome::Continue(mut req) => {
436                    override_connect_addr = req.override_connect_addr.take();
437                    request = req.into_hyper_request();
438                }
439                InterceptOutcome::Complete(resp) => {
440                    let resp = IncomingResponse {
441                        resp,
442                        worker: None,
443                        between_bytes_timeout: config.between_bytes_timeout,
444                    };
445                    return Ok(resp);
446                }
447            }
448        }
449
450        // Backfill span fields after potentially updating the URL in the interceptor
451        let span = tracing::Span::current();
452        if let Some(addr) = override_connect_addr {
453            span.record("server.address", addr.ip().to_string());
454            span.record("server.port", addr.port());
455        } else if let Some(authority) = request.uri().authority() {
456            span.record("server.address", authority.host());
457            if let Some(port) = authority.port_u16() {
458                span.record("server.port", port);
459            }
460        }
461
462        record_content_length_header(
463            &span,
464            request.headers(),
465            "http.request.header.content-length",
466        );
467
468        Ok(self
469            .send_request(request, config, override_connect_addr)
470            .await?)
471    }
472
473    async fn prepare_request(
474        &self,
475        request: &mut OutgoingRequest,
476        config: &mut OutgoingRequestConfig,
477    ) -> Result<(), ErrorCode> {
478        // wasmtime-wasi-http fills in scheme and authority for relative URLs
479        // (e.g. https://:443/<path>), which makes them hard to reason about.
480        // Undo that here.
481        let uri = request.uri_mut();
482        if uri
483            .authority()
484            .is_some_and(|authority| authority.host().is_empty())
485        {
486            let mut builder = http::uri::Builder::new();
487            if let Some(paq) = uri.path_and_query() {
488                builder = builder.path_and_query(paq.clone());
489            }
490            *uri = builder.build().unwrap();
491        }
492        tracing::Span::current().record("url.full", uri.to_string());
493
494        let is_self_request = match request.uri().authority() {
495            // Some SDKs require an authority, so we support e.g. http://self.alt/self-request
496            Some(authority) => authority.host() == "self.alt",
497            // Otherwise self requests have no authority
498            None => true,
499        };
500
501        // Enforce allowed_outbound_hosts
502        let is_allowed = if is_self_request {
503            self.allowed_hosts
504                .check_relative_url(&["http", "https"])
505                .await
506                .unwrap_or(false)
507        } else {
508            self.allowed_hosts
509                .check_url(&request.uri().to_string(), "https")
510                .await
511                .unwrap_or(false)
512        };
513        if !is_allowed {
514            return Err(ErrorCode::HttpRequestDenied);
515        }
516
517        if is_self_request {
518            // Replace the authority with the "self request origin"
519            let Some(origin) = self.self_request_origin.as_ref() else {
520                tracing::error!(
521                    "Couldn't handle outbound HTTP request to relative URI; no origin set"
522                );
523                return Err(ErrorCode::HttpRequestUriInvalid);
524            };
525
526            config.use_tls = origin.use_tls();
527
528            request.headers_mut().insert(HOST, origin.host_header());
529
530            let path_and_query = request.uri().path_and_query().cloned();
531            *request.uri_mut() = origin.clone().into_uri(path_and_query);
532        }
533
534        // Some servers (looking at you nginx) don't like a host header even though
535        // http/2 allows it: https://github.com/hyperium/hyper/issues/3298.
536        //
537        // Note that we do this _before_ invoking the request interceptor.  It may
538        // decide to add the `host` header back in, regardless of the nginx bug, in
539        // which case we'll let it do so without interferring.
540        request.headers_mut().remove(HOST);
541        Ok(())
542    }
543
544    async fn send_request(
545        self,
546        request: OutgoingRequest,
547        config: OutgoingRequestConfig,
548        override_connect_addr: Option<SocketAddr>,
549    ) -> Result<IncomingResponse, ErrorCode> {
550        let OutgoingRequestConfig {
551            use_tls,
552            connect_timeout,
553            first_byte_timeout,
554            between_bytes_timeout,
555        } = config;
556
557        let tls_client_config = if use_tls {
558            let host = request.uri().host().unwrap_or_default();
559            Some(self.component_tls_configs.get_client_config(host).clone())
560        } else {
561            None
562        };
563
564        let resp = CONNECT_OPTIONS.scope(
565            ConnectOptions {
566                blocked_networks: self.blocked_networks,
567                connect_timeout,
568                tls_client_config,
569                override_connect_addr,
570                concurrent_outbound_connections_semaphore: self
571                    .concurrent_outbound_connections_semaphore,
572            },
573            async move {
574                if use_tls {
575                    self.http_clients.https.request(request).await
576                } else {
577                    // For development purposes, allow configuring plaintext HTTP/2 for a specific host.
578                    let h2c_prior_knowledge_host =
579                        std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
580                    let use_h2c = h2c_prior_knowledge_host.as_deref()
581                        == request.uri().authority().map(|a| a.as_str());
582
583                    if use_h2c {
584                        self.http_clients.http2.request(request).await
585                    } else {
586                        self.http_clients.http1.request(request).await
587                    }
588                }
589            },
590        );
591
592        let resp = timeout(first_byte_timeout, resp)
593            .await
594            .map_err(|_| ErrorCode::ConnectionReadTimeout)?
595            .map_err(hyper_legacy_request_error)?
596            .map(|body| body.map_err(hyper_request_error).boxed_unsync());
597
598        let span = tracing::Span::current();
599        span.record("http.response.status_code", resp.status().as_u16());
600
601        record_content_length_header(&span, resp.headers(), "http.response.header.content-length");
602
603        Ok(IncomingResponse {
604            resp,
605            worker: None,
606            between_bytes_timeout,
607        })
608    }
609}
610
611type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
612type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
613
614#[derive(Clone)]
615pub(super) struct HttpClients {
616    /// Used for non-TLS HTTP/1 connections.
617    http1: HttpClient,
618    /// Used for non-TLS HTTP/2 connections (e.g. when h2 prior knowledge is available).
619    http2: HttpClient,
620    /// Used for HTTP-over-TLS connections, using ALPN to negotiate the HTTP version.
621    https: HttpsClient,
622}
623
624impl HttpClients {
625    pub(super) fn new(enable_pooling: bool) -> Self {
626        let builder = move || {
627            let mut builder = Client::builder(TokioExecutor::new());
628            if !enable_pooling {
629                builder.pool_max_idle_per_host(0);
630            }
631            builder
632        };
633        Self {
634            http1: builder().build(HttpConnector),
635            http2: builder().http2_only(true).build(HttpConnector),
636            https: builder().build(HttpsConnector),
637        }
638    }
639}
640
641tokio::task_local! {
642    /// The options used when establishing a new connection.
643    ///
644    /// We must use task-local variables for these config options when using
645    /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
646    /// them through as parameters.  Moreover, if there's already a pooled connection
647    /// ready, we'll reuse that and ignore these options anyway. After each connection
648    /// is established, the options are dropped.
649    static CONNECT_OPTIONS: ConnectOptions;
650}
651
652#[derive(Clone)]
653struct ConnectOptions {
654    /// The blocked networks configuration.
655    blocked_networks: BlockedNetworks,
656    /// Timeout for establishing a TCP connection.
657    connect_timeout: Duration,
658    /// TLS client configuration to use, if any.
659    tls_client_config: Option<TlsClientConfig>,
660    /// If set, override the address to connect to instead of using the given `uri`'s authority.
661    override_connect_addr: Option<SocketAddr>,
662    /// A semaphore to limit the number of concurrent outbound connections.
663    concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
664}
665
666impl ConnectOptions {
667    /// Establish a TCP connection to the given URI and default port.
668    async fn connect_tcp(
669        &self,
670        uri: &Uri,
671        default_port: u16,
672    ) -> Result<PermittedTcpStream, ErrorCode> {
673        let mut socket_addrs = match self.override_connect_addr {
674            Some(override_connect_addr) => vec![override_connect_addr],
675            None => {
676                let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
677
678                let host_and_port = if authority.port().is_some() {
679                    authority.as_str().to_string()
680                } else {
681                    format!("{}:{}", authority.as_str(), default_port)
682                };
683
684                let socket_addrs = tokio::net::lookup_host(&host_and_port)
685                    .await
686                    .map_err(|err| {
687                        tracing::debug!(?host_and_port, ?err, "Error resolving host");
688                        dns_error("address not available".into(), 0)
689                    })?
690                    .collect::<Vec<_>>();
691                tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
692                socket_addrs
693            }
694        };
695
696        // Remove blocked IPs
697        crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;
698
699        // If we're limiting concurrent outbound requests, acquire a permit
700
701        let permit = crate::concurrent_outbound_connections::acquire_owned_semaphore(
702            "wasi",
703            &self.concurrent_outbound_connections_semaphore,
704        )
705        .await;
706
707        let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
708            .await
709            .map_err(|_| ErrorCode::ConnectionTimeout)?
710            .map_err(|err| match err.kind() {
711                std::io::ErrorKind::AddrNotAvailable => {
712                    dns_error("address not available".into(), 0)
713                }
714                _ => ErrorCode::ConnectionRefused,
715            })?;
716        Ok(PermittedTcpStream {
717            inner: stream,
718            _permit: permit,
719        })
720    }
721
722    /// Establish a TLS connection to the given URI and default port.
723    async fn connect_tls(
724        &self,
725        uri: &Uri,
726        default_port: u16,
727    ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
728        let tcp_stream = self.connect_tcp(uri, default_port).await?;
729
730        let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
731        tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
732
733        let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
734        let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
735            .map_err(|e| {
736                tracing::warn!("dns lookup error: {e:?}");
737                dns_error("invalid dns name".into(), 0)
738            })?
739            .to_owned();
740        connector.connect(domain, tcp_stream).await.map_err(|e| {
741            tracing::warn!("tls protocol error: {e:?}");
742            ErrorCode::TlsProtocolError
743        })
744    }
745}
746
747/// A connector the uses `ConnectOptions`
748#[derive(Clone)]
749struct HttpConnector;
750
751impl HttpConnector {
752    async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
753        let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
754        Ok(TokioIo::new(stream))
755    }
756}
757
758impl Service<Uri> for HttpConnector {
759    type Response = TokioIo<PermittedTcpStream>;
760    type Error = ErrorCode;
761    type Future =
762        Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
763
764    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
765        Poll::Ready(Ok(()))
766    }
767
768    fn call(&mut self, uri: Uri) -> Self::Future {
769        Box::pin(async move { Self::connect(uri).await })
770    }
771}
772
773/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
774#[derive(Clone)]
775struct HttpsConnector;
776
777impl HttpsConnector {
778    async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
779        let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
780        Ok(TokioIo::new(RustlsStream(stream)))
781    }
782}
783
784impl Service<Uri> for HttpsConnector {
785    type Response = TokioIo<RustlsStream>;
786    type Error = ErrorCode;
787    type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
788
789    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
790        Poll::Ready(Ok(()))
791    }
792
793    fn call(&mut self, uri: Uri) -> Self::Future {
794        Box::pin(async move { Self::connect(uri).await })
795    }
796}
797
798struct RustlsStream(TlsStream<PermittedTcpStream>);
799
800impl Connection for RustlsStream {
801    fn connected(&self) -> Connected {
802        if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
803            self.0.get_ref().0.connected().negotiated_h2()
804        } else {
805            self.0.get_ref().0.connected()
806        }
807    }
808}
809
810impl AsyncRead for RustlsStream {
811    fn poll_read(
812        self: Pin<&mut Self>,
813        cx: &mut Context<'_>,
814        buf: &mut ReadBuf<'_>,
815    ) -> Poll<Result<(), std::io::Error>> {
816        Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
817    }
818}
819
820impl AsyncWrite for RustlsStream {
821    fn poll_write(
822        self: Pin<&mut Self>,
823        cx: &mut Context<'_>,
824        buf: &[u8],
825    ) -> Poll<Result<usize, std::io::Error>> {
826        Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
827    }
828
829    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
830        Pin::new(&mut self.get_mut().0).poll_flush(cx)
831    }
832
833    fn poll_shutdown(
834        self: Pin<&mut Self>,
835        cx: &mut Context<'_>,
836    ) -> Poll<Result<(), std::io::Error>> {
837        Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
838    }
839
840    fn poll_write_vectored(
841        self: Pin<&mut Self>,
842        cx: &mut Context<'_>,
843        bufs: &[IoSlice<'_>],
844    ) -> Poll<Result<usize, std::io::Error>> {
845        Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
846    }
847
848    fn is_write_vectored(&self) -> bool {
849        self.0.is_write_vectored()
850    }
851}
852
853/// A TCP stream that holds an optional permit indicating that it is allowed to exist.
854struct PermittedTcpStream {
855    /// The wrapped TCP stream.
856    inner: TcpStream,
857    /// A permit indicating that this stream is allowed to exist.
858    ///
859    /// When this stream is dropped, the permit is also dropped, allowing another
860    /// connection to be established.
861    _permit: Option<OwnedSemaphorePermit>,
862}
863
864impl Connection for PermittedTcpStream {
865    fn connected(&self) -> Connected {
866        self.inner.connected()
867    }
868}
869
870impl AsyncRead for PermittedTcpStream {
871    fn poll_read(
872        self: Pin<&mut Self>,
873        cx: &mut Context<'_>,
874        buf: &mut ReadBuf<'_>,
875    ) -> Poll<std::io::Result<()>> {
876        Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
877    }
878}
879
880impl AsyncWrite for PermittedTcpStream {
881    fn poll_write(
882        self: Pin<&mut Self>,
883        cx: &mut Context<'_>,
884        buf: &[u8],
885    ) -> Poll<Result<usize, std::io::Error>> {
886        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
887    }
888
889    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
890        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
891    }
892
893    fn poll_shutdown(
894        self: Pin<&mut Self>,
895        cx: &mut Context<'_>,
896    ) -> Poll<Result<(), std::io::Error>> {
897        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
898    }
899}
900
901/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
902fn hyper_request_error(err: hyper::Error) -> ErrorCode {
903    // If there's a source, we might be able to extract a wasi-http error from it.
904    if let Some(cause) = err.source()
905        && let Some(err) = cause.downcast_ref::<ErrorCode>()
906    {
907        return err.clone();
908    }
909
910    tracing::warn!("hyper request error: {err:?}");
911
912    ErrorCode::HttpProtocolError
913}
914
915/// Translate a [`hyper_util::client::legacy::Error`] to a wasi-http `ErrorCode` in the context of a request.
916fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
917    // If there's a source, we might be able to extract a wasi-http error from it.
918    if let Some(cause) = err.source()
919        && let Some(err) = cause.downcast_ref::<ErrorCode>()
920    {
921        return err.clone();
922    }
923
924    tracing::warn!("hyper request error: {err:?}");
925
926    ErrorCode::HttpProtocolError
927}
928
929fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
930    ErrorCode::DnsError(
931        wasmtime_wasi_http::p2::bindings::http::types::DnsErrorPayload {
932            rcode: Some(rcode),
933            info_code: Some(info_code),
934        },
935    )
936}
937
938// TODO: Remove this (and uses of it) once
939// https://github.com/spinframework/spin/issues/3274 has been addressed.
940pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
941    match code {
942        p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
943        p2_types::ErrorCode::DnsError(payload) => {
944            p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
945                rcode: payload.rcode,
946                info_code: payload.info_code,
947            })
948        }
949        p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
950        p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
951        p2_types::ErrorCode::DestinationIpProhibited => {
952            p3_types::ErrorCode::DestinationIpProhibited
953        }
954        p2_types::ErrorCode::DestinationIpUnroutable => {
955            p3_types::ErrorCode::DestinationIpUnroutable
956        }
957        p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
958        p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
959        p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
960        p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
961        p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
962        p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
963        p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
964        p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
965        p2_types::ErrorCode::TlsAlertReceived(payload) => {
966            p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
967                alert_id: payload.alert_id,
968                alert_message: payload.alert_message,
969            })
970        }
971        p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
972        p2_types::ErrorCode::HttpRequestLengthRequired => {
973            p3_types::ErrorCode::HttpRequestLengthRequired
974        }
975        p2_types::ErrorCode::HttpRequestBodySize(payload) => {
976            p3_types::ErrorCode::HttpRequestBodySize(payload)
977        }
978        p2_types::ErrorCode::HttpRequestMethodInvalid => {
979            p3_types::ErrorCode::HttpRequestMethodInvalid
980        }
981        p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
982        p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
983        p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
984            p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
985        }
986        p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
987            p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
988                p3_types::FieldSizePayload {
989                    field_name: payload.field_name,
990                    field_size: payload.field_size,
991                }
992            }))
993        }
994        p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
995            p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
996        }
997        p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
998            p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
999                field_name: payload.field_name,
1000                field_size: payload.field_size,
1001            })
1002        }
1003        p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
1004        p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1005            p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1006        }
1007        p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1008            p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
1009                field_name: payload.field_name,
1010                field_size: payload.field_size,
1011            })
1012        }
1013        p2_types::ErrorCode::HttpResponseBodySize(payload) => {
1014            p3_types::ErrorCode::HttpResponseBodySize(payload)
1015        }
1016        p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1017            p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1018        }
1019        p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1020            p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
1021                field_name: payload.field_name,
1022                field_size: payload.field_size,
1023            })
1024        }
1025        p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1026            p3_types::ErrorCode::HttpResponseTransferCoding(payload)
1027        }
1028        p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
1029            p3_types::ErrorCode::HttpResponseContentCoding(payload)
1030        }
1031        p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
1032        p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
1033        p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
1034        p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
1035        p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
1036        p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
1037    }
1038}
1039
1040// TODO: Remove this (and uses of it) once
1041// https://github.com/spinframework/spin/issues/3274 has been addressed.
1042pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
1043    match code {
1044        p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
1045        p3_types::ErrorCode::DnsError(payload) => {
1046            p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
1047                rcode: payload.rcode,
1048                info_code: payload.info_code,
1049            })
1050        }
1051        p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
1052        p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
1053        p3_types::ErrorCode::DestinationIpProhibited => {
1054            p2_types::ErrorCode::DestinationIpProhibited
1055        }
1056        p3_types::ErrorCode::DestinationIpUnroutable => {
1057            p2_types::ErrorCode::DestinationIpUnroutable
1058        }
1059        p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
1060        p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
1061        p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
1062        p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
1063        p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
1064        p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
1065        p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
1066        p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
1067        p3_types::ErrorCode::TlsAlertReceived(payload) => {
1068            p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
1069                alert_id: payload.alert_id,
1070                alert_message: payload.alert_message,
1071            })
1072        }
1073        p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
1074        p3_types::ErrorCode::HttpRequestLengthRequired => {
1075            p2_types::ErrorCode::HttpRequestLengthRequired
1076        }
1077        p3_types::ErrorCode::HttpRequestBodySize(payload) => {
1078            p2_types::ErrorCode::HttpRequestBodySize(payload)
1079        }
1080        p3_types::ErrorCode::HttpRequestMethodInvalid => {
1081            p2_types::ErrorCode::HttpRequestMethodInvalid
1082        }
1083        p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
1084        p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
1085        p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
1086            p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
1087        }
1088        p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
1089            p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
1090                p2_types::FieldSizePayload {
1091                    field_name: payload.field_name,
1092                    field_size: payload.field_size,
1093                }
1094            }))
1095        }
1096        p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
1097            p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
1098        }
1099        p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
1100            p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
1101                field_name: payload.field_name,
1102                field_size: payload.field_size,
1103            })
1104        }
1105        p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
1106        p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1107            p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1108        }
1109        p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1110            p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1111                field_name: payload.field_name,
1112                field_size: payload.field_size,
1113            })
1114        }
1115        p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1116            p2_types::ErrorCode::HttpResponseBodySize(payload)
1117        }
1118        p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1119            p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1120        }
1121        p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1122            p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1123                field_name: payload.field_name,
1124                field_size: payload.field_size,
1125            })
1126        }
1127        p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1128            p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1129        }
1130        p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1131            p2_types::ErrorCode::HttpResponseContentCoding(payload)
1132        }
1133        p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1134        p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1135        p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1136        p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1137        p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1138        p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1139    }
1140}
1141
1142fn record_content_length_header(span: &Span, headers: &HeaderMap, attr_name: &'static str) {
1143    if let Some(content_length) = headers.get(CONTENT_LENGTH)
1144        && let Ok(size_str) = content_length.to_str()
1145    {
1146        span.set_attribute(attr_name, size_str.to_string());
1147    }
1148}