spin_factor_outbound_http/
wasi.rs

1use std::{
2    error::Error,
3    future::Future,
4    io::IoSlice,
5    net::SocketAddr,
6    pin::Pin,
7    sync::Arc,
8    task::{self, Context, Poll},
9    time::Duration,
10};
11
12use bytes::Bytes;
13use http::{header::HOST, uri::Scheme, Uri};
14use http_body::{Body, Frame, SizeHint};
15use http_body_util::{combinators::BoxBody, BodyExt};
16use hyper_util::{
17    client::legacy::{
18        connect::{Connected, Connection},
19        Client,
20    },
21    rt::{TokioExecutor, TokioIo},
22};
23use spin_factor_outbound_networking::{
24    config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
25    ComponentTlsClientConfigs, TlsClientConfig,
26};
27use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceState};
28use tokio::{
29    io::{AsyncRead, AsyncWrite, ReadBuf},
30    net::TcpStream,
31    sync::{OwnedSemaphorePermit, Semaphore},
32    time::timeout,
33};
34use tokio_rustls::client::TlsStream;
35use tower_service::Service;
36use tracing::{field::Empty, instrument, Instrument};
37use wasmtime::component::HasData;
38use wasmtime_wasi::TrappableError;
39use wasmtime_wasi_http::{
40    bindings::http::types::{self as p2_types, ErrorCode},
41    body::HyperOutgoingBody,
42    p3::{self, bindings::http::types as p3_types},
43    types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
44    HttpError, WasiHttpCtx, WasiHttpImpl, WasiHttpView,
45};
46
47use crate::{
48    intercept::{InterceptOutcome, OutboundHttpInterceptor},
49    wasi_2023_10_18, wasi_2023_11_10, InstanceState, OutboundHttpFactor, SelfRequestOrigin,
50};
51
52const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
53
54pub(crate) struct HasHttp;
55
56impl HasData for HasHttp {
57    type Data<'a> = WasiHttpImpl<WasiHttpImplInner<'a>>;
58}
59
60impl p3::WasiHttpCtx for InstanceState {
61    fn send_request(
62        &mut self,
63        request: http::Request<BoxBody<Bytes, p3_types::ErrorCode>>,
64        options: Option<p3::RequestOptions>,
65        fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
66    ) -> Box<
67        dyn Future<
68                Output = Result<
69                    (
70                        http::Response<BoxBody<Bytes, p3_types::ErrorCode>>,
71                        Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
72                    ),
73                    TrappableError<p3_types::ErrorCode>,
74                >,
75            > + Send,
76    > {
77        // If the caller (i.e. the guest) has trouble consuming the response
78        // (e.g. encountering a network error while forwarding it on to some
79        // other place), it can report that error to us via `fut`.  However,
80        // there's nothing we'll be able to do with it here, so we ignore it.
81        // Presumably the guest will also drop the body stream and trailers
82        // future if it encounters such an error while those things are still
83        // arriving, which Hyper will deal with as appropriate (e.g. closing the
84        // connection).
85        _ = fut;
86
87        let request_sender = RequestSender {
88            allowed_hosts: self.allowed_hosts.clone(),
89            component_tls_configs: self.component_tls_configs.clone(),
90            request_interceptor: self.request_interceptor.clone(),
91            self_request_origin: self.self_request_origin.clone(),
92            blocked_networks: self.blocked_networks.clone(),
93            http_clients: self.wasi_http_clients.clone(),
94            concurrent_outbound_connections_semaphore: self
95                .concurrent_outbound_connections_semaphore
96                .clone(),
97        };
98        let config = OutgoingRequestConfig {
99            use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
100            connect_timeout: options
101                .and_then(|v| v.connect_timeout)
102                .unwrap_or(DEFAULT_TIMEOUT),
103            first_byte_timeout: options
104                .and_then(|v| v.first_byte_timeout)
105                .unwrap_or(DEFAULT_TIMEOUT),
106            between_bytes_timeout: options
107                .and_then(|v| v.between_bytes_timeout)
108                .unwrap_or(DEFAULT_TIMEOUT),
109        };
110        Box::new(async {
111            match request_sender
112                .send(
113                    request.map(|body| body.map_err(p3_to_p2_error_code).boxed()),
114                    config,
115                )
116                .await
117            {
118                Ok(IncomingResponse {
119                    resp,
120                    between_bytes_timeout,
121                    ..
122                }) => Ok((
123                    resp.map(|body| {
124                        BetweenBytesTimeoutBody {
125                            body,
126                            sleep: None,
127                            timeout: between_bytes_timeout,
128                        }
129                        .boxed()
130                    }),
131                    Box::new(async {
132                        // TODO: Can we plumb connection errors through to here, or
133                        // will `hyper_util::client::legacy::Client` pass them all
134                        // via the response body?
135                        Ok(())
136                    }) as Box<dyn Future<Output = _> + Send>,
137                )),
138                Err(http_error) => match http_error.downcast() {
139                    Ok(error_code) => Err(TrappableError::from(p2_to_p3_error_code(error_code))),
140                    Err(trap) => Err(TrappableError::trap(trap)),
141                },
142            }
143        })
144    }
145}
146
147pin_project_lite::pin_project! {
148    struct BetweenBytesTimeoutBody<B> {
149        #[pin]
150        body: B,
151        #[pin]
152        sleep: Option<tokio::time::Sleep>,
153        timeout: Duration,
154    }
155}
156
157impl<B: Body<Error = p2_types::ErrorCode>> Body for BetweenBytesTimeoutBody<B> {
158    type Data = B::Data;
159    type Error = p3_types::ErrorCode;
160
161    fn poll_frame(
162        self: Pin<&mut Self>,
163        cx: &mut Context<'_>,
164    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
165        let mut me = self.project();
166        match me.body.poll_frame(cx) {
167            Poll::Ready(value) => {
168                me.sleep.as_mut().set(None);
169                Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
170            }
171            Poll::Pending => {
172                if me.sleep.is_none() {
173                    me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
174                }
175                task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
176                Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
177            }
178        }
179    }
180
181    fn is_end_stream(&self) -> bool {
182        self.body.is_end_stream()
183    }
184
185    fn size_hint(&self) -> SizeHint {
186        self.body.size_hint()
187    }
188}
189
190pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
191where
192    C: spin_factors::InitContext<OutboundHttpFactor>,
193{
194    let linker = ctx.linker();
195
196    fn get_http<C>(store: &mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>
197    where
198        C: spin_factors::InitContext<OutboundHttpFactor>,
199    {
200        let (state, table) = C::get_data_with_table(store);
201        WasiHttpImpl(WasiHttpImplInner { state, table })
202    }
203
204    let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>;
205    wasmtime_wasi_http::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
206        linker, get_http,
207    )?;
208    wasmtime_wasi_http::bindings::http::types::add_to_linker::<_, HasHttp>(
209        linker,
210        &Default::default(),
211        get_http,
212    )?;
213
214    fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
215    where
216        C: spin_factors::InitContext<OutboundHttpFactor>,
217    {
218        let (state, table) = C::get_data_with_table(store);
219        p3::WasiHttpCtxView { ctx: state, table }
220    }
221
222    let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
223    p3::bindings::http::handler::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
224    p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
225
226    wasi_2023_10_18::add_to_linker(linker, get_http)?;
227    wasi_2023_11_10::add_to_linker(linker, get_http)?;
228
229    Ok(())
230}
231
232impl OutboundHttpFactor {
233    pub fn get_wasi_http_impl(
234        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
235    ) -> Option<WasiHttpImpl<impl WasiHttpView + '_>> {
236        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
237        Some(WasiHttpImpl(WasiHttpImplInner { state, table }))
238    }
239
240    pub fn get_wasi_p3_http_impl(
241        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
242    ) -> Option<p3::WasiHttpCtxView<'_>> {
243        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
244        Some(p3::WasiHttpCtxView { ctx: state, table })
245    }
246}
247
248pub(crate) struct WasiHttpImplInner<'a> {
249    state: &'a mut InstanceState,
250    table: &'a mut ResourceTable,
251}
252
253type OutgoingRequest = http::Request<HyperOutgoingBody>;
254
255impl WasiHttpView for WasiHttpImplInner<'_> {
256    fn ctx(&mut self) -> &mut WasiHttpCtx {
257        &mut self.state.wasi_http_ctx
258    }
259
260    fn table(&mut self) -> &mut ResourceTable {
261        self.table
262    }
263
264    #[instrument(
265        name = "spin_outbound_http.send_request",
266        skip_all,
267        fields(
268            otel.kind = "client",
269            url.full = Empty,
270            http.request.method = %request.method(),
271            otel.name = %request.method(),
272            http.response.status_code = Empty,
273            server.address = Empty,
274            server.port = Empty,
275        )
276    )]
277    fn send_request(
278        &mut self,
279        request: OutgoingRequest,
280        config: OutgoingRequestConfig,
281    ) -> Result<wasmtime_wasi_http::types::HostFutureIncomingResponse, HttpError> {
282        let request_sender = RequestSender {
283            allowed_hosts: self.state.allowed_hosts.clone(),
284            component_tls_configs: self.state.component_tls_configs.clone(),
285            request_interceptor: self.state.request_interceptor.clone(),
286            self_request_origin: self.state.self_request_origin.clone(),
287            blocked_networks: self.state.blocked_networks.clone(),
288            http_clients: self.state.wasi_http_clients.clone(),
289            concurrent_outbound_connections_semaphore: self
290                .state
291                .concurrent_outbound_connections_semaphore
292                .clone(),
293        };
294        Ok(HostFutureIncomingResponse::Pending(
295            wasmtime_wasi::runtime::spawn(
296                async {
297                    match request_sender.send(request, config).await {
298                        Ok(resp) => Ok(Ok(resp)),
299                        Err(http_error) => match http_error.downcast() {
300                            Ok(error_code) => Ok(Err(error_code)),
301                            Err(trap) => Err(trap),
302                        },
303                    }
304                }
305                .in_current_span(),
306            ),
307        ))
308    }
309}
310
311struct RequestSender {
312    allowed_hosts: OutboundAllowedHosts,
313    blocked_networks: BlockedNetworks,
314    component_tls_configs: ComponentTlsClientConfigs,
315    self_request_origin: Option<SelfRequestOrigin>,
316    request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
317    http_clients: HttpClients,
318    concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
319}
320
321impl RequestSender {
322    async fn send(
323        self,
324        mut request: OutgoingRequest,
325        mut config: OutgoingRequestConfig,
326    ) -> Result<IncomingResponse, HttpError> {
327        self.prepare_request(&mut request, &mut config).await?;
328
329        // If the current span has opentelemetry trace context, inject it into the request
330        spin_telemetry::inject_trace_context(&mut request);
331
332        // Run any configured request interceptor
333        let mut override_connect_addr = None;
334        if let Some(interceptor) = &self.request_interceptor {
335            let intercept_request = std::mem::take(&mut request).into();
336            match interceptor.intercept(intercept_request).await? {
337                InterceptOutcome::Continue(mut req) => {
338                    override_connect_addr = req.override_connect_addr.take();
339                    request = req.into_hyper_request();
340                }
341                InterceptOutcome::Complete(resp) => {
342                    let resp = IncomingResponse {
343                        resp,
344                        worker: None,
345                        between_bytes_timeout: config.between_bytes_timeout,
346                    };
347                    return Ok(resp);
348                }
349            }
350        }
351
352        // Backfill span fields after potentially updating the URL in the interceptor
353        let span = tracing::Span::current();
354        if let Some(addr) = override_connect_addr {
355            span.record("server.address", addr.ip().to_string());
356            span.record("server.port", addr.port());
357        } else if let Some(authority) = request.uri().authority() {
358            span.record("server.address", authority.host());
359            if let Some(port) = authority.port_u16() {
360                span.record("server.port", port);
361            }
362        }
363
364        Ok(self
365            .send_request(request, config, override_connect_addr)
366            .await?)
367    }
368
369    async fn prepare_request(
370        &self,
371        request: &mut OutgoingRequest,
372        config: &mut OutgoingRequestConfig,
373    ) -> Result<(), ErrorCode> {
374        // wasmtime-wasi-http fills in scheme and authority for relative URLs
375        // (e.g. https://:443/<path>), which makes them hard to reason about.
376        // Undo that here.
377        let uri = request.uri_mut();
378        if uri
379            .authority()
380            .is_some_and(|authority| authority.host().is_empty())
381        {
382            let mut builder = http::uri::Builder::new();
383            if let Some(paq) = uri.path_and_query() {
384                builder = builder.path_and_query(paq.clone());
385            }
386            *uri = builder.build().unwrap();
387        }
388        tracing::Span::current().record("url.full", uri.to_string());
389
390        let is_self_request = match request.uri().authority() {
391            // Some SDKs require an authority, so we support e.g. http://self.alt/self-request
392            Some(authority) => authority.host() == "self.alt",
393            // Otherwise self requests have no authority
394            None => true,
395        };
396
397        // Enforce allowed_outbound_hosts
398        let is_allowed = if is_self_request {
399            self.allowed_hosts
400                .check_relative_url(&["http", "https"])
401                .await
402                .unwrap_or(false)
403        } else {
404            self.allowed_hosts
405                .check_url(&request.uri().to_string(), "https")
406                .await
407                .unwrap_or(false)
408        };
409        if !is_allowed {
410            return Err(ErrorCode::HttpRequestDenied);
411        }
412
413        if is_self_request {
414            // Replace the authority with the "self request origin"
415            let Some(origin) = self.self_request_origin.as_ref() else {
416                tracing::error!(
417                    "Couldn't handle outbound HTTP request to relative URI; no origin set"
418                );
419                return Err(ErrorCode::HttpRequestUriInvalid);
420            };
421
422            config.use_tls = origin.use_tls();
423
424            request.headers_mut().insert(HOST, origin.host_header());
425
426            let path_and_query = request.uri().path_and_query().cloned();
427            *request.uri_mut() = origin.clone().into_uri(path_and_query);
428        }
429
430        // Some servers (looking at you nginx) don't like a host header even though
431        // http/2 allows it: https://github.com/hyperium/hyper/issues/3298.
432        //
433        // Note that we do this _before_ invoking the request interceptor.  It may
434        // decide to add the `host` header back in, regardless of the nginx bug, in
435        // which case we'll let it do so without interferring.
436        request.headers_mut().remove(HOST);
437        Ok(())
438    }
439
440    async fn send_request(
441        self,
442        request: OutgoingRequest,
443        config: OutgoingRequestConfig,
444        override_connect_addr: Option<SocketAddr>,
445    ) -> Result<IncomingResponse, ErrorCode> {
446        let OutgoingRequestConfig {
447            use_tls,
448            connect_timeout,
449            first_byte_timeout,
450            between_bytes_timeout,
451        } = config;
452
453        let tls_client_config = if use_tls {
454            let host = request.uri().host().unwrap_or_default();
455            Some(self.component_tls_configs.get_client_config(host).clone())
456        } else {
457            None
458        };
459
460        let resp = CONNECT_OPTIONS.scope(
461            ConnectOptions {
462                blocked_networks: self.blocked_networks,
463                connect_timeout,
464                tls_client_config,
465                override_connect_addr,
466                concurrent_outbound_connections_semaphore: self
467                    .concurrent_outbound_connections_semaphore,
468            },
469            async move {
470                if use_tls {
471                    self.http_clients.https.request(request).await
472                } else {
473                    // For development purposes, allow configuring plaintext HTTP/2 for a specific host.
474                    let h2c_prior_knowledge_host =
475                        std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
476                    let use_h2c = h2c_prior_knowledge_host.as_deref()
477                        == request.uri().authority().map(|a| a.as_str());
478
479                    if use_h2c {
480                        self.http_clients.http2.request(request).await
481                    } else {
482                        self.http_clients.http1.request(request).await
483                    }
484                }
485            },
486        );
487
488        let resp = timeout(first_byte_timeout, resp)
489            .await
490            .map_err(|_| ErrorCode::ConnectionReadTimeout)?
491            .map_err(hyper_legacy_request_error)?
492            .map(|body| body.map_err(hyper_request_error).boxed());
493
494        tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
495
496        Ok(IncomingResponse {
497            resp,
498            worker: None,
499            between_bytes_timeout,
500        })
501    }
502}
503
504type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
505type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
506
507#[derive(Clone)]
508pub(super) struct HttpClients {
509    /// Used for non-TLS HTTP/1 connections.
510    http1: HttpClient,
511    /// Used for non-TLS HTTP/2 connections (e.g. when h2 prior knowledge is available).
512    http2: HttpClient,
513    /// Used for HTTP-over-TLS connections, using ALPN to negotiate the HTTP version.
514    https: HttpsClient,
515}
516
517impl HttpClients {
518    pub(super) fn new(enable_pooling: bool) -> Self {
519        let builder = move || {
520            let mut builder = Client::builder(TokioExecutor::new());
521            if !enable_pooling {
522                builder.pool_max_idle_per_host(0);
523            }
524            builder
525        };
526        Self {
527            http1: builder().build(HttpConnector),
528            http2: builder().http2_only(true).build(HttpConnector),
529            https: builder().build(HttpsConnector),
530        }
531    }
532}
533
534tokio::task_local! {
535    /// The options used when establishing a new connection.
536    ///
537    /// We must use task-local variables for these config options when using
538    /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
539    /// them through as parameters.  Moreover, if there's already a pooled connection
540    /// ready, we'll reuse that and ignore these options anyway. After each connection
541    /// is established, the options are dropped.
542    static CONNECT_OPTIONS: ConnectOptions;
543}
544
545#[derive(Clone)]
546struct ConnectOptions {
547    /// The blocked networks configuration.
548    blocked_networks: BlockedNetworks,
549    /// Timeout for establishing a TCP connection.
550    connect_timeout: Duration,
551    /// TLS client configuration to use, if any.
552    tls_client_config: Option<TlsClientConfig>,
553    /// If set, override the address to connect to instead of using the given `uri`'s authority.
554    override_connect_addr: Option<SocketAddr>,
555    /// A semaphore to limit the number of concurrent outbound connections.
556    concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
557}
558
559impl ConnectOptions {
560    /// Establish a TCP connection to the given URI and default port.
561    async fn connect_tcp(
562        &self,
563        uri: &Uri,
564        default_port: u16,
565    ) -> Result<PermittedTcpStream, ErrorCode> {
566        let mut socket_addrs = match self.override_connect_addr {
567            Some(override_connect_addr) => vec![override_connect_addr],
568            None => {
569                let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
570
571                let host_and_port = if authority.port().is_some() {
572                    authority.as_str().to_string()
573                } else {
574                    format!("{}:{}", authority.as_str(), default_port)
575                };
576
577                let socket_addrs = tokio::net::lookup_host(&host_and_port)
578                    .await
579                    .map_err(|err| {
580                        tracing::debug!(?host_and_port, ?err, "Error resolving host");
581                        dns_error("address not available".into(), 0)
582                    })?
583                    .collect::<Vec<_>>();
584                tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
585                socket_addrs
586            }
587        };
588
589        // Remove blocked IPs
590        let blocked_addrs = self.blocked_networks.remove_blocked(&mut socket_addrs);
591        if socket_addrs.is_empty() && !blocked_addrs.is_empty() {
592            tracing::error!(
593                "error.type" = "destination_ip_prohibited",
594                ?blocked_addrs,
595                "all destination IP(s) prohibited by runtime config"
596            );
597            return Err(ErrorCode::DestinationIpProhibited);
598        }
599
600        // If we're limiting concurrent outbound requests, acquire a permit
601        let permit = match &self.concurrent_outbound_connections_semaphore {
602            Some(s) => s.clone().acquire_owned().await.ok(),
603            None => None,
604        };
605
606        let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
607            .await
608            .map_err(|_| ErrorCode::ConnectionTimeout)?
609            .map_err(|err| match err.kind() {
610                std::io::ErrorKind::AddrNotAvailable => {
611                    dns_error("address not available".into(), 0)
612                }
613                _ => ErrorCode::ConnectionRefused,
614            })?;
615        Ok(PermittedTcpStream {
616            inner: stream,
617            _permit: permit,
618        })
619    }
620
621    /// Establish a TLS connection to the given URI and default port.
622    async fn connect_tls(
623        &self,
624        uri: &Uri,
625        default_port: u16,
626    ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
627        let tcp_stream = self.connect_tcp(uri, default_port).await?;
628
629        let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
630        tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
631
632        let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
633        let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
634            .map_err(|e| {
635                tracing::warn!("dns lookup error: {e:?}");
636                dns_error("invalid dns name".into(), 0)
637            })?
638            .to_owned();
639        connector.connect(domain, tcp_stream).await.map_err(|e| {
640            tracing::warn!("tls protocol error: {e:?}");
641            ErrorCode::TlsProtocolError
642        })
643    }
644}
645
646/// A connector the uses `ConnectOptions`
647#[derive(Clone)]
648struct HttpConnector;
649
650impl HttpConnector {
651    async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
652        let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
653        Ok(TokioIo::new(stream))
654    }
655}
656
657impl Service<Uri> for HttpConnector {
658    type Response = TokioIo<PermittedTcpStream>;
659    type Error = ErrorCode;
660    type Future =
661        Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
662
663    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
664        Poll::Ready(Ok(()))
665    }
666
667    fn call(&mut self, uri: Uri) -> Self::Future {
668        Box::pin(async move { Self::connect(uri).await })
669    }
670}
671
672/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
673#[derive(Clone)]
674struct HttpsConnector;
675
676impl HttpsConnector {
677    async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
678        let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
679        Ok(TokioIo::new(RustlsStream(stream)))
680    }
681}
682
683impl Service<Uri> for HttpsConnector {
684    type Response = TokioIo<RustlsStream>;
685    type Error = ErrorCode;
686    type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
687
688    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
689        Poll::Ready(Ok(()))
690    }
691
692    fn call(&mut self, uri: Uri) -> Self::Future {
693        Box::pin(async move { Self::connect(uri).await })
694    }
695}
696
697struct RustlsStream(TlsStream<PermittedTcpStream>);
698
699impl Connection for RustlsStream {
700    fn connected(&self) -> Connected {
701        if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
702            self.0.get_ref().0.connected().negotiated_h2()
703        } else {
704            self.0.get_ref().0.connected()
705        }
706    }
707}
708
709impl AsyncRead for RustlsStream {
710    fn poll_read(
711        self: Pin<&mut Self>,
712        cx: &mut Context<'_>,
713        buf: &mut ReadBuf<'_>,
714    ) -> Poll<Result<(), std::io::Error>> {
715        Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
716    }
717}
718
719impl AsyncWrite for RustlsStream {
720    fn poll_write(
721        self: Pin<&mut Self>,
722        cx: &mut Context<'_>,
723        buf: &[u8],
724    ) -> Poll<Result<usize, std::io::Error>> {
725        Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
726    }
727
728    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
729        Pin::new(&mut self.get_mut().0).poll_flush(cx)
730    }
731
732    fn poll_shutdown(
733        self: Pin<&mut Self>,
734        cx: &mut Context<'_>,
735    ) -> Poll<Result<(), std::io::Error>> {
736        Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
737    }
738
739    fn poll_write_vectored(
740        self: Pin<&mut Self>,
741        cx: &mut Context<'_>,
742        bufs: &[IoSlice<'_>],
743    ) -> Poll<Result<usize, std::io::Error>> {
744        Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
745    }
746
747    fn is_write_vectored(&self) -> bool {
748        self.0.is_write_vectored()
749    }
750}
751
752/// A TCP stream that holds an optional permit indicating that it is allowed to exist.
753struct PermittedTcpStream {
754    /// The wrapped TCP stream.
755    inner: TcpStream,
756    /// A permit indicating that this stream is allowed to exist.
757    ///
758    /// When this stream is dropped, the permit is also dropped, allowing another
759    /// connection to be established.
760    _permit: Option<OwnedSemaphorePermit>,
761}
762
763impl Connection for PermittedTcpStream {
764    fn connected(&self) -> Connected {
765        self.inner.connected()
766    }
767}
768
769impl AsyncRead for PermittedTcpStream {
770    fn poll_read(
771        self: Pin<&mut Self>,
772        cx: &mut Context<'_>,
773        buf: &mut ReadBuf<'_>,
774    ) -> Poll<std::io::Result<()>> {
775        Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
776    }
777}
778
779impl AsyncWrite for PermittedTcpStream {
780    fn poll_write(
781        self: Pin<&mut Self>,
782        cx: &mut Context<'_>,
783        buf: &[u8],
784    ) -> Poll<Result<usize, std::io::Error>> {
785        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
786    }
787
788    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
789        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
790    }
791
792    fn poll_shutdown(
793        self: Pin<&mut Self>,
794        cx: &mut Context<'_>,
795    ) -> Poll<Result<(), std::io::Error>> {
796        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
797    }
798}
799
800/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
801fn hyper_request_error(err: hyper::Error) -> ErrorCode {
802    // If there's a source, we might be able to extract a wasi-http error from it.
803    if let Some(cause) = err.source() {
804        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
805            return err.clone();
806        }
807    }
808
809    tracing::warn!("hyper request error: {err:?}");
810
811    ErrorCode::HttpProtocolError
812}
813
814/// Translate a [`hyper_util::client::legacy::Error`] to a wasi-http `ErrorCode` in the context of a request.
815fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
816    // If there's a source, we might be able to extract a wasi-http error from it.
817    if let Some(cause) = err.source() {
818        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
819            return err.clone();
820        }
821    }
822
823    tracing::warn!("hyper request error: {err:?}");
824
825    ErrorCode::HttpProtocolError
826}
827
828fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
829    ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
830        rcode: Some(rcode),
831        info_code: Some(info_code),
832    })
833}
834
835// TODO: Remove this (and uses of it) once
836// https://github.com/spinframework/spin/issues/3274 has been addressed.
837pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
838    match code {
839        p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
840        p2_types::ErrorCode::DnsError(payload) => {
841            p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
842                rcode: payload.rcode,
843                info_code: payload.info_code,
844            })
845        }
846        p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
847        p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
848        p2_types::ErrorCode::DestinationIpProhibited => {
849            p3_types::ErrorCode::DestinationIpProhibited
850        }
851        p2_types::ErrorCode::DestinationIpUnroutable => {
852            p3_types::ErrorCode::DestinationIpUnroutable
853        }
854        p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
855        p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
856        p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
857        p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
858        p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
859        p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
860        p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
861        p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
862        p2_types::ErrorCode::TlsAlertReceived(payload) => {
863            p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
864                alert_id: payload.alert_id,
865                alert_message: payload.alert_message,
866            })
867        }
868        p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
869        p2_types::ErrorCode::HttpRequestLengthRequired => {
870            p3_types::ErrorCode::HttpRequestLengthRequired
871        }
872        p2_types::ErrorCode::HttpRequestBodySize(payload) => {
873            p3_types::ErrorCode::HttpRequestBodySize(payload)
874        }
875        p2_types::ErrorCode::HttpRequestMethodInvalid => {
876            p3_types::ErrorCode::HttpRequestMethodInvalid
877        }
878        p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
879        p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
880        p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
881            p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
882        }
883        p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
884            p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
885                p3_types::FieldSizePayload {
886                    field_name: payload.field_name,
887                    field_size: payload.field_size,
888                }
889            }))
890        }
891        p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
892            p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
893        }
894        p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
895            p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
896                field_name: payload.field_name,
897                field_size: payload.field_size,
898            })
899        }
900        p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
901        p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
902            p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
903        }
904        p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
905            p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
906                field_name: payload.field_name,
907                field_size: payload.field_size,
908            })
909        }
910        p2_types::ErrorCode::HttpResponseBodySize(payload) => {
911            p3_types::ErrorCode::HttpResponseBodySize(payload)
912        }
913        p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
914            p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
915        }
916        p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
917            p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
918                field_name: payload.field_name,
919                field_size: payload.field_size,
920            })
921        }
922        p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
923            p3_types::ErrorCode::HttpResponseTransferCoding(payload)
924        }
925        p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
926            p3_types::ErrorCode::HttpResponseContentCoding(payload)
927        }
928        p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
929        p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
930        p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
931        p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
932        p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
933        p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
934    }
935}
936
937// TODO: Remove this (and uses of it) once
938// https://github.com/spinframework/spin/issues/3274 has been addressed.
939pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
940    match code {
941        p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
942        p3_types::ErrorCode::DnsError(payload) => {
943            p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
944                rcode: payload.rcode,
945                info_code: payload.info_code,
946            })
947        }
948        p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
949        p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
950        p3_types::ErrorCode::DestinationIpProhibited => {
951            p2_types::ErrorCode::DestinationIpProhibited
952        }
953        p3_types::ErrorCode::DestinationIpUnroutable => {
954            p2_types::ErrorCode::DestinationIpUnroutable
955        }
956        p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
957        p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
958        p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
959        p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
960        p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
961        p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
962        p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
963        p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
964        p3_types::ErrorCode::TlsAlertReceived(payload) => {
965            p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
966                alert_id: payload.alert_id,
967                alert_message: payload.alert_message,
968            })
969        }
970        p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
971        p3_types::ErrorCode::HttpRequestLengthRequired => {
972            p2_types::ErrorCode::HttpRequestLengthRequired
973        }
974        p3_types::ErrorCode::HttpRequestBodySize(payload) => {
975            p2_types::ErrorCode::HttpRequestBodySize(payload)
976        }
977        p3_types::ErrorCode::HttpRequestMethodInvalid => {
978            p2_types::ErrorCode::HttpRequestMethodInvalid
979        }
980        p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
981        p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
982        p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
983            p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
984        }
985        p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
986            p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
987                p2_types::FieldSizePayload {
988                    field_name: payload.field_name,
989                    field_size: payload.field_size,
990                }
991            }))
992        }
993        p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
994            p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
995        }
996        p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
997            p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
998                field_name: payload.field_name,
999                field_size: payload.field_size,
1000            })
1001        }
1002        p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
1003        p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1004            p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1005        }
1006        p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1007            p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1008                field_name: payload.field_name,
1009                field_size: payload.field_size,
1010            })
1011        }
1012        p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1013            p2_types::ErrorCode::HttpResponseBodySize(payload)
1014        }
1015        p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1016            p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1017        }
1018        p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1019            p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1020                field_name: payload.field_name,
1021                field_size: payload.field_size,
1022            })
1023        }
1024        p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1025            p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1026        }
1027        p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1028            p2_types::ErrorCode::HttpResponseContentCoding(payload)
1029        }
1030        p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1031        p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1032        p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1033        p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1034        p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1035        p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1036    }
1037}