spin_factor_outbound_http/
wasi.rs

1use std::{
2    error::Error,
3    future::Future,
4    io::IoSlice,
5    net::SocketAddr,
6    pin::Pin,
7    sync::Arc,
8    task::{self, Context, Poll},
9    time::Duration,
10};
11
12use bytes::Bytes;
13use http::{header::HOST, uri::Scheme, Uri};
14use http_body::{Body, Frame, SizeHint};
15use http_body_util::{combinators::BoxBody, BodyExt};
16use hyper_util::{
17    client::legacy::{
18        connect::{Connected, Connection},
19        Client,
20    },
21    rt::{TokioExecutor, TokioIo},
22};
23use spin_factor_outbound_networking::{
24    config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
25    ComponentTlsClientConfigs, TlsClientConfig,
26};
27use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceState};
28use tokio::{
29    io::{AsyncRead, AsyncWrite, ReadBuf},
30    net::TcpStream,
31    sync::{OwnedSemaphorePermit, Semaphore},
32    time::timeout,
33};
34use tokio_rustls::client::TlsStream;
35use tower_service::Service;
36use tracing::{field::Empty, instrument, Instrument};
37use wasmtime::component::HasData;
38use wasmtime_wasi::TrappableError;
39use wasmtime_wasi_http::{
40    bindings::http::types::{self as p2_types, ErrorCode},
41    body::HyperOutgoingBody,
42    p3::{self, bindings::http::types as p3_types},
43    types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
44    HttpError, WasiHttpCtx, WasiHttpImpl, WasiHttpView,
45};
46
47use crate::{
48    intercept::{InterceptOutcome, OutboundHttpInterceptor},
49    wasi_2023_10_18, wasi_2023_11_10, InstanceState, OutboundHttpFactor, SelfRequestOrigin,
50};
51
52const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
53
54pub(crate) struct HasHttp;
55
56impl HasData for HasHttp {
57    type Data<'a> = WasiHttpImpl<WasiHttpImplInner<'a>>;
58}
59
60impl p3::WasiHttpCtx for InstanceState {
61    fn send_request(
62        &mut self,
63        request: http::Request<BoxBody<Bytes, p3_types::ErrorCode>>,
64        options: Option<p3::RequestOptions>,
65        fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
66    ) -> Box<
67        dyn Future<
68                Output = Result<
69                    (
70                        http::Response<BoxBody<Bytes, p3_types::ErrorCode>>,
71                        Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
72                    ),
73                    TrappableError<p3_types::ErrorCode>,
74                >,
75            > + Send,
76    > {
77        // If the caller (i.e. the guest) has trouble consuming the response
78        // (e.g. encountering a network error while forwarding it on to some
79        // other place), it can report that error to us via `fut`.  However,
80        // there's nothing we'll be able to do with it here, so we ignore it.
81        // Presumably the guest will also drop the body stream and trailers
82        // future if it encounters such an error while those things are still
83        // arriving, which Hyper will deal with as appropriate (e.g. closing the
84        // connection).
85        _ = fut;
86
87        let request_sender = RequestSender {
88            allowed_hosts: self.allowed_hosts.clone(),
89            component_tls_configs: self.component_tls_configs.clone(),
90            request_interceptor: self.request_interceptor.clone(),
91            self_request_origin: self.self_request_origin.clone(),
92            blocked_networks: self.blocked_networks.clone(),
93            http_clients: self.wasi_http_clients.clone(),
94            concurrent_outbound_connections_semaphore: self
95                .concurrent_outbound_connections_semaphore
96                .clone(),
97        };
98        let config = OutgoingRequestConfig {
99            use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
100            connect_timeout: options
101                .and_then(|v| v.connect_timeout)
102                .unwrap_or(DEFAULT_TIMEOUT),
103            first_byte_timeout: options
104                .and_then(|v| v.first_byte_timeout)
105                .unwrap_or(DEFAULT_TIMEOUT),
106            between_bytes_timeout: options
107                .and_then(|v| v.between_bytes_timeout)
108                .unwrap_or(DEFAULT_TIMEOUT),
109        };
110        Box::new(async {
111            match request_sender
112                .send(
113                    request.map(|body| body.map_err(p3_to_p2_error_code).boxed()),
114                    config,
115                )
116                .await
117            {
118                Ok(IncomingResponse {
119                    resp,
120                    between_bytes_timeout,
121                    ..
122                }) => Ok((
123                    resp.map(|body| {
124                        BetweenBytesTimeoutBody {
125                            body,
126                            sleep: None,
127                            timeout: between_bytes_timeout,
128                        }
129                        .boxed()
130                    }),
131                    Box::new(async {
132                        // TODO: Can we plumb connection errors through to here, or
133                        // will `hyper_util::client::legacy::Client` pass them all
134                        // via the response body?
135                        Ok(())
136                    }) as Box<dyn Future<Output = _> + Send>,
137                )),
138                Err(http_error) => match http_error.downcast() {
139                    Ok(error_code) => Err(TrappableError::from(p2_to_p3_error_code(error_code))),
140                    Err(trap) => Err(TrappableError::trap(trap)),
141                },
142            }
143        })
144    }
145}
146
147pin_project_lite::pin_project! {
148    struct BetweenBytesTimeoutBody<B> {
149        #[pin]
150        body: B,
151        #[pin]
152        sleep: Option<tokio::time::Sleep>,
153        timeout: Duration,
154    }
155}
156
157impl<B: Body<Error = p2_types::ErrorCode>> Body for BetweenBytesTimeoutBody<B> {
158    type Data = B::Data;
159    type Error = p3_types::ErrorCode;
160
161    fn poll_frame(
162        self: Pin<&mut Self>,
163        cx: &mut Context<'_>,
164    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
165        let mut me = self.project();
166        match me.body.poll_frame(cx) {
167            Poll::Ready(value) => {
168                me.sleep.as_mut().set(None);
169                Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
170            }
171            Poll::Pending => {
172                if me.sleep.is_none() {
173                    me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
174                }
175                task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
176                Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
177            }
178        }
179    }
180
181    fn is_end_stream(&self) -> bool {
182        self.body.is_end_stream()
183    }
184
185    fn size_hint(&self) -> SizeHint {
186        self.body.size_hint()
187    }
188}
189
190pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
191where
192    C: spin_factors::InitContext<OutboundHttpFactor>,
193{
194    let linker = ctx.linker();
195
196    fn get_http<C>(store: &mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>
197    where
198        C: spin_factors::InitContext<OutboundHttpFactor>,
199    {
200        let (state, table) = C::get_data_with_table(store);
201        WasiHttpImpl(WasiHttpImplInner { state, table })
202    }
203
204    let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>;
205    wasmtime_wasi_http::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
206        linker, get_http,
207    )?;
208    wasmtime_wasi_http::bindings::http::types::add_to_linker::<_, HasHttp>(
209        linker,
210        &Default::default(),
211        get_http,
212    )?;
213
214    fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
215    where
216        C: spin_factors::InitContext<OutboundHttpFactor>,
217    {
218        let (state, table) = C::get_data_with_table(store);
219        p3::WasiHttpCtxView { ctx: state, table }
220    }
221
222    let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
223    p3::bindings::http::handler::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
224    p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
225
226    wasi_2023_10_18::add_to_linker(linker, get_http)?;
227    wasi_2023_11_10::add_to_linker(linker, get_http)?;
228
229    Ok(())
230}
231
232impl OutboundHttpFactor {
233    pub fn get_wasi_http_impl(
234        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
235    ) -> Option<WasiHttpImpl<impl WasiHttpView + '_>> {
236        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
237        Some(WasiHttpImpl(WasiHttpImplInner { state, table }))
238    }
239
240    pub fn get_wasi_p3_http_impl(
241        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
242    ) -> Option<p3::WasiHttpCtxView<'_>> {
243        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
244        Some(p3::WasiHttpCtxView { ctx: state, table })
245    }
246}
247
248pub(crate) struct WasiHttpImplInner<'a> {
249    state: &'a mut InstanceState,
250    table: &'a mut ResourceTable,
251}
252
253type OutgoingRequest = http::Request<HyperOutgoingBody>;
254
255impl WasiHttpView for WasiHttpImplInner<'_> {
256    fn ctx(&mut self) -> &mut WasiHttpCtx {
257        &mut self.state.wasi_http_ctx
258    }
259
260    fn table(&mut self) -> &mut ResourceTable {
261        self.table
262    }
263
264    #[instrument(
265        name = "spin_outbound_http.send_request",
266        skip_all,
267        fields(
268            otel.kind = "client",
269            url.full = Empty,
270            http.request.method = %request.method(),
271            otel.name = %request.method(),
272            http.response.status_code = Empty,
273            server.address = Empty,
274            server.port = Empty,
275        )
276    )]
277    fn send_request(
278        &mut self,
279        request: OutgoingRequest,
280        config: OutgoingRequestConfig,
281    ) -> Result<wasmtime_wasi_http::types::HostFutureIncomingResponse, HttpError> {
282        let request_sender = RequestSender {
283            allowed_hosts: self.state.allowed_hosts.clone(),
284            component_tls_configs: self.state.component_tls_configs.clone(),
285            request_interceptor: self.state.request_interceptor.clone(),
286            self_request_origin: self.state.self_request_origin.clone(),
287            blocked_networks: self.state.blocked_networks.clone(),
288            http_clients: self.state.wasi_http_clients.clone(),
289            concurrent_outbound_connections_semaphore: self
290                .state
291                .concurrent_outbound_connections_semaphore
292                .clone(),
293        };
294        Ok(HostFutureIncomingResponse::Pending(
295            wasmtime_wasi::runtime::spawn(
296                async {
297                    match request_sender.send(request, config).await {
298                        Ok(resp) => Ok(Ok(resp)),
299                        Err(http_error) => match http_error.downcast() {
300                            Ok(error_code) => Ok(Err(error_code)),
301                            Err(trap) => Err(trap),
302                        },
303                    }
304                }
305                .in_current_span(),
306            ),
307        ))
308    }
309}
310
311struct RequestSender {
312    allowed_hosts: OutboundAllowedHosts,
313    blocked_networks: BlockedNetworks,
314    component_tls_configs: ComponentTlsClientConfigs,
315    self_request_origin: Option<SelfRequestOrigin>,
316    request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
317    http_clients: HttpClients,
318    concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
319}
320
321impl RequestSender {
322    async fn send(
323        self,
324        mut request: OutgoingRequest,
325        mut config: OutgoingRequestConfig,
326    ) -> Result<IncomingResponse, HttpError> {
327        self.prepare_request(&mut request, &mut config).await?;
328
329        // If the current span has opentelemetry trace context, inject it into the request
330        spin_telemetry::inject_trace_context(&mut request);
331
332        // Run any configured request interceptor
333        let mut override_connect_addr = None;
334        if let Some(interceptor) = &self.request_interceptor {
335            let intercept_request = std::mem::take(&mut request).into();
336            match interceptor.intercept(intercept_request).await? {
337                InterceptOutcome::Continue(mut req) => {
338                    override_connect_addr = req.override_connect_addr.take();
339                    request = req.into_hyper_request();
340                }
341                InterceptOutcome::Complete(resp) => {
342                    let resp = IncomingResponse {
343                        resp,
344                        worker: None,
345                        between_bytes_timeout: config.between_bytes_timeout,
346                    };
347                    return Ok(resp);
348                }
349            }
350        }
351
352        // Backfill span fields after potentially updating the URL in the interceptor
353        let span = tracing::Span::current();
354        if let Some(addr) = override_connect_addr {
355            span.record("server.address", addr.ip().to_string());
356            span.record("server.port", addr.port());
357        } else if let Some(authority) = request.uri().authority() {
358            span.record("server.address", authority.host());
359            if let Some(port) = authority.port_u16() {
360                span.record("server.port", port);
361            }
362        }
363
364        Ok(self
365            .send_request(request, config, override_connect_addr)
366            .await?)
367    }
368
369    async fn prepare_request(
370        &self,
371        request: &mut OutgoingRequest,
372        config: &mut OutgoingRequestConfig,
373    ) -> Result<(), ErrorCode> {
374        // wasmtime-wasi-http fills in scheme and authority for relative URLs
375        // (e.g. https://:443/<path>), which makes them hard to reason about.
376        // Undo that here.
377        let uri = request.uri_mut();
378        if uri
379            .authority()
380            .is_some_and(|authority| authority.host().is_empty())
381        {
382            let mut builder = http::uri::Builder::new();
383            if let Some(paq) = uri.path_and_query() {
384                builder = builder.path_and_query(paq.clone());
385            }
386            *uri = builder.build().unwrap();
387        }
388        tracing::Span::current().record("url.full", uri.to_string());
389
390        let is_self_request = match request.uri().authority() {
391            // Some SDKs require an authority, so we support e.g. http://self.alt/self-request
392            Some(authority) => authority.host() == "self.alt",
393            // Otherwise self requests have no authority
394            None => true,
395        };
396
397        // Enforce allowed_outbound_hosts
398        let is_allowed = if is_self_request {
399            self.allowed_hosts
400                .check_relative_url(&["http", "https"])
401                .await
402                .unwrap_or(false)
403        } else {
404            self.allowed_hosts
405                .check_url(&request.uri().to_string(), "https")
406                .await
407                .unwrap_or(false)
408        };
409        if !is_allowed {
410            return Err(ErrorCode::HttpRequestDenied);
411        }
412
413        if is_self_request {
414            // Replace the authority with the "self request origin"
415            let Some(origin) = self.self_request_origin.as_ref() else {
416                tracing::error!(
417                    "Couldn't handle outbound HTTP request to relative URI; no origin set"
418                );
419                return Err(ErrorCode::HttpRequestUriInvalid);
420            };
421
422            config.use_tls = origin.use_tls();
423
424            request.headers_mut().insert(HOST, origin.host_header());
425
426            let path_and_query = request.uri().path_and_query().cloned();
427            *request.uri_mut() = origin.clone().into_uri(path_and_query);
428        }
429
430        // Some servers (looking at you nginx) don't like a host header even though
431        // http/2 allows it: https://github.com/hyperium/hyper/issues/3298.
432        //
433        // Note that we do this _before_ invoking the request interceptor.  It may
434        // decide to add the `host` header back in, regardless of the nginx bug, in
435        // which case we'll let it do so without interferring.
436        request.headers_mut().remove(HOST);
437        Ok(())
438    }
439
440    async fn send_request(
441        self,
442        request: OutgoingRequest,
443        config: OutgoingRequestConfig,
444        override_connect_addr: Option<SocketAddr>,
445    ) -> Result<IncomingResponse, ErrorCode> {
446        let OutgoingRequestConfig {
447            use_tls,
448            connect_timeout,
449            first_byte_timeout,
450            between_bytes_timeout,
451        } = config;
452
453        let tls_client_config = if use_tls {
454            let host = request.uri().host().unwrap_or_default();
455            Some(self.component_tls_configs.get_client_config(host).clone())
456        } else {
457            None
458        };
459
460        let resp = CONNECT_OPTIONS.scope(
461            ConnectOptions {
462                blocked_networks: self.blocked_networks,
463                connect_timeout,
464                tls_client_config,
465                override_connect_addr,
466                concurrent_outbound_connections_semaphore: self
467                    .concurrent_outbound_connections_semaphore,
468            },
469            async move {
470                if use_tls {
471                    self.http_clients.https.request(request).await
472                } else {
473                    // For development purposes, allow configuring plaintext HTTP/2 for a specific host.
474                    let h2c_prior_knowledge_host =
475                        std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
476                    let use_h2c = h2c_prior_knowledge_host.as_deref()
477                        == request.uri().authority().map(|a| a.as_str());
478
479                    if use_h2c {
480                        self.http_clients.http2.request(request).await
481                    } else {
482                        self.http_clients.http1.request(request).await
483                    }
484                }
485            },
486        );
487
488        let resp = timeout(first_byte_timeout, resp)
489            .await
490            .map_err(|_| ErrorCode::ConnectionReadTimeout)?
491            .map_err(hyper_legacy_request_error)?
492            .map(|body| body.map_err(hyper_request_error).boxed());
493
494        tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
495
496        Ok(IncomingResponse {
497            resp,
498            worker: None,
499            between_bytes_timeout,
500        })
501    }
502}
503
504type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
505type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
506
507#[derive(Clone)]
508pub(super) struct HttpClients {
509    /// Used for non-TLS HTTP/1 connections.
510    http1: HttpClient,
511    /// Used for non-TLS HTTP/2 connections (e.g. when h2 prior knowledge is available).
512    http2: HttpClient,
513    /// Used for HTTP-over-TLS connections, using ALPN to negotiate the HTTP version.
514    https: HttpsClient,
515}
516
517impl HttpClients {
518    pub(super) fn new(enable_pooling: bool) -> Self {
519        let builder = move || {
520            let mut builder = Client::builder(TokioExecutor::new());
521            if !enable_pooling {
522                builder.pool_max_idle_per_host(0);
523            }
524            builder
525        };
526        Self {
527            http1: builder().build(HttpConnector),
528            http2: builder().http2_only(true).build(HttpConnector),
529            https: builder().build(HttpsConnector),
530        }
531    }
532}
533
534tokio::task_local! {
535    /// The options used when establishing a new connection.
536    ///
537    /// We must use task-local variables for these config options when using
538    /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
539    /// them through as parameters.  Moreover, if there's already a pooled connection
540    /// ready, we'll reuse that and ignore these options anyway. After each connection
541    /// is established, the options are dropped.
542    static CONNECT_OPTIONS: ConnectOptions;
543}
544
545#[derive(Clone)]
546struct ConnectOptions {
547    /// The blocked networks configuration.
548    blocked_networks: BlockedNetworks,
549    /// Timeout for establishing a TCP connection.
550    connect_timeout: Duration,
551    /// TLS client configuration to use, if any.
552    tls_client_config: Option<TlsClientConfig>,
553    /// If set, override the address to connect to instead of using the given `uri`'s authority.
554    override_connect_addr: Option<SocketAddr>,
555    /// A semaphore to limit the number of concurrent outbound connections.
556    concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
557}
558
559impl ConnectOptions {
560    /// Establish a TCP connection to the given URI and default port.
561    async fn connect_tcp(
562        &self,
563        uri: &Uri,
564        default_port: u16,
565    ) -> Result<PermittedTcpStream, ErrorCode> {
566        let mut socket_addrs = match self.override_connect_addr {
567            Some(override_connect_addr) => vec![override_connect_addr],
568            None => {
569                let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
570
571                let host_and_port = if authority.port().is_some() {
572                    authority.as_str().to_string()
573                } else {
574                    format!("{}:{}", authority.as_str(), default_port)
575                };
576
577                let socket_addrs = tokio::net::lookup_host(&host_and_port)
578                    .await
579                    .map_err(|err| {
580                        tracing::debug!(?host_and_port, ?err, "Error resolving host");
581                        dns_error("address not available".into(), 0)
582                    })?
583                    .collect::<Vec<_>>();
584                tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
585                socket_addrs
586            }
587        };
588
589        // Remove blocked IPs
590        crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;
591
592        // If we're limiting concurrent outbound requests, acquire a permit
593        let permit = match &self.concurrent_outbound_connections_semaphore {
594            Some(s) => s.clone().acquire_owned().await.ok(),
595            None => None,
596        };
597
598        let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
599            .await
600            .map_err(|_| ErrorCode::ConnectionTimeout)?
601            .map_err(|err| match err.kind() {
602                std::io::ErrorKind::AddrNotAvailable => {
603                    dns_error("address not available".into(), 0)
604                }
605                _ => ErrorCode::ConnectionRefused,
606            })?;
607        Ok(PermittedTcpStream {
608            inner: stream,
609            _permit: permit,
610        })
611    }
612
613    /// Establish a TLS connection to the given URI and default port.
614    async fn connect_tls(
615        &self,
616        uri: &Uri,
617        default_port: u16,
618    ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
619        let tcp_stream = self.connect_tcp(uri, default_port).await?;
620
621        let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
622        tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
623
624        let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
625        let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
626            .map_err(|e| {
627                tracing::warn!("dns lookup error: {e:?}");
628                dns_error("invalid dns name".into(), 0)
629            })?
630            .to_owned();
631        connector.connect(domain, tcp_stream).await.map_err(|e| {
632            tracing::warn!("tls protocol error: {e:?}");
633            ErrorCode::TlsProtocolError
634        })
635    }
636}
637
638/// A connector the uses `ConnectOptions`
639#[derive(Clone)]
640struct HttpConnector;
641
642impl HttpConnector {
643    async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
644        let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
645        Ok(TokioIo::new(stream))
646    }
647}
648
649impl Service<Uri> for HttpConnector {
650    type Response = TokioIo<PermittedTcpStream>;
651    type Error = ErrorCode;
652    type Future =
653        Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
654
655    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
656        Poll::Ready(Ok(()))
657    }
658
659    fn call(&mut self, uri: Uri) -> Self::Future {
660        Box::pin(async move { Self::connect(uri).await })
661    }
662}
663
664/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
665#[derive(Clone)]
666struct HttpsConnector;
667
668impl HttpsConnector {
669    async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
670        let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
671        Ok(TokioIo::new(RustlsStream(stream)))
672    }
673}
674
675impl Service<Uri> for HttpsConnector {
676    type Response = TokioIo<RustlsStream>;
677    type Error = ErrorCode;
678    type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
679
680    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
681        Poll::Ready(Ok(()))
682    }
683
684    fn call(&mut self, uri: Uri) -> Self::Future {
685        Box::pin(async move { Self::connect(uri).await })
686    }
687}
688
689struct RustlsStream(TlsStream<PermittedTcpStream>);
690
691impl Connection for RustlsStream {
692    fn connected(&self) -> Connected {
693        if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
694            self.0.get_ref().0.connected().negotiated_h2()
695        } else {
696            self.0.get_ref().0.connected()
697        }
698    }
699}
700
701impl AsyncRead for RustlsStream {
702    fn poll_read(
703        self: Pin<&mut Self>,
704        cx: &mut Context<'_>,
705        buf: &mut ReadBuf<'_>,
706    ) -> Poll<Result<(), std::io::Error>> {
707        Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
708    }
709}
710
711impl AsyncWrite for RustlsStream {
712    fn poll_write(
713        self: Pin<&mut Self>,
714        cx: &mut Context<'_>,
715        buf: &[u8],
716    ) -> Poll<Result<usize, std::io::Error>> {
717        Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
718    }
719
720    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
721        Pin::new(&mut self.get_mut().0).poll_flush(cx)
722    }
723
724    fn poll_shutdown(
725        self: Pin<&mut Self>,
726        cx: &mut Context<'_>,
727    ) -> Poll<Result<(), std::io::Error>> {
728        Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
729    }
730
731    fn poll_write_vectored(
732        self: Pin<&mut Self>,
733        cx: &mut Context<'_>,
734        bufs: &[IoSlice<'_>],
735    ) -> Poll<Result<usize, std::io::Error>> {
736        Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
737    }
738
739    fn is_write_vectored(&self) -> bool {
740        self.0.is_write_vectored()
741    }
742}
743
744/// A TCP stream that holds an optional permit indicating that it is allowed to exist.
745struct PermittedTcpStream {
746    /// The wrapped TCP stream.
747    inner: TcpStream,
748    /// A permit indicating that this stream is allowed to exist.
749    ///
750    /// When this stream is dropped, the permit is also dropped, allowing another
751    /// connection to be established.
752    _permit: Option<OwnedSemaphorePermit>,
753}
754
755impl Connection for PermittedTcpStream {
756    fn connected(&self) -> Connected {
757        self.inner.connected()
758    }
759}
760
761impl AsyncRead for PermittedTcpStream {
762    fn poll_read(
763        self: Pin<&mut Self>,
764        cx: &mut Context<'_>,
765        buf: &mut ReadBuf<'_>,
766    ) -> Poll<std::io::Result<()>> {
767        Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
768    }
769}
770
771impl AsyncWrite for PermittedTcpStream {
772    fn poll_write(
773        self: Pin<&mut Self>,
774        cx: &mut Context<'_>,
775        buf: &[u8],
776    ) -> Poll<Result<usize, std::io::Error>> {
777        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
778    }
779
780    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
781        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
782    }
783
784    fn poll_shutdown(
785        self: Pin<&mut Self>,
786        cx: &mut Context<'_>,
787    ) -> Poll<Result<(), std::io::Error>> {
788        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
789    }
790}
791
792/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
793fn hyper_request_error(err: hyper::Error) -> ErrorCode {
794    // If there's a source, we might be able to extract a wasi-http error from it.
795    if let Some(cause) = err.source() {
796        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
797            return err.clone();
798        }
799    }
800
801    tracing::warn!("hyper request error: {err:?}");
802
803    ErrorCode::HttpProtocolError
804}
805
806/// Translate a [`hyper_util::client::legacy::Error`] to a wasi-http `ErrorCode` in the context of a request.
807fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
808    // If there's a source, we might be able to extract a wasi-http error from it.
809    if let Some(cause) = err.source() {
810        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
811            return err.clone();
812        }
813    }
814
815    tracing::warn!("hyper request error: {err:?}");
816
817    ErrorCode::HttpProtocolError
818}
819
820fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
821    ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
822        rcode: Some(rcode),
823        info_code: Some(info_code),
824    })
825}
826
827// TODO: Remove this (and uses of it) once
828// https://github.com/spinframework/spin/issues/3274 has been addressed.
829pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
830    match code {
831        p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
832        p2_types::ErrorCode::DnsError(payload) => {
833            p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
834                rcode: payload.rcode,
835                info_code: payload.info_code,
836            })
837        }
838        p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
839        p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
840        p2_types::ErrorCode::DestinationIpProhibited => {
841            p3_types::ErrorCode::DestinationIpProhibited
842        }
843        p2_types::ErrorCode::DestinationIpUnroutable => {
844            p3_types::ErrorCode::DestinationIpUnroutable
845        }
846        p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
847        p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
848        p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
849        p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
850        p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
851        p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
852        p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
853        p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
854        p2_types::ErrorCode::TlsAlertReceived(payload) => {
855            p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
856                alert_id: payload.alert_id,
857                alert_message: payload.alert_message,
858            })
859        }
860        p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
861        p2_types::ErrorCode::HttpRequestLengthRequired => {
862            p3_types::ErrorCode::HttpRequestLengthRequired
863        }
864        p2_types::ErrorCode::HttpRequestBodySize(payload) => {
865            p3_types::ErrorCode::HttpRequestBodySize(payload)
866        }
867        p2_types::ErrorCode::HttpRequestMethodInvalid => {
868            p3_types::ErrorCode::HttpRequestMethodInvalid
869        }
870        p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
871        p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
872        p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
873            p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
874        }
875        p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
876            p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
877                p3_types::FieldSizePayload {
878                    field_name: payload.field_name,
879                    field_size: payload.field_size,
880                }
881            }))
882        }
883        p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
884            p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
885        }
886        p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
887            p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
888                field_name: payload.field_name,
889                field_size: payload.field_size,
890            })
891        }
892        p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
893        p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
894            p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
895        }
896        p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
897            p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
898                field_name: payload.field_name,
899                field_size: payload.field_size,
900            })
901        }
902        p2_types::ErrorCode::HttpResponseBodySize(payload) => {
903            p3_types::ErrorCode::HttpResponseBodySize(payload)
904        }
905        p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
906            p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
907        }
908        p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
909            p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
910                field_name: payload.field_name,
911                field_size: payload.field_size,
912            })
913        }
914        p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
915            p3_types::ErrorCode::HttpResponseTransferCoding(payload)
916        }
917        p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
918            p3_types::ErrorCode::HttpResponseContentCoding(payload)
919        }
920        p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
921        p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
922        p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
923        p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
924        p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
925        p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
926    }
927}
928
929// TODO: Remove this (and uses of it) once
930// https://github.com/spinframework/spin/issues/3274 has been addressed.
931pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
932    match code {
933        p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
934        p3_types::ErrorCode::DnsError(payload) => {
935            p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
936                rcode: payload.rcode,
937                info_code: payload.info_code,
938            })
939        }
940        p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
941        p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
942        p3_types::ErrorCode::DestinationIpProhibited => {
943            p2_types::ErrorCode::DestinationIpProhibited
944        }
945        p3_types::ErrorCode::DestinationIpUnroutable => {
946            p2_types::ErrorCode::DestinationIpUnroutable
947        }
948        p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
949        p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
950        p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
951        p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
952        p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
953        p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
954        p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
955        p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
956        p3_types::ErrorCode::TlsAlertReceived(payload) => {
957            p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
958                alert_id: payload.alert_id,
959                alert_message: payload.alert_message,
960            })
961        }
962        p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
963        p3_types::ErrorCode::HttpRequestLengthRequired => {
964            p2_types::ErrorCode::HttpRequestLengthRequired
965        }
966        p3_types::ErrorCode::HttpRequestBodySize(payload) => {
967            p2_types::ErrorCode::HttpRequestBodySize(payload)
968        }
969        p3_types::ErrorCode::HttpRequestMethodInvalid => {
970            p2_types::ErrorCode::HttpRequestMethodInvalid
971        }
972        p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
973        p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
974        p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
975            p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
976        }
977        p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
978            p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
979                p2_types::FieldSizePayload {
980                    field_name: payload.field_name,
981                    field_size: payload.field_size,
982                }
983            }))
984        }
985        p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
986            p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
987        }
988        p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
989            p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
990                field_name: payload.field_name,
991                field_size: payload.field_size,
992            })
993        }
994        p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
995        p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
996            p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
997        }
998        p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
999            p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1000                field_name: payload.field_name,
1001                field_size: payload.field_size,
1002            })
1003        }
1004        p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1005            p2_types::ErrorCode::HttpResponseBodySize(payload)
1006        }
1007        p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1008            p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1009        }
1010        p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1011            p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1012                field_name: payload.field_name,
1013                field_size: payload.field_size,
1014            })
1015        }
1016        p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1017            p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1018        }
1019        p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1020            p2_types::ErrorCode::HttpResponseContentCoding(payload)
1021        }
1022        p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1023        p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1024        p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1025        p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1026        p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1027        p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1028    }
1029}