1use std::{
2 error::Error,
3 future::Future,
4 io::IoSlice,
5 net::SocketAddr,
6 ops::DerefMut,
7 pin::Pin,
8 sync::{Arc, Mutex},
9 task::{self, Context, Poll},
10 time::Duration,
11};
12
13use bytes::Bytes;
14use http::{header::HOST, uri::Scheme, Uri};
15use http_body::{Body, Frame, SizeHint};
16use http_body_util::{combinators::UnsyncBoxBody, BodyExt};
17use hyper_util::{
18 client::legacy::{
19 connect::{Connected, Connection},
20 Client,
21 },
22 rt::{TokioExecutor, TokioIo},
23};
24use spin_factor_outbound_networking::{
25 config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
26 ComponentTlsClientConfigs, TlsClientConfig,
27};
28use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceState};
29use tokio::{
30 io::{AsyncRead, AsyncWrite, ReadBuf},
31 net::TcpStream,
32 sync::{OwnedSemaphorePermit, Semaphore},
33 time::timeout,
34};
35use tokio_rustls::client::TlsStream;
36use tower_service::Service;
37use tracing::{field::Empty, instrument, Instrument};
38use wasmtime::component::HasData;
39use wasmtime_wasi::TrappableError;
40use wasmtime_wasi_http::{
41 bindings::http::types::{self as p2_types, ErrorCode},
42 body::HyperOutgoingBody,
43 p3::{self, bindings::http::types as p3_types},
44 types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
45 HttpError, WasiHttpCtx, WasiHttpImpl, WasiHttpView,
46};
47
48use crate::{
49 intercept::{InterceptOutcome, OutboundHttpInterceptor},
50 wasi_2023_10_18, wasi_2023_11_10, InstanceState, OutboundHttpFactor, SelfRequestOrigin,
51};
52
53const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
54
55pub struct MutexBody<T>(Mutex<T>);
56
57impl<T> MutexBody<T> {
58 pub fn new(body: T) -> Self {
59 Self(Mutex::new(body))
60 }
61}
62
63impl<T: Body + Unpin> Body for MutexBody<T> {
64 type Data = T::Data;
65 type Error = T::Error;
66
67 fn poll_frame(
68 self: Pin<&mut Self>,
69 cx: &mut Context<'_>,
70 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
71 Pin::new(self.0.lock().unwrap().deref_mut()).poll_frame(cx)
72 }
73
74 fn is_end_stream(&self) -> bool {
75 self.0.lock().unwrap().is_end_stream()
76 }
77
78 fn size_hint(&self) -> SizeHint {
79 self.0.lock().unwrap().size_hint()
80 }
81}
82
83pub(crate) struct HasHttp;
84
85impl HasData for HasHttp {
86 type Data<'a> = WasiHttpImpl<WasiHttpImplInner<'a>>;
87}
88
89impl p3::WasiHttpCtx for InstanceState {
90 fn send_request(
91 &mut self,
92 request: http::Request<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
93 options: Option<p3::RequestOptions>,
94 fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
95 ) -> Box<
96 dyn Future<
97 Output = Result<
98 (
99 http::Response<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
100 Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
101 ),
102 TrappableError<p3_types::ErrorCode>,
103 >,
104 > + Send,
105 > {
106 _ = fut;
115
116 let request = request.map(|body| MutexBody::new(body).boxed());
117
118 let request_sender = RequestSender {
119 allowed_hosts: self.allowed_hosts.clone(),
120 component_tls_configs: self.component_tls_configs.clone(),
121 request_interceptor: self.request_interceptor.clone(),
122 self_request_origin: self.self_request_origin.clone(),
123 blocked_networks: self.blocked_networks.clone(),
124 http_clients: self.wasi_http_clients.clone(),
125 concurrent_outbound_connections_semaphore: self
126 .concurrent_outbound_connections_semaphore
127 .clone(),
128 };
129 let config = OutgoingRequestConfig {
130 use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
131 connect_timeout: options
132 .and_then(|v| v.connect_timeout)
133 .unwrap_or(DEFAULT_TIMEOUT),
134 first_byte_timeout: options
135 .and_then(|v| v.first_byte_timeout)
136 .unwrap_or(DEFAULT_TIMEOUT),
137 between_bytes_timeout: options
138 .and_then(|v| v.between_bytes_timeout)
139 .unwrap_or(DEFAULT_TIMEOUT),
140 };
141 Box::new(async {
142 match request_sender
143 .send(
144 request.map(|body| body.map_err(p3_to_p2_error_code).boxed_unsync()),
145 config,
146 )
147 .await
148 {
149 Ok(IncomingResponse {
150 resp,
151 between_bytes_timeout,
152 ..
153 }) => Ok((
154 resp.map(|body| {
155 BetweenBytesTimeoutBody {
156 body,
157 sleep: None,
158 timeout: between_bytes_timeout,
159 }
160 .boxed_unsync()
161 }),
162 Box::new(async {
163 Ok(())
167 }) as Box<dyn Future<Output = _> + Send>,
168 )),
169 Err(http_error) => match http_error.downcast() {
170 Ok(error_code) => Err(TrappableError::from(p2_to_p3_error_code(error_code))),
171 Err(trap) => Err(TrappableError::trap(trap)),
172 },
173 }
174 })
175 }
176}
177
178pin_project_lite::pin_project! {
179 struct BetweenBytesTimeoutBody<B> {
180 #[pin]
181 body: B,
182 #[pin]
183 sleep: Option<tokio::time::Sleep>,
184 timeout: Duration,
185 }
186}
187
188impl<B: Body<Error = p2_types::ErrorCode>> Body for BetweenBytesTimeoutBody<B> {
189 type Data = B::Data;
190 type Error = p3_types::ErrorCode;
191
192 fn poll_frame(
193 self: Pin<&mut Self>,
194 cx: &mut Context<'_>,
195 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
196 let mut me = self.project();
197 match me.body.poll_frame(cx) {
198 Poll::Ready(value) => {
199 me.sleep.as_mut().set(None);
200 Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
201 }
202 Poll::Pending => {
203 if me.sleep.is_none() {
204 me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
205 }
206 task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
207 Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
208 }
209 }
210 }
211
212 fn is_end_stream(&self) -> bool {
213 self.body.is_end_stream()
214 }
215
216 fn size_hint(&self) -> SizeHint {
217 self.body.size_hint()
218 }
219}
220
221pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
222where
223 C: spin_factors::InitContext<OutboundHttpFactor>,
224{
225 let linker = ctx.linker();
226
227 fn get_http<C>(store: &mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>
228 where
229 C: spin_factors::InitContext<OutboundHttpFactor>,
230 {
231 let (state, table) = C::get_data_with_table(store);
232 WasiHttpImpl(WasiHttpImplInner { state, table })
233 }
234
235 let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>;
236 wasmtime_wasi_http::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
237 linker, get_http,
238 )?;
239 wasmtime_wasi_http::bindings::http::types::add_to_linker::<_, HasHttp>(
240 linker,
241 &Default::default(),
242 get_http,
243 )?;
244
245 fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
246 where
247 C: spin_factors::InitContext<OutboundHttpFactor>,
248 {
249 let (state, table) = C::get_data_with_table(store);
250 p3::WasiHttpCtxView { ctx: state, table }
251 }
252
253 let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
254 p3::bindings::http::handler::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
255 p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
256
257 wasi_2023_10_18::add_to_linker(linker, get_http)?;
258 wasi_2023_11_10::add_to_linker(linker, get_http)?;
259
260 Ok(())
261}
262
263impl OutboundHttpFactor {
264 pub fn get_wasi_http_impl(
265 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
266 ) -> Option<WasiHttpImpl<impl WasiHttpView + '_>> {
267 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
268 Some(WasiHttpImpl(WasiHttpImplInner { state, table }))
269 }
270
271 pub fn get_wasi_p3_http_impl(
272 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
273 ) -> Option<p3::WasiHttpCtxView<'_>> {
274 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
275 Some(p3::WasiHttpCtxView { ctx: state, table })
276 }
277}
278
279pub(crate) struct WasiHttpImplInner<'a> {
280 state: &'a mut InstanceState,
281 table: &'a mut ResourceTable,
282}
283
284type OutgoingRequest = http::Request<HyperOutgoingBody>;
285
286impl WasiHttpView for WasiHttpImplInner<'_> {
287 fn ctx(&mut self) -> &mut WasiHttpCtx {
288 &mut self.state.wasi_http_ctx
289 }
290
291 fn table(&mut self) -> &mut ResourceTable {
292 self.table
293 }
294
295 #[instrument(
296 name = "spin_outbound_http.send_request",
297 skip_all,
298 fields(
299 otel.kind = "client",
300 url.full = Empty,
301 http.request.method = %request.method(),
302 otel.name = %request.method(),
303 http.response.status_code = Empty,
304 server.address = Empty,
305 server.port = Empty,
306 )
307 )]
308 fn send_request(
309 &mut self,
310 request: OutgoingRequest,
311 config: OutgoingRequestConfig,
312 ) -> Result<wasmtime_wasi_http::types::HostFutureIncomingResponse, HttpError> {
313 let request_sender = RequestSender {
314 allowed_hosts: self.state.allowed_hosts.clone(),
315 component_tls_configs: self.state.component_tls_configs.clone(),
316 request_interceptor: self.state.request_interceptor.clone(),
317 self_request_origin: self.state.self_request_origin.clone(),
318 blocked_networks: self.state.blocked_networks.clone(),
319 http_clients: self.state.wasi_http_clients.clone(),
320 concurrent_outbound_connections_semaphore: self
321 .state
322 .concurrent_outbound_connections_semaphore
323 .clone(),
324 };
325 Ok(HostFutureIncomingResponse::Pending(
326 wasmtime_wasi::runtime::spawn(
327 async {
328 match request_sender.send(request, config).await {
329 Ok(resp) => Ok(Ok(resp)),
330 Err(http_error) => match http_error.downcast() {
331 Ok(error_code) => Ok(Err(error_code)),
332 Err(trap) => Err(trap),
333 },
334 }
335 }
336 .in_current_span(),
337 ),
338 ))
339 }
340}
341
342struct RequestSender {
343 allowed_hosts: OutboundAllowedHosts,
344 blocked_networks: BlockedNetworks,
345 component_tls_configs: ComponentTlsClientConfigs,
346 self_request_origin: Option<SelfRequestOrigin>,
347 request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
348 http_clients: HttpClients,
349 concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
350}
351
352impl RequestSender {
353 async fn send(
354 self,
355 mut request: OutgoingRequest,
356 mut config: OutgoingRequestConfig,
357 ) -> Result<IncomingResponse, HttpError> {
358 self.prepare_request(&mut request, &mut config).await?;
359
360 spin_telemetry::inject_trace_context(&mut request);
362
363 let mut override_connect_addr = None;
365 if let Some(interceptor) = &self.request_interceptor {
366 let intercept_request = std::mem::take(&mut request).into();
367 match interceptor.intercept(intercept_request).await? {
368 InterceptOutcome::Continue(mut req) => {
369 override_connect_addr = req.override_connect_addr.take();
370 request = req.into_hyper_request();
371 }
372 InterceptOutcome::Complete(resp) => {
373 let resp = IncomingResponse {
374 resp,
375 worker: None,
376 between_bytes_timeout: config.between_bytes_timeout,
377 };
378 return Ok(resp);
379 }
380 }
381 }
382
383 let span = tracing::Span::current();
385 if let Some(addr) = override_connect_addr {
386 span.record("server.address", addr.ip().to_string());
387 span.record("server.port", addr.port());
388 } else if let Some(authority) = request.uri().authority() {
389 span.record("server.address", authority.host());
390 if let Some(port) = authority.port_u16() {
391 span.record("server.port", port);
392 }
393 }
394
395 Ok(self
396 .send_request(request, config, override_connect_addr)
397 .await?)
398 }
399
400 async fn prepare_request(
401 &self,
402 request: &mut OutgoingRequest,
403 config: &mut OutgoingRequestConfig,
404 ) -> Result<(), ErrorCode> {
405 let uri = request.uri_mut();
409 if uri
410 .authority()
411 .is_some_and(|authority| authority.host().is_empty())
412 {
413 let mut builder = http::uri::Builder::new();
414 if let Some(paq) = uri.path_and_query() {
415 builder = builder.path_and_query(paq.clone());
416 }
417 *uri = builder.build().unwrap();
418 }
419 tracing::Span::current().record("url.full", uri.to_string());
420
421 let is_self_request = match request.uri().authority() {
422 Some(authority) => authority.host() == "self.alt",
424 None => true,
426 };
427
428 let is_allowed = if is_self_request {
430 self.allowed_hosts
431 .check_relative_url(&["http", "https"])
432 .await
433 .unwrap_or(false)
434 } else {
435 self.allowed_hosts
436 .check_url(&request.uri().to_string(), "https")
437 .await
438 .unwrap_or(false)
439 };
440 if !is_allowed {
441 return Err(ErrorCode::HttpRequestDenied);
442 }
443
444 if is_self_request {
445 let Some(origin) = self.self_request_origin.as_ref() else {
447 tracing::error!(
448 "Couldn't handle outbound HTTP request to relative URI; no origin set"
449 );
450 return Err(ErrorCode::HttpRequestUriInvalid);
451 };
452
453 config.use_tls = origin.use_tls();
454
455 request.headers_mut().insert(HOST, origin.host_header());
456
457 let path_and_query = request.uri().path_and_query().cloned();
458 *request.uri_mut() = origin.clone().into_uri(path_and_query);
459 }
460
461 request.headers_mut().remove(HOST);
468 Ok(())
469 }
470
471 async fn send_request(
472 self,
473 request: OutgoingRequest,
474 config: OutgoingRequestConfig,
475 override_connect_addr: Option<SocketAddr>,
476 ) -> Result<IncomingResponse, ErrorCode> {
477 let OutgoingRequestConfig {
478 use_tls,
479 connect_timeout,
480 first_byte_timeout,
481 between_bytes_timeout,
482 } = config;
483
484 let tls_client_config = if use_tls {
485 let host = request.uri().host().unwrap_or_default();
486 Some(self.component_tls_configs.get_client_config(host).clone())
487 } else {
488 None
489 };
490
491 let resp = CONNECT_OPTIONS.scope(
492 ConnectOptions {
493 blocked_networks: self.blocked_networks,
494 connect_timeout,
495 tls_client_config,
496 override_connect_addr,
497 concurrent_outbound_connections_semaphore: self
498 .concurrent_outbound_connections_semaphore,
499 },
500 async move {
501 if use_tls {
502 self.http_clients.https.request(request).await
503 } else {
504 let h2c_prior_knowledge_host =
506 std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
507 let use_h2c = h2c_prior_knowledge_host.as_deref()
508 == request.uri().authority().map(|a| a.as_str());
509
510 if use_h2c {
511 self.http_clients.http2.request(request).await
512 } else {
513 self.http_clients.http1.request(request).await
514 }
515 }
516 },
517 );
518
519 let resp = timeout(first_byte_timeout, resp)
520 .await
521 .map_err(|_| ErrorCode::ConnectionReadTimeout)?
522 .map_err(hyper_legacy_request_error)?
523 .map(|body| body.map_err(hyper_request_error).boxed_unsync());
524
525 tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
526
527 Ok(IncomingResponse {
528 resp,
529 worker: None,
530 between_bytes_timeout,
531 })
532 }
533}
534
535type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
536type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
537
538#[derive(Clone)]
539pub(super) struct HttpClients {
540 http1: HttpClient,
542 http2: HttpClient,
544 https: HttpsClient,
546}
547
548impl HttpClients {
549 pub(super) fn new(enable_pooling: bool) -> Self {
550 let builder = move || {
551 let mut builder = Client::builder(TokioExecutor::new());
552 if !enable_pooling {
553 builder.pool_max_idle_per_host(0);
554 }
555 builder
556 };
557 Self {
558 http1: builder().build(HttpConnector),
559 http2: builder().http2_only(true).build(HttpConnector),
560 https: builder().build(HttpsConnector),
561 }
562 }
563}
564
565tokio::task_local! {
566 static CONNECT_OPTIONS: ConnectOptions;
574}
575
576#[derive(Clone)]
577struct ConnectOptions {
578 blocked_networks: BlockedNetworks,
580 connect_timeout: Duration,
582 tls_client_config: Option<TlsClientConfig>,
584 override_connect_addr: Option<SocketAddr>,
586 concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
588}
589
590impl ConnectOptions {
591 async fn connect_tcp(
593 &self,
594 uri: &Uri,
595 default_port: u16,
596 ) -> Result<PermittedTcpStream, ErrorCode> {
597 let mut socket_addrs = match self.override_connect_addr {
598 Some(override_connect_addr) => vec![override_connect_addr],
599 None => {
600 let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
601
602 let host_and_port = if authority.port().is_some() {
603 authority.as_str().to_string()
604 } else {
605 format!("{}:{}", authority.as_str(), default_port)
606 };
607
608 let socket_addrs = tokio::net::lookup_host(&host_and_port)
609 .await
610 .map_err(|err| {
611 tracing::debug!(?host_and_port, ?err, "Error resolving host");
612 dns_error("address not available".into(), 0)
613 })?
614 .collect::<Vec<_>>();
615 tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
616 socket_addrs
617 }
618 };
619
620 crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;
622
623 let permit = crate::concurrent_outbound_connections::acquire_owned_semaphore(
626 "wasi",
627 &self.concurrent_outbound_connections_semaphore,
628 )
629 .await;
630
631 let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
632 .await
633 .map_err(|_| ErrorCode::ConnectionTimeout)?
634 .map_err(|err| match err.kind() {
635 std::io::ErrorKind::AddrNotAvailable => {
636 dns_error("address not available".into(), 0)
637 }
638 _ => ErrorCode::ConnectionRefused,
639 })?;
640 Ok(PermittedTcpStream {
641 inner: stream,
642 _permit: permit,
643 })
644 }
645
646 async fn connect_tls(
648 &self,
649 uri: &Uri,
650 default_port: u16,
651 ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
652 let tcp_stream = self.connect_tcp(uri, default_port).await?;
653
654 let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
655 tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
656
657 let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
658 let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
659 .map_err(|e| {
660 tracing::warn!("dns lookup error: {e:?}");
661 dns_error("invalid dns name".into(), 0)
662 })?
663 .to_owned();
664 connector.connect(domain, tcp_stream).await.map_err(|e| {
665 tracing::warn!("tls protocol error: {e:?}");
666 ErrorCode::TlsProtocolError
667 })
668 }
669}
670
671#[derive(Clone)]
673struct HttpConnector;
674
675impl HttpConnector {
676 async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
677 let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
678 Ok(TokioIo::new(stream))
679 }
680}
681
682impl Service<Uri> for HttpConnector {
683 type Response = TokioIo<PermittedTcpStream>;
684 type Error = ErrorCode;
685 type Future =
686 Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
687
688 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
689 Poll::Ready(Ok(()))
690 }
691
692 fn call(&mut self, uri: Uri) -> Self::Future {
693 Box::pin(async move { Self::connect(uri).await })
694 }
695}
696
697#[derive(Clone)]
699struct HttpsConnector;
700
701impl HttpsConnector {
702 async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
703 let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
704 Ok(TokioIo::new(RustlsStream(stream)))
705 }
706}
707
708impl Service<Uri> for HttpsConnector {
709 type Response = TokioIo<RustlsStream>;
710 type Error = ErrorCode;
711 type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
712
713 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
714 Poll::Ready(Ok(()))
715 }
716
717 fn call(&mut self, uri: Uri) -> Self::Future {
718 Box::pin(async move { Self::connect(uri).await })
719 }
720}
721
722struct RustlsStream(TlsStream<PermittedTcpStream>);
723
724impl Connection for RustlsStream {
725 fn connected(&self) -> Connected {
726 if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
727 self.0.get_ref().0.connected().negotiated_h2()
728 } else {
729 self.0.get_ref().0.connected()
730 }
731 }
732}
733
734impl AsyncRead for RustlsStream {
735 fn poll_read(
736 self: Pin<&mut Self>,
737 cx: &mut Context<'_>,
738 buf: &mut ReadBuf<'_>,
739 ) -> Poll<Result<(), std::io::Error>> {
740 Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
741 }
742}
743
744impl AsyncWrite for RustlsStream {
745 fn poll_write(
746 self: Pin<&mut Self>,
747 cx: &mut Context<'_>,
748 buf: &[u8],
749 ) -> Poll<Result<usize, std::io::Error>> {
750 Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
751 }
752
753 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
754 Pin::new(&mut self.get_mut().0).poll_flush(cx)
755 }
756
757 fn poll_shutdown(
758 self: Pin<&mut Self>,
759 cx: &mut Context<'_>,
760 ) -> Poll<Result<(), std::io::Error>> {
761 Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
762 }
763
764 fn poll_write_vectored(
765 self: Pin<&mut Self>,
766 cx: &mut Context<'_>,
767 bufs: &[IoSlice<'_>],
768 ) -> Poll<Result<usize, std::io::Error>> {
769 Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
770 }
771
772 fn is_write_vectored(&self) -> bool {
773 self.0.is_write_vectored()
774 }
775}
776
777struct PermittedTcpStream {
779 inner: TcpStream,
781 _permit: Option<OwnedSemaphorePermit>,
786}
787
788impl Connection for PermittedTcpStream {
789 fn connected(&self) -> Connected {
790 self.inner.connected()
791 }
792}
793
794impl AsyncRead for PermittedTcpStream {
795 fn poll_read(
796 self: Pin<&mut Self>,
797 cx: &mut Context<'_>,
798 buf: &mut ReadBuf<'_>,
799 ) -> Poll<std::io::Result<()>> {
800 Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
801 }
802}
803
804impl AsyncWrite for PermittedTcpStream {
805 fn poll_write(
806 self: Pin<&mut Self>,
807 cx: &mut Context<'_>,
808 buf: &[u8],
809 ) -> Poll<Result<usize, std::io::Error>> {
810 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
811 }
812
813 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
814 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
815 }
816
817 fn poll_shutdown(
818 self: Pin<&mut Self>,
819 cx: &mut Context<'_>,
820 ) -> Poll<Result<(), std::io::Error>> {
821 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
822 }
823}
824
825fn hyper_request_error(err: hyper::Error) -> ErrorCode {
827 if let Some(cause) = err.source() {
829 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
830 return err.clone();
831 }
832 }
833
834 tracing::warn!("hyper request error: {err:?}");
835
836 ErrorCode::HttpProtocolError
837}
838
839fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
841 if let Some(cause) = err.source() {
843 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
844 return err.clone();
845 }
846 }
847
848 tracing::warn!("hyper request error: {err:?}");
849
850 ErrorCode::HttpProtocolError
851}
852
853fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
854 ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
855 rcode: Some(rcode),
856 info_code: Some(info_code),
857 })
858}
859
860pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
863 match code {
864 p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
865 p2_types::ErrorCode::DnsError(payload) => {
866 p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
867 rcode: payload.rcode,
868 info_code: payload.info_code,
869 })
870 }
871 p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
872 p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
873 p2_types::ErrorCode::DestinationIpProhibited => {
874 p3_types::ErrorCode::DestinationIpProhibited
875 }
876 p2_types::ErrorCode::DestinationIpUnroutable => {
877 p3_types::ErrorCode::DestinationIpUnroutable
878 }
879 p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
880 p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
881 p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
882 p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
883 p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
884 p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
885 p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
886 p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
887 p2_types::ErrorCode::TlsAlertReceived(payload) => {
888 p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
889 alert_id: payload.alert_id,
890 alert_message: payload.alert_message,
891 })
892 }
893 p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
894 p2_types::ErrorCode::HttpRequestLengthRequired => {
895 p3_types::ErrorCode::HttpRequestLengthRequired
896 }
897 p2_types::ErrorCode::HttpRequestBodySize(payload) => {
898 p3_types::ErrorCode::HttpRequestBodySize(payload)
899 }
900 p2_types::ErrorCode::HttpRequestMethodInvalid => {
901 p3_types::ErrorCode::HttpRequestMethodInvalid
902 }
903 p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
904 p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
905 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
906 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
907 }
908 p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
909 p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
910 p3_types::FieldSizePayload {
911 field_name: payload.field_name,
912 field_size: payload.field_size,
913 }
914 }))
915 }
916 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
917 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
918 }
919 p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
920 p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
921 field_name: payload.field_name,
922 field_size: payload.field_size,
923 })
924 }
925 p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
926 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
927 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
928 }
929 p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
930 p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
931 field_name: payload.field_name,
932 field_size: payload.field_size,
933 })
934 }
935 p2_types::ErrorCode::HttpResponseBodySize(payload) => {
936 p3_types::ErrorCode::HttpResponseBodySize(payload)
937 }
938 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
939 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
940 }
941 p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
942 p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
943 field_name: payload.field_name,
944 field_size: payload.field_size,
945 })
946 }
947 p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
948 p3_types::ErrorCode::HttpResponseTransferCoding(payload)
949 }
950 p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
951 p3_types::ErrorCode::HttpResponseContentCoding(payload)
952 }
953 p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
954 p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
955 p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
956 p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
957 p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
958 p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
959 }
960}
961
962pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
965 match code {
966 p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
967 p3_types::ErrorCode::DnsError(payload) => {
968 p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
969 rcode: payload.rcode,
970 info_code: payload.info_code,
971 })
972 }
973 p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
974 p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
975 p3_types::ErrorCode::DestinationIpProhibited => {
976 p2_types::ErrorCode::DestinationIpProhibited
977 }
978 p3_types::ErrorCode::DestinationIpUnroutable => {
979 p2_types::ErrorCode::DestinationIpUnroutable
980 }
981 p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
982 p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
983 p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
984 p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
985 p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
986 p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
987 p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
988 p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
989 p3_types::ErrorCode::TlsAlertReceived(payload) => {
990 p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
991 alert_id: payload.alert_id,
992 alert_message: payload.alert_message,
993 })
994 }
995 p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
996 p3_types::ErrorCode::HttpRequestLengthRequired => {
997 p2_types::ErrorCode::HttpRequestLengthRequired
998 }
999 p3_types::ErrorCode::HttpRequestBodySize(payload) => {
1000 p2_types::ErrorCode::HttpRequestBodySize(payload)
1001 }
1002 p3_types::ErrorCode::HttpRequestMethodInvalid => {
1003 p2_types::ErrorCode::HttpRequestMethodInvalid
1004 }
1005 p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
1006 p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
1007 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
1008 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
1009 }
1010 p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
1011 p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
1012 p2_types::FieldSizePayload {
1013 field_name: payload.field_name,
1014 field_size: payload.field_size,
1015 }
1016 }))
1017 }
1018 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
1019 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
1020 }
1021 p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
1022 p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
1023 field_name: payload.field_name,
1024 field_size: payload.field_size,
1025 })
1026 }
1027 p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
1028 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1029 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1030 }
1031 p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1032 p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1033 field_name: payload.field_name,
1034 field_size: payload.field_size,
1035 })
1036 }
1037 p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1038 p2_types::ErrorCode::HttpResponseBodySize(payload)
1039 }
1040 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1041 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1042 }
1043 p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1044 p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1045 field_name: payload.field_name,
1046 field_size: payload.field_size,
1047 })
1048 }
1049 p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1050 p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1051 }
1052 p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1053 p2_types::ErrorCode::HttpResponseContentCoding(payload)
1054 }
1055 p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1056 p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1057 p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1058 p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1059 p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1060 p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1061 }
1062}