spin_factor_outbound_http/
wasi.rs

1use std::{
2    error::Error,
3    future::Future,
4    io::IoSlice,
5    net::SocketAddr,
6    ops::DerefMut,
7    pin::Pin,
8    sync::{Arc, Mutex},
9    task::{self, Context, Poll},
10    time::Duration,
11};
12
13use bytes::Bytes;
14use http::{header::HOST, uri::Scheme, Uri};
15use http_body::{Body, Frame, SizeHint};
16use http_body_util::{combinators::UnsyncBoxBody, BodyExt};
17use hyper_util::{
18    client::legacy::{
19        connect::{Connected, Connection},
20        Client,
21    },
22    rt::{TokioExecutor, TokioIo},
23};
24use spin_factor_outbound_networking::{
25    config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
26    ComponentTlsClientConfigs, TlsClientConfig,
27};
28use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceState};
29use tokio::{
30    io::{AsyncRead, AsyncWrite, ReadBuf},
31    net::TcpStream,
32    sync::{OwnedSemaphorePermit, Semaphore},
33    time::timeout,
34};
35use tokio_rustls::client::TlsStream;
36use tower_service::Service;
37use tracing::{field::Empty, instrument, Instrument};
38use wasmtime::component::HasData;
39use wasmtime_wasi::TrappableError;
40use wasmtime_wasi_http::{
41    bindings::http::types::{self as p2_types, ErrorCode},
42    body::HyperOutgoingBody,
43    p3::{self, bindings::http::types as p3_types},
44    types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
45    HttpError, WasiHttpCtx, WasiHttpImpl, WasiHttpView,
46};
47
48use crate::{
49    intercept::{InterceptOutcome, OutboundHttpInterceptor},
50    wasi_2023_10_18, wasi_2023_11_10, InstanceState, OutboundHttpFactor, SelfRequestOrigin,
51};
52
53const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
54
55pub struct MutexBody<T>(Mutex<T>);
56
57impl<T> MutexBody<T> {
58    pub fn new(body: T) -> Self {
59        Self(Mutex::new(body))
60    }
61}
62
63impl<T: Body + Unpin> Body for MutexBody<T> {
64    type Data = T::Data;
65    type Error = T::Error;
66
67    fn poll_frame(
68        self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
71        Pin::new(self.0.lock().unwrap().deref_mut()).poll_frame(cx)
72    }
73
74    fn is_end_stream(&self) -> bool {
75        self.0.lock().unwrap().is_end_stream()
76    }
77
78    fn size_hint(&self) -> SizeHint {
79        self.0.lock().unwrap().size_hint()
80    }
81}
82
83pub(crate) struct HasHttp;
84
85impl HasData for HasHttp {
86    type Data<'a> = WasiHttpImpl<WasiHttpImplInner<'a>>;
87}
88
89impl p3::WasiHttpCtx for InstanceState {
90    fn send_request(
91        &mut self,
92        request: http::Request<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
93        options: Option<p3::RequestOptions>,
94        fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
95    ) -> Box<
96        dyn Future<
97                Output = Result<
98                    (
99                        http::Response<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
100                        Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
101                    ),
102                    TrappableError<p3_types::ErrorCode>,
103                >,
104            > + Send,
105    > {
106        // If the caller (i.e. the guest) has trouble consuming the response
107        // (e.g. encountering a network error while forwarding it on to some
108        // other place), it can report that error to us via `fut`.  However,
109        // there's nothing we'll be able to do with it here, so we ignore it.
110        // Presumably the guest will also drop the body stream and trailers
111        // future if it encounters such an error while those things are still
112        // arriving, which Hyper will deal with as appropriate (e.g. closing the
113        // connection).
114        _ = fut;
115
116        let request = request.map(|body| MutexBody::new(body).boxed());
117
118        let request_sender = RequestSender {
119            allowed_hosts: self.allowed_hosts.clone(),
120            component_tls_configs: self.component_tls_configs.clone(),
121            request_interceptor: self.request_interceptor.clone(),
122            self_request_origin: self.self_request_origin.clone(),
123            blocked_networks: self.blocked_networks.clone(),
124            http_clients: self.wasi_http_clients.clone(),
125            concurrent_outbound_connections_semaphore: self
126                .concurrent_outbound_connections_semaphore
127                .clone(),
128        };
129        let config = OutgoingRequestConfig {
130            use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
131            connect_timeout: options
132                .and_then(|v| v.connect_timeout)
133                .unwrap_or(DEFAULT_TIMEOUT),
134            first_byte_timeout: options
135                .and_then(|v| v.first_byte_timeout)
136                .unwrap_or(DEFAULT_TIMEOUT),
137            between_bytes_timeout: options
138                .and_then(|v| v.between_bytes_timeout)
139                .unwrap_or(DEFAULT_TIMEOUT),
140        };
141        Box::new(async {
142            match request_sender
143                .send(
144                    request.map(|body| body.map_err(p3_to_p2_error_code).boxed_unsync()),
145                    config,
146                )
147                .await
148            {
149                Ok(IncomingResponse {
150                    resp,
151                    between_bytes_timeout,
152                    ..
153                }) => Ok((
154                    resp.map(|body| {
155                        BetweenBytesTimeoutBody {
156                            body,
157                            sleep: None,
158                            timeout: between_bytes_timeout,
159                        }
160                        .boxed_unsync()
161                    }),
162                    Box::new(async {
163                        // TODO: Can we plumb connection errors through to here, or
164                        // will `hyper_util::client::legacy::Client` pass them all
165                        // via the response body?
166                        Ok(())
167                    }) as Box<dyn Future<Output = _> + Send>,
168                )),
169                Err(http_error) => match http_error.downcast() {
170                    Ok(error_code) => Err(TrappableError::from(p2_to_p3_error_code(error_code))),
171                    Err(trap) => Err(TrappableError::trap(trap)),
172                },
173            }
174        })
175    }
176}
177
178pin_project_lite::pin_project! {
179    struct BetweenBytesTimeoutBody<B> {
180        #[pin]
181        body: B,
182        #[pin]
183        sleep: Option<tokio::time::Sleep>,
184        timeout: Duration,
185    }
186}
187
188impl<B: Body<Error = p2_types::ErrorCode>> Body for BetweenBytesTimeoutBody<B> {
189    type Data = B::Data;
190    type Error = p3_types::ErrorCode;
191
192    fn poll_frame(
193        self: Pin<&mut Self>,
194        cx: &mut Context<'_>,
195    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
196        let mut me = self.project();
197        match me.body.poll_frame(cx) {
198            Poll::Ready(value) => {
199                me.sleep.as_mut().set(None);
200                Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
201            }
202            Poll::Pending => {
203                if me.sleep.is_none() {
204                    me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
205                }
206                task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
207                Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
208            }
209        }
210    }
211
212    fn is_end_stream(&self) -> bool {
213        self.body.is_end_stream()
214    }
215
216    fn size_hint(&self) -> SizeHint {
217        self.body.size_hint()
218    }
219}
220
221pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
222where
223    C: spin_factors::InitContext<OutboundHttpFactor>,
224{
225    let linker = ctx.linker();
226
227    fn get_http<C>(store: &mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>
228    where
229        C: spin_factors::InitContext<OutboundHttpFactor>,
230    {
231        let (state, table) = C::get_data_with_table(store);
232        WasiHttpImpl(WasiHttpImplInner { state, table })
233    }
234
235    let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>;
236    wasmtime_wasi_http::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
237        linker, get_http,
238    )?;
239    wasmtime_wasi_http::bindings::http::types::add_to_linker::<_, HasHttp>(
240        linker,
241        &Default::default(),
242        get_http,
243    )?;
244
245    fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
246    where
247        C: spin_factors::InitContext<OutboundHttpFactor>,
248    {
249        let (state, table) = C::get_data_with_table(store);
250        p3::WasiHttpCtxView { ctx: state, table }
251    }
252
253    let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
254    p3::bindings::http::handler::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
255    p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
256
257    wasi_2023_10_18::add_to_linker(linker, get_http)?;
258    wasi_2023_11_10::add_to_linker(linker, get_http)?;
259
260    Ok(())
261}
262
263impl OutboundHttpFactor {
264    pub fn get_wasi_http_impl(
265        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
266    ) -> Option<WasiHttpImpl<impl WasiHttpView + '_>> {
267        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
268        Some(WasiHttpImpl(WasiHttpImplInner { state, table }))
269    }
270
271    pub fn get_wasi_p3_http_impl(
272        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
273    ) -> Option<p3::WasiHttpCtxView<'_>> {
274        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
275        Some(p3::WasiHttpCtxView { ctx: state, table })
276    }
277}
278
279pub(crate) struct WasiHttpImplInner<'a> {
280    state: &'a mut InstanceState,
281    table: &'a mut ResourceTable,
282}
283
284type OutgoingRequest = http::Request<HyperOutgoingBody>;
285
286impl WasiHttpView for WasiHttpImplInner<'_> {
287    fn ctx(&mut self) -> &mut WasiHttpCtx {
288        &mut self.state.wasi_http_ctx
289    }
290
291    fn table(&mut self) -> &mut ResourceTable {
292        self.table
293    }
294
295    #[instrument(
296        name = "spin_outbound_http.send_request",
297        skip_all,
298        fields(
299            otel.kind = "client",
300            url.full = Empty,
301            http.request.method = %request.method(),
302            otel.name = %request.method(),
303            http.response.status_code = Empty,
304            server.address = Empty,
305            server.port = Empty,
306        )
307    )]
308    fn send_request(
309        &mut self,
310        request: OutgoingRequest,
311        config: OutgoingRequestConfig,
312    ) -> Result<wasmtime_wasi_http::types::HostFutureIncomingResponse, HttpError> {
313        let request_sender = RequestSender {
314            allowed_hosts: self.state.allowed_hosts.clone(),
315            component_tls_configs: self.state.component_tls_configs.clone(),
316            request_interceptor: self.state.request_interceptor.clone(),
317            self_request_origin: self.state.self_request_origin.clone(),
318            blocked_networks: self.state.blocked_networks.clone(),
319            http_clients: self.state.wasi_http_clients.clone(),
320            concurrent_outbound_connections_semaphore: self
321                .state
322                .concurrent_outbound_connections_semaphore
323                .clone(),
324        };
325        Ok(HostFutureIncomingResponse::Pending(
326            wasmtime_wasi::runtime::spawn(
327                async {
328                    match request_sender.send(request, config).await {
329                        Ok(resp) => Ok(Ok(resp)),
330                        Err(http_error) => match http_error.downcast() {
331                            Ok(error_code) => Ok(Err(error_code)),
332                            Err(trap) => Err(trap),
333                        },
334                    }
335                }
336                .in_current_span(),
337            ),
338        ))
339    }
340}
341
342struct RequestSender {
343    allowed_hosts: OutboundAllowedHosts,
344    blocked_networks: BlockedNetworks,
345    component_tls_configs: ComponentTlsClientConfigs,
346    self_request_origin: Option<SelfRequestOrigin>,
347    request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
348    http_clients: HttpClients,
349    concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
350}
351
352impl RequestSender {
353    async fn send(
354        self,
355        mut request: OutgoingRequest,
356        mut config: OutgoingRequestConfig,
357    ) -> Result<IncomingResponse, HttpError> {
358        self.prepare_request(&mut request, &mut config).await?;
359
360        // If the current span has opentelemetry trace context, inject it into the request
361        spin_telemetry::inject_trace_context(&mut request);
362
363        // Run any configured request interceptor
364        let mut override_connect_addr = None;
365        if let Some(interceptor) = &self.request_interceptor {
366            let intercept_request = std::mem::take(&mut request).into();
367            match interceptor.intercept(intercept_request).await? {
368                InterceptOutcome::Continue(mut req) => {
369                    override_connect_addr = req.override_connect_addr.take();
370                    request = req.into_hyper_request();
371                }
372                InterceptOutcome::Complete(resp) => {
373                    let resp = IncomingResponse {
374                        resp,
375                        worker: None,
376                        between_bytes_timeout: config.between_bytes_timeout,
377                    };
378                    return Ok(resp);
379                }
380            }
381        }
382
383        // Backfill span fields after potentially updating the URL in the interceptor
384        let span = tracing::Span::current();
385        if let Some(addr) = override_connect_addr {
386            span.record("server.address", addr.ip().to_string());
387            span.record("server.port", addr.port());
388        } else if let Some(authority) = request.uri().authority() {
389            span.record("server.address", authority.host());
390            if let Some(port) = authority.port_u16() {
391                span.record("server.port", port);
392            }
393        }
394
395        Ok(self
396            .send_request(request, config, override_connect_addr)
397            .await?)
398    }
399
400    async fn prepare_request(
401        &self,
402        request: &mut OutgoingRequest,
403        config: &mut OutgoingRequestConfig,
404    ) -> Result<(), ErrorCode> {
405        // wasmtime-wasi-http fills in scheme and authority for relative URLs
406        // (e.g. https://:443/<path>), which makes them hard to reason about.
407        // Undo that here.
408        let uri = request.uri_mut();
409        if uri
410            .authority()
411            .is_some_and(|authority| authority.host().is_empty())
412        {
413            let mut builder = http::uri::Builder::new();
414            if let Some(paq) = uri.path_and_query() {
415                builder = builder.path_and_query(paq.clone());
416            }
417            *uri = builder.build().unwrap();
418        }
419        tracing::Span::current().record("url.full", uri.to_string());
420
421        let is_self_request = match request.uri().authority() {
422            // Some SDKs require an authority, so we support e.g. http://self.alt/self-request
423            Some(authority) => authority.host() == "self.alt",
424            // Otherwise self requests have no authority
425            None => true,
426        };
427
428        // Enforce allowed_outbound_hosts
429        let is_allowed = if is_self_request {
430            self.allowed_hosts
431                .check_relative_url(&["http", "https"])
432                .await
433                .unwrap_or(false)
434        } else {
435            self.allowed_hosts
436                .check_url(&request.uri().to_string(), "https")
437                .await
438                .unwrap_or(false)
439        };
440        if !is_allowed {
441            return Err(ErrorCode::HttpRequestDenied);
442        }
443
444        if is_self_request {
445            // Replace the authority with the "self request origin"
446            let Some(origin) = self.self_request_origin.as_ref() else {
447                tracing::error!(
448                    "Couldn't handle outbound HTTP request to relative URI; no origin set"
449                );
450                return Err(ErrorCode::HttpRequestUriInvalid);
451            };
452
453            config.use_tls = origin.use_tls();
454
455            request.headers_mut().insert(HOST, origin.host_header());
456
457            let path_and_query = request.uri().path_and_query().cloned();
458            *request.uri_mut() = origin.clone().into_uri(path_and_query);
459        }
460
461        // Some servers (looking at you nginx) don't like a host header even though
462        // http/2 allows it: https://github.com/hyperium/hyper/issues/3298.
463        //
464        // Note that we do this _before_ invoking the request interceptor.  It may
465        // decide to add the `host` header back in, regardless of the nginx bug, in
466        // which case we'll let it do so without interferring.
467        request.headers_mut().remove(HOST);
468        Ok(())
469    }
470
471    async fn send_request(
472        self,
473        request: OutgoingRequest,
474        config: OutgoingRequestConfig,
475        override_connect_addr: Option<SocketAddr>,
476    ) -> Result<IncomingResponse, ErrorCode> {
477        let OutgoingRequestConfig {
478            use_tls,
479            connect_timeout,
480            first_byte_timeout,
481            between_bytes_timeout,
482        } = config;
483
484        let tls_client_config = if use_tls {
485            let host = request.uri().host().unwrap_or_default();
486            Some(self.component_tls_configs.get_client_config(host).clone())
487        } else {
488            None
489        };
490
491        let resp = CONNECT_OPTIONS.scope(
492            ConnectOptions {
493                blocked_networks: self.blocked_networks,
494                connect_timeout,
495                tls_client_config,
496                override_connect_addr,
497                concurrent_outbound_connections_semaphore: self
498                    .concurrent_outbound_connections_semaphore,
499            },
500            async move {
501                if use_tls {
502                    self.http_clients.https.request(request).await
503                } else {
504                    // For development purposes, allow configuring plaintext HTTP/2 for a specific host.
505                    let h2c_prior_knowledge_host =
506                        std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
507                    let use_h2c = h2c_prior_knowledge_host.as_deref()
508                        == request.uri().authority().map(|a| a.as_str());
509
510                    if use_h2c {
511                        self.http_clients.http2.request(request).await
512                    } else {
513                        self.http_clients.http1.request(request).await
514                    }
515                }
516            },
517        );
518
519        let resp = timeout(first_byte_timeout, resp)
520            .await
521            .map_err(|_| ErrorCode::ConnectionReadTimeout)?
522            .map_err(hyper_legacy_request_error)?
523            .map(|body| body.map_err(hyper_request_error).boxed_unsync());
524
525        tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
526
527        Ok(IncomingResponse {
528            resp,
529            worker: None,
530            between_bytes_timeout,
531        })
532    }
533}
534
535type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
536type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
537
538#[derive(Clone)]
539pub(super) struct HttpClients {
540    /// Used for non-TLS HTTP/1 connections.
541    http1: HttpClient,
542    /// Used for non-TLS HTTP/2 connections (e.g. when h2 prior knowledge is available).
543    http2: HttpClient,
544    /// Used for HTTP-over-TLS connections, using ALPN to negotiate the HTTP version.
545    https: HttpsClient,
546}
547
548impl HttpClients {
549    pub(super) fn new(enable_pooling: bool) -> Self {
550        let builder = move || {
551            let mut builder = Client::builder(TokioExecutor::new());
552            if !enable_pooling {
553                builder.pool_max_idle_per_host(0);
554            }
555            builder
556        };
557        Self {
558            http1: builder().build(HttpConnector),
559            http2: builder().http2_only(true).build(HttpConnector),
560            https: builder().build(HttpsConnector),
561        }
562    }
563}
564
565tokio::task_local! {
566    /// The options used when establishing a new connection.
567    ///
568    /// We must use task-local variables for these config options when using
569    /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
570    /// them through as parameters.  Moreover, if there's already a pooled connection
571    /// ready, we'll reuse that and ignore these options anyway. After each connection
572    /// is established, the options are dropped.
573    static CONNECT_OPTIONS: ConnectOptions;
574}
575
576#[derive(Clone)]
577struct ConnectOptions {
578    /// The blocked networks configuration.
579    blocked_networks: BlockedNetworks,
580    /// Timeout for establishing a TCP connection.
581    connect_timeout: Duration,
582    /// TLS client configuration to use, if any.
583    tls_client_config: Option<TlsClientConfig>,
584    /// If set, override the address to connect to instead of using the given `uri`'s authority.
585    override_connect_addr: Option<SocketAddr>,
586    /// A semaphore to limit the number of concurrent outbound connections.
587    concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
588}
589
590impl ConnectOptions {
591    /// Establish a TCP connection to the given URI and default port.
592    async fn connect_tcp(
593        &self,
594        uri: &Uri,
595        default_port: u16,
596    ) -> Result<PermittedTcpStream, ErrorCode> {
597        let mut socket_addrs = match self.override_connect_addr {
598            Some(override_connect_addr) => vec![override_connect_addr],
599            None => {
600                let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
601
602                let host_and_port = if authority.port().is_some() {
603                    authority.as_str().to_string()
604                } else {
605                    format!("{}:{}", authority.as_str(), default_port)
606                };
607
608                let socket_addrs = tokio::net::lookup_host(&host_and_port)
609                    .await
610                    .map_err(|err| {
611                        tracing::debug!(?host_and_port, ?err, "Error resolving host");
612                        dns_error("address not available".into(), 0)
613                    })?
614                    .collect::<Vec<_>>();
615                tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
616                socket_addrs
617            }
618        };
619
620        // Remove blocked IPs
621        crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;
622
623        // If we're limiting concurrent outbound requests, acquire a permit
624
625        let permit = crate::concurrent_outbound_connections::acquire_owned_semaphore(
626            "wasi",
627            &self.concurrent_outbound_connections_semaphore,
628        )
629        .await;
630
631        let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
632            .await
633            .map_err(|_| ErrorCode::ConnectionTimeout)?
634            .map_err(|err| match err.kind() {
635                std::io::ErrorKind::AddrNotAvailable => {
636                    dns_error("address not available".into(), 0)
637                }
638                _ => ErrorCode::ConnectionRefused,
639            })?;
640        Ok(PermittedTcpStream {
641            inner: stream,
642            _permit: permit,
643        })
644    }
645
646    /// Establish a TLS connection to the given URI and default port.
647    async fn connect_tls(
648        &self,
649        uri: &Uri,
650        default_port: u16,
651    ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
652        let tcp_stream = self.connect_tcp(uri, default_port).await?;
653
654        let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
655        tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
656
657        let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
658        let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
659            .map_err(|e| {
660                tracing::warn!("dns lookup error: {e:?}");
661                dns_error("invalid dns name".into(), 0)
662            })?
663            .to_owned();
664        connector.connect(domain, tcp_stream).await.map_err(|e| {
665            tracing::warn!("tls protocol error: {e:?}");
666            ErrorCode::TlsProtocolError
667        })
668    }
669}
670
671/// A connector the uses `ConnectOptions`
672#[derive(Clone)]
673struct HttpConnector;
674
675impl HttpConnector {
676    async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
677        let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
678        Ok(TokioIo::new(stream))
679    }
680}
681
682impl Service<Uri> for HttpConnector {
683    type Response = TokioIo<PermittedTcpStream>;
684    type Error = ErrorCode;
685    type Future =
686        Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
687
688    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
689        Poll::Ready(Ok(()))
690    }
691
692    fn call(&mut self, uri: Uri) -> Self::Future {
693        Box::pin(async move { Self::connect(uri).await })
694    }
695}
696
697/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
698#[derive(Clone)]
699struct HttpsConnector;
700
701impl HttpsConnector {
702    async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
703        let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
704        Ok(TokioIo::new(RustlsStream(stream)))
705    }
706}
707
708impl Service<Uri> for HttpsConnector {
709    type Response = TokioIo<RustlsStream>;
710    type Error = ErrorCode;
711    type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
712
713    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
714        Poll::Ready(Ok(()))
715    }
716
717    fn call(&mut self, uri: Uri) -> Self::Future {
718        Box::pin(async move { Self::connect(uri).await })
719    }
720}
721
722struct RustlsStream(TlsStream<PermittedTcpStream>);
723
724impl Connection for RustlsStream {
725    fn connected(&self) -> Connected {
726        if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
727            self.0.get_ref().0.connected().negotiated_h2()
728        } else {
729            self.0.get_ref().0.connected()
730        }
731    }
732}
733
734impl AsyncRead for RustlsStream {
735    fn poll_read(
736        self: Pin<&mut Self>,
737        cx: &mut Context<'_>,
738        buf: &mut ReadBuf<'_>,
739    ) -> Poll<Result<(), std::io::Error>> {
740        Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
741    }
742}
743
744impl AsyncWrite for RustlsStream {
745    fn poll_write(
746        self: Pin<&mut Self>,
747        cx: &mut Context<'_>,
748        buf: &[u8],
749    ) -> Poll<Result<usize, std::io::Error>> {
750        Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
751    }
752
753    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
754        Pin::new(&mut self.get_mut().0).poll_flush(cx)
755    }
756
757    fn poll_shutdown(
758        self: Pin<&mut Self>,
759        cx: &mut Context<'_>,
760    ) -> Poll<Result<(), std::io::Error>> {
761        Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
762    }
763
764    fn poll_write_vectored(
765        self: Pin<&mut Self>,
766        cx: &mut Context<'_>,
767        bufs: &[IoSlice<'_>],
768    ) -> Poll<Result<usize, std::io::Error>> {
769        Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
770    }
771
772    fn is_write_vectored(&self) -> bool {
773        self.0.is_write_vectored()
774    }
775}
776
777/// A TCP stream that holds an optional permit indicating that it is allowed to exist.
778struct PermittedTcpStream {
779    /// The wrapped TCP stream.
780    inner: TcpStream,
781    /// A permit indicating that this stream is allowed to exist.
782    ///
783    /// When this stream is dropped, the permit is also dropped, allowing another
784    /// connection to be established.
785    _permit: Option<OwnedSemaphorePermit>,
786}
787
788impl Connection for PermittedTcpStream {
789    fn connected(&self) -> Connected {
790        self.inner.connected()
791    }
792}
793
794impl AsyncRead for PermittedTcpStream {
795    fn poll_read(
796        self: Pin<&mut Self>,
797        cx: &mut Context<'_>,
798        buf: &mut ReadBuf<'_>,
799    ) -> Poll<std::io::Result<()>> {
800        Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
801    }
802}
803
804impl AsyncWrite for PermittedTcpStream {
805    fn poll_write(
806        self: Pin<&mut Self>,
807        cx: &mut Context<'_>,
808        buf: &[u8],
809    ) -> Poll<Result<usize, std::io::Error>> {
810        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
811    }
812
813    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
814        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
815    }
816
817    fn poll_shutdown(
818        self: Pin<&mut Self>,
819        cx: &mut Context<'_>,
820    ) -> Poll<Result<(), std::io::Error>> {
821        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
822    }
823}
824
825/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
826fn hyper_request_error(err: hyper::Error) -> ErrorCode {
827    // If there's a source, we might be able to extract a wasi-http error from it.
828    if let Some(cause) = err.source() {
829        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
830            return err.clone();
831        }
832    }
833
834    tracing::warn!("hyper request error: {err:?}");
835
836    ErrorCode::HttpProtocolError
837}
838
839/// Translate a [`hyper_util::client::legacy::Error`] to a wasi-http `ErrorCode` in the context of a request.
840fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
841    // If there's a source, we might be able to extract a wasi-http error from it.
842    if let Some(cause) = err.source() {
843        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
844            return err.clone();
845        }
846    }
847
848    tracing::warn!("hyper request error: {err:?}");
849
850    ErrorCode::HttpProtocolError
851}
852
853fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
854    ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
855        rcode: Some(rcode),
856        info_code: Some(info_code),
857    })
858}
859
860// TODO: Remove this (and uses of it) once
861// https://github.com/spinframework/spin/issues/3274 has been addressed.
862pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
863    match code {
864        p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
865        p2_types::ErrorCode::DnsError(payload) => {
866            p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
867                rcode: payload.rcode,
868                info_code: payload.info_code,
869            })
870        }
871        p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
872        p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
873        p2_types::ErrorCode::DestinationIpProhibited => {
874            p3_types::ErrorCode::DestinationIpProhibited
875        }
876        p2_types::ErrorCode::DestinationIpUnroutable => {
877            p3_types::ErrorCode::DestinationIpUnroutable
878        }
879        p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
880        p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
881        p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
882        p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
883        p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
884        p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
885        p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
886        p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
887        p2_types::ErrorCode::TlsAlertReceived(payload) => {
888            p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
889                alert_id: payload.alert_id,
890                alert_message: payload.alert_message,
891            })
892        }
893        p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
894        p2_types::ErrorCode::HttpRequestLengthRequired => {
895            p3_types::ErrorCode::HttpRequestLengthRequired
896        }
897        p2_types::ErrorCode::HttpRequestBodySize(payload) => {
898            p3_types::ErrorCode::HttpRequestBodySize(payload)
899        }
900        p2_types::ErrorCode::HttpRequestMethodInvalid => {
901            p3_types::ErrorCode::HttpRequestMethodInvalid
902        }
903        p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
904        p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
905        p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
906            p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
907        }
908        p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
909            p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
910                p3_types::FieldSizePayload {
911                    field_name: payload.field_name,
912                    field_size: payload.field_size,
913                }
914            }))
915        }
916        p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
917            p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
918        }
919        p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
920            p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
921                field_name: payload.field_name,
922                field_size: payload.field_size,
923            })
924        }
925        p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
926        p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
927            p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
928        }
929        p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
930            p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
931                field_name: payload.field_name,
932                field_size: payload.field_size,
933            })
934        }
935        p2_types::ErrorCode::HttpResponseBodySize(payload) => {
936            p3_types::ErrorCode::HttpResponseBodySize(payload)
937        }
938        p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
939            p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
940        }
941        p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
942            p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
943                field_name: payload.field_name,
944                field_size: payload.field_size,
945            })
946        }
947        p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
948            p3_types::ErrorCode::HttpResponseTransferCoding(payload)
949        }
950        p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
951            p3_types::ErrorCode::HttpResponseContentCoding(payload)
952        }
953        p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
954        p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
955        p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
956        p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
957        p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
958        p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
959    }
960}
961
962// TODO: Remove this (and uses of it) once
963// https://github.com/spinframework/spin/issues/3274 has been addressed.
964pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
965    match code {
966        p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
967        p3_types::ErrorCode::DnsError(payload) => {
968            p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
969                rcode: payload.rcode,
970                info_code: payload.info_code,
971            })
972        }
973        p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
974        p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
975        p3_types::ErrorCode::DestinationIpProhibited => {
976            p2_types::ErrorCode::DestinationIpProhibited
977        }
978        p3_types::ErrorCode::DestinationIpUnroutable => {
979            p2_types::ErrorCode::DestinationIpUnroutable
980        }
981        p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
982        p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
983        p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
984        p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
985        p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
986        p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
987        p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
988        p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
989        p3_types::ErrorCode::TlsAlertReceived(payload) => {
990            p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
991                alert_id: payload.alert_id,
992                alert_message: payload.alert_message,
993            })
994        }
995        p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
996        p3_types::ErrorCode::HttpRequestLengthRequired => {
997            p2_types::ErrorCode::HttpRequestLengthRequired
998        }
999        p3_types::ErrorCode::HttpRequestBodySize(payload) => {
1000            p2_types::ErrorCode::HttpRequestBodySize(payload)
1001        }
1002        p3_types::ErrorCode::HttpRequestMethodInvalid => {
1003            p2_types::ErrorCode::HttpRequestMethodInvalid
1004        }
1005        p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
1006        p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
1007        p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
1008            p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
1009        }
1010        p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
1011            p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
1012                p2_types::FieldSizePayload {
1013                    field_name: payload.field_name,
1014                    field_size: payload.field_size,
1015                }
1016            }))
1017        }
1018        p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
1019            p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
1020        }
1021        p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
1022            p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
1023                field_name: payload.field_name,
1024                field_size: payload.field_size,
1025            })
1026        }
1027        p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
1028        p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1029            p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1030        }
1031        p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1032            p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1033                field_name: payload.field_name,
1034                field_size: payload.field_size,
1035            })
1036        }
1037        p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1038            p2_types::ErrorCode::HttpResponseBodySize(payload)
1039        }
1040        p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1041            p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1042        }
1043        p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1044            p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1045                field_name: payload.field_name,
1046                field_size: payload.field_size,
1047            })
1048        }
1049        p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1050            p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1051        }
1052        p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1053            p2_types::ErrorCode::HttpResponseContentCoding(payload)
1054        }
1055        p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1056        p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1057        p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1058        p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1059        p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1060        p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1061    }
1062}