1use std::{
2 error::Error,
3 future::Future,
4 io::IoSlice,
5 net::SocketAddr,
6 ops::DerefMut,
7 pin::Pin,
8 sync::{Arc, Mutex},
9 task::{self, Context, Poll},
10 time::Duration,
11};
12
13use bytes::{Buf, Bytes};
14use futures::channel::oneshot;
15use http::{
16 HeaderMap, Uri,
17 header::{CONTENT_LENGTH, HOST},
18 uri::Scheme,
19};
20use http_body::{Body, Frame, SizeHint};
21use http_body_util::{BodyExt, combinators::UnsyncBoxBody};
22use hyper_util::{
23 client::legacy::{
24 Client,
25 connect::{Connected, Connection},
26 },
27 rt::{TokioExecutor, TokioIo},
28};
29use opentelemetry_semantic_conventions::attribute as otel_attribute;
30use spin_factor_outbound_networking::{
31 ComponentTlsClientConfigs, TlsClientConfig,
32 config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
33};
34use spin_factors::RuntimeFactorsInstanceState;
35use tokio::{
36 io::{AsyncRead, AsyncWrite, ReadBuf},
37 net::TcpStream,
38 time::timeout,
39};
40use tokio_rustls::client::TlsStream;
41use tower_service::Service;
42use tracing::{Instrument, Span, field::Empty, instrument};
43use wasmtime::component::HasData;
44use wasmtime_wasi::TrappableError;
45use wasmtime_wasi_http::{
46 p2::{
47 self, HttpError, WasiHttpCtxView,
48 bindings::http::types::{self as p2_types, ErrorCode},
49 body::HyperOutgoingBody,
50 types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
51 },
52 p3::{self, bindings::http::types as p3_types},
53};
54
55use spin_factor_outbound_networking::{ConnectionPermit, ConnectionSemaphore};
56
57use crate::{
58 InstanceHttpHooks, OutboundHttpFactor, SelfRequestOrigin,
59 intercept::{InterceptOutcome, OutboundHttpInterceptor},
60 wasi_2023_10_18, wasi_2023_11_10,
61};
62
63use tracing_opentelemetry::OpenTelemetrySpanExt as _;
64
65const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
66
67pub struct MutexBody<T>(Mutex<T>);
68
69impl<T> MutexBody<T> {
70 pub fn new(body: T) -> Self {
71 Self(Mutex::new(body))
72 }
73}
74
75impl<T: Body + Unpin> Body for MutexBody<T> {
76 type Data = T::Data;
77 type Error = T::Error;
78
79 fn poll_frame(
80 self: Pin<&mut Self>,
81 cx: &mut Context<'_>,
82 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
83 Pin::new(self.0.lock().unwrap().deref_mut()).poll_frame(cx)
84 }
85
86 fn is_end_stream(&self) -> bool {
87 self.0.lock().unwrap().is_end_stream()
88 }
89
90 fn size_hint(&self) -> SizeHint {
91 self.0.lock().unwrap().size_hint()
92 }
93}
94
95pub struct NotifyOnDropBody<B> {
98 body: B,
99 _tx: oneshot::Sender<()>,
100}
101
102impl<B> NotifyOnDropBody<B> {
103 pub fn new(body: B, tx: oneshot::Sender<()>) -> Self {
104 Self { body, _tx: tx }
105 }
106}
107
108impl<B: Body + Unpin> Body for NotifyOnDropBody<B> {
109 type Data = B::Data;
110 type Error = B::Error;
111
112 fn poll_frame(
113 mut self: Pin<&mut Self>,
114 cx: &mut Context<'_>,
115 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
116 Pin::new(&mut self.body).poll_frame(cx)
117 }
118
119 fn is_end_stream(&self) -> bool {
120 self.body.is_end_stream()
121 }
122
123 fn size_hint(&self) -> SizeHint {
124 self.body.size_hint()
125 }
126}
127
128pub(crate) struct HasHttp;
129
130impl HasData for HasHttp {
131 type Data<'a> = WasiHttpCtxView<'a>;
132}
133
134impl p3::WasiHttpHooks for InstanceHttpHooks {
135 #[instrument(
136 name = "spin_outbound_http.send_request",
137 skip_all,
138 fields(
139 otel.kind = "client",
140 {otel_attribute::URL_FULL} = Empty,
141 {otel_attribute::HTTP_REQUEST_METHOD} = %request.method(),
142 otel.name = %request.method(),
143 http.response.body.size = Empty,
145 {otel_attribute::HTTP_RESPONSE_STATUS_CODE} = Empty,
146 {otel_attribute::SERVER_ADDRESS} = Empty,
147 {otel_attribute::SERVER_PORT} = Empty,
148 )
149 )]
150 #[allow(clippy::type_complexity)]
151 fn send_request(
152 &mut self,
153 request: http::Request<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
154 options: Option<p3::RequestOptions>,
155 fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
156 ) -> Box<
157 dyn Future<
158 Output = Result<
159 (
160 http::Response<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
161 Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
162 ),
163 TrappableError<p3_types::ErrorCode>,
164 >,
165 > + Send,
166 > {
167 self.otel.reparent_tracing_span();
168
169 _ = fut;
178
179 let request = request.map(|body| MutexBody::new(body).boxed());
180
181 let request_sender = RequestSender {
182 allowed_hosts: self.allowed_hosts.clone(),
183 component_tls_configs: self.component_tls_configs.clone(),
184 request_interceptor: self.request_interceptor.clone(),
185 self_request_origin: self.self_request_origin.clone(),
186 blocked_networks: self.blocked_networks.clone(),
187 http_clients: self.wasi_http_clients.clone(),
188 semaphore: self.semaphore.clone(),
189 };
190 let config = OutgoingRequestConfig {
191 use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
192 connect_timeout: options
193 .and_then(|v| v.connect_timeout)
194 .unwrap_or(DEFAULT_TIMEOUT),
195 first_byte_timeout: options
196 .and_then(|v| v.first_byte_timeout)
197 .unwrap_or(DEFAULT_TIMEOUT),
198 between_bytes_timeout: options
199 .and_then(|v| v.between_bytes_timeout)
200 .unwrap_or(DEFAULT_TIMEOUT),
201 };
202 Box::new(
203 async {
204 match request_sender
205 .send(
206 request.map(|body| body.map_err(p3_to_p2_error_code).boxed_unsync()),
207 config,
208 )
209 .await
210 {
211 Ok(IncomingResponse {
212 resp,
213 between_bytes_timeout,
214 ..
215 }) => Ok((
216 resp.map(|body| {
217 BetweenBytesTimeoutBody {
218 body: Some(body),
219 sleep: None,
220 timeout: between_bytes_timeout,
221 byte_count: 0,
222 span: Some(Span::current()),
223 }
224 .boxed_unsync()
225 }),
226 Box::new(async {
227 Ok(())
231 }) as Box<dyn Future<Output = _> + Send>,
232 )),
233 Err(http_error) => match http_error.downcast() {
234 Ok(error_code) => {
235 Err(TrappableError::from(p2_to_p3_error_code(error_code)))
236 }
237 Err(trap) => Err(TrappableError::trap(trap)),
238 },
239 }
240 }
241 .in_current_span(),
242 )
243 }
244}
245
246pin_project_lite::pin_project! {
247 struct BetweenBytesTimeoutBody<B> {
248 body: Option<B>,
250 #[pin]
251 sleep: Option<tokio::time::Sleep>,
252 timeout: Duration,
253 byte_count: u64,
254 span: Option<Span>,
255 }
256}
257
258impl<B: Body<Error = p2_types::ErrorCode> + Unpin> Body for BetweenBytesTimeoutBody<B> {
259 type Data = B::Data;
260 type Error = p3_types::ErrorCode;
261
262 fn poll_frame(
263 self: Pin<&mut Self>,
264 cx: &mut Context<'_>,
265 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
266 let mut me = self.project();
267
268 let poll_result = {
270 let Some(body) = me.body.as_mut() else {
271 return Poll::Ready(None);
273 };
274 Pin::new(body).poll_frame(cx)
275 };
276
277 let mut record_body_size_once = |body_size: u64| {
278 if let Some(span) = me.span.take() {
279 span.record("http.response.body.size", body_size);
283 }
284 };
285 match poll_result {
286 Poll::Ready(value) => {
287 me.sleep.as_mut().set(None);
288
289 match &value {
290 Some(Ok(frame)) => {
291 if let Some(data) = frame.data_ref() {
292 *me.byte_count += data.remaining() as u64;
293 }
294 if me.body.as_ref().is_some_and(|b| b.is_end_stream()) {
295 record_body_size_once(*me.byte_count);
296 }
297 }
298 None => {
299 record_body_size_once(*me.byte_count);
300 }
301 Some(Err(e)) => {
302 tracing::warn!("error reading response body: {e:?}");
303 }
304 }
305
306 Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
307 }
308 Poll::Pending => {
309 if me.sleep.is_none() {
310 me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
311 }
312 task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
313
314 *me.body = None;
317 record_body_size_once(*me.byte_count);
318
319 Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
320 }
321 }
322 }
323
324 fn is_end_stream(&self) -> bool {
325 self.body.as_ref().is_none_or(|b| b.is_end_stream())
326 }
327
328 fn size_hint(&self) -> SizeHint {
329 self.body
330 .as_ref()
331 .map(|b| b.size_hint())
332 .unwrap_or_default()
333 }
334}
335
336pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
337where
338 C: spin_factors::InitContext<OutboundHttpFactor>,
339{
340 let linker = ctx.linker();
341
342 fn get_http<C>(store: &mut C::StoreData) -> WasiHttpCtxView<'_>
343 where
344 C: spin_factors::InitContext<OutboundHttpFactor>,
345 {
346 let (state, table) = C::get_data_with_table(store);
347 let ctx = &mut state.wasi_http_ctx;
348 WasiHttpCtxView {
349 ctx,
350 table,
351 hooks: &mut state.hooks,
352 }
353 }
354
355 let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpCtxView<'_>;
356 wasmtime_wasi_http::p2::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
357 linker, get_http,
358 )?;
359 wasmtime_wasi_http::p2::bindings::http::types::add_to_linker::<_, HasHttp>(
360 linker,
361 &Default::default(),
362 get_http,
363 )?;
364
365 fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
366 where
367 C: spin_factors::InitContext<OutboundHttpFactor>,
368 {
369 let (state, table) = C::get_data_with_table(store);
370 let ctx = &mut state.wasi_http_ctx;
371 p3::WasiHttpCtxView {
372 ctx,
373 table,
374 hooks: &mut state.hooks,
375 }
376 }
377
378 let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
379 p3::bindings::http::client::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
380 p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
381
382 wasi_2023_10_18::add_to_linker(linker, get_http)?;
383 wasi_2023_11_10::add_to_linker(linker, get_http)?;
384
385 Ok(())
386}
387
388impl OutboundHttpFactor {
389 pub fn get_wasi_http_impl(
390 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
391 ) -> Option<WasiHttpCtxView<'_>> {
392 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
393 let ctx = &mut state.wasi_http_ctx;
394 Some(WasiHttpCtxView {
395 ctx,
396 table,
397 hooks: &mut state.hooks,
398 })
399 }
400
401 pub fn get_wasi_p3_http_impl(
402 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
403 ) -> Option<p3::WasiHttpCtxView<'_>> {
404 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
405 let ctx = &mut state.wasi_http_ctx;
406 Some(p3::WasiHttpCtxView {
407 ctx,
408 table,
409 hooks: &mut state.hooks,
410 })
411 }
412}
413
414type OutgoingRequest = http::Request<HyperOutgoingBody>;
415
416impl p2::WasiHttpHooks for InstanceHttpHooks {
417 #[instrument(
418 name = "spin_outbound_http.send_request",
419 skip_all,
420 fields(
421 otel.kind = "client",
422 {otel_attribute::URL_FULL} = Empty,
423 {otel_attribute::HTTP_REQUEST_METHOD} = %request.method(),
424 otel.name = %request.method(),
425 {otel_attribute::HTTP_RESPONSE_STATUS_CODE} = Empty,
426 {otel_attribute::SERVER_ADDRESS} = Empty,
427 {otel_attribute::SERVER_PORT} = Empty,
428 )
429 )]
430 fn send_request(
431 &mut self,
432 request: OutgoingRequest,
433 config: OutgoingRequestConfig,
434 ) -> Result<wasmtime_wasi_http::p2::types::HostFutureIncomingResponse, HttpError> {
435 self.otel.reparent_tracing_span();
436
437 let request_sender = RequestSender {
438 allowed_hosts: self.allowed_hosts.clone(),
439 component_tls_configs: self.component_tls_configs.clone(),
440 request_interceptor: self.request_interceptor.clone(),
441 self_request_origin: self.self_request_origin.clone(),
442 blocked_networks: self.blocked_networks.clone(),
443 http_clients: self.wasi_http_clients.clone(),
444 semaphore: self.semaphore.clone(),
445 };
446 Ok(HostFutureIncomingResponse::Pending(
447 wasmtime_wasi::runtime::spawn(
448 async {
449 match request_sender.send(request, config).await {
450 Ok(resp) => Ok(Ok(resp)),
451 Err(http_error) => match http_error.downcast() {
452 Ok(error_code) => Ok(Err(error_code)),
453 Err(trap) => Err(trap),
454 },
455 }
456 }
457 .in_current_span(),
458 ),
459 ))
460 }
461}
462
463struct RequestSender {
464 allowed_hosts: OutboundAllowedHosts,
465 blocked_networks: BlockedNetworks,
466 component_tls_configs: ComponentTlsClientConfigs,
467 self_request_origin: Option<SelfRequestOrigin>,
468 request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
469 http_clients: HttpClients,
470 semaphore: ConnectionSemaphore,
471}
472
473impl RequestSender {
474 async fn send(
475 self,
476 mut request: OutgoingRequest,
477 mut config: OutgoingRequestConfig,
478 ) -> Result<IncomingResponse, HttpError> {
479 self.prepare_request(&mut request, &mut config).await?;
480
481 spin_telemetry::inject_trace_context(&mut request);
483
484 let mut override_connect_addr = None;
486 if let Some(interceptor) = &self.request_interceptor {
487 let intercept_request = std::mem::take(&mut request).into();
488 match interceptor.intercept(intercept_request).await? {
489 InterceptOutcome::Continue(mut req) => {
490 override_connect_addr = req.override_connect_addr.take();
491 request = req.into_hyper_request();
492 }
493 InterceptOutcome::Complete(resp) => {
494 let resp = IncomingResponse {
495 resp,
496 worker: None,
497 between_bytes_timeout: config.between_bytes_timeout,
498 };
499 return Ok(resp);
500 }
501 }
502 }
503
504 let span = tracing::Span::current();
506 if let Some(addr) = override_connect_addr {
507 span.record(otel_attribute::SERVER_ADDRESS, addr.ip().to_string());
508 span.record(otel_attribute::SERVER_PORT, addr.port());
509 } else if let Some(authority) = request.uri().authority() {
510 span.record(otel_attribute::SERVER_ADDRESS, authority.host());
511 if let Some(port) = authority.port_u16() {
512 span.record(otel_attribute::SERVER_PORT, port);
513 }
514 }
515
516 record_content_length_header(
517 &span,
518 request.headers(),
519 "http.request.header.content-length",
520 );
521
522 Ok(self
523 .send_request(request, config, override_connect_addr)
524 .await?)
525 }
526
527 async fn prepare_request(
528 &self,
529 request: &mut OutgoingRequest,
530 config: &mut OutgoingRequestConfig,
531 ) -> Result<(), ErrorCode> {
532 let uri = request.uri_mut();
536 if uri
537 .authority()
538 .is_some_and(|authority| authority.host().is_empty())
539 {
540 let mut builder = http::uri::Builder::new();
541 if let Some(paq) = uri.path_and_query() {
542 builder = builder.path_and_query(paq.clone());
543 }
544 *uri = builder.build().unwrap();
545 }
546 tracing::Span::current().record(otel_attribute::URL_FULL, uri.to_string());
547
548 let is_self_request = match request.uri().authority() {
549 Some(authority) => authority.host() == "self.alt",
551 None => true,
553 };
554
555 let is_allowed = if is_self_request {
557 self.allowed_hosts
558 .check_relative_url(&["http", "https"])
559 .await
560 .unwrap_or(false)
561 } else {
562 self.allowed_hosts
563 .check_url(&request.uri().to_string(), "https")
564 .await
565 .unwrap_or(false)
566 };
567 if !is_allowed {
568 return Err(ErrorCode::HttpRequestDenied);
569 }
570
571 if is_self_request {
572 let Some(origin) = self.self_request_origin.as_ref() else {
574 tracing::error!(
575 "Couldn't handle outbound HTTP request to relative URI; no origin set"
576 );
577 return Err(ErrorCode::HttpRequestUriInvalid);
578 };
579
580 config.use_tls = origin.use_tls();
581
582 request.headers_mut().insert(HOST, origin.host_header());
583
584 let path_and_query = request.uri().path_and_query().cloned();
585 *request.uri_mut() = origin.clone().into_uri(path_and_query);
586 }
587
588 request.headers_mut().remove(HOST);
595 Ok(())
596 }
597
598 async fn send_request(
599 self,
600 request: OutgoingRequest,
601 config: OutgoingRequestConfig,
602 override_connect_addr: Option<SocketAddr>,
603 ) -> Result<IncomingResponse, ErrorCode> {
604 let OutgoingRequestConfig {
605 use_tls,
606 connect_timeout,
607 first_byte_timeout,
608 between_bytes_timeout,
609 } = config;
610
611 let tls_client_config = if use_tls {
612 let host = request.uri().host().unwrap_or_default();
613 Some(self.component_tls_configs.get_client_config(host).clone())
614 } else {
615 None
616 };
617
618 let resp = CONNECT_OPTIONS.scope(
619 ConnectOptions {
620 blocked_networks: self.blocked_networks,
621 connect_timeout,
622 tls_client_config,
623 override_connect_addr,
624 semaphore: self.semaphore,
625 },
626 async move {
627 if use_tls {
628 self.http_clients.https.request(request).await
629 } else {
630 let h2c_prior_knowledge_host =
632 std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
633 let use_h2c = h2c_prior_knowledge_host.as_deref()
634 == request.uri().authority().map(|a| a.as_str());
635
636 if use_h2c {
637 self.http_clients.http2.request(request).await
638 } else {
639 self.http_clients.http1.request(request).await
640 }
641 }
642 },
643 );
644
645 let resp = timeout(first_byte_timeout, resp)
646 .await
647 .map_err(|_| ErrorCode::ConnectionReadTimeout)?
648 .map_err(hyper_legacy_request_error)?
649 .map(|body| body.map_err(hyper_request_error).boxed_unsync());
650
651 let span = tracing::Span::current();
652 span.record(
653 otel_attribute::HTTP_RESPONSE_STATUS_CODE,
654 resp.status().as_u16(),
655 );
656
657 record_content_length_header(&span, resp.headers(), "http.response.header.content-length");
658
659 Ok(IncomingResponse {
660 resp,
661 worker: None,
662 between_bytes_timeout,
663 })
664 }
665}
666
667type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
668type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
669
670#[derive(Clone)]
671pub(super) struct HttpClients {
672 http1: HttpClient,
674 http2: HttpClient,
676 https: HttpsClient,
678}
679
680impl HttpClients {
681 pub(super) fn new(enable_pooling: bool) -> Self {
682 let builder = move || {
683 let mut builder = Client::builder(TokioExecutor::new());
684 if !enable_pooling {
685 builder.pool_max_idle_per_host(0);
686 }
687 builder
688 };
689 Self {
690 http1: builder().build(HttpConnector),
691 http2: builder().http2_only(true).build(HttpConnector),
692 https: builder().build(HttpsConnector),
693 }
694 }
695}
696
697tokio::task_local! {
698 static CONNECT_OPTIONS: ConnectOptions;
706}
707
708#[derive(Clone)]
709struct ConnectOptions {
710 blocked_networks: BlockedNetworks,
712 connect_timeout: Duration,
714 tls_client_config: Option<TlsClientConfig>,
716 override_connect_addr: Option<SocketAddr>,
718 semaphore: ConnectionSemaphore,
720}
721
722impl ConnectOptions {
723 async fn connect_tcp(
725 &self,
726 uri: &Uri,
727 default_port: u16,
728 ) -> Result<PermittedTcpStream, ErrorCode> {
729 let mut socket_addrs = match self.override_connect_addr {
730 Some(override_connect_addr) => vec![override_connect_addr],
731 None => {
732 let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
733
734 let host_and_port = if authority.port().is_some() {
735 authority.as_str().to_string()
736 } else {
737 format!("{}:{}", authority.as_str(), default_port)
738 };
739
740 let socket_addrs = tokio::net::lookup_host(&host_and_port)
741 .await
742 .map_err(|err| {
743 tracing::debug!(?host_and_port, ?err, "Error resolving host");
744 dns_error("address not available".into(), 0)
745 })?
746 .collect::<Vec<_>>();
747 tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
748 socket_addrs
749 }
750 };
751
752 crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;
754
755 let connect = async {
756 let permit = self.semaphore.acquire().await;
758 (TcpStream::connect(&*socket_addrs).await, permit)
759 };
760
761 let (stream, permit) = timeout(self.connect_timeout, connect)
764 .await
765 .map_err(|_| ErrorCode::ConnectionTimeout)?;
766 let permit = permit.map_err(|_| ErrorCode::ConnectionLimitReached)?;
767 let stream = stream.map_err(|err| match err.kind() {
768 std::io::ErrorKind::AddrNotAvailable => dns_error("address not available".into(), 0),
769 _ => ErrorCode::ConnectionRefused,
770 })?;
771 Ok(PermittedTcpStream {
772 inner: stream,
773 _permit: permit,
774 })
775 }
776
777 async fn connect_tls(
779 &self,
780 uri: &Uri,
781 default_port: u16,
782 ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
783 let tcp_stream = self.connect_tcp(uri, default_port).await?;
784
785 let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
786 tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
787
788 let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
789 let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
790 .map_err(|e| {
791 tracing::warn!("dns lookup error: {e:?}");
792 dns_error("invalid dns name".into(), 0)
793 })?
794 .to_owned();
795 connector.connect(domain, tcp_stream).await.map_err(|e| {
796 tracing::warn!("tls protocol error: {e:?}");
797 ErrorCode::TlsProtocolError
798 })
799 }
800}
801
802#[derive(Clone)]
804struct HttpConnector;
805
806impl HttpConnector {
807 async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
808 let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
809 Ok(TokioIo::new(stream))
810 }
811}
812
813impl Service<Uri> for HttpConnector {
814 type Response = TokioIo<PermittedTcpStream>;
815 type Error = ErrorCode;
816 type Future =
817 Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
818
819 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
820 Poll::Ready(Ok(()))
821 }
822
823 fn call(&mut self, uri: Uri) -> Self::Future {
824 Box::pin(async move { Self::connect(uri).await })
825 }
826}
827
828#[derive(Clone)]
830struct HttpsConnector;
831
832impl HttpsConnector {
833 async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
834 let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
835 Ok(TokioIo::new(RustlsStream(stream)))
836 }
837}
838
839impl Service<Uri> for HttpsConnector {
840 type Response = TokioIo<RustlsStream>;
841 type Error = ErrorCode;
842 type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
843
844 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
845 Poll::Ready(Ok(()))
846 }
847
848 fn call(&mut self, uri: Uri) -> Self::Future {
849 Box::pin(async move { Self::connect(uri).await })
850 }
851}
852
853struct RustlsStream(TlsStream<PermittedTcpStream>);
854
855impl Connection for RustlsStream {
856 fn connected(&self) -> Connected {
857 if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
858 self.0.get_ref().0.connected().negotiated_h2()
859 } else {
860 self.0.get_ref().0.connected()
861 }
862 }
863}
864
865impl AsyncRead for RustlsStream {
866 fn poll_read(
867 self: Pin<&mut Self>,
868 cx: &mut Context<'_>,
869 buf: &mut ReadBuf<'_>,
870 ) -> Poll<Result<(), std::io::Error>> {
871 Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
872 }
873}
874
875impl AsyncWrite for RustlsStream {
876 fn poll_write(
877 self: Pin<&mut Self>,
878 cx: &mut Context<'_>,
879 buf: &[u8],
880 ) -> Poll<Result<usize, std::io::Error>> {
881 Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
882 }
883
884 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
885 Pin::new(&mut self.get_mut().0).poll_flush(cx)
886 }
887
888 fn poll_shutdown(
889 self: Pin<&mut Self>,
890 cx: &mut Context<'_>,
891 ) -> Poll<Result<(), std::io::Error>> {
892 Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
893 }
894
895 fn poll_write_vectored(
896 self: Pin<&mut Self>,
897 cx: &mut Context<'_>,
898 bufs: &[IoSlice<'_>],
899 ) -> Poll<Result<usize, std::io::Error>> {
900 Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
901 }
902
903 fn is_write_vectored(&self) -> bool {
904 self.0.is_write_vectored()
905 }
906}
907
908struct PermittedTcpStream {
910 inner: TcpStream,
912 _permit: ConnectionPermit,
917}
918
919impl Connection for PermittedTcpStream {
920 fn connected(&self) -> Connected {
921 self.inner.connected()
922 }
923}
924
925impl AsyncRead for PermittedTcpStream {
926 fn poll_read(
927 self: Pin<&mut Self>,
928 cx: &mut Context<'_>,
929 buf: &mut ReadBuf<'_>,
930 ) -> Poll<std::io::Result<()>> {
931 Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
932 }
933}
934
935impl AsyncWrite for PermittedTcpStream {
936 fn poll_write(
937 self: Pin<&mut Self>,
938 cx: &mut Context<'_>,
939 buf: &[u8],
940 ) -> Poll<Result<usize, std::io::Error>> {
941 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
942 }
943
944 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
945 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
946 }
947
948 fn poll_shutdown(
949 self: Pin<&mut Self>,
950 cx: &mut Context<'_>,
951 ) -> Poll<Result<(), std::io::Error>> {
952 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
953 }
954}
955
956fn hyper_request_error(err: hyper::Error) -> ErrorCode {
958 if let Some(cause) = err.source()
960 && let Some(err) = cause.downcast_ref::<ErrorCode>()
961 {
962 return err.clone();
963 }
964
965 tracing::warn!("hyper request error: {err:?}");
966
967 ErrorCode::HttpProtocolError
968}
969
970fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
972 if let Some(cause) = err.source()
974 && let Some(err) = cause.downcast_ref::<ErrorCode>()
975 {
976 return err.clone();
977 }
978
979 tracing::warn!("hyper request error: {err:?}");
980
981 ErrorCode::HttpProtocolError
982}
983
984fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
985 ErrorCode::DnsError(
986 wasmtime_wasi_http::p2::bindings::http::types::DnsErrorPayload {
987 rcode: Some(rcode),
988 info_code: Some(info_code),
989 },
990 )
991}
992
993pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
996 match code {
997 p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
998 p2_types::ErrorCode::DnsError(payload) => {
999 p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
1000 rcode: payload.rcode,
1001 info_code: payload.info_code,
1002 })
1003 }
1004 p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
1005 p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
1006 p2_types::ErrorCode::DestinationIpProhibited => {
1007 p3_types::ErrorCode::DestinationIpProhibited
1008 }
1009 p2_types::ErrorCode::DestinationIpUnroutable => {
1010 p3_types::ErrorCode::DestinationIpUnroutable
1011 }
1012 p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
1013 p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
1014 p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
1015 p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
1016 p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
1017 p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
1018 p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
1019 p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
1020 p2_types::ErrorCode::TlsAlertReceived(payload) => {
1021 p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
1022 alert_id: payload.alert_id,
1023 alert_message: payload.alert_message,
1024 })
1025 }
1026 p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
1027 p2_types::ErrorCode::HttpRequestLengthRequired => {
1028 p3_types::ErrorCode::HttpRequestLengthRequired
1029 }
1030 p2_types::ErrorCode::HttpRequestBodySize(payload) => {
1031 p3_types::ErrorCode::HttpRequestBodySize(payload)
1032 }
1033 p2_types::ErrorCode::HttpRequestMethodInvalid => {
1034 p3_types::ErrorCode::HttpRequestMethodInvalid
1035 }
1036 p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
1037 p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
1038 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
1039 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
1040 }
1041 p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
1042 p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
1043 p3_types::FieldSizePayload {
1044 field_name: payload.field_name,
1045 field_size: payload.field_size,
1046 }
1047 }))
1048 }
1049 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
1050 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
1051 }
1052 p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
1053 p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
1054 field_name: payload.field_name,
1055 field_size: payload.field_size,
1056 })
1057 }
1058 p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
1059 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1060 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1061 }
1062 p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1063 p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
1064 field_name: payload.field_name,
1065 field_size: payload.field_size,
1066 })
1067 }
1068 p2_types::ErrorCode::HttpResponseBodySize(payload) => {
1069 p3_types::ErrorCode::HttpResponseBodySize(payload)
1070 }
1071 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1072 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1073 }
1074 p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1075 p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
1076 field_name: payload.field_name,
1077 field_size: payload.field_size,
1078 })
1079 }
1080 p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1081 p3_types::ErrorCode::HttpResponseTransferCoding(payload)
1082 }
1083 p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
1084 p3_types::ErrorCode::HttpResponseContentCoding(payload)
1085 }
1086 p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
1087 p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
1088 p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
1089 p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
1090 p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
1091 p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
1092 }
1093}
1094
1095pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
1098 match code {
1099 p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
1100 p3_types::ErrorCode::DnsError(payload) => {
1101 p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
1102 rcode: payload.rcode,
1103 info_code: payload.info_code,
1104 })
1105 }
1106 p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
1107 p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
1108 p3_types::ErrorCode::DestinationIpProhibited => {
1109 p2_types::ErrorCode::DestinationIpProhibited
1110 }
1111 p3_types::ErrorCode::DestinationIpUnroutable => {
1112 p2_types::ErrorCode::DestinationIpUnroutable
1113 }
1114 p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
1115 p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
1116 p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
1117 p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
1118 p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
1119 p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
1120 p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
1121 p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
1122 p3_types::ErrorCode::TlsAlertReceived(payload) => {
1123 p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
1124 alert_id: payload.alert_id,
1125 alert_message: payload.alert_message,
1126 })
1127 }
1128 p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
1129 p3_types::ErrorCode::HttpRequestLengthRequired => {
1130 p2_types::ErrorCode::HttpRequestLengthRequired
1131 }
1132 p3_types::ErrorCode::HttpRequestBodySize(payload) => {
1133 p2_types::ErrorCode::HttpRequestBodySize(payload)
1134 }
1135 p3_types::ErrorCode::HttpRequestMethodInvalid => {
1136 p2_types::ErrorCode::HttpRequestMethodInvalid
1137 }
1138 p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
1139 p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
1140 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
1141 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
1142 }
1143 p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
1144 p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
1145 p2_types::FieldSizePayload {
1146 field_name: payload.field_name,
1147 field_size: payload.field_size,
1148 }
1149 }))
1150 }
1151 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
1152 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
1153 }
1154 p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
1155 p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
1156 field_name: payload.field_name,
1157 field_size: payload.field_size,
1158 })
1159 }
1160 p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
1161 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1162 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1163 }
1164 p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1165 p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1166 field_name: payload.field_name,
1167 field_size: payload.field_size,
1168 })
1169 }
1170 p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1171 p2_types::ErrorCode::HttpResponseBodySize(payload)
1172 }
1173 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1174 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1175 }
1176 p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1177 p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1178 field_name: payload.field_name,
1179 field_size: payload.field_size,
1180 })
1181 }
1182 p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1183 p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1184 }
1185 p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1186 p2_types::ErrorCode::HttpResponseContentCoding(payload)
1187 }
1188 p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1189 p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1190 p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1191 p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1192 p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1193 p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1194 }
1195}
1196
1197fn record_content_length_header(span: &Span, headers: &HeaderMap, attr_name: &'static str) {
1198 if let Some(content_length) = headers.get(CONTENT_LENGTH)
1199 && let Ok(size_str) = content_length.to_str()
1200 {
1201 span.set_attribute(attr_name, size_str.to_string());
1202 }
1203}
1204
1205#[cfg(test)]
1206mod tests {
1207 use super::*;
1208
1209 #[tokio::test]
1214 async fn connect_timeout_applies_to_permit_acquisition() {
1215 let conn_semaphore = ConnectionSemaphore::new(None, Some(1), "test", None);
1218 let _held = conn_semaphore
1219 .try_acquire()
1220 .expect("exhausting the single permit");
1221
1222 let options = ConnectOptions {
1223 blocked_networks: BlockedNetworks::default(),
1225 connect_timeout: Duration::from_millis(50),
1227 tls_client_config: None,
1228 override_connect_addr: Some("127.0.0.1:1".parse().unwrap()),
1230 semaphore: conn_semaphore,
1231 };
1232
1233 let result = options
1236 .connect_tcp(&Uri::from_static("http://test.example"), 80)
1237 .await;
1238
1239 assert!(
1240 matches!(result, Err(ErrorCode::ConnectionTimeout)),
1241 "expected ConnectionTimeout"
1242 );
1243 }
1244}