1use std::{
2 error::Error,
3 future::Future,
4 io::IoSlice,
5 net::SocketAddr,
6 pin::Pin,
7 sync::Arc,
8 task::{self, Context, Poll},
9 time::Duration,
10};
11
12use bytes::Bytes;
13use http::{header::HOST, uri::Scheme, Uri};
14use http_body::{Body, Frame, SizeHint};
15use http_body_util::{combinators::BoxBody, BodyExt};
16use hyper_util::{
17 client::legacy::{
18 connect::{Connected, Connection},
19 Client,
20 },
21 rt::{TokioExecutor, TokioIo},
22};
23use spin_factor_outbound_networking::{
24 config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
25 ComponentTlsClientConfigs, TlsClientConfig,
26};
27use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceState};
28use tokio::{
29 io::{AsyncRead, AsyncWrite, ReadBuf},
30 net::TcpStream,
31 sync::{OwnedSemaphorePermit, Semaphore},
32 time::timeout,
33};
34use tokio_rustls::client::TlsStream;
35use tower_service::Service;
36use tracing::{field::Empty, instrument, Instrument};
37use wasmtime::component::HasData;
38use wasmtime_wasi::TrappableError;
39use wasmtime_wasi_http::{
40 bindings::http::types::{self as p2_types, ErrorCode},
41 body::HyperOutgoingBody,
42 p3::{self, bindings::http::types as p3_types},
43 types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
44 HttpError, WasiHttpCtx, WasiHttpImpl, WasiHttpView,
45};
46
47use crate::{
48 intercept::{InterceptOutcome, OutboundHttpInterceptor},
49 wasi_2023_10_18, wasi_2023_11_10, InstanceState, OutboundHttpFactor, SelfRequestOrigin,
50};
51
52const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
53
54pub(crate) struct HasHttp;
55
56impl HasData for HasHttp {
57 type Data<'a> = WasiHttpImpl<WasiHttpImplInner<'a>>;
58}
59
60impl p3::WasiHttpCtx for InstanceState {
61 fn send_request(
62 &mut self,
63 request: http::Request<BoxBody<Bytes, p3_types::ErrorCode>>,
64 options: Option<p3::RequestOptions>,
65 fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
66 ) -> Box<
67 dyn Future<
68 Output = Result<
69 (
70 http::Response<BoxBody<Bytes, p3_types::ErrorCode>>,
71 Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
72 ),
73 TrappableError<p3_types::ErrorCode>,
74 >,
75 > + Send,
76 > {
77 _ = fut;
86
87 let request_sender = RequestSender {
88 allowed_hosts: self.allowed_hosts.clone(),
89 component_tls_configs: self.component_tls_configs.clone(),
90 request_interceptor: self.request_interceptor.clone(),
91 self_request_origin: self.self_request_origin.clone(),
92 blocked_networks: self.blocked_networks.clone(),
93 http_clients: self.wasi_http_clients.clone(),
94 concurrent_outbound_connections_semaphore: self
95 .concurrent_outbound_connections_semaphore
96 .clone(),
97 };
98 let config = OutgoingRequestConfig {
99 use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
100 connect_timeout: options
101 .and_then(|v| v.connect_timeout)
102 .unwrap_or(DEFAULT_TIMEOUT),
103 first_byte_timeout: options
104 .and_then(|v| v.first_byte_timeout)
105 .unwrap_or(DEFAULT_TIMEOUT),
106 between_bytes_timeout: options
107 .and_then(|v| v.between_bytes_timeout)
108 .unwrap_or(DEFAULT_TIMEOUT),
109 };
110 Box::new(async {
111 match request_sender
112 .send(
113 request.map(|body| body.map_err(p3_to_p2_error_code).boxed()),
114 config,
115 )
116 .await
117 {
118 Ok(IncomingResponse {
119 resp,
120 between_bytes_timeout,
121 ..
122 }) => Ok((
123 resp.map(|body| {
124 BetweenBytesTimeoutBody {
125 body,
126 sleep: None,
127 timeout: between_bytes_timeout,
128 }
129 .boxed()
130 }),
131 Box::new(async {
132 Ok(())
136 }) as Box<dyn Future<Output = _> + Send>,
137 )),
138 Err(http_error) => match http_error.downcast() {
139 Ok(error_code) => Err(TrappableError::from(p2_to_p3_error_code(error_code))),
140 Err(trap) => Err(TrappableError::trap(trap)),
141 },
142 }
143 })
144 }
145}
146
147pin_project_lite::pin_project! {
148 struct BetweenBytesTimeoutBody<B> {
149 #[pin]
150 body: B,
151 #[pin]
152 sleep: Option<tokio::time::Sleep>,
153 timeout: Duration,
154 }
155}
156
157impl<B: Body<Error = p2_types::ErrorCode>> Body for BetweenBytesTimeoutBody<B> {
158 type Data = B::Data;
159 type Error = p3_types::ErrorCode;
160
161 fn poll_frame(
162 self: Pin<&mut Self>,
163 cx: &mut Context<'_>,
164 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
165 let mut me = self.project();
166 match me.body.poll_frame(cx) {
167 Poll::Ready(value) => {
168 me.sleep.as_mut().set(None);
169 Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
170 }
171 Poll::Pending => {
172 if me.sleep.is_none() {
173 me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
174 }
175 task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
176 Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
177 }
178 }
179 }
180
181 fn is_end_stream(&self) -> bool {
182 self.body.is_end_stream()
183 }
184
185 fn size_hint(&self) -> SizeHint {
186 self.body.size_hint()
187 }
188}
189
190pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
191where
192 C: spin_factors::InitContext<OutboundHttpFactor>,
193{
194 let linker = ctx.linker();
195
196 fn get_http<C>(store: &mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>
197 where
198 C: spin_factors::InitContext<OutboundHttpFactor>,
199 {
200 let (state, table) = C::get_data_with_table(store);
201 WasiHttpImpl(WasiHttpImplInner { state, table })
202 }
203
204 let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>;
205 wasmtime_wasi_http::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
206 linker, get_http,
207 )?;
208 wasmtime_wasi_http::bindings::http::types::add_to_linker::<_, HasHttp>(
209 linker,
210 &Default::default(),
211 get_http,
212 )?;
213
214 fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
215 where
216 C: spin_factors::InitContext<OutboundHttpFactor>,
217 {
218 let (state, table) = C::get_data_with_table(store);
219 p3::WasiHttpCtxView { ctx: state, table }
220 }
221
222 let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
223 p3::bindings::http::handler::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
224 p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
225
226 wasi_2023_10_18::add_to_linker(linker, get_http)?;
227 wasi_2023_11_10::add_to_linker(linker, get_http)?;
228
229 Ok(())
230}
231
232impl OutboundHttpFactor {
233 pub fn get_wasi_http_impl(
234 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
235 ) -> Option<WasiHttpImpl<impl WasiHttpView + '_>> {
236 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
237 Some(WasiHttpImpl(WasiHttpImplInner { state, table }))
238 }
239
240 pub fn get_wasi_p3_http_impl(
241 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
242 ) -> Option<p3::WasiHttpCtxView<'_>> {
243 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
244 Some(p3::WasiHttpCtxView { ctx: state, table })
245 }
246}
247
248pub(crate) struct WasiHttpImplInner<'a> {
249 state: &'a mut InstanceState,
250 table: &'a mut ResourceTable,
251}
252
253type OutgoingRequest = http::Request<HyperOutgoingBody>;
254
255impl WasiHttpView for WasiHttpImplInner<'_> {
256 fn ctx(&mut self) -> &mut WasiHttpCtx {
257 &mut self.state.wasi_http_ctx
258 }
259
260 fn table(&mut self) -> &mut ResourceTable {
261 self.table
262 }
263
264 #[instrument(
265 name = "spin_outbound_http.send_request",
266 skip_all,
267 fields(
268 otel.kind = "client",
269 url.full = Empty,
270 http.request.method = %request.method(),
271 otel.name = %request.method(),
272 http.response.status_code = Empty,
273 server.address = Empty,
274 server.port = Empty,
275 )
276 )]
277 fn send_request(
278 &mut self,
279 request: OutgoingRequest,
280 config: OutgoingRequestConfig,
281 ) -> Result<wasmtime_wasi_http::types::HostFutureIncomingResponse, HttpError> {
282 let request_sender = RequestSender {
283 allowed_hosts: self.state.allowed_hosts.clone(),
284 component_tls_configs: self.state.component_tls_configs.clone(),
285 request_interceptor: self.state.request_interceptor.clone(),
286 self_request_origin: self.state.self_request_origin.clone(),
287 blocked_networks: self.state.blocked_networks.clone(),
288 http_clients: self.state.wasi_http_clients.clone(),
289 concurrent_outbound_connections_semaphore: self
290 .state
291 .concurrent_outbound_connections_semaphore
292 .clone(),
293 };
294 Ok(HostFutureIncomingResponse::Pending(
295 wasmtime_wasi::runtime::spawn(
296 async {
297 match request_sender.send(request, config).await {
298 Ok(resp) => Ok(Ok(resp)),
299 Err(http_error) => match http_error.downcast() {
300 Ok(error_code) => Ok(Err(error_code)),
301 Err(trap) => Err(trap),
302 },
303 }
304 }
305 .in_current_span(),
306 ),
307 ))
308 }
309}
310
311struct RequestSender {
312 allowed_hosts: OutboundAllowedHosts,
313 blocked_networks: BlockedNetworks,
314 component_tls_configs: ComponentTlsClientConfigs,
315 self_request_origin: Option<SelfRequestOrigin>,
316 request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
317 http_clients: HttpClients,
318 concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
319}
320
321impl RequestSender {
322 async fn send(
323 self,
324 mut request: OutgoingRequest,
325 mut config: OutgoingRequestConfig,
326 ) -> Result<IncomingResponse, HttpError> {
327 self.prepare_request(&mut request, &mut config).await?;
328
329 spin_telemetry::inject_trace_context(&mut request);
331
332 let mut override_connect_addr = None;
334 if let Some(interceptor) = &self.request_interceptor {
335 let intercept_request = std::mem::take(&mut request).into();
336 match interceptor.intercept(intercept_request).await? {
337 InterceptOutcome::Continue(mut req) => {
338 override_connect_addr = req.override_connect_addr.take();
339 request = req.into_hyper_request();
340 }
341 InterceptOutcome::Complete(resp) => {
342 let resp = IncomingResponse {
343 resp,
344 worker: None,
345 between_bytes_timeout: config.between_bytes_timeout,
346 };
347 return Ok(resp);
348 }
349 }
350 }
351
352 let span = tracing::Span::current();
354 if let Some(addr) = override_connect_addr {
355 span.record("server.address", addr.ip().to_string());
356 span.record("server.port", addr.port());
357 } else if let Some(authority) = request.uri().authority() {
358 span.record("server.address", authority.host());
359 if let Some(port) = authority.port_u16() {
360 span.record("server.port", port);
361 }
362 }
363
364 Ok(self
365 .send_request(request, config, override_connect_addr)
366 .await?)
367 }
368
369 async fn prepare_request(
370 &self,
371 request: &mut OutgoingRequest,
372 config: &mut OutgoingRequestConfig,
373 ) -> Result<(), ErrorCode> {
374 let uri = request.uri_mut();
378 if uri
379 .authority()
380 .is_some_and(|authority| authority.host().is_empty())
381 {
382 let mut builder = http::uri::Builder::new();
383 if let Some(paq) = uri.path_and_query() {
384 builder = builder.path_and_query(paq.clone());
385 }
386 *uri = builder.build().unwrap();
387 }
388 tracing::Span::current().record("url.full", uri.to_string());
389
390 let is_self_request = match request.uri().authority() {
391 Some(authority) => authority.host() == "self.alt",
393 None => true,
395 };
396
397 let is_allowed = if is_self_request {
399 self.allowed_hosts
400 .check_relative_url(&["http", "https"])
401 .await
402 .unwrap_or(false)
403 } else {
404 self.allowed_hosts
405 .check_url(&request.uri().to_string(), "https")
406 .await
407 .unwrap_or(false)
408 };
409 if !is_allowed {
410 return Err(ErrorCode::HttpRequestDenied);
411 }
412
413 if is_self_request {
414 let Some(origin) = self.self_request_origin.as_ref() else {
416 tracing::error!(
417 "Couldn't handle outbound HTTP request to relative URI; no origin set"
418 );
419 return Err(ErrorCode::HttpRequestUriInvalid);
420 };
421
422 config.use_tls = origin.use_tls();
423
424 request.headers_mut().insert(HOST, origin.host_header());
425
426 let path_and_query = request.uri().path_and_query().cloned();
427 *request.uri_mut() = origin.clone().into_uri(path_and_query);
428 }
429
430 request.headers_mut().remove(HOST);
437 Ok(())
438 }
439
440 async fn send_request(
441 self,
442 request: OutgoingRequest,
443 config: OutgoingRequestConfig,
444 override_connect_addr: Option<SocketAddr>,
445 ) -> Result<IncomingResponse, ErrorCode> {
446 let OutgoingRequestConfig {
447 use_tls,
448 connect_timeout,
449 first_byte_timeout,
450 between_bytes_timeout,
451 } = config;
452
453 let tls_client_config = if use_tls {
454 let host = request.uri().host().unwrap_or_default();
455 Some(self.component_tls_configs.get_client_config(host).clone())
456 } else {
457 None
458 };
459
460 let resp = CONNECT_OPTIONS.scope(
461 ConnectOptions {
462 blocked_networks: self.blocked_networks,
463 connect_timeout,
464 tls_client_config,
465 override_connect_addr,
466 concurrent_outbound_connections_semaphore: self
467 .concurrent_outbound_connections_semaphore,
468 },
469 async move {
470 if use_tls {
471 self.http_clients.https.request(request).await
472 } else {
473 let h2c_prior_knowledge_host =
475 std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
476 let use_h2c = h2c_prior_knowledge_host.as_deref()
477 == request.uri().authority().map(|a| a.as_str());
478
479 if use_h2c {
480 self.http_clients.http2.request(request).await
481 } else {
482 self.http_clients.http1.request(request).await
483 }
484 }
485 },
486 );
487
488 let resp = timeout(first_byte_timeout, resp)
489 .await
490 .map_err(|_| ErrorCode::ConnectionReadTimeout)?
491 .map_err(hyper_legacy_request_error)?
492 .map(|body| body.map_err(hyper_request_error).boxed());
493
494 tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
495
496 Ok(IncomingResponse {
497 resp,
498 worker: None,
499 between_bytes_timeout,
500 })
501 }
502}
503
504type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
505type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
506
507#[derive(Clone)]
508pub(super) struct HttpClients {
509 http1: HttpClient,
511 http2: HttpClient,
513 https: HttpsClient,
515}
516
517impl HttpClients {
518 pub(super) fn new(enable_pooling: bool) -> Self {
519 let builder = move || {
520 let mut builder = Client::builder(TokioExecutor::new());
521 if !enable_pooling {
522 builder.pool_max_idle_per_host(0);
523 }
524 builder
525 };
526 Self {
527 http1: builder().build(HttpConnector),
528 http2: builder().http2_only(true).build(HttpConnector),
529 https: builder().build(HttpsConnector),
530 }
531 }
532}
533
534tokio::task_local! {
535 static CONNECT_OPTIONS: ConnectOptions;
543}
544
545#[derive(Clone)]
546struct ConnectOptions {
547 blocked_networks: BlockedNetworks,
549 connect_timeout: Duration,
551 tls_client_config: Option<TlsClientConfig>,
553 override_connect_addr: Option<SocketAddr>,
555 concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
557}
558
559impl ConnectOptions {
560 async fn connect_tcp(
562 &self,
563 uri: &Uri,
564 default_port: u16,
565 ) -> Result<PermittedTcpStream, ErrorCode> {
566 let mut socket_addrs = match self.override_connect_addr {
567 Some(override_connect_addr) => vec![override_connect_addr],
568 None => {
569 let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
570
571 let host_and_port = if authority.port().is_some() {
572 authority.as_str().to_string()
573 } else {
574 format!("{}:{}", authority.as_str(), default_port)
575 };
576
577 let socket_addrs = tokio::net::lookup_host(&host_and_port)
578 .await
579 .map_err(|err| {
580 tracing::debug!(?host_and_port, ?err, "Error resolving host");
581 dns_error("address not available".into(), 0)
582 })?
583 .collect::<Vec<_>>();
584 tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
585 socket_addrs
586 }
587 };
588
589 let blocked_addrs = self.blocked_networks.remove_blocked(&mut socket_addrs);
591 if socket_addrs.is_empty() && !blocked_addrs.is_empty() {
592 tracing::error!(
593 "error.type" = "destination_ip_prohibited",
594 ?blocked_addrs,
595 "all destination IP(s) prohibited by runtime config"
596 );
597 return Err(ErrorCode::DestinationIpProhibited);
598 }
599
600 let permit = match &self.concurrent_outbound_connections_semaphore {
602 Some(s) => s.clone().acquire_owned().await.ok(),
603 None => None,
604 };
605
606 let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
607 .await
608 .map_err(|_| ErrorCode::ConnectionTimeout)?
609 .map_err(|err| match err.kind() {
610 std::io::ErrorKind::AddrNotAvailable => {
611 dns_error("address not available".into(), 0)
612 }
613 _ => ErrorCode::ConnectionRefused,
614 })?;
615 Ok(PermittedTcpStream {
616 inner: stream,
617 _permit: permit,
618 })
619 }
620
621 async fn connect_tls(
623 &self,
624 uri: &Uri,
625 default_port: u16,
626 ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
627 let tcp_stream = self.connect_tcp(uri, default_port).await?;
628
629 let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
630 tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
631
632 let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
633 let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
634 .map_err(|e| {
635 tracing::warn!("dns lookup error: {e:?}");
636 dns_error("invalid dns name".into(), 0)
637 })?
638 .to_owned();
639 connector.connect(domain, tcp_stream).await.map_err(|e| {
640 tracing::warn!("tls protocol error: {e:?}");
641 ErrorCode::TlsProtocolError
642 })
643 }
644}
645
646#[derive(Clone)]
648struct HttpConnector;
649
650impl HttpConnector {
651 async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
652 let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
653 Ok(TokioIo::new(stream))
654 }
655}
656
657impl Service<Uri> for HttpConnector {
658 type Response = TokioIo<PermittedTcpStream>;
659 type Error = ErrorCode;
660 type Future =
661 Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
662
663 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
664 Poll::Ready(Ok(()))
665 }
666
667 fn call(&mut self, uri: Uri) -> Self::Future {
668 Box::pin(async move { Self::connect(uri).await })
669 }
670}
671
672#[derive(Clone)]
674struct HttpsConnector;
675
676impl HttpsConnector {
677 async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
678 let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
679 Ok(TokioIo::new(RustlsStream(stream)))
680 }
681}
682
683impl Service<Uri> for HttpsConnector {
684 type Response = TokioIo<RustlsStream>;
685 type Error = ErrorCode;
686 type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
687
688 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
689 Poll::Ready(Ok(()))
690 }
691
692 fn call(&mut self, uri: Uri) -> Self::Future {
693 Box::pin(async move { Self::connect(uri).await })
694 }
695}
696
697struct RustlsStream(TlsStream<PermittedTcpStream>);
698
699impl Connection for RustlsStream {
700 fn connected(&self) -> Connected {
701 if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
702 self.0.get_ref().0.connected().negotiated_h2()
703 } else {
704 self.0.get_ref().0.connected()
705 }
706 }
707}
708
709impl AsyncRead for RustlsStream {
710 fn poll_read(
711 self: Pin<&mut Self>,
712 cx: &mut Context<'_>,
713 buf: &mut ReadBuf<'_>,
714 ) -> Poll<Result<(), std::io::Error>> {
715 Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
716 }
717}
718
719impl AsyncWrite for RustlsStream {
720 fn poll_write(
721 self: Pin<&mut Self>,
722 cx: &mut Context<'_>,
723 buf: &[u8],
724 ) -> Poll<Result<usize, std::io::Error>> {
725 Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
726 }
727
728 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
729 Pin::new(&mut self.get_mut().0).poll_flush(cx)
730 }
731
732 fn poll_shutdown(
733 self: Pin<&mut Self>,
734 cx: &mut Context<'_>,
735 ) -> Poll<Result<(), std::io::Error>> {
736 Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
737 }
738
739 fn poll_write_vectored(
740 self: Pin<&mut Self>,
741 cx: &mut Context<'_>,
742 bufs: &[IoSlice<'_>],
743 ) -> Poll<Result<usize, std::io::Error>> {
744 Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
745 }
746
747 fn is_write_vectored(&self) -> bool {
748 self.0.is_write_vectored()
749 }
750}
751
752struct PermittedTcpStream {
754 inner: TcpStream,
756 _permit: Option<OwnedSemaphorePermit>,
761}
762
763impl Connection for PermittedTcpStream {
764 fn connected(&self) -> Connected {
765 self.inner.connected()
766 }
767}
768
769impl AsyncRead for PermittedTcpStream {
770 fn poll_read(
771 self: Pin<&mut Self>,
772 cx: &mut Context<'_>,
773 buf: &mut ReadBuf<'_>,
774 ) -> Poll<std::io::Result<()>> {
775 Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
776 }
777}
778
779impl AsyncWrite for PermittedTcpStream {
780 fn poll_write(
781 self: Pin<&mut Self>,
782 cx: &mut Context<'_>,
783 buf: &[u8],
784 ) -> Poll<Result<usize, std::io::Error>> {
785 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
786 }
787
788 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
789 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
790 }
791
792 fn poll_shutdown(
793 self: Pin<&mut Self>,
794 cx: &mut Context<'_>,
795 ) -> Poll<Result<(), std::io::Error>> {
796 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
797 }
798}
799
800fn hyper_request_error(err: hyper::Error) -> ErrorCode {
802 if let Some(cause) = err.source() {
804 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
805 return err.clone();
806 }
807 }
808
809 tracing::warn!("hyper request error: {err:?}");
810
811 ErrorCode::HttpProtocolError
812}
813
814fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
816 if let Some(cause) = err.source() {
818 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
819 return err.clone();
820 }
821 }
822
823 tracing::warn!("hyper request error: {err:?}");
824
825 ErrorCode::HttpProtocolError
826}
827
828fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
829 ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
830 rcode: Some(rcode),
831 info_code: Some(info_code),
832 })
833}
834
835pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
838 match code {
839 p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
840 p2_types::ErrorCode::DnsError(payload) => {
841 p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
842 rcode: payload.rcode,
843 info_code: payload.info_code,
844 })
845 }
846 p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
847 p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
848 p2_types::ErrorCode::DestinationIpProhibited => {
849 p3_types::ErrorCode::DestinationIpProhibited
850 }
851 p2_types::ErrorCode::DestinationIpUnroutable => {
852 p3_types::ErrorCode::DestinationIpUnroutable
853 }
854 p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
855 p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
856 p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
857 p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
858 p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
859 p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
860 p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
861 p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
862 p2_types::ErrorCode::TlsAlertReceived(payload) => {
863 p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
864 alert_id: payload.alert_id,
865 alert_message: payload.alert_message,
866 })
867 }
868 p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
869 p2_types::ErrorCode::HttpRequestLengthRequired => {
870 p3_types::ErrorCode::HttpRequestLengthRequired
871 }
872 p2_types::ErrorCode::HttpRequestBodySize(payload) => {
873 p3_types::ErrorCode::HttpRequestBodySize(payload)
874 }
875 p2_types::ErrorCode::HttpRequestMethodInvalid => {
876 p3_types::ErrorCode::HttpRequestMethodInvalid
877 }
878 p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
879 p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
880 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
881 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
882 }
883 p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
884 p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
885 p3_types::FieldSizePayload {
886 field_name: payload.field_name,
887 field_size: payload.field_size,
888 }
889 }))
890 }
891 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
892 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
893 }
894 p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
895 p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
896 field_name: payload.field_name,
897 field_size: payload.field_size,
898 })
899 }
900 p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
901 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
902 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
903 }
904 p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
905 p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
906 field_name: payload.field_name,
907 field_size: payload.field_size,
908 })
909 }
910 p2_types::ErrorCode::HttpResponseBodySize(payload) => {
911 p3_types::ErrorCode::HttpResponseBodySize(payload)
912 }
913 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
914 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
915 }
916 p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
917 p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
918 field_name: payload.field_name,
919 field_size: payload.field_size,
920 })
921 }
922 p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
923 p3_types::ErrorCode::HttpResponseTransferCoding(payload)
924 }
925 p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
926 p3_types::ErrorCode::HttpResponseContentCoding(payload)
927 }
928 p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
929 p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
930 p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
931 p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
932 p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
933 p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
934 }
935}
936
937pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
940 match code {
941 p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
942 p3_types::ErrorCode::DnsError(payload) => {
943 p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
944 rcode: payload.rcode,
945 info_code: payload.info_code,
946 })
947 }
948 p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
949 p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
950 p3_types::ErrorCode::DestinationIpProhibited => {
951 p2_types::ErrorCode::DestinationIpProhibited
952 }
953 p3_types::ErrorCode::DestinationIpUnroutable => {
954 p2_types::ErrorCode::DestinationIpUnroutable
955 }
956 p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
957 p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
958 p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
959 p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
960 p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
961 p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
962 p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
963 p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
964 p3_types::ErrorCode::TlsAlertReceived(payload) => {
965 p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
966 alert_id: payload.alert_id,
967 alert_message: payload.alert_message,
968 })
969 }
970 p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
971 p3_types::ErrorCode::HttpRequestLengthRequired => {
972 p2_types::ErrorCode::HttpRequestLengthRequired
973 }
974 p3_types::ErrorCode::HttpRequestBodySize(payload) => {
975 p2_types::ErrorCode::HttpRequestBodySize(payload)
976 }
977 p3_types::ErrorCode::HttpRequestMethodInvalid => {
978 p2_types::ErrorCode::HttpRequestMethodInvalid
979 }
980 p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
981 p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
982 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
983 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
984 }
985 p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
986 p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
987 p2_types::FieldSizePayload {
988 field_name: payload.field_name,
989 field_size: payload.field_size,
990 }
991 }))
992 }
993 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
994 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
995 }
996 p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
997 p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
998 field_name: payload.field_name,
999 field_size: payload.field_size,
1000 })
1001 }
1002 p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
1003 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1004 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1005 }
1006 p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1007 p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1008 field_name: payload.field_name,
1009 field_size: payload.field_size,
1010 })
1011 }
1012 p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1013 p2_types::ErrorCode::HttpResponseBodySize(payload)
1014 }
1015 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1016 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1017 }
1018 p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1019 p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1020 field_name: payload.field_name,
1021 field_size: payload.field_size,
1022 })
1023 }
1024 p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1025 p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1026 }
1027 p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1028 p2_types::ErrorCode::HttpResponseContentCoding(payload)
1029 }
1030 p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1031 p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1032 p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1033 p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1034 p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1035 p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1036 }
1037}