Skip to main content

spin_factor_outbound_http/
wasi.rs

1use std::{
2    error::Error,
3    future::Future,
4    io::IoSlice,
5    net::SocketAddr,
6    ops::DerefMut,
7    pin::Pin,
8    sync::{Arc, Mutex},
9    task::{self, Context, Poll},
10    time::Duration,
11};
12
13use bytes::{Buf, Bytes};
14use futures::channel::oneshot;
15use http::{
16    HeaderMap, Uri,
17    header::{CONTENT_LENGTH, HOST},
18    uri::Scheme,
19};
20use http_body::{Body, Frame, SizeHint};
21use http_body_util::{BodyExt, combinators::UnsyncBoxBody};
22use hyper_util::{
23    client::legacy::{
24        Client,
25        connect::{Connected, Connection},
26    },
27    rt::{TokioExecutor, TokioIo},
28};
29use opentelemetry_semantic_conventions::attribute as otel_attribute;
30use spin_factor_outbound_networking::{
31    ComponentTlsClientConfigs, TlsClientConfig,
32    config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
33};
34use spin_factors::RuntimeFactorsInstanceState;
35use tokio::{
36    io::{AsyncRead, AsyncWrite, ReadBuf},
37    net::TcpStream,
38    time::timeout,
39};
40use tokio_rustls::client::TlsStream;
41use tower_service::Service;
42use tracing::{Instrument, Span, field::Empty, instrument};
43use wasmtime::component::HasData;
44use wasmtime_wasi::TrappableError;
45use wasmtime_wasi_http::{
46    p2::{
47        self, HttpError, WasiHttpCtxView,
48        bindings::http::types::{self as p2_types, ErrorCode},
49        body::HyperOutgoingBody,
50        types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
51    },
52    p3::{self, bindings::http::types as p3_types},
53};
54
55use spin_factor_outbound_networking::{ConnectionPermit, ConnectionSemaphore};
56
57use crate::{
58    InstanceHttpHooks, OutboundHttpFactor, SelfRequestOrigin,
59    intercept::{InterceptOutcome, OutboundHttpInterceptor},
60    wasi_2023_10_18, wasi_2023_11_10,
61};
62
63use tracing_opentelemetry::OpenTelemetrySpanExt as _;
64
65const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
66
67pub struct MutexBody<T>(Mutex<T>);
68
69impl<T> MutexBody<T> {
70    pub fn new(body: T) -> Self {
71        Self(Mutex::new(body))
72    }
73}
74
75impl<T: Body + Unpin> Body for MutexBody<T> {
76    type Data = T::Data;
77    type Error = T::Error;
78
79    fn poll_frame(
80        self: Pin<&mut Self>,
81        cx: &mut Context<'_>,
82    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
83        Pin::new(self.0.lock().unwrap().deref_mut()).poll_frame(cx)
84    }
85
86    fn is_end_stream(&self) -> bool {
87        self.0.lock().unwrap().is_end_stream()
88    }
89
90    fn size_hint(&self) -> SizeHint {
91        self.0.lock().unwrap().size_hint()
92    }
93}
94
95/// Body which, when dropped, will notify the receiver corresponding to the
96/// specified sender.
97pub struct NotifyOnDropBody<B> {
98    body: B,
99    _tx: oneshot::Sender<()>,
100}
101
102impl<B> NotifyOnDropBody<B> {
103    pub fn new(body: B, tx: oneshot::Sender<()>) -> Self {
104        Self { body, _tx: tx }
105    }
106}
107
108impl<B: Body + Unpin> Body for NotifyOnDropBody<B> {
109    type Data = B::Data;
110    type Error = B::Error;
111
112    fn poll_frame(
113        mut self: Pin<&mut Self>,
114        cx: &mut Context<'_>,
115    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
116        Pin::new(&mut self.body).poll_frame(cx)
117    }
118
119    fn is_end_stream(&self) -> bool {
120        self.body.is_end_stream()
121    }
122
123    fn size_hint(&self) -> SizeHint {
124        self.body.size_hint()
125    }
126}
127
128pub(crate) struct HasHttp;
129
130impl HasData for HasHttp {
131    type Data<'a> = WasiHttpCtxView<'a>;
132}
133
134impl p3::WasiHttpHooks for InstanceHttpHooks {
135    #[instrument(
136        name = "spin_outbound_http.send_request",
137        skip_all,
138        fields(
139            otel.kind = "client",
140            {otel_attribute::URL_FULL} = Empty,
141            {otel_attribute::HTTP_REQUEST_METHOD} = %request.method(),
142            otel.name = %request.method(),
143            // Incubating convention; not yet a stable `opentelemetry_semantic_conventions` constant.
144            http.response.body.size = Empty,
145            {otel_attribute::HTTP_RESPONSE_STATUS_CODE} = Empty,
146            {otel_attribute::SERVER_ADDRESS} = Empty,
147            {otel_attribute::SERVER_PORT} = Empty,
148        )
149    )]
150    #[allow(clippy::type_complexity)]
151    fn send_request(
152        &mut self,
153        request: http::Request<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
154        options: Option<p3::RequestOptions>,
155        fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
156    ) -> Box<
157        dyn Future<
158                Output = Result<
159                    (
160                        http::Response<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
161                        Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
162                    ),
163                    TrappableError<p3_types::ErrorCode>,
164                >,
165            > + Send,
166    > {
167        self.otel.reparent_tracing_span();
168
169        // If the caller (i.e. the guest) has trouble consuming the response
170        // (e.g. encountering a network error while forwarding it on to some
171        // other place), it can report that error to us via `fut`.  However,
172        // there's nothing we'll be able to do with it here, so we ignore it.
173        // Presumably the guest will also drop the body stream and trailers
174        // future if it encounters such an error while those things are still
175        // arriving, which Hyper will deal with as appropriate (e.g. closing the
176        // connection).
177        _ = fut;
178
179        let request = request.map(|body| MutexBody::new(body).boxed());
180
181        let request_sender = RequestSender {
182            allowed_hosts: self.allowed_hosts.clone(),
183            component_tls_configs: self.component_tls_configs.clone(),
184            request_interceptor: self.request_interceptor.clone(),
185            self_request_origin: self.self_request_origin.clone(),
186            blocked_networks: self.blocked_networks.clone(),
187            http_clients: self.wasi_http_clients.clone(),
188            semaphore: self.semaphore.clone(),
189        };
190        let config = OutgoingRequestConfig {
191            use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
192            connect_timeout: options
193                .and_then(|v| v.connect_timeout)
194                .unwrap_or(DEFAULT_TIMEOUT),
195            first_byte_timeout: options
196                .and_then(|v| v.first_byte_timeout)
197                .unwrap_or(DEFAULT_TIMEOUT),
198            between_bytes_timeout: options
199                .and_then(|v| v.between_bytes_timeout)
200                .unwrap_or(DEFAULT_TIMEOUT),
201        };
202        Box::new(
203            async {
204                match request_sender
205                    .send(
206                        request.map(|body| body.map_err(p3_to_p2_error_code).boxed_unsync()),
207                        config,
208                    )
209                    .await
210                {
211                    Ok(IncomingResponse {
212                        resp,
213                        between_bytes_timeout,
214                        ..
215                    }) => Ok((
216                        resp.map(|body| {
217                            BetweenBytesTimeoutBody {
218                                body: Some(body),
219                                sleep: None,
220                                timeout: between_bytes_timeout,
221                                byte_count: 0,
222                                span: Some(Span::current()),
223                            }
224                            .boxed_unsync()
225                        }),
226                        Box::new(async {
227                            // TODO: Can we plumb connection errors through to here, or
228                            // will `hyper_util::client::legacy::Client` pass them all
229                            // via the response body?
230                            Ok(())
231                        }) as Box<dyn Future<Output = _> + Send>,
232                    )),
233                    Err(http_error) => match http_error.downcast() {
234                        Ok(error_code) => {
235                            Err(TrappableError::from(p2_to_p3_error_code(error_code)))
236                        }
237                        Err(trap) => Err(TrappableError::trap(trap)),
238                    },
239                }
240            }
241            .in_current_span(),
242        )
243    }
244}
245
246pin_project_lite::pin_project! {
247    struct BetweenBytesTimeoutBody<B> {
248        // Wrapping `body` in `Option` lets us take it out (dropping the underlying connection) when the timeout fires.
249        body: Option<B>,
250        #[pin]
251        sleep: Option<tokio::time::Sleep>,
252        timeout: Duration,
253        byte_count: u64,
254        span: Option<Span>,
255    }
256}
257
258impl<B: Body<Error = p2_types::ErrorCode> + Unpin> Body for BetweenBytesTimeoutBody<B> {
259    type Data = B::Data;
260    type Error = p3_types::ErrorCode;
261
262    fn poll_frame(
263        self: Pin<&mut Self>,
264        cx: &mut Context<'_>,
265    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
266        let mut me = self.project();
267
268        // Scope the mutable borrow so we can touch `me.body` again below.
269        let poll_result = {
270            let Some(body) = me.body.as_mut() else {
271                // Body already dropped by a previous timeout fire.
272                return Poll::Ready(None);
273            };
274            Pin::new(body).poll_frame(cx)
275        };
276
277        let mut record_body_size_once = |body_size: u64| {
278            if let Some(span) = me.span.take() {
279                // `http.response.body.size` is incubating (behind semconv_experimental)
280                // in opentelemetry-semantic-conventions 0.28. Leave as literal to avoid
281                // enabling the experimental feature.
282                span.record("http.response.body.size", body_size);
283            }
284        };
285        match poll_result {
286            Poll::Ready(value) => {
287                me.sleep.as_mut().set(None);
288
289                match &value {
290                    Some(Ok(frame)) => {
291                        if let Some(data) = frame.data_ref() {
292                            *me.byte_count += data.remaining() as u64;
293                        }
294                        if me.body.as_ref().is_some_and(|b| b.is_end_stream()) {
295                            record_body_size_once(*me.byte_count);
296                        }
297                    }
298                    None => {
299                        record_body_size_once(*me.byte_count);
300                    }
301                    Some(Err(e)) => {
302                        tracing::warn!("error reading response body: {e:?}");
303                    }
304                }
305
306                Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
307            }
308            Poll::Pending => {
309                if me.sleep.is_none() {
310                    me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
311                }
312                task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
313
314                // Drop the inner body immediately to free resources (like sockets)
315                // rather than waiting for the guest to release the resource.
316                *me.body = None;
317                record_body_size_once(*me.byte_count);
318
319                Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
320            }
321        }
322    }
323
324    fn is_end_stream(&self) -> bool {
325        self.body.as_ref().is_none_or(|b| b.is_end_stream())
326    }
327
328    fn size_hint(&self) -> SizeHint {
329        self.body
330            .as_ref()
331            .map(|b| b.size_hint())
332            .unwrap_or_default()
333    }
334}
335
336pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
337where
338    C: spin_factors::InitContext<OutboundHttpFactor>,
339{
340    let linker = ctx.linker();
341
342    fn get_http<C>(store: &mut C::StoreData) -> WasiHttpCtxView<'_>
343    where
344        C: spin_factors::InitContext<OutboundHttpFactor>,
345    {
346        let (state, table) = C::get_data_with_table(store);
347        let ctx = &mut state.wasi_http_ctx;
348        WasiHttpCtxView {
349            ctx,
350            table,
351            hooks: &mut state.hooks,
352        }
353    }
354
355    let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpCtxView<'_>;
356    wasmtime_wasi_http::p2::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
357        linker, get_http,
358    )?;
359    wasmtime_wasi_http::p2::bindings::http::types::add_to_linker::<_, HasHttp>(
360        linker,
361        &Default::default(),
362        get_http,
363    )?;
364
365    fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
366    where
367        C: spin_factors::InitContext<OutboundHttpFactor>,
368    {
369        let (state, table) = C::get_data_with_table(store);
370        let ctx = &mut state.wasi_http_ctx;
371        p3::WasiHttpCtxView {
372            ctx,
373            table,
374            hooks: &mut state.hooks,
375        }
376    }
377
378    let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
379    p3::bindings::http::client::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
380    p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
381
382    wasi_2023_10_18::add_to_linker(linker, get_http)?;
383    wasi_2023_11_10::add_to_linker(linker, get_http)?;
384
385    Ok(())
386}
387
388impl OutboundHttpFactor {
389    pub fn get_wasi_http_impl(
390        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
391    ) -> Option<WasiHttpCtxView<'_>> {
392        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
393        let ctx = &mut state.wasi_http_ctx;
394        Some(WasiHttpCtxView {
395            ctx,
396            table,
397            hooks: &mut state.hooks,
398        })
399    }
400
401    pub fn get_wasi_p3_http_impl(
402        runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
403    ) -> Option<p3::WasiHttpCtxView<'_>> {
404        let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
405        let ctx = &mut state.wasi_http_ctx;
406        Some(p3::WasiHttpCtxView {
407            ctx,
408            table,
409            hooks: &mut state.hooks,
410        })
411    }
412}
413
414type OutgoingRequest = http::Request<HyperOutgoingBody>;
415
416impl p2::WasiHttpHooks for InstanceHttpHooks {
417    #[instrument(
418        name = "spin_outbound_http.send_request",
419        skip_all,
420        fields(
421            otel.kind = "client",
422            {otel_attribute::URL_FULL} = Empty,
423            {otel_attribute::HTTP_REQUEST_METHOD} = %request.method(),
424            otel.name = %request.method(),
425            {otel_attribute::HTTP_RESPONSE_STATUS_CODE} = Empty,
426            {otel_attribute::SERVER_ADDRESS} = Empty,
427            {otel_attribute::SERVER_PORT} = Empty,
428        )
429    )]
430    fn send_request(
431        &mut self,
432        request: OutgoingRequest,
433        config: OutgoingRequestConfig,
434    ) -> Result<wasmtime_wasi_http::p2::types::HostFutureIncomingResponse, HttpError> {
435        self.otel.reparent_tracing_span();
436
437        let request_sender = RequestSender {
438            allowed_hosts: self.allowed_hosts.clone(),
439            component_tls_configs: self.component_tls_configs.clone(),
440            request_interceptor: self.request_interceptor.clone(),
441            self_request_origin: self.self_request_origin.clone(),
442            blocked_networks: self.blocked_networks.clone(),
443            http_clients: self.wasi_http_clients.clone(),
444            semaphore: self.semaphore.clone(),
445        };
446        Ok(HostFutureIncomingResponse::Pending(
447            wasmtime_wasi::runtime::spawn(
448                async {
449                    match request_sender.send(request, config).await {
450                        Ok(resp) => Ok(Ok(resp)),
451                        Err(http_error) => match http_error.downcast() {
452                            Ok(error_code) => Ok(Err(error_code)),
453                            Err(trap) => Err(trap),
454                        },
455                    }
456                }
457                .in_current_span(),
458            ),
459        ))
460    }
461}
462
463struct RequestSender {
464    allowed_hosts: OutboundAllowedHosts,
465    blocked_networks: BlockedNetworks,
466    component_tls_configs: ComponentTlsClientConfigs,
467    self_request_origin: Option<SelfRequestOrigin>,
468    request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
469    http_clients: HttpClients,
470    semaphore: ConnectionSemaphore,
471}
472
473impl RequestSender {
474    async fn send(
475        self,
476        mut request: OutgoingRequest,
477        mut config: OutgoingRequestConfig,
478    ) -> Result<IncomingResponse, HttpError> {
479        self.prepare_request(&mut request, &mut config).await?;
480
481        // If the current span has opentelemetry trace context, inject it into the request
482        spin_telemetry::inject_trace_context(&mut request);
483
484        // Run any configured request interceptor
485        let mut override_connect_addr = None;
486        if let Some(interceptor) = &self.request_interceptor {
487            let intercept_request = std::mem::take(&mut request).into();
488            match interceptor.intercept(intercept_request).await? {
489                InterceptOutcome::Continue(mut req) => {
490                    override_connect_addr = req.override_connect_addr.take();
491                    request = req.into_hyper_request();
492                }
493                InterceptOutcome::Complete(resp) => {
494                    let resp = IncomingResponse {
495                        resp,
496                        worker: None,
497                        between_bytes_timeout: config.between_bytes_timeout,
498                    };
499                    return Ok(resp);
500                }
501            }
502        }
503
504        // Backfill span fields after potentially updating the URL in the interceptor
505        let span = tracing::Span::current();
506        if let Some(addr) = override_connect_addr {
507            span.record(otel_attribute::SERVER_ADDRESS, addr.ip().to_string());
508            span.record(otel_attribute::SERVER_PORT, addr.port());
509        } else if let Some(authority) = request.uri().authority() {
510            span.record(otel_attribute::SERVER_ADDRESS, authority.host());
511            if let Some(port) = authority.port_u16() {
512                span.record(otel_attribute::SERVER_PORT, port);
513            }
514        }
515
516        record_content_length_header(
517            &span,
518            request.headers(),
519            "http.request.header.content-length",
520        );
521
522        Ok(self
523            .send_request(request, config, override_connect_addr)
524            .await?)
525    }
526
527    async fn prepare_request(
528        &self,
529        request: &mut OutgoingRequest,
530        config: &mut OutgoingRequestConfig,
531    ) -> Result<(), ErrorCode> {
532        // wasmtime-wasi-http fills in scheme and authority for relative URLs
533        // (e.g. https://:443/<path>), which makes them hard to reason about.
534        // Undo that here.
535        let uri = request.uri_mut();
536        if uri
537            .authority()
538            .is_some_and(|authority| authority.host().is_empty())
539        {
540            let mut builder = http::uri::Builder::new();
541            if let Some(paq) = uri.path_and_query() {
542                builder = builder.path_and_query(paq.clone());
543            }
544            *uri = builder.build().unwrap();
545        }
546        tracing::Span::current().record(otel_attribute::URL_FULL, uri.to_string());
547
548        let is_self_request = match request.uri().authority() {
549            // Some SDKs require an authority, so we support e.g. http://self.alt/self-request
550            Some(authority) => authority.host() == "self.alt",
551            // Otherwise self requests have no authority
552            None => true,
553        };
554
555        // Enforce allowed_outbound_hosts
556        let is_allowed = if is_self_request {
557            self.allowed_hosts
558                .check_relative_url(&["http", "https"])
559                .await
560                .unwrap_or(false)
561        } else {
562            self.allowed_hosts
563                .check_url(&request.uri().to_string(), "https")
564                .await
565                .unwrap_or(false)
566        };
567        if !is_allowed {
568            return Err(ErrorCode::HttpRequestDenied);
569        }
570
571        if is_self_request {
572            // Replace the authority with the "self request origin"
573            let Some(origin) = self.self_request_origin.as_ref() else {
574                tracing::error!(
575                    "Couldn't handle outbound HTTP request to relative URI; no origin set"
576                );
577                return Err(ErrorCode::HttpRequestUriInvalid);
578            };
579
580            config.use_tls = origin.use_tls();
581
582            request.headers_mut().insert(HOST, origin.host_header());
583
584            let path_and_query = request.uri().path_and_query().cloned();
585            *request.uri_mut() = origin.clone().into_uri(path_and_query);
586        }
587
588        // Some servers (looking at you nginx) don't like a host header even though
589        // http/2 allows it: https://github.com/hyperium/hyper/issues/3298.
590        //
591        // Note that we do this _before_ invoking the request interceptor.  It may
592        // decide to add the `host` header back in, regardless of the nginx bug, in
593        // which case we'll let it do so without interferring.
594        request.headers_mut().remove(HOST);
595        Ok(())
596    }
597
598    async fn send_request(
599        self,
600        request: OutgoingRequest,
601        config: OutgoingRequestConfig,
602        override_connect_addr: Option<SocketAddr>,
603    ) -> Result<IncomingResponse, ErrorCode> {
604        let OutgoingRequestConfig {
605            use_tls,
606            connect_timeout,
607            first_byte_timeout,
608            between_bytes_timeout,
609        } = config;
610
611        let tls_client_config = if use_tls {
612            let host = request.uri().host().unwrap_or_default();
613            Some(self.component_tls_configs.get_client_config(host).clone())
614        } else {
615            None
616        };
617
618        let resp = CONNECT_OPTIONS.scope(
619            ConnectOptions {
620                blocked_networks: self.blocked_networks,
621                connect_timeout,
622                tls_client_config,
623                override_connect_addr,
624                semaphore: self.semaphore,
625            },
626            async move {
627                if use_tls {
628                    self.http_clients.https.request(request).await
629                } else {
630                    // For development purposes, allow configuring plaintext HTTP/2 for a specific host.
631                    let h2c_prior_knowledge_host =
632                        std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
633                    let use_h2c = h2c_prior_knowledge_host.as_deref()
634                        == request.uri().authority().map(|a| a.as_str());
635
636                    if use_h2c {
637                        self.http_clients.http2.request(request).await
638                    } else {
639                        self.http_clients.http1.request(request).await
640                    }
641                }
642            },
643        );
644
645        let resp = timeout(first_byte_timeout, resp)
646            .await
647            .map_err(|_| ErrorCode::ConnectionReadTimeout)?
648            .map_err(hyper_legacy_request_error)?
649            .map(|body| body.map_err(hyper_request_error).boxed_unsync());
650
651        let span = tracing::Span::current();
652        span.record(
653            otel_attribute::HTTP_RESPONSE_STATUS_CODE,
654            resp.status().as_u16(),
655        );
656
657        record_content_length_header(&span, resp.headers(), "http.response.header.content-length");
658
659        Ok(IncomingResponse {
660            resp,
661            worker: None,
662            between_bytes_timeout,
663        })
664    }
665}
666
667type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
668type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
669
670#[derive(Clone)]
671pub(super) struct HttpClients {
672    /// Used for non-TLS HTTP/1 connections.
673    http1: HttpClient,
674    /// Used for non-TLS HTTP/2 connections (e.g. when h2 prior knowledge is available).
675    http2: HttpClient,
676    /// Used for HTTP-over-TLS connections, using ALPN to negotiate the HTTP version.
677    https: HttpsClient,
678}
679
680impl HttpClients {
681    pub(super) fn new(enable_pooling: bool) -> Self {
682        let builder = move || {
683            let mut builder = Client::builder(TokioExecutor::new());
684            if !enable_pooling {
685                builder.pool_max_idle_per_host(0);
686            }
687            builder
688        };
689        Self {
690            http1: builder().build(HttpConnector),
691            http2: builder().http2_only(true).build(HttpConnector),
692            https: builder().build(HttpsConnector),
693        }
694    }
695}
696
697tokio::task_local! {
698    /// The options used when establishing a new connection.
699    ///
700    /// We must use task-local variables for these config options when using
701    /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
702    /// them through as parameters.  Moreover, if there's already a pooled connection
703    /// ready, we'll reuse that and ignore these options anyway. After each connection
704    /// is established, the options are dropped.
705    static CONNECT_OPTIONS: ConnectOptions;
706}
707
708#[derive(Clone)]
709struct ConnectOptions {
710    /// The blocked networks configuration.
711    blocked_networks: BlockedNetworks,
712    /// Timeout for establishing a TCP connection.
713    connect_timeout: Duration,
714    /// TLS client configuration to use, if any.
715    tls_client_config: Option<TlsClientConfig>,
716    /// If set, override the address to connect to instead of using the given `uri`'s authority.
717    override_connect_addr: Option<SocketAddr>,
718    /// Semaphore to limit concurrent outbound connections.
719    semaphore: ConnectionSemaphore,
720}
721
722impl ConnectOptions {
723    /// Establish a TCP connection to the given URI and default port.
724    async fn connect_tcp(
725        &self,
726        uri: &Uri,
727        default_port: u16,
728    ) -> Result<PermittedTcpStream, ErrorCode> {
729        let mut socket_addrs = match self.override_connect_addr {
730            Some(override_connect_addr) => vec![override_connect_addr],
731            None => {
732                let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
733
734                let host_and_port = if authority.port().is_some() {
735                    authority.as_str().to_string()
736                } else {
737                    format!("{}:{}", authority.as_str(), default_port)
738                };
739
740                let socket_addrs = tokio::net::lookup_host(&host_and_port)
741                    .await
742                    .map_err(|err| {
743                        tracing::debug!(?host_and_port, ?err, "Error resolving host");
744                        dns_error("address not available".into(), 0)
745                    })?
746                    .collect::<Vec<_>>();
747                tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
748                socket_addrs
749            }
750        };
751
752        // Remove blocked IPs
753        crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;
754
755        let connect = async {
756            // If we're limiting concurrent outbound requests, acquire a permit
757            let permit = self.semaphore.acquire().await;
758            (TcpStream::connect(&*socket_addrs).await, permit)
759        };
760
761        // Make sure that the connect timeout applies to both acquiring the outbound request permit and establishing the TCP connection,
762        // since acquiring the permit could potentially take a long time if there are many outbound requests happening.
763        let (stream, permit) = timeout(self.connect_timeout, connect)
764            .await
765            .map_err(|_| ErrorCode::ConnectionTimeout)?;
766        let permit = permit.map_err(|_| ErrorCode::ConnectionLimitReached)?;
767        let stream = stream.map_err(|err| match err.kind() {
768            std::io::ErrorKind::AddrNotAvailable => dns_error("address not available".into(), 0),
769            _ => ErrorCode::ConnectionRefused,
770        })?;
771        Ok(PermittedTcpStream {
772            inner: stream,
773            _permit: permit,
774        })
775    }
776
777    /// Establish a TLS connection to the given URI and default port.
778    async fn connect_tls(
779        &self,
780        uri: &Uri,
781        default_port: u16,
782    ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
783        let tcp_stream = self.connect_tcp(uri, default_port).await?;
784
785        let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
786        tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
787
788        let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
789        let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
790            .map_err(|e| {
791                tracing::warn!("dns lookup error: {e:?}");
792                dns_error("invalid dns name".into(), 0)
793            })?
794            .to_owned();
795        connector.connect(domain, tcp_stream).await.map_err(|e| {
796            tracing::warn!("tls protocol error: {e:?}");
797            ErrorCode::TlsProtocolError
798        })
799    }
800}
801
802/// A connector the uses `ConnectOptions`
803#[derive(Clone)]
804struct HttpConnector;
805
806impl HttpConnector {
807    async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
808        let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
809        Ok(TokioIo::new(stream))
810    }
811}
812
813impl Service<Uri> for HttpConnector {
814    type Response = TokioIo<PermittedTcpStream>;
815    type Error = ErrorCode;
816    type Future =
817        Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
818
819    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
820        Poll::Ready(Ok(()))
821    }
822
823    fn call(&mut self, uri: Uri) -> Self::Future {
824        Box::pin(async move { Self::connect(uri).await })
825    }
826}
827
828/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
829#[derive(Clone)]
830struct HttpsConnector;
831
832impl HttpsConnector {
833    async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
834        let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
835        Ok(TokioIo::new(RustlsStream(stream)))
836    }
837}
838
839impl Service<Uri> for HttpsConnector {
840    type Response = TokioIo<RustlsStream>;
841    type Error = ErrorCode;
842    type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
843
844    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
845        Poll::Ready(Ok(()))
846    }
847
848    fn call(&mut self, uri: Uri) -> Self::Future {
849        Box::pin(async move { Self::connect(uri).await })
850    }
851}
852
853struct RustlsStream(TlsStream<PermittedTcpStream>);
854
855impl Connection for RustlsStream {
856    fn connected(&self) -> Connected {
857        if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
858            self.0.get_ref().0.connected().negotiated_h2()
859        } else {
860            self.0.get_ref().0.connected()
861        }
862    }
863}
864
865impl AsyncRead for RustlsStream {
866    fn poll_read(
867        self: Pin<&mut Self>,
868        cx: &mut Context<'_>,
869        buf: &mut ReadBuf<'_>,
870    ) -> Poll<Result<(), std::io::Error>> {
871        Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
872    }
873}
874
875impl AsyncWrite for RustlsStream {
876    fn poll_write(
877        self: Pin<&mut Self>,
878        cx: &mut Context<'_>,
879        buf: &[u8],
880    ) -> Poll<Result<usize, std::io::Error>> {
881        Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
882    }
883
884    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
885        Pin::new(&mut self.get_mut().0).poll_flush(cx)
886    }
887
888    fn poll_shutdown(
889        self: Pin<&mut Self>,
890        cx: &mut Context<'_>,
891    ) -> Poll<Result<(), std::io::Error>> {
892        Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
893    }
894
895    fn poll_write_vectored(
896        self: Pin<&mut Self>,
897        cx: &mut Context<'_>,
898        bufs: &[IoSlice<'_>],
899    ) -> Poll<Result<usize, std::io::Error>> {
900        Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
901    }
902
903    fn is_write_vectored(&self) -> bool {
904        self.0.is_write_vectored()
905    }
906}
907
908/// A TCP stream that holds a permit indicating that it is allowed to exist.
909struct PermittedTcpStream {
910    /// The wrapped TCP stream.
911    inner: TcpStream,
912    /// A permit indicating that this stream is allowed to exist.
913    ///
914    /// When this stream is dropped, the permit is also dropped, allowing another
915    /// connection to be established.
916    _permit: ConnectionPermit,
917}
918
919impl Connection for PermittedTcpStream {
920    fn connected(&self) -> Connected {
921        self.inner.connected()
922    }
923}
924
925impl AsyncRead for PermittedTcpStream {
926    fn poll_read(
927        self: Pin<&mut Self>,
928        cx: &mut Context<'_>,
929        buf: &mut ReadBuf<'_>,
930    ) -> Poll<std::io::Result<()>> {
931        Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
932    }
933}
934
935impl AsyncWrite for PermittedTcpStream {
936    fn poll_write(
937        self: Pin<&mut Self>,
938        cx: &mut Context<'_>,
939        buf: &[u8],
940    ) -> Poll<Result<usize, std::io::Error>> {
941        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
942    }
943
944    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
945        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
946    }
947
948    fn poll_shutdown(
949        self: Pin<&mut Self>,
950        cx: &mut Context<'_>,
951    ) -> Poll<Result<(), std::io::Error>> {
952        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
953    }
954}
955
956/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
957fn hyper_request_error(err: hyper::Error) -> ErrorCode {
958    // If there's a source, we might be able to extract a wasi-http error from it.
959    if let Some(cause) = err.source()
960        && let Some(err) = cause.downcast_ref::<ErrorCode>()
961    {
962        return err.clone();
963    }
964
965    tracing::warn!("hyper request error: {err:?}");
966
967    ErrorCode::HttpProtocolError
968}
969
970/// Translate a [`hyper_util::client::legacy::Error`] to a wasi-http `ErrorCode` in the context of a request.
971fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
972    // If there's a source, we might be able to extract a wasi-http error from it.
973    if let Some(cause) = err.source()
974        && let Some(err) = cause.downcast_ref::<ErrorCode>()
975    {
976        return err.clone();
977    }
978
979    tracing::warn!("hyper request error: {err:?}");
980
981    ErrorCode::HttpProtocolError
982}
983
984fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
985    ErrorCode::DnsError(
986        wasmtime_wasi_http::p2::bindings::http::types::DnsErrorPayload {
987            rcode: Some(rcode),
988            info_code: Some(info_code),
989        },
990    )
991}
992
993// TODO: Remove this (and uses of it) once
994// https://github.com/spinframework/spin/issues/3274 has been addressed.
995pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
996    match code {
997        p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
998        p2_types::ErrorCode::DnsError(payload) => {
999            p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
1000                rcode: payload.rcode,
1001                info_code: payload.info_code,
1002            })
1003        }
1004        p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
1005        p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
1006        p2_types::ErrorCode::DestinationIpProhibited => {
1007            p3_types::ErrorCode::DestinationIpProhibited
1008        }
1009        p2_types::ErrorCode::DestinationIpUnroutable => {
1010            p3_types::ErrorCode::DestinationIpUnroutable
1011        }
1012        p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
1013        p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
1014        p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
1015        p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
1016        p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
1017        p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
1018        p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
1019        p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
1020        p2_types::ErrorCode::TlsAlertReceived(payload) => {
1021            p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
1022                alert_id: payload.alert_id,
1023                alert_message: payload.alert_message,
1024            })
1025        }
1026        p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
1027        p2_types::ErrorCode::HttpRequestLengthRequired => {
1028            p3_types::ErrorCode::HttpRequestLengthRequired
1029        }
1030        p2_types::ErrorCode::HttpRequestBodySize(payload) => {
1031            p3_types::ErrorCode::HttpRequestBodySize(payload)
1032        }
1033        p2_types::ErrorCode::HttpRequestMethodInvalid => {
1034            p3_types::ErrorCode::HttpRequestMethodInvalid
1035        }
1036        p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
1037        p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
1038        p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
1039            p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
1040        }
1041        p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
1042            p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
1043                p3_types::FieldSizePayload {
1044                    field_name: payload.field_name,
1045                    field_size: payload.field_size,
1046                }
1047            }))
1048        }
1049        p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
1050            p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
1051        }
1052        p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
1053            p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
1054                field_name: payload.field_name,
1055                field_size: payload.field_size,
1056            })
1057        }
1058        p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
1059        p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1060            p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1061        }
1062        p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1063            p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
1064                field_name: payload.field_name,
1065                field_size: payload.field_size,
1066            })
1067        }
1068        p2_types::ErrorCode::HttpResponseBodySize(payload) => {
1069            p3_types::ErrorCode::HttpResponseBodySize(payload)
1070        }
1071        p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1072            p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1073        }
1074        p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1075            p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
1076                field_name: payload.field_name,
1077                field_size: payload.field_size,
1078            })
1079        }
1080        p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1081            p3_types::ErrorCode::HttpResponseTransferCoding(payload)
1082        }
1083        p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
1084            p3_types::ErrorCode::HttpResponseContentCoding(payload)
1085        }
1086        p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
1087        p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
1088        p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
1089        p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
1090        p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
1091        p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
1092    }
1093}
1094
1095// TODO: Remove this (and uses of it) once
1096// https://github.com/spinframework/spin/issues/3274 has been addressed.
1097pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
1098    match code {
1099        p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
1100        p3_types::ErrorCode::DnsError(payload) => {
1101            p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
1102                rcode: payload.rcode,
1103                info_code: payload.info_code,
1104            })
1105        }
1106        p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
1107        p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
1108        p3_types::ErrorCode::DestinationIpProhibited => {
1109            p2_types::ErrorCode::DestinationIpProhibited
1110        }
1111        p3_types::ErrorCode::DestinationIpUnroutable => {
1112            p2_types::ErrorCode::DestinationIpUnroutable
1113        }
1114        p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
1115        p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
1116        p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
1117        p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
1118        p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
1119        p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
1120        p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
1121        p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
1122        p3_types::ErrorCode::TlsAlertReceived(payload) => {
1123            p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
1124                alert_id: payload.alert_id,
1125                alert_message: payload.alert_message,
1126            })
1127        }
1128        p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
1129        p3_types::ErrorCode::HttpRequestLengthRequired => {
1130            p2_types::ErrorCode::HttpRequestLengthRequired
1131        }
1132        p3_types::ErrorCode::HttpRequestBodySize(payload) => {
1133            p2_types::ErrorCode::HttpRequestBodySize(payload)
1134        }
1135        p3_types::ErrorCode::HttpRequestMethodInvalid => {
1136            p2_types::ErrorCode::HttpRequestMethodInvalid
1137        }
1138        p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
1139        p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
1140        p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
1141            p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
1142        }
1143        p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
1144            p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
1145                p2_types::FieldSizePayload {
1146                    field_name: payload.field_name,
1147                    field_size: payload.field_size,
1148                }
1149            }))
1150        }
1151        p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
1152            p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
1153        }
1154        p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
1155            p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
1156                field_name: payload.field_name,
1157                field_size: payload.field_size,
1158            })
1159        }
1160        p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
1161        p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1162            p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1163        }
1164        p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1165            p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1166                field_name: payload.field_name,
1167                field_size: payload.field_size,
1168            })
1169        }
1170        p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1171            p2_types::ErrorCode::HttpResponseBodySize(payload)
1172        }
1173        p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1174            p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1175        }
1176        p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1177            p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1178                field_name: payload.field_name,
1179                field_size: payload.field_size,
1180            })
1181        }
1182        p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1183            p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1184        }
1185        p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1186            p2_types::ErrorCode::HttpResponseContentCoding(payload)
1187        }
1188        p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1189        p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1190        p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1191        p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1192        p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1193        p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1194    }
1195}
1196
1197fn record_content_length_header(span: &Span, headers: &HeaderMap, attr_name: &'static str) {
1198    if let Some(content_length) = headers.get(CONTENT_LENGTH)
1199        && let Ok(size_str) = content_length.to_str()
1200    {
1201        span.set_attribute(attr_name, size_str.to_string());
1202    }
1203}
1204
1205#[cfg(test)]
1206mod tests {
1207    use super::*;
1208
1209    /// Regression test: the connect timeout must cover permit acquisition, not
1210    /// just the TCP handshake.  Before the fix, a fully-saturated semaphore
1211    /// caused `connect_tcp` to hang indefinitely; after the fix it returns
1212    /// `ConnectionTimeout` within the configured deadline.
1213    #[tokio::test]
1214    async fn connect_timeout_applies_to_permit_acquisition() {
1215        // Create a semaphore with exactly 1 permit and immediately exhaust it, leaving
1216        // 0 permits available.  This simulates all outbound-connection slots being occupied.
1217        let conn_semaphore = ConnectionSemaphore::new(None, Some(1), "test", None);
1218        let _held = conn_semaphore
1219            .try_acquire()
1220            .expect("exhausting the single permit");
1221
1222        let options = ConnectOptions {
1223            // No blocked networks; we want the address to pass the filter.
1224            blocked_networks: BlockedNetworks::default(),
1225            // A very short deadline so the test runs quickly.
1226            connect_timeout: Duration::from_millis(50),
1227            tls_client_config: None,
1228            // Skip DNS by supplying the address directly.
1229            override_connect_addr: Some("127.0.0.1:1".parse().unwrap()),
1230            semaphore: conn_semaphore,
1231        };
1232
1233        // `connect_tcp` must time out while waiting for a permit rather than
1234        // blocking forever.
1235        let result = options
1236            .connect_tcp(&Uri::from_static("http://test.example"), 80)
1237            .await;
1238
1239        assert!(
1240            matches!(result, Err(ErrorCode::ConnectionTimeout)),
1241            "expected ConnectionTimeout"
1242        );
1243    }
1244}