Skip to main content

spin_factor_outbound_http/
wasi.rs

1use std::{
2    error::Error,
3    future::Future,
4    io::IoSlice,
5    net::SocketAddr,
6    ops::DerefMut,
7    pin::Pin,
8    sync::{Arc, Mutex},
9    task::{self, Context, Poll},
10    time::Duration,
11};
12
13use bytes::Bytes;
14use http::{header::HOST, uri::Scheme, Uri};
15use http_body::{Body, Frame, SizeHint};
16use http_body_util::{combinators::UnsyncBoxBody, BodyExt};
17use hyper_util::{
18    client::legacy::{
19        connect::{Connected, Connection},
20        Client,
21    },
22    rt::{TokioExecutor, TokioIo},
23};
24use spin_factor_outbound_networking::{
25    config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
26    ComponentTlsClientConfigs, TlsClientConfig,
27};
28use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceState};
29use tokio::{
30    io::{AsyncRead, AsyncWrite, ReadBuf},
31    net::TcpStream,
32    sync::{OwnedSemaphorePermit, Semaphore},
33    time::timeout,
34};
35use tokio_rustls::client::TlsStream;
36use tower_service::Service;
37use tracing::{field::Empty, instrument, Instrument};
38use wasmtime::component::HasData;
39use wasmtime_wasi::TrappableError;
40use wasmtime_wasi_http::{
41    bindings::http::types::{self as p2_types, ErrorCode},
42    body::HyperOutgoingBody,
43    p3::{self, bindings::http::types as p3_types},
44    types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
45    HttpError, WasiHttpCtx, WasiHttpImpl, WasiHttpView,
46};
47
48use crate::{
49    intercept::{InterceptOutcome, OutboundHttpInterceptor},
50    wasi_2023_10_18, wasi_2023_11_10, InstanceState, OutboundHttpFactor, SelfRequestOrigin,
51};
52
53const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
54
55pub struct MutexBody<T>(Mutex<T>);
56
57impl<T> MutexBody<T> {
58    pub fn new(body: T) -> Self {
59        Self(Mutex::new(body))
60    }
61}
62
63impl<T: Body + Unpin> Body for MutexBody<T> {
64    type Data = T::Data;
65    type Error = T::Error;
66
67    fn poll_frame(
68        self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
71        Pin::new(self.0.lock().unwrap().deref_mut()).poll_frame(cx)
72    }
73
74    fn is_end_stream(&self) -> bool {
75        self.0.lock().unwrap().is_end_stream()
76    }
77
78    fn size_hint(&self) -> SizeHint {
79        self.0.lock().unwrap().size_hint()
80    }
81}
82
83pub(crate) struct HasHttp;
84
85impl HasData for HasHttp {
86    type Data<'a> = WasiHttpImpl<WasiHttpImplInner<'a>>;
87}
88
89impl p3::WasiHttpCtx for InstanceState {
90    #[instrument(
91        name = "spin_outbound_http.send_request",
92        skip_all,
93        fields(
94            otel.kind = "client",
95            url.full = Empty,
96            http.request.method = %request.method(),
97            otel.name = %request.method(),
98            http.response.status_code = Empty,
99            server.address = Empty,
100            server.port = Empty,
101        )
102    )]
103    #[allow(clippy::type_complexity)]
104    fn send_request(
105        &mut self,
106        request: http::Request<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
107        options: Option<p3::RequestOptions>,
108        fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
109    ) -> Box<
110        dyn Future<
111                Output = Result<
112                    (
113                        http::Response<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
114                        Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
115                    ),
116                    TrappableError<p3_types::ErrorCode>,
117                >,
118            > + Send,
119    > {
120        self.otel.reparent_tracing_span();
121
122        // If the caller (i.e. the guest) has trouble consuming the response
123        // (e.g. encountering a network error while forwarding it on to some
124        // other place), it can report that error to us via `fut`.  However,
125        // there's nothing we'll be able to do with it here, so we ignore it.
126        // Presumably the guest will also drop the body stream and trailers
127        // future if it encounters such an error while those things are still
128        // arriving, which Hyper will deal with as appropriate (e.g. closing the
129        // connection).
130        _ = fut;
131
132        let request = request.map(|body| MutexBody::new(body).boxed());
133
134        let request_sender = RequestSender {
135            allowed_hosts: self.allowed_hosts.clone(),
136            component_tls_configs: self.component_tls_configs.clone(),
137            request_interceptor: self.request_interceptor.clone(),
138            self_request_origin: self.self_request_origin.clone(),
139            blocked_networks: self.blocked_networks.clone(),
140            http_clients: self.wasi_http_clients.clone(),
141            concurrent_outbound_connections_semaphore: self
142                .concurrent_outbound_connections_semaphore
143                .clone(),
144        };
145        let config = OutgoingRequestConfig {
146            use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
147            connect_timeout: options
148                .and_then(|v| v.connect_timeout)
149                .unwrap_or(DEFAULT_TIMEOUT),
150            first_byte_timeout: options
151                .and_then(|v| v.first_byte_timeout)
152                .unwrap_or(DEFAULT_TIMEOUT),
153            between_bytes_timeout: options
154                .and_then(|v| v.between_bytes_timeout)
155                .unwrap_or(DEFAULT_TIMEOUT),
156        };
157        Box::new(async {
158            match request_sender
159                .send(
160                    request.map(|body| body.map_err(p3_to_p2_error_code).boxed_unsync()),
161                    config,
162                )
163                .await
164            {
165                Ok(IncomingResponse {
166                    resp,
167                    between_bytes_timeout,
168                    ..
169                }) => Ok((
170                    resp.map(|body| {
171                        BetweenBytesTimeoutBody {
172                            body,
173                            sleep: None,
174                            timeout: between_bytes_timeout,
175                        }
176                        .boxed_unsync()
177                    }),
178                    Box::new(async {
179                        // TODO: Can we plumb connection errors through to here, or
180                        // will `hyper_util::client::legacy::Client` pass them all
181                        // via the response body?
182                        Ok(())
183                    }) as Box<dyn Future<Output = _> + Send>,
184                )),
185                Err(http_error) => match http_error.downcast() {
186                    Ok(error_code) => Err(TrappableError::from(p2_to_p3_error_code(error_code))),
187                    Err(trap) => Err(TrappableError::trap(trap)),
188                },
189            }
190        })
191    }
192}
193
194pin_project_lite::pin_project! {
195    struct BetweenBytesTimeoutBody<B> {
196        #[pin]
197        body: B,
198        #[pin]
199        sleep: Option<tokio::time::Sleep>,
200        timeout: Duration,
201    }
202}
203
204impl<B: Body<Error = p2_types::ErrorCode>> Body for BetweenBytesTimeoutBody<B> {
205    type Data = B::Data;
206    type Error = p3_types::ErrorCode;
207
208    fn poll_frame(
209        self: Pin<&mut Self>,
210        cx: &mut Context<'_>,
211    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
212        let mut me = self.project();
213        match me.body.poll_frame(cx) {
214            Poll::Ready(value) => {
215                me.sleep.as_mut().set(None);
216                Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
217            }
218            Poll::Pending => {
219                if me.sleep.is_none() {
220                    me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
221                }
222                task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
223                Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
224            }
225        }
226    }
227
228    fn is_end_stream(&self) -> bool {
229        self.body.is_end_stream()
230    }
231
232    fn size_hint(&self) -> SizeHint {
233        self.body.size_hint()
234    }
235}
236
237pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
238where
239    C: spin_factors::InitContext<OutboundHttpFactor>,
240{
241    let linker = ctx.linker();
242
243    fn get_http<C>(store: &mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>
244    where
245        C: spin_factors::InitContext<OutboundHttpFactor>,
246    {
247        let (state, table) = C::get_data_with_table(store);
248        WasiHttpImpl(WasiHttpImplInner { state, table })
249    }
250
251    let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>;
252    wasmtime_wasi_http::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
253        linker, get_http,
254    )?;
255    wasmtime_wasi_http::bindings::http::types::add_to_linker::<_, HasHttp>(
256        linker,
257        &Default::default(),
258        get_http,
259    )?;
260
261    fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
262    where
263        C: spin_factors::InitContext<OutboundHttpFactor>,
264    {
265        let (state, table) = C::get_data_with_table(store);
266        p3::WasiHttpCtxView { ctx: state, table }
267    }
268
269    let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
270    p3::bindings::http::client::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
271    p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
272
273    wasi_2023_10_18::add_to_linker(linker, get_http)?;
274    wasi_2023_11_10::add_to_linker(linker, get_http)?;
275
276    Ok(())
277}
278
279impl OutboundHttpFactor {
280    pub fn get_wasi_http_impl(
281        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
282    ) -> Option<WasiHttpImpl<impl WasiHttpView + '_>> {
283        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
284        Some(WasiHttpImpl(WasiHttpImplInner { state, table }))
285    }
286
287    pub fn get_wasi_p3_http_impl(
288        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
289    ) -> Option<p3::WasiHttpCtxView<'_>> {
290        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
291        Some(p3::WasiHttpCtxView { ctx: state, table })
292    }
293}
294
295pub(crate) struct WasiHttpImplInner<'a> {
296    state: &'a mut InstanceState,
297    table: &'a mut ResourceTable,
298}
299
300type OutgoingRequest = http::Request<HyperOutgoingBody>;
301
302impl WasiHttpView for WasiHttpImplInner<'_> {
303    fn ctx(&mut self) -> &mut WasiHttpCtx {
304        &mut self.state.wasi_http_ctx
305    }
306
307    fn table(&mut self) -> &mut ResourceTable {
308        self.table
309    }
310
311    #[instrument(
312        name = "spin_outbound_http.send_request",
313        skip_all,
314        fields(
315            otel.kind = "client",
316            url.full = Empty,
317            http.request.method = %request.method(),
318            otel.name = %request.method(),
319            http.response.status_code = Empty,
320            server.address = Empty,
321            server.port = Empty,
322        )
323    )]
324    fn send_request(
325        &mut self,
326        request: OutgoingRequest,
327        config: OutgoingRequestConfig,
328    ) -> Result<wasmtime_wasi_http::types::HostFutureIncomingResponse, HttpError> {
329        self.state.otel.reparent_tracing_span();
330
331        let request_sender = RequestSender {
332            allowed_hosts: self.state.allowed_hosts.clone(),
333            component_tls_configs: self.state.component_tls_configs.clone(),
334            request_interceptor: self.state.request_interceptor.clone(),
335            self_request_origin: self.state.self_request_origin.clone(),
336            blocked_networks: self.state.blocked_networks.clone(),
337            http_clients: self.state.wasi_http_clients.clone(),
338            concurrent_outbound_connections_semaphore: self
339                .state
340                .concurrent_outbound_connections_semaphore
341                .clone(),
342        };
343        Ok(HostFutureIncomingResponse::Pending(
344            wasmtime_wasi::runtime::spawn(
345                async {
346                    match request_sender.send(request, config).await {
347                        Ok(resp) => Ok(Ok(resp)),
348                        Err(http_error) => match http_error.downcast() {
349                            Ok(error_code) => Ok(Err(error_code)),
350                            Err(trap) => Err(trap),
351                        },
352                    }
353                }
354                .in_current_span(),
355            ),
356        ))
357    }
358}
359
360struct RequestSender {
361    allowed_hosts: OutboundAllowedHosts,
362    blocked_networks: BlockedNetworks,
363    component_tls_configs: ComponentTlsClientConfigs,
364    self_request_origin: Option<SelfRequestOrigin>,
365    request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
366    http_clients: HttpClients,
367    concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
368}
369
370impl RequestSender {
371    async fn send(
372        self,
373        mut request: OutgoingRequest,
374        mut config: OutgoingRequestConfig,
375    ) -> Result<IncomingResponse, HttpError> {
376        self.prepare_request(&mut request, &mut config).await?;
377
378        // If the current span has opentelemetry trace context, inject it into the request
379        spin_telemetry::inject_trace_context(&mut request);
380
381        // Run any configured request interceptor
382        let mut override_connect_addr = None;
383        if let Some(interceptor) = &self.request_interceptor {
384            let intercept_request = std::mem::take(&mut request).into();
385            match interceptor.intercept(intercept_request).await? {
386                InterceptOutcome::Continue(mut req) => {
387                    override_connect_addr = req.override_connect_addr.take();
388                    request = req.into_hyper_request();
389                }
390                InterceptOutcome::Complete(resp) => {
391                    let resp = IncomingResponse {
392                        resp,
393                        worker: None,
394                        between_bytes_timeout: config.between_bytes_timeout,
395                    };
396                    return Ok(resp);
397                }
398            }
399        }
400
401        // Backfill span fields after potentially updating the URL in the interceptor
402        let span = tracing::Span::current();
403        if let Some(addr) = override_connect_addr {
404            span.record("server.address", addr.ip().to_string());
405            span.record("server.port", addr.port());
406        } else if let Some(authority) = request.uri().authority() {
407            span.record("server.address", authority.host());
408            if let Some(port) = authority.port_u16() {
409                span.record("server.port", port);
410            }
411        }
412
413        Ok(self
414            .send_request(request, config, override_connect_addr)
415            .await?)
416    }
417
418    async fn prepare_request(
419        &self,
420        request: &mut OutgoingRequest,
421        config: &mut OutgoingRequestConfig,
422    ) -> Result<(), ErrorCode> {
423        // wasmtime-wasi-http fills in scheme and authority for relative URLs
424        // (e.g. https://:443/<path>), which makes them hard to reason about.
425        // Undo that here.
426        let uri = request.uri_mut();
427        if uri
428            .authority()
429            .is_some_and(|authority| authority.host().is_empty())
430        {
431            let mut builder = http::uri::Builder::new();
432            if let Some(paq) = uri.path_and_query() {
433                builder = builder.path_and_query(paq.clone());
434            }
435            *uri = builder.build().unwrap();
436        }
437        tracing::Span::current().record("url.full", uri.to_string());
438
439        let is_self_request = match request.uri().authority() {
440            // Some SDKs require an authority, so we support e.g. http://self.alt/self-request
441            Some(authority) => authority.host() == "self.alt",
442            // Otherwise self requests have no authority
443            None => true,
444        };
445
446        // Enforce allowed_outbound_hosts
447        let is_allowed = if is_self_request {
448            self.allowed_hosts
449                .check_relative_url(&["http", "https"])
450                .await
451                .unwrap_or(false)
452        } else {
453            self.allowed_hosts
454                .check_url(&request.uri().to_string(), "https")
455                .await
456                .unwrap_or(false)
457        };
458        if !is_allowed {
459            return Err(ErrorCode::HttpRequestDenied);
460        }
461
462        if is_self_request {
463            // Replace the authority with the "self request origin"
464            let Some(origin) = self.self_request_origin.as_ref() else {
465                tracing::error!(
466                    "Couldn't handle outbound HTTP request to relative URI; no origin set"
467                );
468                return Err(ErrorCode::HttpRequestUriInvalid);
469            };
470
471            config.use_tls = origin.use_tls();
472
473            request.headers_mut().insert(HOST, origin.host_header());
474
475            let path_and_query = request.uri().path_and_query().cloned();
476            *request.uri_mut() = origin.clone().into_uri(path_and_query);
477        }
478
479        // Some servers (looking at you nginx) don't like a host header even though
480        // http/2 allows it: https://github.com/hyperium/hyper/issues/3298.
481        //
482        // Note that we do this _before_ invoking the request interceptor.  It may
483        // decide to add the `host` header back in, regardless of the nginx bug, in
484        // which case we'll let it do so without interferring.
485        request.headers_mut().remove(HOST);
486        Ok(())
487    }
488
489    async fn send_request(
490        self,
491        request: OutgoingRequest,
492        config: OutgoingRequestConfig,
493        override_connect_addr: Option<SocketAddr>,
494    ) -> Result<IncomingResponse, ErrorCode> {
495        let OutgoingRequestConfig {
496            use_tls,
497            connect_timeout,
498            first_byte_timeout,
499            between_bytes_timeout,
500        } = config;
501
502        let tls_client_config = if use_tls {
503            let host = request.uri().host().unwrap_or_default();
504            Some(self.component_tls_configs.get_client_config(host).clone())
505        } else {
506            None
507        };
508
509        let resp = CONNECT_OPTIONS.scope(
510            ConnectOptions {
511                blocked_networks: self.blocked_networks,
512                connect_timeout,
513                tls_client_config,
514                override_connect_addr,
515                concurrent_outbound_connections_semaphore: self
516                    .concurrent_outbound_connections_semaphore,
517            },
518            async move {
519                if use_tls {
520                    self.http_clients.https.request(request).await
521                } else {
522                    // For development purposes, allow configuring plaintext HTTP/2 for a specific host.
523                    let h2c_prior_knowledge_host =
524                        std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
525                    let use_h2c = h2c_prior_knowledge_host.as_deref()
526                        == request.uri().authority().map(|a| a.as_str());
527
528                    if use_h2c {
529                        self.http_clients.http2.request(request).await
530                    } else {
531                        self.http_clients.http1.request(request).await
532                    }
533                }
534            },
535        );
536
537        let resp = timeout(first_byte_timeout, resp)
538            .await
539            .map_err(|_| ErrorCode::ConnectionReadTimeout)?
540            .map_err(hyper_legacy_request_error)?
541            .map(|body| body.map_err(hyper_request_error).boxed_unsync());
542
543        tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
544
545        Ok(IncomingResponse {
546            resp,
547            worker: None,
548            between_bytes_timeout,
549        })
550    }
551}
552
553type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
554type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
555
556#[derive(Clone)]
557pub(super) struct HttpClients {
558    /// Used for non-TLS HTTP/1 connections.
559    http1: HttpClient,
560    /// Used for non-TLS HTTP/2 connections (e.g. when h2 prior knowledge is available).
561    http2: HttpClient,
562    /// Used for HTTP-over-TLS connections, using ALPN to negotiate the HTTP version.
563    https: HttpsClient,
564}
565
566impl HttpClients {
567    pub(super) fn new(enable_pooling: bool) -> Self {
568        let builder = move || {
569            let mut builder = Client::builder(TokioExecutor::new());
570            if !enable_pooling {
571                builder.pool_max_idle_per_host(0);
572            }
573            builder
574        };
575        Self {
576            http1: builder().build(HttpConnector),
577            http2: builder().http2_only(true).build(HttpConnector),
578            https: builder().build(HttpsConnector),
579        }
580    }
581}
582
583tokio::task_local! {
584    /// The options used when establishing a new connection.
585    ///
586    /// We must use task-local variables for these config options when using
587    /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
588    /// them through as parameters.  Moreover, if there's already a pooled connection
589    /// ready, we'll reuse that and ignore these options anyway. After each connection
590    /// is established, the options are dropped.
591    static CONNECT_OPTIONS: ConnectOptions;
592}
593
594#[derive(Clone)]
595struct ConnectOptions {
596    /// The blocked networks configuration.
597    blocked_networks: BlockedNetworks,
598    /// Timeout for establishing a TCP connection.
599    connect_timeout: Duration,
600    /// TLS client configuration to use, if any.
601    tls_client_config: Option<TlsClientConfig>,
602    /// If set, override the address to connect to instead of using the given `uri`'s authority.
603    override_connect_addr: Option<SocketAddr>,
604    /// A semaphore to limit the number of concurrent outbound connections.
605    concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
606}
607
608impl ConnectOptions {
609    /// Establish a TCP connection to the given URI and default port.
610    async fn connect_tcp(
611        &self,
612        uri: &Uri,
613        default_port: u16,
614    ) -> Result<PermittedTcpStream, ErrorCode> {
615        let mut socket_addrs = match self.override_connect_addr {
616            Some(override_connect_addr) => vec![override_connect_addr],
617            None => {
618                let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
619
620                let host_and_port = if authority.port().is_some() {
621                    authority.as_str().to_string()
622                } else {
623                    format!("{}:{}", authority.as_str(), default_port)
624                };
625
626                let socket_addrs = tokio::net::lookup_host(&host_and_port)
627                    .await
628                    .map_err(|err| {
629                        tracing::debug!(?host_and_port, ?err, "Error resolving host");
630                        dns_error("address not available".into(), 0)
631                    })?
632                    .collect::<Vec<_>>();
633                tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
634                socket_addrs
635            }
636        };
637
638        // Remove blocked IPs
639        crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;
640
641        // If we're limiting concurrent outbound requests, acquire a permit
642
643        let permit = crate::concurrent_outbound_connections::acquire_owned_semaphore(
644            "wasi",
645            &self.concurrent_outbound_connections_semaphore,
646        )
647        .await;
648
649        let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
650            .await
651            .map_err(|_| ErrorCode::ConnectionTimeout)?
652            .map_err(|err| match err.kind() {
653                std::io::ErrorKind::AddrNotAvailable => {
654                    dns_error("address not available".into(), 0)
655                }
656                _ => ErrorCode::ConnectionRefused,
657            })?;
658        Ok(PermittedTcpStream {
659            inner: stream,
660            _permit: permit,
661        })
662    }
663
664    /// Establish a TLS connection to the given URI and default port.
665    async fn connect_tls(
666        &self,
667        uri: &Uri,
668        default_port: u16,
669    ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
670        let tcp_stream = self.connect_tcp(uri, default_port).await?;
671
672        let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
673        tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
674
675        let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
676        let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
677            .map_err(|e| {
678                tracing::warn!("dns lookup error: {e:?}");
679                dns_error("invalid dns name".into(), 0)
680            })?
681            .to_owned();
682        connector.connect(domain, tcp_stream).await.map_err(|e| {
683            tracing::warn!("tls protocol error: {e:?}");
684            ErrorCode::TlsProtocolError
685        })
686    }
687}
688
689/// A connector the uses `ConnectOptions`
690#[derive(Clone)]
691struct HttpConnector;
692
693impl HttpConnector {
694    async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
695        let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
696        Ok(TokioIo::new(stream))
697    }
698}
699
700impl Service<Uri> for HttpConnector {
701    type Response = TokioIo<PermittedTcpStream>;
702    type Error = ErrorCode;
703    type Future =
704        Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
705
706    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
707        Poll::Ready(Ok(()))
708    }
709
710    fn call(&mut self, uri: Uri) -> Self::Future {
711        Box::pin(async move { Self::connect(uri).await })
712    }
713}
714
715/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
716#[derive(Clone)]
717struct HttpsConnector;
718
719impl HttpsConnector {
720    async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
721        let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
722        Ok(TokioIo::new(RustlsStream(stream)))
723    }
724}
725
726impl Service<Uri> for HttpsConnector {
727    type Response = TokioIo<RustlsStream>;
728    type Error = ErrorCode;
729    type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
730
731    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
732        Poll::Ready(Ok(()))
733    }
734
735    fn call(&mut self, uri: Uri) -> Self::Future {
736        Box::pin(async move { Self::connect(uri).await })
737    }
738}
739
740struct RustlsStream(TlsStream<PermittedTcpStream>);
741
742impl Connection for RustlsStream {
743    fn connected(&self) -> Connected {
744        if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
745            self.0.get_ref().0.connected().negotiated_h2()
746        } else {
747            self.0.get_ref().0.connected()
748        }
749    }
750}
751
752impl AsyncRead for RustlsStream {
753    fn poll_read(
754        self: Pin<&mut Self>,
755        cx: &mut Context<'_>,
756        buf: &mut ReadBuf<'_>,
757    ) -> Poll<Result<(), std::io::Error>> {
758        Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
759    }
760}
761
762impl AsyncWrite for RustlsStream {
763    fn poll_write(
764        self: Pin<&mut Self>,
765        cx: &mut Context<'_>,
766        buf: &[u8],
767    ) -> Poll<Result<usize, std::io::Error>> {
768        Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
769    }
770
771    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
772        Pin::new(&mut self.get_mut().0).poll_flush(cx)
773    }
774
775    fn poll_shutdown(
776        self: Pin<&mut Self>,
777        cx: &mut Context<'_>,
778    ) -> Poll<Result<(), std::io::Error>> {
779        Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
780    }
781
782    fn poll_write_vectored(
783        self: Pin<&mut Self>,
784        cx: &mut Context<'_>,
785        bufs: &[IoSlice<'_>],
786    ) -> Poll<Result<usize, std::io::Error>> {
787        Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
788    }
789
790    fn is_write_vectored(&self) -> bool {
791        self.0.is_write_vectored()
792    }
793}
794
795/// A TCP stream that holds an optional permit indicating that it is allowed to exist.
796struct PermittedTcpStream {
797    /// The wrapped TCP stream.
798    inner: TcpStream,
799    /// A permit indicating that this stream is allowed to exist.
800    ///
801    /// When this stream is dropped, the permit is also dropped, allowing another
802    /// connection to be established.
803    _permit: Option<OwnedSemaphorePermit>,
804}
805
806impl Connection for PermittedTcpStream {
807    fn connected(&self) -> Connected {
808        self.inner.connected()
809    }
810}
811
812impl AsyncRead for PermittedTcpStream {
813    fn poll_read(
814        self: Pin<&mut Self>,
815        cx: &mut Context<'_>,
816        buf: &mut ReadBuf<'_>,
817    ) -> Poll<std::io::Result<()>> {
818        Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
819    }
820}
821
822impl AsyncWrite for PermittedTcpStream {
823    fn poll_write(
824        self: Pin<&mut Self>,
825        cx: &mut Context<'_>,
826        buf: &[u8],
827    ) -> Poll<Result<usize, std::io::Error>> {
828        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
829    }
830
831    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
832        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
833    }
834
835    fn poll_shutdown(
836        self: Pin<&mut Self>,
837        cx: &mut Context<'_>,
838    ) -> Poll<Result<(), std::io::Error>> {
839        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
840    }
841}
842
843/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
844fn hyper_request_error(err: hyper::Error) -> ErrorCode {
845    // If there's a source, we might be able to extract a wasi-http error from it.
846    if let Some(cause) = err.source() {
847        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
848            return err.clone();
849        }
850    }
851
852    tracing::warn!("hyper request error: {err:?}");
853
854    ErrorCode::HttpProtocolError
855}
856
857/// Translate a [`hyper_util::client::legacy::Error`] to a wasi-http `ErrorCode` in the context of a request.
858fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
859    // If there's a source, we might be able to extract a wasi-http error from it.
860    if let Some(cause) = err.source() {
861        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
862            return err.clone();
863        }
864    }
865
866    tracing::warn!("hyper request error: {err:?}");
867
868    ErrorCode::HttpProtocolError
869}
870
871fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
872    ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
873        rcode: Some(rcode),
874        info_code: Some(info_code),
875    })
876}
877
878// TODO: Remove this (and uses of it) once
879// https://github.com/spinframework/spin/issues/3274 has been addressed.
880pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
881    match code {
882        p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
883        p2_types::ErrorCode::DnsError(payload) => {
884            p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
885                rcode: payload.rcode,
886                info_code: payload.info_code,
887            })
888        }
889        p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
890        p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
891        p2_types::ErrorCode::DestinationIpProhibited => {
892            p3_types::ErrorCode::DestinationIpProhibited
893        }
894        p2_types::ErrorCode::DestinationIpUnroutable => {
895            p3_types::ErrorCode::DestinationIpUnroutable
896        }
897        p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
898        p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
899        p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
900        p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
901        p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
902        p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
903        p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
904        p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
905        p2_types::ErrorCode::TlsAlertReceived(payload) => {
906            p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
907                alert_id: payload.alert_id,
908                alert_message: payload.alert_message,
909            })
910        }
911        p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
912        p2_types::ErrorCode::HttpRequestLengthRequired => {
913            p3_types::ErrorCode::HttpRequestLengthRequired
914        }
915        p2_types::ErrorCode::HttpRequestBodySize(payload) => {
916            p3_types::ErrorCode::HttpRequestBodySize(payload)
917        }
918        p2_types::ErrorCode::HttpRequestMethodInvalid => {
919            p3_types::ErrorCode::HttpRequestMethodInvalid
920        }
921        p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
922        p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
923        p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
924            p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
925        }
926        p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
927            p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
928                p3_types::FieldSizePayload {
929                    field_name: payload.field_name,
930                    field_size: payload.field_size,
931                }
932            }))
933        }
934        p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
935            p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
936        }
937        p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
938            p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
939                field_name: payload.field_name,
940                field_size: payload.field_size,
941            })
942        }
943        p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
944        p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
945            p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
946        }
947        p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
948            p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
949                field_name: payload.field_name,
950                field_size: payload.field_size,
951            })
952        }
953        p2_types::ErrorCode::HttpResponseBodySize(payload) => {
954            p3_types::ErrorCode::HttpResponseBodySize(payload)
955        }
956        p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
957            p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
958        }
959        p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
960            p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
961                field_name: payload.field_name,
962                field_size: payload.field_size,
963            })
964        }
965        p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
966            p3_types::ErrorCode::HttpResponseTransferCoding(payload)
967        }
968        p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
969            p3_types::ErrorCode::HttpResponseContentCoding(payload)
970        }
971        p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
972        p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
973        p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
974        p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
975        p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
976        p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
977    }
978}
979
980// TODO: Remove this (and uses of it) once
981// https://github.com/spinframework/spin/issues/3274 has been addressed.
982pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
983    match code {
984        p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
985        p3_types::ErrorCode::DnsError(payload) => {
986            p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
987                rcode: payload.rcode,
988                info_code: payload.info_code,
989            })
990        }
991        p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
992        p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
993        p3_types::ErrorCode::DestinationIpProhibited => {
994            p2_types::ErrorCode::DestinationIpProhibited
995        }
996        p3_types::ErrorCode::DestinationIpUnroutable => {
997            p2_types::ErrorCode::DestinationIpUnroutable
998        }
999        p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
1000        p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
1001        p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
1002        p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
1003        p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
1004        p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
1005        p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
1006        p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
1007        p3_types::ErrorCode::TlsAlertReceived(payload) => {
1008            p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
1009                alert_id: payload.alert_id,
1010                alert_message: payload.alert_message,
1011            })
1012        }
1013        p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
1014        p3_types::ErrorCode::HttpRequestLengthRequired => {
1015            p2_types::ErrorCode::HttpRequestLengthRequired
1016        }
1017        p3_types::ErrorCode::HttpRequestBodySize(payload) => {
1018            p2_types::ErrorCode::HttpRequestBodySize(payload)
1019        }
1020        p3_types::ErrorCode::HttpRequestMethodInvalid => {
1021            p2_types::ErrorCode::HttpRequestMethodInvalid
1022        }
1023        p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
1024        p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
1025        p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
1026            p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
1027        }
1028        p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
1029            p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
1030                p2_types::FieldSizePayload {
1031                    field_name: payload.field_name,
1032                    field_size: payload.field_size,
1033                }
1034            }))
1035        }
1036        p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
1037            p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
1038        }
1039        p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
1040            p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
1041                field_name: payload.field_name,
1042                field_size: payload.field_size,
1043            })
1044        }
1045        p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
1046        p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1047            p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1048        }
1049        p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1050            p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1051                field_name: payload.field_name,
1052                field_size: payload.field_size,
1053            })
1054        }
1055        p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1056            p2_types::ErrorCode::HttpResponseBodySize(payload)
1057        }
1058        p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1059            p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1060        }
1061        p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1062            p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1063                field_name: payload.field_name,
1064                field_size: payload.field_size,
1065            })
1066        }
1067        p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1068            p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1069        }
1070        p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1071            p2_types::ErrorCode::HttpResponseContentCoding(payload)
1072        }
1073        p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1074        p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1075        p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1076        p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1077        p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1078        p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1079    }
1080}