1use std::{
2 error::Error,
3 future::Future,
4 io::IoSlice,
5 net::SocketAddr,
6 pin::Pin,
7 sync::Arc,
8 task::{self, Context, Poll},
9 time::Duration,
10};
11
12use bytes::Bytes;
13use http::{header::HOST, uri::Scheme, Uri};
14use http_body::{Body, Frame, SizeHint};
15use http_body_util::{combinators::BoxBody, BodyExt};
16use hyper_util::{
17 client::legacy::{
18 connect::{Connected, Connection},
19 Client,
20 },
21 rt::{TokioExecutor, TokioIo},
22};
23use spin_factor_outbound_networking::{
24 config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
25 ComponentTlsClientConfigs, TlsClientConfig,
26};
27use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceState};
28use tokio::{
29 io::{AsyncRead, AsyncWrite, ReadBuf},
30 net::TcpStream,
31 sync::{OwnedSemaphorePermit, Semaphore},
32 time::timeout,
33};
34use tokio_rustls::client::TlsStream;
35use tower_service::Service;
36use tracing::{field::Empty, instrument, Instrument};
37use wasmtime::component::HasData;
38use wasmtime_wasi::TrappableError;
39use wasmtime_wasi_http::{
40 bindings::http::types::{self as p2_types, ErrorCode},
41 body::HyperOutgoingBody,
42 p3::{self, bindings::http::types as p3_types},
43 types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
44 HttpError, WasiHttpCtx, WasiHttpImpl, WasiHttpView,
45};
46
47use crate::{
48 intercept::{InterceptOutcome, OutboundHttpInterceptor},
49 wasi_2023_10_18, wasi_2023_11_10, InstanceState, OutboundHttpFactor, SelfRequestOrigin,
50};
51
52const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
53
54pub(crate) struct HasHttp;
55
56impl HasData for HasHttp {
57 type Data<'a> = WasiHttpImpl<WasiHttpImplInner<'a>>;
58}
59
60impl p3::WasiHttpCtx for InstanceState {
61 fn send_request(
62 &mut self,
63 request: http::Request<BoxBody<Bytes, p3_types::ErrorCode>>,
64 options: Option<p3::RequestOptions>,
65 fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
66 ) -> Box<
67 dyn Future<
68 Output = Result<
69 (
70 http::Response<BoxBody<Bytes, p3_types::ErrorCode>>,
71 Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
72 ),
73 TrappableError<p3_types::ErrorCode>,
74 >,
75 > + Send,
76 > {
77 _ = fut;
86
87 let request_sender = RequestSender {
88 allowed_hosts: self.allowed_hosts.clone(),
89 component_tls_configs: self.component_tls_configs.clone(),
90 request_interceptor: self.request_interceptor.clone(),
91 self_request_origin: self.self_request_origin.clone(),
92 blocked_networks: self.blocked_networks.clone(),
93 http_clients: self.wasi_http_clients.clone(),
94 concurrent_outbound_connections_semaphore: self
95 .concurrent_outbound_connections_semaphore
96 .clone(),
97 };
98 let config = OutgoingRequestConfig {
99 use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
100 connect_timeout: options
101 .and_then(|v| v.connect_timeout)
102 .unwrap_or(DEFAULT_TIMEOUT),
103 first_byte_timeout: options
104 .and_then(|v| v.first_byte_timeout)
105 .unwrap_or(DEFAULT_TIMEOUT),
106 between_bytes_timeout: options
107 .and_then(|v| v.between_bytes_timeout)
108 .unwrap_or(DEFAULT_TIMEOUT),
109 };
110 Box::new(async {
111 match request_sender
112 .send(
113 request.map(|body| body.map_err(p3_to_p2_error_code).boxed()),
114 config,
115 )
116 .await
117 {
118 Ok(IncomingResponse {
119 resp,
120 between_bytes_timeout,
121 ..
122 }) => Ok((
123 resp.map(|body| {
124 BetweenBytesTimeoutBody {
125 body,
126 sleep: None,
127 timeout: between_bytes_timeout,
128 }
129 .boxed()
130 }),
131 Box::new(async {
132 Ok(())
136 }) as Box<dyn Future<Output = _> + Send>,
137 )),
138 Err(http_error) => match http_error.downcast() {
139 Ok(error_code) => Err(TrappableError::from(p2_to_p3_error_code(error_code))),
140 Err(trap) => Err(TrappableError::trap(trap)),
141 },
142 }
143 })
144 }
145}
146
147pin_project_lite::pin_project! {
148 struct BetweenBytesTimeoutBody<B> {
149 #[pin]
150 body: B,
151 #[pin]
152 sleep: Option<tokio::time::Sleep>,
153 timeout: Duration,
154 }
155}
156
157impl<B: Body<Error = p2_types::ErrorCode>> Body for BetweenBytesTimeoutBody<B> {
158 type Data = B::Data;
159 type Error = p3_types::ErrorCode;
160
161 fn poll_frame(
162 self: Pin<&mut Self>,
163 cx: &mut Context<'_>,
164 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
165 let mut me = self.project();
166 match me.body.poll_frame(cx) {
167 Poll::Ready(value) => {
168 me.sleep.as_mut().set(None);
169 Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
170 }
171 Poll::Pending => {
172 if me.sleep.is_none() {
173 me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
174 }
175 task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
176 Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
177 }
178 }
179 }
180
181 fn is_end_stream(&self) -> bool {
182 self.body.is_end_stream()
183 }
184
185 fn size_hint(&self) -> SizeHint {
186 self.body.size_hint()
187 }
188}
189
190pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
191where
192 C: spin_factors::InitContext<OutboundHttpFactor>,
193{
194 let linker = ctx.linker();
195
196 fn get_http<C>(store: &mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>
197 where
198 C: spin_factors::InitContext<OutboundHttpFactor>,
199 {
200 let (state, table) = C::get_data_with_table(store);
201 WasiHttpImpl(WasiHttpImplInner { state, table })
202 }
203
204 let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>;
205 wasmtime_wasi_http::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
206 linker, get_http,
207 )?;
208 wasmtime_wasi_http::bindings::http::types::add_to_linker::<_, HasHttp>(
209 linker,
210 &Default::default(),
211 get_http,
212 )?;
213
214 fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
215 where
216 C: spin_factors::InitContext<OutboundHttpFactor>,
217 {
218 let (state, table) = C::get_data_with_table(store);
219 p3::WasiHttpCtxView { ctx: state, table }
220 }
221
222 let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
223 p3::bindings::http::handler::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
224 p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
225
226 wasi_2023_10_18::add_to_linker(linker, get_http)?;
227 wasi_2023_11_10::add_to_linker(linker, get_http)?;
228
229 Ok(())
230}
231
232impl OutboundHttpFactor {
233 pub fn get_wasi_http_impl(
234 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
235 ) -> Option<WasiHttpImpl<impl WasiHttpView + '_>> {
236 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
237 Some(WasiHttpImpl(WasiHttpImplInner { state, table }))
238 }
239
240 pub fn get_wasi_p3_http_impl(
241 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
242 ) -> Option<p3::WasiHttpCtxView<'_>> {
243 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
244 Some(p3::WasiHttpCtxView { ctx: state, table })
245 }
246}
247
248pub(crate) struct WasiHttpImplInner<'a> {
249 state: &'a mut InstanceState,
250 table: &'a mut ResourceTable,
251}
252
253type OutgoingRequest = http::Request<HyperOutgoingBody>;
254
255impl WasiHttpView for WasiHttpImplInner<'_> {
256 fn ctx(&mut self) -> &mut WasiHttpCtx {
257 &mut self.state.wasi_http_ctx
258 }
259
260 fn table(&mut self) -> &mut ResourceTable {
261 self.table
262 }
263
264 #[instrument(
265 name = "spin_outbound_http.send_request",
266 skip_all,
267 fields(
268 otel.kind = "client",
269 url.full = Empty,
270 http.request.method = %request.method(),
271 otel.name = %request.method(),
272 http.response.status_code = Empty,
273 server.address = Empty,
274 server.port = Empty,
275 )
276 )]
277 fn send_request(
278 &mut self,
279 request: OutgoingRequest,
280 config: OutgoingRequestConfig,
281 ) -> Result<wasmtime_wasi_http::types::HostFutureIncomingResponse, HttpError> {
282 let request_sender = RequestSender {
283 allowed_hosts: self.state.allowed_hosts.clone(),
284 component_tls_configs: self.state.component_tls_configs.clone(),
285 request_interceptor: self.state.request_interceptor.clone(),
286 self_request_origin: self.state.self_request_origin.clone(),
287 blocked_networks: self.state.blocked_networks.clone(),
288 http_clients: self.state.wasi_http_clients.clone(),
289 concurrent_outbound_connections_semaphore: self
290 .state
291 .concurrent_outbound_connections_semaphore
292 .clone(),
293 };
294 Ok(HostFutureIncomingResponse::Pending(
295 wasmtime_wasi::runtime::spawn(
296 async {
297 match request_sender.send(request, config).await {
298 Ok(resp) => Ok(Ok(resp)),
299 Err(http_error) => match http_error.downcast() {
300 Ok(error_code) => Ok(Err(error_code)),
301 Err(trap) => Err(trap),
302 },
303 }
304 }
305 .in_current_span(),
306 ),
307 ))
308 }
309}
310
311struct RequestSender {
312 allowed_hosts: OutboundAllowedHosts,
313 blocked_networks: BlockedNetworks,
314 component_tls_configs: ComponentTlsClientConfigs,
315 self_request_origin: Option<SelfRequestOrigin>,
316 request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
317 http_clients: HttpClients,
318 concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
319}
320
321impl RequestSender {
322 async fn send(
323 self,
324 mut request: OutgoingRequest,
325 mut config: OutgoingRequestConfig,
326 ) -> Result<IncomingResponse, HttpError> {
327 self.prepare_request(&mut request, &mut config).await?;
328
329 spin_telemetry::inject_trace_context(&mut request);
331
332 let mut override_connect_addr = None;
334 if let Some(interceptor) = &self.request_interceptor {
335 let intercept_request = std::mem::take(&mut request).into();
336 match interceptor.intercept(intercept_request).await? {
337 InterceptOutcome::Continue(mut req) => {
338 override_connect_addr = req.override_connect_addr.take();
339 request = req.into_hyper_request();
340 }
341 InterceptOutcome::Complete(resp) => {
342 let resp = IncomingResponse {
343 resp,
344 worker: None,
345 between_bytes_timeout: config.between_bytes_timeout,
346 };
347 return Ok(resp);
348 }
349 }
350 }
351
352 let span = tracing::Span::current();
354 if let Some(addr) = override_connect_addr {
355 span.record("server.address", addr.ip().to_string());
356 span.record("server.port", addr.port());
357 } else if let Some(authority) = request.uri().authority() {
358 span.record("server.address", authority.host());
359 if let Some(port) = authority.port_u16() {
360 span.record("server.port", port);
361 }
362 }
363
364 Ok(self
365 .send_request(request, config, override_connect_addr)
366 .await?)
367 }
368
369 async fn prepare_request(
370 &self,
371 request: &mut OutgoingRequest,
372 config: &mut OutgoingRequestConfig,
373 ) -> Result<(), ErrorCode> {
374 let uri = request.uri_mut();
378 if uri
379 .authority()
380 .is_some_and(|authority| authority.host().is_empty())
381 {
382 let mut builder = http::uri::Builder::new();
383 if let Some(paq) = uri.path_and_query() {
384 builder = builder.path_and_query(paq.clone());
385 }
386 *uri = builder.build().unwrap();
387 }
388 tracing::Span::current().record("url.full", uri.to_string());
389
390 let is_self_request = match request.uri().authority() {
391 Some(authority) => authority.host() == "self.alt",
393 None => true,
395 };
396
397 let is_allowed = if is_self_request {
399 self.allowed_hosts
400 .check_relative_url(&["http", "https"])
401 .await
402 .unwrap_or(false)
403 } else {
404 self.allowed_hosts
405 .check_url(&request.uri().to_string(), "https")
406 .await
407 .unwrap_or(false)
408 };
409 if !is_allowed {
410 return Err(ErrorCode::HttpRequestDenied);
411 }
412
413 if is_self_request {
414 let Some(origin) = self.self_request_origin.as_ref() else {
416 tracing::error!(
417 "Couldn't handle outbound HTTP request to relative URI; no origin set"
418 );
419 return Err(ErrorCode::HttpRequestUriInvalid);
420 };
421
422 config.use_tls = origin.use_tls();
423
424 request.headers_mut().insert(HOST, origin.host_header());
425
426 let path_and_query = request.uri().path_and_query().cloned();
427 *request.uri_mut() = origin.clone().into_uri(path_and_query);
428 }
429
430 request.headers_mut().remove(HOST);
437 Ok(())
438 }
439
440 async fn send_request(
441 self,
442 request: OutgoingRequest,
443 config: OutgoingRequestConfig,
444 override_connect_addr: Option<SocketAddr>,
445 ) -> Result<IncomingResponse, ErrorCode> {
446 let OutgoingRequestConfig {
447 use_tls,
448 connect_timeout,
449 first_byte_timeout,
450 between_bytes_timeout,
451 } = config;
452
453 let tls_client_config = if use_tls {
454 let host = request.uri().host().unwrap_or_default();
455 Some(self.component_tls_configs.get_client_config(host).clone())
456 } else {
457 None
458 };
459
460 let resp = CONNECT_OPTIONS.scope(
461 ConnectOptions {
462 blocked_networks: self.blocked_networks,
463 connect_timeout,
464 tls_client_config,
465 override_connect_addr,
466 concurrent_outbound_connections_semaphore: self
467 .concurrent_outbound_connections_semaphore,
468 },
469 async move {
470 if use_tls {
471 self.http_clients.https.request(request).await
472 } else {
473 let h2c_prior_knowledge_host =
475 std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
476 let use_h2c = h2c_prior_knowledge_host.as_deref()
477 == request.uri().authority().map(|a| a.as_str());
478
479 if use_h2c {
480 self.http_clients.http2.request(request).await
481 } else {
482 self.http_clients.http1.request(request).await
483 }
484 }
485 },
486 );
487
488 let resp = timeout(first_byte_timeout, resp)
489 .await
490 .map_err(|_| ErrorCode::ConnectionReadTimeout)?
491 .map_err(hyper_legacy_request_error)?
492 .map(|body| body.map_err(hyper_request_error).boxed());
493
494 tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
495
496 Ok(IncomingResponse {
497 resp,
498 worker: None,
499 between_bytes_timeout,
500 })
501 }
502}
503
504type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
505type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
506
507#[derive(Clone)]
508pub(super) struct HttpClients {
509 http1: HttpClient,
511 http2: HttpClient,
513 https: HttpsClient,
515}
516
517impl HttpClients {
518 pub(super) fn new(enable_pooling: bool) -> Self {
519 let builder = move || {
520 let mut builder = Client::builder(TokioExecutor::new());
521 if !enable_pooling {
522 builder.pool_max_idle_per_host(0);
523 }
524 builder
525 };
526 Self {
527 http1: builder().build(HttpConnector),
528 http2: builder().http2_only(true).build(HttpConnector),
529 https: builder().build(HttpsConnector),
530 }
531 }
532}
533
534tokio::task_local! {
535 static CONNECT_OPTIONS: ConnectOptions;
543}
544
545#[derive(Clone)]
546struct ConnectOptions {
547 blocked_networks: BlockedNetworks,
549 connect_timeout: Duration,
551 tls_client_config: Option<TlsClientConfig>,
553 override_connect_addr: Option<SocketAddr>,
555 concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
557}
558
559impl ConnectOptions {
560 async fn connect_tcp(
562 &self,
563 uri: &Uri,
564 default_port: u16,
565 ) -> Result<PermittedTcpStream, ErrorCode> {
566 let mut socket_addrs = match self.override_connect_addr {
567 Some(override_connect_addr) => vec![override_connect_addr],
568 None => {
569 let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
570
571 let host_and_port = if authority.port().is_some() {
572 authority.as_str().to_string()
573 } else {
574 format!("{}:{}", authority.as_str(), default_port)
575 };
576
577 let socket_addrs = tokio::net::lookup_host(&host_and_port)
578 .await
579 .map_err(|err| {
580 tracing::debug!(?host_and_port, ?err, "Error resolving host");
581 dns_error("address not available".into(), 0)
582 })?
583 .collect::<Vec<_>>();
584 tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
585 socket_addrs
586 }
587 };
588
589 crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;
591
592 let permit = match &self.concurrent_outbound_connections_semaphore {
594 Some(s) => s.clone().acquire_owned().await.ok(),
595 None => None,
596 };
597
598 let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
599 .await
600 .map_err(|_| ErrorCode::ConnectionTimeout)?
601 .map_err(|err| match err.kind() {
602 std::io::ErrorKind::AddrNotAvailable => {
603 dns_error("address not available".into(), 0)
604 }
605 _ => ErrorCode::ConnectionRefused,
606 })?;
607 Ok(PermittedTcpStream {
608 inner: stream,
609 _permit: permit,
610 })
611 }
612
613 async fn connect_tls(
615 &self,
616 uri: &Uri,
617 default_port: u16,
618 ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
619 let tcp_stream = self.connect_tcp(uri, default_port).await?;
620
621 let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
622 tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
623
624 let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
625 let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
626 .map_err(|e| {
627 tracing::warn!("dns lookup error: {e:?}");
628 dns_error("invalid dns name".into(), 0)
629 })?
630 .to_owned();
631 connector.connect(domain, tcp_stream).await.map_err(|e| {
632 tracing::warn!("tls protocol error: {e:?}");
633 ErrorCode::TlsProtocolError
634 })
635 }
636}
637
638#[derive(Clone)]
640struct HttpConnector;
641
642impl HttpConnector {
643 async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
644 let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
645 Ok(TokioIo::new(stream))
646 }
647}
648
649impl Service<Uri> for HttpConnector {
650 type Response = TokioIo<PermittedTcpStream>;
651 type Error = ErrorCode;
652 type Future =
653 Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
654
655 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
656 Poll::Ready(Ok(()))
657 }
658
659 fn call(&mut self, uri: Uri) -> Self::Future {
660 Box::pin(async move { Self::connect(uri).await })
661 }
662}
663
664#[derive(Clone)]
666struct HttpsConnector;
667
668impl HttpsConnector {
669 async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
670 let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
671 Ok(TokioIo::new(RustlsStream(stream)))
672 }
673}
674
675impl Service<Uri> for HttpsConnector {
676 type Response = TokioIo<RustlsStream>;
677 type Error = ErrorCode;
678 type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
679
680 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
681 Poll::Ready(Ok(()))
682 }
683
684 fn call(&mut self, uri: Uri) -> Self::Future {
685 Box::pin(async move { Self::connect(uri).await })
686 }
687}
688
689struct RustlsStream(TlsStream<PermittedTcpStream>);
690
691impl Connection for RustlsStream {
692 fn connected(&self) -> Connected {
693 if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
694 self.0.get_ref().0.connected().negotiated_h2()
695 } else {
696 self.0.get_ref().0.connected()
697 }
698 }
699}
700
701impl AsyncRead for RustlsStream {
702 fn poll_read(
703 self: Pin<&mut Self>,
704 cx: &mut Context<'_>,
705 buf: &mut ReadBuf<'_>,
706 ) -> Poll<Result<(), std::io::Error>> {
707 Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
708 }
709}
710
711impl AsyncWrite for RustlsStream {
712 fn poll_write(
713 self: Pin<&mut Self>,
714 cx: &mut Context<'_>,
715 buf: &[u8],
716 ) -> Poll<Result<usize, std::io::Error>> {
717 Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
718 }
719
720 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
721 Pin::new(&mut self.get_mut().0).poll_flush(cx)
722 }
723
724 fn poll_shutdown(
725 self: Pin<&mut Self>,
726 cx: &mut Context<'_>,
727 ) -> Poll<Result<(), std::io::Error>> {
728 Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
729 }
730
731 fn poll_write_vectored(
732 self: Pin<&mut Self>,
733 cx: &mut Context<'_>,
734 bufs: &[IoSlice<'_>],
735 ) -> Poll<Result<usize, std::io::Error>> {
736 Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
737 }
738
739 fn is_write_vectored(&self) -> bool {
740 self.0.is_write_vectored()
741 }
742}
743
744struct PermittedTcpStream {
746 inner: TcpStream,
748 _permit: Option<OwnedSemaphorePermit>,
753}
754
755impl Connection for PermittedTcpStream {
756 fn connected(&self) -> Connected {
757 self.inner.connected()
758 }
759}
760
761impl AsyncRead for PermittedTcpStream {
762 fn poll_read(
763 self: Pin<&mut Self>,
764 cx: &mut Context<'_>,
765 buf: &mut ReadBuf<'_>,
766 ) -> Poll<std::io::Result<()>> {
767 Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
768 }
769}
770
771impl AsyncWrite for PermittedTcpStream {
772 fn poll_write(
773 self: Pin<&mut Self>,
774 cx: &mut Context<'_>,
775 buf: &[u8],
776 ) -> Poll<Result<usize, std::io::Error>> {
777 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
778 }
779
780 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
781 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
782 }
783
784 fn poll_shutdown(
785 self: Pin<&mut Self>,
786 cx: &mut Context<'_>,
787 ) -> Poll<Result<(), std::io::Error>> {
788 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
789 }
790}
791
792fn hyper_request_error(err: hyper::Error) -> ErrorCode {
794 if let Some(cause) = err.source() {
796 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
797 return err.clone();
798 }
799 }
800
801 tracing::warn!("hyper request error: {err:?}");
802
803 ErrorCode::HttpProtocolError
804}
805
806fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
808 if let Some(cause) = err.source() {
810 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
811 return err.clone();
812 }
813 }
814
815 tracing::warn!("hyper request error: {err:?}");
816
817 ErrorCode::HttpProtocolError
818}
819
820fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
821 ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
822 rcode: Some(rcode),
823 info_code: Some(info_code),
824 })
825}
826
827pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
830 match code {
831 p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
832 p2_types::ErrorCode::DnsError(payload) => {
833 p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
834 rcode: payload.rcode,
835 info_code: payload.info_code,
836 })
837 }
838 p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
839 p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
840 p2_types::ErrorCode::DestinationIpProhibited => {
841 p3_types::ErrorCode::DestinationIpProhibited
842 }
843 p2_types::ErrorCode::DestinationIpUnroutable => {
844 p3_types::ErrorCode::DestinationIpUnroutable
845 }
846 p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
847 p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
848 p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
849 p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
850 p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
851 p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
852 p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
853 p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
854 p2_types::ErrorCode::TlsAlertReceived(payload) => {
855 p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
856 alert_id: payload.alert_id,
857 alert_message: payload.alert_message,
858 })
859 }
860 p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
861 p2_types::ErrorCode::HttpRequestLengthRequired => {
862 p3_types::ErrorCode::HttpRequestLengthRequired
863 }
864 p2_types::ErrorCode::HttpRequestBodySize(payload) => {
865 p3_types::ErrorCode::HttpRequestBodySize(payload)
866 }
867 p2_types::ErrorCode::HttpRequestMethodInvalid => {
868 p3_types::ErrorCode::HttpRequestMethodInvalid
869 }
870 p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
871 p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
872 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
873 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
874 }
875 p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
876 p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
877 p3_types::FieldSizePayload {
878 field_name: payload.field_name,
879 field_size: payload.field_size,
880 }
881 }))
882 }
883 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
884 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
885 }
886 p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
887 p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
888 field_name: payload.field_name,
889 field_size: payload.field_size,
890 })
891 }
892 p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
893 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
894 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
895 }
896 p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
897 p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
898 field_name: payload.field_name,
899 field_size: payload.field_size,
900 })
901 }
902 p2_types::ErrorCode::HttpResponseBodySize(payload) => {
903 p3_types::ErrorCode::HttpResponseBodySize(payload)
904 }
905 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
906 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
907 }
908 p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
909 p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
910 field_name: payload.field_name,
911 field_size: payload.field_size,
912 })
913 }
914 p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
915 p3_types::ErrorCode::HttpResponseTransferCoding(payload)
916 }
917 p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
918 p3_types::ErrorCode::HttpResponseContentCoding(payload)
919 }
920 p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
921 p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
922 p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
923 p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
924 p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
925 p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
926 }
927}
928
929pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
932 match code {
933 p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
934 p3_types::ErrorCode::DnsError(payload) => {
935 p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
936 rcode: payload.rcode,
937 info_code: payload.info_code,
938 })
939 }
940 p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
941 p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
942 p3_types::ErrorCode::DestinationIpProhibited => {
943 p2_types::ErrorCode::DestinationIpProhibited
944 }
945 p3_types::ErrorCode::DestinationIpUnroutable => {
946 p2_types::ErrorCode::DestinationIpUnroutable
947 }
948 p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
949 p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
950 p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
951 p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
952 p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
953 p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
954 p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
955 p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
956 p3_types::ErrorCode::TlsAlertReceived(payload) => {
957 p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
958 alert_id: payload.alert_id,
959 alert_message: payload.alert_message,
960 })
961 }
962 p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
963 p3_types::ErrorCode::HttpRequestLengthRequired => {
964 p2_types::ErrorCode::HttpRequestLengthRequired
965 }
966 p3_types::ErrorCode::HttpRequestBodySize(payload) => {
967 p2_types::ErrorCode::HttpRequestBodySize(payload)
968 }
969 p3_types::ErrorCode::HttpRequestMethodInvalid => {
970 p2_types::ErrorCode::HttpRequestMethodInvalid
971 }
972 p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
973 p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
974 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
975 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
976 }
977 p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
978 p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
979 p2_types::FieldSizePayload {
980 field_name: payload.field_name,
981 field_size: payload.field_size,
982 }
983 }))
984 }
985 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
986 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
987 }
988 p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
989 p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
990 field_name: payload.field_name,
991 field_size: payload.field_size,
992 })
993 }
994 p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
995 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
996 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
997 }
998 p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
999 p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1000 field_name: payload.field_name,
1001 field_size: payload.field_size,
1002 })
1003 }
1004 p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1005 p2_types::ErrorCode::HttpResponseBodySize(payload)
1006 }
1007 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1008 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1009 }
1010 p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1011 p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1012 field_name: payload.field_name,
1013 field_size: payload.field_size,
1014 })
1015 }
1016 p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1017 p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1018 }
1019 p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1020 p2_types::ErrorCode::HttpResponseContentCoding(payload)
1021 }
1022 p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1023 p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1024 p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1025 p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1026 p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1027 p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1028 }
1029}