1use std::{
2 error::Error,
3 future::Future,
4 io::IoSlice,
5 net::SocketAddr,
6 ops::DerefMut,
7 pin::Pin,
8 sync::{Arc, Mutex},
9 task::{self, Context, Poll},
10 time::Duration,
11};
12
13use bytes::Bytes;
14use futures::channel::oneshot;
15use http::{
16 HeaderMap, Uri,
17 header::{CONTENT_LENGTH, HOST},
18 uri::Scheme,
19};
20use http_body::{Body, Frame, SizeHint};
21use http_body_util::{BodyExt, combinators::UnsyncBoxBody};
22use hyper_util::{
23 client::legacy::{
24 Client,
25 connect::{Connected, Connection},
26 },
27 rt::{TokioExecutor, TokioIo},
28};
29use spin_factor_outbound_networking::{
30 ComponentTlsClientConfigs, TlsClientConfig,
31 config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
32};
33use spin_factors::RuntimeFactorsInstanceState;
34use tokio::{
35 io::{AsyncRead, AsyncWrite, ReadBuf},
36 net::TcpStream,
37 sync::{OwnedSemaphorePermit, Semaphore},
38 time::timeout,
39};
40use tokio_rustls::client::TlsStream;
41use tower_service::Service;
42use tracing::{Instrument, Span, field::Empty, instrument};
43use wasmtime::component::HasData;
44use wasmtime_wasi::TrappableError;
45use wasmtime_wasi_http::{
46 p2::{
47 self, HttpError, WasiHttpCtxView,
48 bindings::http::types::{self as p2_types, ErrorCode},
49 body::HyperOutgoingBody,
50 types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
51 },
52 p3::{self, bindings::http::types as p3_types},
53};
54
55use crate::{
56 InstanceHttpHooks, OutboundHttpFactor, SelfRequestOrigin,
57 intercept::{InterceptOutcome, OutboundHttpInterceptor},
58 wasi_2023_10_18, wasi_2023_11_10,
59};
60
61use tracing_opentelemetry::OpenTelemetrySpanExt as _;
62
63const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
64
65pub struct MutexBody<T>(Mutex<T>);
66
67impl<T> MutexBody<T> {
68 pub fn new(body: T) -> Self {
69 Self(Mutex::new(body))
70 }
71}
72
73impl<T: Body + Unpin> Body for MutexBody<T> {
74 type Data = T::Data;
75 type Error = T::Error;
76
77 fn poll_frame(
78 self: Pin<&mut Self>,
79 cx: &mut Context<'_>,
80 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
81 Pin::new(self.0.lock().unwrap().deref_mut()).poll_frame(cx)
82 }
83
84 fn is_end_stream(&self) -> bool {
85 self.0.lock().unwrap().is_end_stream()
86 }
87
88 fn size_hint(&self) -> SizeHint {
89 self.0.lock().unwrap().size_hint()
90 }
91}
92
93pub struct NotifyOnDropBody<B> {
96 body: B,
97 _tx: oneshot::Sender<()>,
98}
99
100impl<B> NotifyOnDropBody<B> {
101 pub fn new(body: B, tx: oneshot::Sender<()>) -> Self {
102 Self { body, _tx: tx }
103 }
104}
105
106impl<B: Body + Unpin> Body for NotifyOnDropBody<B> {
107 type Data = B::Data;
108 type Error = B::Error;
109
110 fn poll_frame(
111 mut self: Pin<&mut Self>,
112 cx: &mut Context<'_>,
113 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
114 Pin::new(&mut self.body).poll_frame(cx)
115 }
116
117 fn is_end_stream(&self) -> bool {
118 self.body.is_end_stream()
119 }
120
121 fn size_hint(&self) -> SizeHint {
122 self.body.size_hint()
123 }
124}
125
126pub(crate) struct HasHttp;
127
128impl HasData for HasHttp {
129 type Data<'a> = WasiHttpCtxView<'a>;
130}
131
132impl p3::WasiHttpHooks for InstanceHttpHooks {
133 #[instrument(
134 name = "spin_outbound_http.send_request",
135 skip_all,
136 fields(
137 otel.kind = "client",
138 url.full = Empty,
139 http.request.method = %request.method(),
140 otel.name = %request.method(),
141 http.response.status_code = Empty,
142 server.address = Empty,
143 server.port = Empty,
144 )
145 )]
146 #[allow(clippy::type_complexity)]
147 fn send_request(
148 &mut self,
149 request: http::Request<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
150 options: Option<p3::RequestOptions>,
151 fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
152 ) -> Box<
153 dyn Future<
154 Output = Result<
155 (
156 http::Response<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
157 Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
158 ),
159 TrappableError<p3_types::ErrorCode>,
160 >,
161 > + Send,
162 > {
163 self.otel.reparent_tracing_span();
164
165 _ = fut;
174
175 let request = request.map(|body| MutexBody::new(body).boxed());
176
177 let request_sender = RequestSender {
178 allowed_hosts: self.allowed_hosts.clone(),
179 component_tls_configs: self.component_tls_configs.clone(),
180 request_interceptor: self.request_interceptor.clone(),
181 self_request_origin: self.self_request_origin.clone(),
182 blocked_networks: self.blocked_networks.clone(),
183 http_clients: self.wasi_http_clients.clone(),
184 concurrent_outbound_connections_semaphore: self
185 .concurrent_outbound_connections_semaphore
186 .clone(),
187 };
188 let config = OutgoingRequestConfig {
189 use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
190 connect_timeout: options
191 .and_then(|v| v.connect_timeout)
192 .unwrap_or(DEFAULT_TIMEOUT),
193 first_byte_timeout: options
194 .and_then(|v| v.first_byte_timeout)
195 .unwrap_or(DEFAULT_TIMEOUT),
196 between_bytes_timeout: options
197 .and_then(|v| v.between_bytes_timeout)
198 .unwrap_or(DEFAULT_TIMEOUT),
199 };
200 Box::new(async {
201 match request_sender
202 .send(
203 request.map(|body| body.map_err(p3_to_p2_error_code).boxed_unsync()),
204 config,
205 )
206 .await
207 {
208 Ok(IncomingResponse {
209 resp,
210 between_bytes_timeout,
211 ..
212 }) => Ok((
213 resp.map(|body| {
214 BetweenBytesTimeoutBody {
215 body,
216 sleep: None,
217 timeout: between_bytes_timeout,
218 }
219 .boxed_unsync()
220 }),
221 Box::new(async {
222 Ok(())
226 }) as Box<dyn Future<Output = _> + Send>,
227 )),
228 Err(http_error) => match http_error.downcast() {
229 Ok(error_code) => Err(TrappableError::from(p2_to_p3_error_code(error_code))),
230 Err(trap) => Err(TrappableError::trap(trap)),
231 },
232 }
233 })
234 }
235}
236
237pin_project_lite::pin_project! {
238 struct BetweenBytesTimeoutBody<B> {
239 #[pin]
240 body: B,
241 #[pin]
242 sleep: Option<tokio::time::Sleep>,
243 timeout: Duration,
244 }
245}
246
247impl<B: Body<Error = p2_types::ErrorCode>> Body for BetweenBytesTimeoutBody<B> {
248 type Data = B::Data;
249 type Error = p3_types::ErrorCode;
250
251 fn poll_frame(
252 self: Pin<&mut Self>,
253 cx: &mut Context<'_>,
254 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
255 let mut me = self.project();
256 match me.body.poll_frame(cx) {
257 Poll::Ready(value) => {
258 me.sleep.as_mut().set(None);
259 Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
260 }
261 Poll::Pending => {
262 if me.sleep.is_none() {
263 me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
264 }
265 task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
266 Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
267 }
268 }
269 }
270
271 fn is_end_stream(&self) -> bool {
272 self.body.is_end_stream()
273 }
274
275 fn size_hint(&self) -> SizeHint {
276 self.body.size_hint()
277 }
278}
279
280pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
281where
282 C: spin_factors::InitContext<OutboundHttpFactor>,
283{
284 let linker = ctx.linker();
285
286 fn get_http<C>(store: &mut C::StoreData) -> WasiHttpCtxView<'_>
287 where
288 C: spin_factors::InitContext<OutboundHttpFactor>,
289 {
290 let (state, table) = C::get_data_with_table(store);
291 let ctx = &mut state.wasi_http_ctx;
292 WasiHttpCtxView {
293 ctx,
294 table,
295 hooks: &mut state.hooks,
296 }
297 }
298
299 let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpCtxView<'_>;
300 wasmtime_wasi_http::p2::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
301 linker, get_http,
302 )?;
303 wasmtime_wasi_http::p2::bindings::http::types::add_to_linker::<_, HasHttp>(
304 linker,
305 &Default::default(),
306 get_http,
307 )?;
308
309 fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
310 where
311 C: spin_factors::InitContext<OutboundHttpFactor>,
312 {
313 let (state, table) = C::get_data_with_table(store);
314 let ctx = &mut state.wasi_http_ctx;
315 p3::WasiHttpCtxView {
316 ctx,
317 table,
318 hooks: &mut state.hooks,
319 }
320 }
321
322 let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
323 p3::bindings::http::client::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
324 p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
325
326 wasi_2023_10_18::add_to_linker(linker, get_http)?;
327 wasi_2023_11_10::add_to_linker(linker, get_http)?;
328
329 Ok(())
330}
331
332impl OutboundHttpFactor {
333 pub fn get_wasi_http_impl(
334 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
335 ) -> Option<WasiHttpCtxView<'_>> {
336 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
337 let ctx = &mut state.wasi_http_ctx;
338 Some(WasiHttpCtxView {
339 ctx,
340 table,
341 hooks: &mut state.hooks,
342 })
343 }
344
345 pub fn get_wasi_p3_http_impl(
346 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
347 ) -> Option<p3::WasiHttpCtxView<'_>> {
348 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
349 let ctx = &mut state.wasi_http_ctx;
350 Some(p3::WasiHttpCtxView {
351 ctx,
352 table,
353 hooks: &mut state.hooks,
354 })
355 }
356}
357
358type OutgoingRequest = http::Request<HyperOutgoingBody>;
359
360impl p2::WasiHttpHooks for InstanceHttpHooks {
361 #[instrument(
362 name = "spin_outbound_http.send_request",
363 skip_all,
364 fields(
365 otel.kind = "client",
366 url.full = Empty,
367 http.request.method = %request.method(),
368 otel.name = %request.method(),
369 http.response.status_code = Empty,
370 server.address = Empty,
371 server.port = Empty,
372 )
373 )]
374 fn send_request(
375 &mut self,
376 request: OutgoingRequest,
377 config: OutgoingRequestConfig,
378 ) -> Result<wasmtime_wasi_http::p2::types::HostFutureIncomingResponse, HttpError> {
379 self.otel.reparent_tracing_span();
380
381 let request_sender = RequestSender {
382 allowed_hosts: self.allowed_hosts.clone(),
383 component_tls_configs: self.component_tls_configs.clone(),
384 request_interceptor: self.request_interceptor.clone(),
385 self_request_origin: self.self_request_origin.clone(),
386 blocked_networks: self.blocked_networks.clone(),
387 http_clients: self.wasi_http_clients.clone(),
388 concurrent_outbound_connections_semaphore: self
389 .concurrent_outbound_connections_semaphore
390 .clone(),
391 };
392 Ok(HostFutureIncomingResponse::Pending(
393 wasmtime_wasi::runtime::spawn(
394 async {
395 match request_sender.send(request, config).await {
396 Ok(resp) => Ok(Ok(resp)),
397 Err(http_error) => match http_error.downcast() {
398 Ok(error_code) => Ok(Err(error_code)),
399 Err(trap) => Err(trap),
400 },
401 }
402 }
403 .in_current_span(),
404 ),
405 ))
406 }
407}
408
409struct RequestSender {
410 allowed_hosts: OutboundAllowedHosts,
411 blocked_networks: BlockedNetworks,
412 component_tls_configs: ComponentTlsClientConfigs,
413 self_request_origin: Option<SelfRequestOrigin>,
414 request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
415 http_clients: HttpClients,
416 concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
417}
418
419impl RequestSender {
420 async fn send(
421 self,
422 mut request: OutgoingRequest,
423 mut config: OutgoingRequestConfig,
424 ) -> Result<IncomingResponse, HttpError> {
425 self.prepare_request(&mut request, &mut config).await?;
426
427 spin_telemetry::inject_trace_context(&mut request);
429
430 let mut override_connect_addr = None;
432 if let Some(interceptor) = &self.request_interceptor {
433 let intercept_request = std::mem::take(&mut request).into();
434 match interceptor.intercept(intercept_request).await? {
435 InterceptOutcome::Continue(mut req) => {
436 override_connect_addr = req.override_connect_addr.take();
437 request = req.into_hyper_request();
438 }
439 InterceptOutcome::Complete(resp) => {
440 let resp = IncomingResponse {
441 resp,
442 worker: None,
443 between_bytes_timeout: config.between_bytes_timeout,
444 };
445 return Ok(resp);
446 }
447 }
448 }
449
450 let span = tracing::Span::current();
452 if let Some(addr) = override_connect_addr {
453 span.record("server.address", addr.ip().to_string());
454 span.record("server.port", addr.port());
455 } else if let Some(authority) = request.uri().authority() {
456 span.record("server.address", authority.host());
457 if let Some(port) = authority.port_u16() {
458 span.record("server.port", port);
459 }
460 }
461
462 record_content_length_header(
463 &span,
464 request.headers(),
465 "http.request.header.content-length",
466 );
467
468 Ok(self
469 .send_request(request, config, override_connect_addr)
470 .await?)
471 }
472
473 async fn prepare_request(
474 &self,
475 request: &mut OutgoingRequest,
476 config: &mut OutgoingRequestConfig,
477 ) -> Result<(), ErrorCode> {
478 let uri = request.uri_mut();
482 if uri
483 .authority()
484 .is_some_and(|authority| authority.host().is_empty())
485 {
486 let mut builder = http::uri::Builder::new();
487 if let Some(paq) = uri.path_and_query() {
488 builder = builder.path_and_query(paq.clone());
489 }
490 *uri = builder.build().unwrap();
491 }
492 tracing::Span::current().record("url.full", uri.to_string());
493
494 let is_self_request = match request.uri().authority() {
495 Some(authority) => authority.host() == "self.alt",
497 None => true,
499 };
500
501 let is_allowed = if is_self_request {
503 self.allowed_hosts
504 .check_relative_url(&["http", "https"])
505 .await
506 .unwrap_or(false)
507 } else {
508 self.allowed_hosts
509 .check_url(&request.uri().to_string(), "https")
510 .await
511 .unwrap_or(false)
512 };
513 if !is_allowed {
514 return Err(ErrorCode::HttpRequestDenied);
515 }
516
517 if is_self_request {
518 let Some(origin) = self.self_request_origin.as_ref() else {
520 tracing::error!(
521 "Couldn't handle outbound HTTP request to relative URI; no origin set"
522 );
523 return Err(ErrorCode::HttpRequestUriInvalid);
524 };
525
526 config.use_tls = origin.use_tls();
527
528 request.headers_mut().insert(HOST, origin.host_header());
529
530 let path_and_query = request.uri().path_and_query().cloned();
531 *request.uri_mut() = origin.clone().into_uri(path_and_query);
532 }
533
534 request.headers_mut().remove(HOST);
541 Ok(())
542 }
543
544 async fn send_request(
545 self,
546 request: OutgoingRequest,
547 config: OutgoingRequestConfig,
548 override_connect_addr: Option<SocketAddr>,
549 ) -> Result<IncomingResponse, ErrorCode> {
550 let OutgoingRequestConfig {
551 use_tls,
552 connect_timeout,
553 first_byte_timeout,
554 between_bytes_timeout,
555 } = config;
556
557 let tls_client_config = if use_tls {
558 let host = request.uri().host().unwrap_or_default();
559 Some(self.component_tls_configs.get_client_config(host).clone())
560 } else {
561 None
562 };
563
564 let resp = CONNECT_OPTIONS.scope(
565 ConnectOptions {
566 blocked_networks: self.blocked_networks,
567 connect_timeout,
568 tls_client_config,
569 override_connect_addr,
570 concurrent_outbound_connections_semaphore: self
571 .concurrent_outbound_connections_semaphore,
572 },
573 async move {
574 if use_tls {
575 self.http_clients.https.request(request).await
576 } else {
577 let h2c_prior_knowledge_host =
579 std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
580 let use_h2c = h2c_prior_knowledge_host.as_deref()
581 == request.uri().authority().map(|a| a.as_str());
582
583 if use_h2c {
584 self.http_clients.http2.request(request).await
585 } else {
586 self.http_clients.http1.request(request).await
587 }
588 }
589 },
590 );
591
592 let resp = timeout(first_byte_timeout, resp)
593 .await
594 .map_err(|_| ErrorCode::ConnectionReadTimeout)?
595 .map_err(hyper_legacy_request_error)?
596 .map(|body| body.map_err(hyper_request_error).boxed_unsync());
597
598 let span = tracing::Span::current();
599 span.record("http.response.status_code", resp.status().as_u16());
600
601 record_content_length_header(&span, resp.headers(), "http.response.header.content-length");
602
603 Ok(IncomingResponse {
604 resp,
605 worker: None,
606 between_bytes_timeout,
607 })
608 }
609}
610
611type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
612type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
613
614#[derive(Clone)]
615pub(super) struct HttpClients {
616 http1: HttpClient,
618 http2: HttpClient,
620 https: HttpsClient,
622}
623
624impl HttpClients {
625 pub(super) fn new(enable_pooling: bool) -> Self {
626 let builder = move || {
627 let mut builder = Client::builder(TokioExecutor::new());
628 if !enable_pooling {
629 builder.pool_max_idle_per_host(0);
630 }
631 builder
632 };
633 Self {
634 http1: builder().build(HttpConnector),
635 http2: builder().http2_only(true).build(HttpConnector),
636 https: builder().build(HttpsConnector),
637 }
638 }
639}
640
641tokio::task_local! {
642 static CONNECT_OPTIONS: ConnectOptions;
650}
651
652#[derive(Clone)]
653struct ConnectOptions {
654 blocked_networks: BlockedNetworks,
656 connect_timeout: Duration,
658 tls_client_config: Option<TlsClientConfig>,
660 override_connect_addr: Option<SocketAddr>,
662 concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
664}
665
666impl ConnectOptions {
667 async fn connect_tcp(
669 &self,
670 uri: &Uri,
671 default_port: u16,
672 ) -> Result<PermittedTcpStream, ErrorCode> {
673 let mut socket_addrs = match self.override_connect_addr {
674 Some(override_connect_addr) => vec![override_connect_addr],
675 None => {
676 let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
677
678 let host_and_port = if authority.port().is_some() {
679 authority.as_str().to_string()
680 } else {
681 format!("{}:{}", authority.as_str(), default_port)
682 };
683
684 let socket_addrs = tokio::net::lookup_host(&host_and_port)
685 .await
686 .map_err(|err| {
687 tracing::debug!(?host_and_port, ?err, "Error resolving host");
688 dns_error("address not available".into(), 0)
689 })?
690 .collect::<Vec<_>>();
691 tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
692 socket_addrs
693 }
694 };
695
696 crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;
698
699 let permit = crate::concurrent_outbound_connections::acquire_owned_semaphore(
702 "wasi",
703 &self.concurrent_outbound_connections_semaphore,
704 )
705 .await;
706
707 let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
708 .await
709 .map_err(|_| ErrorCode::ConnectionTimeout)?
710 .map_err(|err| match err.kind() {
711 std::io::ErrorKind::AddrNotAvailable => {
712 dns_error("address not available".into(), 0)
713 }
714 _ => ErrorCode::ConnectionRefused,
715 })?;
716 Ok(PermittedTcpStream {
717 inner: stream,
718 _permit: permit,
719 })
720 }
721
722 async fn connect_tls(
724 &self,
725 uri: &Uri,
726 default_port: u16,
727 ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
728 let tcp_stream = self.connect_tcp(uri, default_port).await?;
729
730 let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
731 tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
732
733 let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
734 let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
735 .map_err(|e| {
736 tracing::warn!("dns lookup error: {e:?}");
737 dns_error("invalid dns name".into(), 0)
738 })?
739 .to_owned();
740 connector.connect(domain, tcp_stream).await.map_err(|e| {
741 tracing::warn!("tls protocol error: {e:?}");
742 ErrorCode::TlsProtocolError
743 })
744 }
745}
746
747#[derive(Clone)]
749struct HttpConnector;
750
751impl HttpConnector {
752 async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
753 let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
754 Ok(TokioIo::new(stream))
755 }
756}
757
758impl Service<Uri> for HttpConnector {
759 type Response = TokioIo<PermittedTcpStream>;
760 type Error = ErrorCode;
761 type Future =
762 Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
763
764 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
765 Poll::Ready(Ok(()))
766 }
767
768 fn call(&mut self, uri: Uri) -> Self::Future {
769 Box::pin(async move { Self::connect(uri).await })
770 }
771}
772
773#[derive(Clone)]
775struct HttpsConnector;
776
777impl HttpsConnector {
778 async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
779 let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
780 Ok(TokioIo::new(RustlsStream(stream)))
781 }
782}
783
784impl Service<Uri> for HttpsConnector {
785 type Response = TokioIo<RustlsStream>;
786 type Error = ErrorCode;
787 type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
788
789 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
790 Poll::Ready(Ok(()))
791 }
792
793 fn call(&mut self, uri: Uri) -> Self::Future {
794 Box::pin(async move { Self::connect(uri).await })
795 }
796}
797
798struct RustlsStream(TlsStream<PermittedTcpStream>);
799
800impl Connection for RustlsStream {
801 fn connected(&self) -> Connected {
802 if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
803 self.0.get_ref().0.connected().negotiated_h2()
804 } else {
805 self.0.get_ref().0.connected()
806 }
807 }
808}
809
810impl AsyncRead for RustlsStream {
811 fn poll_read(
812 self: Pin<&mut Self>,
813 cx: &mut Context<'_>,
814 buf: &mut ReadBuf<'_>,
815 ) -> Poll<Result<(), std::io::Error>> {
816 Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
817 }
818}
819
820impl AsyncWrite for RustlsStream {
821 fn poll_write(
822 self: Pin<&mut Self>,
823 cx: &mut Context<'_>,
824 buf: &[u8],
825 ) -> Poll<Result<usize, std::io::Error>> {
826 Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
827 }
828
829 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
830 Pin::new(&mut self.get_mut().0).poll_flush(cx)
831 }
832
833 fn poll_shutdown(
834 self: Pin<&mut Self>,
835 cx: &mut Context<'_>,
836 ) -> Poll<Result<(), std::io::Error>> {
837 Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
838 }
839
840 fn poll_write_vectored(
841 self: Pin<&mut Self>,
842 cx: &mut Context<'_>,
843 bufs: &[IoSlice<'_>],
844 ) -> Poll<Result<usize, std::io::Error>> {
845 Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
846 }
847
848 fn is_write_vectored(&self) -> bool {
849 self.0.is_write_vectored()
850 }
851}
852
853struct PermittedTcpStream {
855 inner: TcpStream,
857 _permit: Option<OwnedSemaphorePermit>,
862}
863
864impl Connection for PermittedTcpStream {
865 fn connected(&self) -> Connected {
866 self.inner.connected()
867 }
868}
869
870impl AsyncRead for PermittedTcpStream {
871 fn poll_read(
872 self: Pin<&mut Self>,
873 cx: &mut Context<'_>,
874 buf: &mut ReadBuf<'_>,
875 ) -> Poll<std::io::Result<()>> {
876 Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
877 }
878}
879
880impl AsyncWrite for PermittedTcpStream {
881 fn poll_write(
882 self: Pin<&mut Self>,
883 cx: &mut Context<'_>,
884 buf: &[u8],
885 ) -> Poll<Result<usize, std::io::Error>> {
886 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
887 }
888
889 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
890 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
891 }
892
893 fn poll_shutdown(
894 self: Pin<&mut Self>,
895 cx: &mut Context<'_>,
896 ) -> Poll<Result<(), std::io::Error>> {
897 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
898 }
899}
900
901fn hyper_request_error(err: hyper::Error) -> ErrorCode {
903 if let Some(cause) = err.source()
905 && let Some(err) = cause.downcast_ref::<ErrorCode>()
906 {
907 return err.clone();
908 }
909
910 tracing::warn!("hyper request error: {err:?}");
911
912 ErrorCode::HttpProtocolError
913}
914
915fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
917 if let Some(cause) = err.source()
919 && let Some(err) = cause.downcast_ref::<ErrorCode>()
920 {
921 return err.clone();
922 }
923
924 tracing::warn!("hyper request error: {err:?}");
925
926 ErrorCode::HttpProtocolError
927}
928
929fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
930 ErrorCode::DnsError(
931 wasmtime_wasi_http::p2::bindings::http::types::DnsErrorPayload {
932 rcode: Some(rcode),
933 info_code: Some(info_code),
934 },
935 )
936}
937
938pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
941 match code {
942 p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
943 p2_types::ErrorCode::DnsError(payload) => {
944 p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
945 rcode: payload.rcode,
946 info_code: payload.info_code,
947 })
948 }
949 p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
950 p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
951 p2_types::ErrorCode::DestinationIpProhibited => {
952 p3_types::ErrorCode::DestinationIpProhibited
953 }
954 p2_types::ErrorCode::DestinationIpUnroutable => {
955 p3_types::ErrorCode::DestinationIpUnroutable
956 }
957 p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
958 p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
959 p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
960 p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
961 p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
962 p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
963 p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
964 p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
965 p2_types::ErrorCode::TlsAlertReceived(payload) => {
966 p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
967 alert_id: payload.alert_id,
968 alert_message: payload.alert_message,
969 })
970 }
971 p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
972 p2_types::ErrorCode::HttpRequestLengthRequired => {
973 p3_types::ErrorCode::HttpRequestLengthRequired
974 }
975 p2_types::ErrorCode::HttpRequestBodySize(payload) => {
976 p3_types::ErrorCode::HttpRequestBodySize(payload)
977 }
978 p2_types::ErrorCode::HttpRequestMethodInvalid => {
979 p3_types::ErrorCode::HttpRequestMethodInvalid
980 }
981 p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
982 p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
983 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
984 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
985 }
986 p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
987 p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
988 p3_types::FieldSizePayload {
989 field_name: payload.field_name,
990 field_size: payload.field_size,
991 }
992 }))
993 }
994 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
995 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
996 }
997 p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
998 p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
999 field_name: payload.field_name,
1000 field_size: payload.field_size,
1001 })
1002 }
1003 p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
1004 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1005 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1006 }
1007 p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1008 p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
1009 field_name: payload.field_name,
1010 field_size: payload.field_size,
1011 })
1012 }
1013 p2_types::ErrorCode::HttpResponseBodySize(payload) => {
1014 p3_types::ErrorCode::HttpResponseBodySize(payload)
1015 }
1016 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1017 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1018 }
1019 p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1020 p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
1021 field_name: payload.field_name,
1022 field_size: payload.field_size,
1023 })
1024 }
1025 p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1026 p3_types::ErrorCode::HttpResponseTransferCoding(payload)
1027 }
1028 p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
1029 p3_types::ErrorCode::HttpResponseContentCoding(payload)
1030 }
1031 p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
1032 p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
1033 p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
1034 p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
1035 p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
1036 p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
1037 }
1038}
1039
1040pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
1043 match code {
1044 p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
1045 p3_types::ErrorCode::DnsError(payload) => {
1046 p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
1047 rcode: payload.rcode,
1048 info_code: payload.info_code,
1049 })
1050 }
1051 p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
1052 p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
1053 p3_types::ErrorCode::DestinationIpProhibited => {
1054 p2_types::ErrorCode::DestinationIpProhibited
1055 }
1056 p3_types::ErrorCode::DestinationIpUnroutable => {
1057 p2_types::ErrorCode::DestinationIpUnroutable
1058 }
1059 p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
1060 p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
1061 p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
1062 p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
1063 p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
1064 p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
1065 p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
1066 p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
1067 p3_types::ErrorCode::TlsAlertReceived(payload) => {
1068 p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
1069 alert_id: payload.alert_id,
1070 alert_message: payload.alert_message,
1071 })
1072 }
1073 p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
1074 p3_types::ErrorCode::HttpRequestLengthRequired => {
1075 p2_types::ErrorCode::HttpRequestLengthRequired
1076 }
1077 p3_types::ErrorCode::HttpRequestBodySize(payload) => {
1078 p2_types::ErrorCode::HttpRequestBodySize(payload)
1079 }
1080 p3_types::ErrorCode::HttpRequestMethodInvalid => {
1081 p2_types::ErrorCode::HttpRequestMethodInvalid
1082 }
1083 p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
1084 p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
1085 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
1086 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
1087 }
1088 p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
1089 p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
1090 p2_types::FieldSizePayload {
1091 field_name: payload.field_name,
1092 field_size: payload.field_size,
1093 }
1094 }))
1095 }
1096 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
1097 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
1098 }
1099 p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
1100 p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
1101 field_name: payload.field_name,
1102 field_size: payload.field_size,
1103 })
1104 }
1105 p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
1106 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1107 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1108 }
1109 p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1110 p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1111 field_name: payload.field_name,
1112 field_size: payload.field_size,
1113 })
1114 }
1115 p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1116 p2_types::ErrorCode::HttpResponseBodySize(payload)
1117 }
1118 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1119 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1120 }
1121 p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1122 p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1123 field_name: payload.field_name,
1124 field_size: payload.field_size,
1125 })
1126 }
1127 p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1128 p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1129 }
1130 p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1131 p2_types::ErrorCode::HttpResponseContentCoding(payload)
1132 }
1133 p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1134 p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1135 p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1136 p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1137 p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1138 p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1139 }
1140}
1141
1142fn record_content_length_header(span: &Span, headers: &HeaderMap, attr_name: &'static str) {
1143 if let Some(content_length) = headers.get(CONTENT_LENGTH)
1144 && let Ok(size_str) = content_length.to_str()
1145 {
1146 span.set_attribute(attr_name, size_str.to_string());
1147 }
1148}