1use std::{
2 error::Error,
3 future::Future,
4 io::IoSlice,
5 net::SocketAddr,
6 ops::DerefMut,
7 pin::Pin,
8 sync::{Arc, Mutex},
9 task::{self, Context, Poll},
10 time::Duration,
11};
12
13use bytes::Bytes;
14use http::{header::HOST, uri::Scheme, Uri};
15use http_body::{Body, Frame, SizeHint};
16use http_body_util::{combinators::UnsyncBoxBody, BodyExt};
17use hyper_util::{
18 client::legacy::{
19 connect::{Connected, Connection},
20 Client,
21 },
22 rt::{TokioExecutor, TokioIo},
23};
24use spin_factor_outbound_networking::{
25 config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks},
26 ComponentTlsClientConfigs, TlsClientConfig,
27};
28use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceState};
29use tokio::{
30 io::{AsyncRead, AsyncWrite, ReadBuf},
31 net::TcpStream,
32 sync::{OwnedSemaphorePermit, Semaphore},
33 time::timeout,
34};
35use tokio_rustls::client::TlsStream;
36use tower_service::Service;
37use tracing::{field::Empty, instrument, Instrument};
38use wasmtime::component::HasData;
39use wasmtime_wasi::TrappableError;
40use wasmtime_wasi_http::{
41 bindings::http::types::{self as p2_types, ErrorCode},
42 body::HyperOutgoingBody,
43 p3::{self, bindings::http::types as p3_types},
44 types::{HostFutureIncomingResponse, IncomingResponse, OutgoingRequestConfig},
45 HttpError, WasiHttpCtx, WasiHttpImpl, WasiHttpView,
46};
47
48use crate::{
49 intercept::{InterceptOutcome, OutboundHttpInterceptor},
50 wasi_2023_10_18, wasi_2023_11_10, InstanceState, OutboundHttpFactor, SelfRequestOrigin,
51};
52
53const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
54
55pub struct MutexBody<T>(Mutex<T>);
56
57impl<T> MutexBody<T> {
58 pub fn new(body: T) -> Self {
59 Self(Mutex::new(body))
60 }
61}
62
63impl<T: Body + Unpin> Body for MutexBody<T> {
64 type Data = T::Data;
65 type Error = T::Error;
66
67 fn poll_frame(
68 self: Pin<&mut Self>,
69 cx: &mut Context<'_>,
70 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
71 Pin::new(self.0.lock().unwrap().deref_mut()).poll_frame(cx)
72 }
73
74 fn is_end_stream(&self) -> bool {
75 self.0.lock().unwrap().is_end_stream()
76 }
77
78 fn size_hint(&self) -> SizeHint {
79 self.0.lock().unwrap().size_hint()
80 }
81}
82
83pub(crate) struct HasHttp;
84
85impl HasData for HasHttp {
86 type Data<'a> = WasiHttpImpl<WasiHttpImplInner<'a>>;
87}
88
89impl p3::WasiHttpCtx for InstanceState {
90 #[instrument(
91 name = "spin_outbound_http.send_request",
92 skip_all,
93 fields(
94 otel.kind = "client",
95 url.full = Empty,
96 http.request.method = %request.method(),
97 otel.name = %request.method(),
98 http.response.status_code = Empty,
99 server.address = Empty,
100 server.port = Empty,
101 )
102 )]
103 #[allow(clippy::type_complexity)]
104 fn send_request(
105 &mut self,
106 request: http::Request<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
107 options: Option<p3::RequestOptions>,
108 fut: Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
109 ) -> Box<
110 dyn Future<
111 Output = Result<
112 (
113 http::Response<UnsyncBoxBody<Bytes, p3_types::ErrorCode>>,
114 Box<dyn Future<Output = Result<(), p3_types::ErrorCode>> + Send>,
115 ),
116 TrappableError<p3_types::ErrorCode>,
117 >,
118 > + Send,
119 > {
120 self.otel.reparent_tracing_span();
121
122 _ = fut;
131
132 let request = request.map(|body| MutexBody::new(body).boxed());
133
134 let request_sender = RequestSender {
135 allowed_hosts: self.allowed_hosts.clone(),
136 component_tls_configs: self.component_tls_configs.clone(),
137 request_interceptor: self.request_interceptor.clone(),
138 self_request_origin: self.self_request_origin.clone(),
139 blocked_networks: self.blocked_networks.clone(),
140 http_clients: self.wasi_http_clients.clone(),
141 concurrent_outbound_connections_semaphore: self
142 .concurrent_outbound_connections_semaphore
143 .clone(),
144 };
145 let config = OutgoingRequestConfig {
146 use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
147 connect_timeout: options
148 .and_then(|v| v.connect_timeout)
149 .unwrap_or(DEFAULT_TIMEOUT),
150 first_byte_timeout: options
151 .and_then(|v| v.first_byte_timeout)
152 .unwrap_or(DEFAULT_TIMEOUT),
153 between_bytes_timeout: options
154 .and_then(|v| v.between_bytes_timeout)
155 .unwrap_or(DEFAULT_TIMEOUT),
156 };
157 Box::new(async {
158 match request_sender
159 .send(
160 request.map(|body| body.map_err(p3_to_p2_error_code).boxed_unsync()),
161 config,
162 )
163 .await
164 {
165 Ok(IncomingResponse {
166 resp,
167 between_bytes_timeout,
168 ..
169 }) => Ok((
170 resp.map(|body| {
171 BetweenBytesTimeoutBody {
172 body,
173 sleep: None,
174 timeout: between_bytes_timeout,
175 }
176 .boxed_unsync()
177 }),
178 Box::new(async {
179 Ok(())
183 }) as Box<dyn Future<Output = _> + Send>,
184 )),
185 Err(http_error) => match http_error.downcast() {
186 Ok(error_code) => Err(TrappableError::from(p2_to_p3_error_code(error_code))),
187 Err(trap) => Err(TrappableError::trap(trap)),
188 },
189 }
190 })
191 }
192}
193
194pin_project_lite::pin_project! {
195 struct BetweenBytesTimeoutBody<B> {
196 #[pin]
197 body: B,
198 #[pin]
199 sleep: Option<tokio::time::Sleep>,
200 timeout: Duration,
201 }
202}
203
204impl<B: Body<Error = p2_types::ErrorCode>> Body for BetweenBytesTimeoutBody<B> {
205 type Data = B::Data;
206 type Error = p3_types::ErrorCode;
207
208 fn poll_frame(
209 self: Pin<&mut Self>,
210 cx: &mut Context<'_>,
211 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
212 let mut me = self.project();
213 match me.body.poll_frame(cx) {
214 Poll::Ready(value) => {
215 me.sleep.as_mut().set(None);
216 Poll::Ready(value.map(|v| v.map_err(p2_to_p3_error_code)))
217 }
218 Poll::Pending => {
219 if me.sleep.is_none() {
220 me.sleep.as_mut().set(Some(tokio::time::sleep(*me.timeout)));
221 }
222 task::ready!(me.sleep.as_pin_mut().unwrap().poll(cx));
223 Poll::Ready(Some(Err(p3_types::ErrorCode::ConnectionReadTimeout)))
224 }
225 }
226 }
227
228 fn is_end_stream(&self) -> bool {
229 self.body.is_end_stream()
230 }
231
232 fn size_hint(&self) -> SizeHint {
233 self.body.size_hint()
234 }
235}
236
237pub(crate) fn add_to_linker<C>(ctx: &mut C) -> anyhow::Result<()>
238where
239 C: spin_factors::InitContext<OutboundHttpFactor>,
240{
241 let linker = ctx.linker();
242
243 fn get_http<C>(store: &mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>
244 where
245 C: spin_factors::InitContext<OutboundHttpFactor>,
246 {
247 let (state, table) = C::get_data_with_table(store);
248 WasiHttpImpl(WasiHttpImplInner { state, table })
249 }
250
251 let get_http = get_http::<C> as fn(&mut C::StoreData) -> WasiHttpImpl<WasiHttpImplInner<'_>>;
252 wasmtime_wasi_http::bindings::http::outgoing_handler::add_to_linker::<_, HasHttp>(
253 linker, get_http,
254 )?;
255 wasmtime_wasi_http::bindings::http::types::add_to_linker::<_, HasHttp>(
256 linker,
257 &Default::default(),
258 get_http,
259 )?;
260
261 fn get_http_p3<C>(store: &mut C::StoreData) -> p3::WasiHttpCtxView<'_>
262 where
263 C: spin_factors::InitContext<OutboundHttpFactor>,
264 {
265 let (state, table) = C::get_data_with_table(store);
266 p3::WasiHttpCtxView { ctx: state, table }
267 }
268
269 let get_http_p3 = get_http_p3::<C> as fn(&mut C::StoreData) -> p3::WasiHttpCtxView<'_>;
270 p3::bindings::http::client::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
271 p3::bindings::http::types::add_to_linker::<_, p3::WasiHttp>(linker, get_http_p3)?;
272
273 wasi_2023_10_18::add_to_linker(linker, get_http)?;
274 wasi_2023_11_10::add_to_linker(linker, get_http)?;
275
276 Ok(())
277}
278
279impl OutboundHttpFactor {
280 pub fn get_wasi_http_impl(
281 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
282 ) -> Option<WasiHttpImpl<impl WasiHttpView + '_>> {
283 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
284 Some(WasiHttpImpl(WasiHttpImplInner { state, table }))
285 }
286
287 pub fn get_wasi_p3_http_impl(
288 runtime_instance_state: &mut impl RuntimeFactorsInstanceState,
289 ) -> Option<p3::WasiHttpCtxView<'_>> {
290 let (state, table) = runtime_instance_state.get_with_table::<OutboundHttpFactor>()?;
291 Some(p3::WasiHttpCtxView { ctx: state, table })
292 }
293}
294
295pub(crate) struct WasiHttpImplInner<'a> {
296 state: &'a mut InstanceState,
297 table: &'a mut ResourceTable,
298}
299
300type OutgoingRequest = http::Request<HyperOutgoingBody>;
301
302impl WasiHttpView for WasiHttpImplInner<'_> {
303 fn ctx(&mut self) -> &mut WasiHttpCtx {
304 &mut self.state.wasi_http_ctx
305 }
306
307 fn table(&mut self) -> &mut ResourceTable {
308 self.table
309 }
310
311 #[instrument(
312 name = "spin_outbound_http.send_request",
313 skip_all,
314 fields(
315 otel.kind = "client",
316 url.full = Empty,
317 http.request.method = %request.method(),
318 otel.name = %request.method(),
319 http.response.status_code = Empty,
320 server.address = Empty,
321 server.port = Empty,
322 )
323 )]
324 fn send_request(
325 &mut self,
326 request: OutgoingRequest,
327 config: OutgoingRequestConfig,
328 ) -> Result<wasmtime_wasi_http::types::HostFutureIncomingResponse, HttpError> {
329 self.state.otel.reparent_tracing_span();
330
331 let request_sender = RequestSender {
332 allowed_hosts: self.state.allowed_hosts.clone(),
333 component_tls_configs: self.state.component_tls_configs.clone(),
334 request_interceptor: self.state.request_interceptor.clone(),
335 self_request_origin: self.state.self_request_origin.clone(),
336 blocked_networks: self.state.blocked_networks.clone(),
337 http_clients: self.state.wasi_http_clients.clone(),
338 concurrent_outbound_connections_semaphore: self
339 .state
340 .concurrent_outbound_connections_semaphore
341 .clone(),
342 };
343 Ok(HostFutureIncomingResponse::Pending(
344 wasmtime_wasi::runtime::spawn(
345 async {
346 match request_sender.send(request, config).await {
347 Ok(resp) => Ok(Ok(resp)),
348 Err(http_error) => match http_error.downcast() {
349 Ok(error_code) => Ok(Err(error_code)),
350 Err(trap) => Err(trap),
351 },
352 }
353 }
354 .in_current_span(),
355 ),
356 ))
357 }
358}
359
360struct RequestSender {
361 allowed_hosts: OutboundAllowedHosts,
362 blocked_networks: BlockedNetworks,
363 component_tls_configs: ComponentTlsClientConfigs,
364 self_request_origin: Option<SelfRequestOrigin>,
365 request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
366 http_clients: HttpClients,
367 concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
368}
369
370impl RequestSender {
371 async fn send(
372 self,
373 mut request: OutgoingRequest,
374 mut config: OutgoingRequestConfig,
375 ) -> Result<IncomingResponse, HttpError> {
376 self.prepare_request(&mut request, &mut config).await?;
377
378 spin_telemetry::inject_trace_context(&mut request);
380
381 let mut override_connect_addr = None;
383 if let Some(interceptor) = &self.request_interceptor {
384 let intercept_request = std::mem::take(&mut request).into();
385 match interceptor.intercept(intercept_request).await? {
386 InterceptOutcome::Continue(mut req) => {
387 override_connect_addr = req.override_connect_addr.take();
388 request = req.into_hyper_request();
389 }
390 InterceptOutcome::Complete(resp) => {
391 let resp = IncomingResponse {
392 resp,
393 worker: None,
394 between_bytes_timeout: config.between_bytes_timeout,
395 };
396 return Ok(resp);
397 }
398 }
399 }
400
401 let span = tracing::Span::current();
403 if let Some(addr) = override_connect_addr {
404 span.record("server.address", addr.ip().to_string());
405 span.record("server.port", addr.port());
406 } else if let Some(authority) = request.uri().authority() {
407 span.record("server.address", authority.host());
408 if let Some(port) = authority.port_u16() {
409 span.record("server.port", port);
410 }
411 }
412
413 Ok(self
414 .send_request(request, config, override_connect_addr)
415 .await?)
416 }
417
418 async fn prepare_request(
419 &self,
420 request: &mut OutgoingRequest,
421 config: &mut OutgoingRequestConfig,
422 ) -> Result<(), ErrorCode> {
423 let uri = request.uri_mut();
427 if uri
428 .authority()
429 .is_some_and(|authority| authority.host().is_empty())
430 {
431 let mut builder = http::uri::Builder::new();
432 if let Some(paq) = uri.path_and_query() {
433 builder = builder.path_and_query(paq.clone());
434 }
435 *uri = builder.build().unwrap();
436 }
437 tracing::Span::current().record("url.full", uri.to_string());
438
439 let is_self_request = match request.uri().authority() {
440 Some(authority) => authority.host() == "self.alt",
442 None => true,
444 };
445
446 let is_allowed = if is_self_request {
448 self.allowed_hosts
449 .check_relative_url(&["http", "https"])
450 .await
451 .unwrap_or(false)
452 } else {
453 self.allowed_hosts
454 .check_url(&request.uri().to_string(), "https")
455 .await
456 .unwrap_or(false)
457 };
458 if !is_allowed {
459 return Err(ErrorCode::HttpRequestDenied);
460 }
461
462 if is_self_request {
463 let Some(origin) = self.self_request_origin.as_ref() else {
465 tracing::error!(
466 "Couldn't handle outbound HTTP request to relative URI; no origin set"
467 );
468 return Err(ErrorCode::HttpRequestUriInvalid);
469 };
470
471 config.use_tls = origin.use_tls();
472
473 request.headers_mut().insert(HOST, origin.host_header());
474
475 let path_and_query = request.uri().path_and_query().cloned();
476 *request.uri_mut() = origin.clone().into_uri(path_and_query);
477 }
478
479 request.headers_mut().remove(HOST);
486 Ok(())
487 }
488
489 async fn send_request(
490 self,
491 request: OutgoingRequest,
492 config: OutgoingRequestConfig,
493 override_connect_addr: Option<SocketAddr>,
494 ) -> Result<IncomingResponse, ErrorCode> {
495 let OutgoingRequestConfig {
496 use_tls,
497 connect_timeout,
498 first_byte_timeout,
499 between_bytes_timeout,
500 } = config;
501
502 let tls_client_config = if use_tls {
503 let host = request.uri().host().unwrap_or_default();
504 Some(self.component_tls_configs.get_client_config(host).clone())
505 } else {
506 None
507 };
508
509 let resp = CONNECT_OPTIONS.scope(
510 ConnectOptions {
511 blocked_networks: self.blocked_networks,
512 connect_timeout,
513 tls_client_config,
514 override_connect_addr,
515 concurrent_outbound_connections_semaphore: self
516 .concurrent_outbound_connections_semaphore,
517 },
518 async move {
519 if use_tls {
520 self.http_clients.https.request(request).await
521 } else {
522 let h2c_prior_knowledge_host =
524 std::env::var("SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE").ok();
525 let use_h2c = h2c_prior_knowledge_host.as_deref()
526 == request.uri().authority().map(|a| a.as_str());
527
528 if use_h2c {
529 self.http_clients.http2.request(request).await
530 } else {
531 self.http_clients.http1.request(request).await
532 }
533 }
534 },
535 );
536
537 let resp = timeout(first_byte_timeout, resp)
538 .await
539 .map_err(|_| ErrorCode::ConnectionReadTimeout)?
540 .map_err(hyper_legacy_request_error)?
541 .map(|body| body.map_err(hyper_request_error).boxed_unsync());
542
543 tracing::Span::current().record("http.response.status_code", resp.status().as_u16());
544
545 Ok(IncomingResponse {
546 resp,
547 worker: None,
548 between_bytes_timeout,
549 })
550 }
551}
552
553type HttpClient = Client<HttpConnector, HyperOutgoingBody>;
554type HttpsClient = Client<HttpsConnector, HyperOutgoingBody>;
555
556#[derive(Clone)]
557pub(super) struct HttpClients {
558 http1: HttpClient,
560 http2: HttpClient,
562 https: HttpsClient,
564}
565
566impl HttpClients {
567 pub(super) fn new(enable_pooling: bool) -> Self {
568 let builder = move || {
569 let mut builder = Client::builder(TokioExecutor::new());
570 if !enable_pooling {
571 builder.pool_max_idle_per_host(0);
572 }
573 builder
574 };
575 Self {
576 http1: builder().build(HttpConnector),
577 http2: builder().http2_only(true).build(HttpConnector),
578 https: builder().build(HttpsConnector),
579 }
580 }
581}
582
583tokio::task_local! {
584 static CONNECT_OPTIONS: ConnectOptions;
592}
593
594#[derive(Clone)]
595struct ConnectOptions {
596 blocked_networks: BlockedNetworks,
598 connect_timeout: Duration,
600 tls_client_config: Option<TlsClientConfig>,
602 override_connect_addr: Option<SocketAddr>,
604 concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
606}
607
608impl ConnectOptions {
609 async fn connect_tcp(
611 &self,
612 uri: &Uri,
613 default_port: u16,
614 ) -> Result<PermittedTcpStream, ErrorCode> {
615 let mut socket_addrs = match self.override_connect_addr {
616 Some(override_connect_addr) => vec![override_connect_addr],
617 None => {
618 let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
619
620 let host_and_port = if authority.port().is_some() {
621 authority.as_str().to_string()
622 } else {
623 format!("{}:{}", authority.as_str(), default_port)
624 };
625
626 let socket_addrs = tokio::net::lookup_host(&host_and_port)
627 .await
628 .map_err(|err| {
629 tracing::debug!(?host_and_port, ?err, "Error resolving host");
630 dns_error("address not available".into(), 0)
631 })?
632 .collect::<Vec<_>>();
633 tracing::debug!(?host_and_port, ?socket_addrs, "Resolved host");
634 socket_addrs
635 }
636 };
637
638 crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;
640
641 let permit = crate::concurrent_outbound_connections::acquire_owned_semaphore(
644 "wasi",
645 &self.concurrent_outbound_connections_semaphore,
646 )
647 .await;
648
649 let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
650 .await
651 .map_err(|_| ErrorCode::ConnectionTimeout)?
652 .map_err(|err| match err.kind() {
653 std::io::ErrorKind::AddrNotAvailable => {
654 dns_error("address not available".into(), 0)
655 }
656 _ => ErrorCode::ConnectionRefused,
657 })?;
658 Ok(PermittedTcpStream {
659 inner: stream,
660 _permit: permit,
661 })
662 }
663
664 async fn connect_tls(
666 &self,
667 uri: &Uri,
668 default_port: u16,
669 ) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
670 let tcp_stream = self.connect_tcp(uri, default_port).await?;
671
672 let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
673 tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
674
675 let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
676 let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
677 .map_err(|e| {
678 tracing::warn!("dns lookup error: {e:?}");
679 dns_error("invalid dns name".into(), 0)
680 })?
681 .to_owned();
682 connector.connect(domain, tcp_stream).await.map_err(|e| {
683 tracing::warn!("tls protocol error: {e:?}");
684 ErrorCode::TlsProtocolError
685 })
686 }
687}
688
689#[derive(Clone)]
691struct HttpConnector;
692
693impl HttpConnector {
694 async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
695 let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
696 Ok(TokioIo::new(stream))
697 }
698}
699
700impl Service<Uri> for HttpConnector {
701 type Response = TokioIo<PermittedTcpStream>;
702 type Error = ErrorCode;
703 type Future =
704 Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
705
706 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
707 Poll::Ready(Ok(()))
708 }
709
710 fn call(&mut self, uri: Uri) -> Self::Future {
711 Box::pin(async move { Self::connect(uri).await })
712 }
713}
714
715#[derive(Clone)]
717struct HttpsConnector;
718
719impl HttpsConnector {
720 async fn connect(uri: Uri) -> Result<TokioIo<RustlsStream>, ErrorCode> {
721 let stream = CONNECT_OPTIONS.get().connect_tls(&uri, 443).await?;
722 Ok(TokioIo::new(RustlsStream(stream)))
723 }
724}
725
726impl Service<Uri> for HttpsConnector {
727 type Response = TokioIo<RustlsStream>;
728 type Error = ErrorCode;
729 type Future = Pin<Box<dyn Future<Output = Result<TokioIo<RustlsStream>, ErrorCode>> + Send>>;
730
731 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
732 Poll::Ready(Ok(()))
733 }
734
735 fn call(&mut self, uri: Uri) -> Self::Future {
736 Box::pin(async move { Self::connect(uri).await })
737 }
738}
739
740struct RustlsStream(TlsStream<PermittedTcpStream>);
741
742impl Connection for RustlsStream {
743 fn connected(&self) -> Connected {
744 if self.0.get_ref().1.alpn_protocol() == Some(b"h2") {
745 self.0.get_ref().0.connected().negotiated_h2()
746 } else {
747 self.0.get_ref().0.connected()
748 }
749 }
750}
751
752impl AsyncRead for RustlsStream {
753 fn poll_read(
754 self: Pin<&mut Self>,
755 cx: &mut Context<'_>,
756 buf: &mut ReadBuf<'_>,
757 ) -> Poll<Result<(), std::io::Error>> {
758 Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
759 }
760}
761
762impl AsyncWrite for RustlsStream {
763 fn poll_write(
764 self: Pin<&mut Self>,
765 cx: &mut Context<'_>,
766 buf: &[u8],
767 ) -> Poll<Result<usize, std::io::Error>> {
768 Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
769 }
770
771 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
772 Pin::new(&mut self.get_mut().0).poll_flush(cx)
773 }
774
775 fn poll_shutdown(
776 self: Pin<&mut Self>,
777 cx: &mut Context<'_>,
778 ) -> Poll<Result<(), std::io::Error>> {
779 Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
780 }
781
782 fn poll_write_vectored(
783 self: Pin<&mut Self>,
784 cx: &mut Context<'_>,
785 bufs: &[IoSlice<'_>],
786 ) -> Poll<Result<usize, std::io::Error>> {
787 Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
788 }
789
790 fn is_write_vectored(&self) -> bool {
791 self.0.is_write_vectored()
792 }
793}
794
795struct PermittedTcpStream {
797 inner: TcpStream,
799 _permit: Option<OwnedSemaphorePermit>,
804}
805
806impl Connection for PermittedTcpStream {
807 fn connected(&self) -> Connected {
808 self.inner.connected()
809 }
810}
811
812impl AsyncRead for PermittedTcpStream {
813 fn poll_read(
814 self: Pin<&mut Self>,
815 cx: &mut Context<'_>,
816 buf: &mut ReadBuf<'_>,
817 ) -> Poll<std::io::Result<()>> {
818 Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
819 }
820}
821
822impl AsyncWrite for PermittedTcpStream {
823 fn poll_write(
824 self: Pin<&mut Self>,
825 cx: &mut Context<'_>,
826 buf: &[u8],
827 ) -> Poll<Result<usize, std::io::Error>> {
828 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
829 }
830
831 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
832 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
833 }
834
835 fn poll_shutdown(
836 self: Pin<&mut Self>,
837 cx: &mut Context<'_>,
838 ) -> Poll<Result<(), std::io::Error>> {
839 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
840 }
841}
842
843fn hyper_request_error(err: hyper::Error) -> ErrorCode {
845 if let Some(cause) = err.source() {
847 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
848 return err.clone();
849 }
850 }
851
852 tracing::warn!("hyper request error: {err:?}");
853
854 ErrorCode::HttpProtocolError
855}
856
857fn hyper_legacy_request_error(err: hyper_util::client::legacy::Error) -> ErrorCode {
859 if let Some(cause) = err.source() {
861 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
862 return err.clone();
863 }
864 }
865
866 tracing::warn!("hyper request error: {err:?}");
867
868 ErrorCode::HttpProtocolError
869}
870
871fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
872 ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
873 rcode: Some(rcode),
874 info_code: Some(info_code),
875 })
876}
877
878pub fn p2_to_p3_error_code(code: p2_types::ErrorCode) -> p3_types::ErrorCode {
881 match code {
882 p2_types::ErrorCode::DnsTimeout => p3_types::ErrorCode::DnsTimeout,
883 p2_types::ErrorCode::DnsError(payload) => {
884 p3_types::ErrorCode::DnsError(p3_types::DnsErrorPayload {
885 rcode: payload.rcode,
886 info_code: payload.info_code,
887 })
888 }
889 p2_types::ErrorCode::DestinationNotFound => p3_types::ErrorCode::DestinationNotFound,
890 p2_types::ErrorCode::DestinationUnavailable => p3_types::ErrorCode::DestinationUnavailable,
891 p2_types::ErrorCode::DestinationIpProhibited => {
892 p3_types::ErrorCode::DestinationIpProhibited
893 }
894 p2_types::ErrorCode::DestinationIpUnroutable => {
895 p3_types::ErrorCode::DestinationIpUnroutable
896 }
897 p2_types::ErrorCode::ConnectionRefused => p3_types::ErrorCode::ConnectionRefused,
898 p2_types::ErrorCode::ConnectionTerminated => p3_types::ErrorCode::ConnectionTerminated,
899 p2_types::ErrorCode::ConnectionTimeout => p3_types::ErrorCode::ConnectionTimeout,
900 p2_types::ErrorCode::ConnectionReadTimeout => p3_types::ErrorCode::ConnectionReadTimeout,
901 p2_types::ErrorCode::ConnectionWriteTimeout => p3_types::ErrorCode::ConnectionWriteTimeout,
902 p2_types::ErrorCode::ConnectionLimitReached => p3_types::ErrorCode::ConnectionLimitReached,
903 p2_types::ErrorCode::TlsProtocolError => p3_types::ErrorCode::TlsProtocolError,
904 p2_types::ErrorCode::TlsCertificateError => p3_types::ErrorCode::TlsCertificateError,
905 p2_types::ErrorCode::TlsAlertReceived(payload) => {
906 p3_types::ErrorCode::TlsAlertReceived(p3_types::TlsAlertReceivedPayload {
907 alert_id: payload.alert_id,
908 alert_message: payload.alert_message,
909 })
910 }
911 p2_types::ErrorCode::HttpRequestDenied => p3_types::ErrorCode::HttpRequestDenied,
912 p2_types::ErrorCode::HttpRequestLengthRequired => {
913 p3_types::ErrorCode::HttpRequestLengthRequired
914 }
915 p2_types::ErrorCode::HttpRequestBodySize(payload) => {
916 p3_types::ErrorCode::HttpRequestBodySize(payload)
917 }
918 p2_types::ErrorCode::HttpRequestMethodInvalid => {
919 p3_types::ErrorCode::HttpRequestMethodInvalid
920 }
921 p2_types::ErrorCode::HttpRequestUriInvalid => p3_types::ErrorCode::HttpRequestUriInvalid,
922 p2_types::ErrorCode::HttpRequestUriTooLong => p3_types::ErrorCode::HttpRequestUriTooLong,
923 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
924 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
925 }
926 p2_types::ErrorCode::HttpRequestHeaderSize(payload) => {
927 p3_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
928 p3_types::FieldSizePayload {
929 field_name: payload.field_name,
930 field_size: payload.field_size,
931 }
932 }))
933 }
934 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
935 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
936 }
937 p2_types::ErrorCode::HttpRequestTrailerSize(payload) => {
938 p3_types::ErrorCode::HttpRequestTrailerSize(p3_types::FieldSizePayload {
939 field_name: payload.field_name,
940 field_size: payload.field_size,
941 })
942 }
943 p2_types::ErrorCode::HttpResponseIncomplete => p3_types::ErrorCode::HttpResponseIncomplete,
944 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
945 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
946 }
947 p2_types::ErrorCode::HttpResponseHeaderSize(payload) => {
948 p3_types::ErrorCode::HttpResponseHeaderSize(p3_types::FieldSizePayload {
949 field_name: payload.field_name,
950 field_size: payload.field_size,
951 })
952 }
953 p2_types::ErrorCode::HttpResponseBodySize(payload) => {
954 p3_types::ErrorCode::HttpResponseBodySize(payload)
955 }
956 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
957 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
958 }
959 p2_types::ErrorCode::HttpResponseTrailerSize(payload) => {
960 p3_types::ErrorCode::HttpResponseTrailerSize(p3_types::FieldSizePayload {
961 field_name: payload.field_name,
962 field_size: payload.field_size,
963 })
964 }
965 p2_types::ErrorCode::HttpResponseTransferCoding(payload) => {
966 p3_types::ErrorCode::HttpResponseTransferCoding(payload)
967 }
968 p2_types::ErrorCode::HttpResponseContentCoding(payload) => {
969 p3_types::ErrorCode::HttpResponseContentCoding(payload)
970 }
971 p2_types::ErrorCode::HttpResponseTimeout => p3_types::ErrorCode::HttpResponseTimeout,
972 p2_types::ErrorCode::HttpUpgradeFailed => p3_types::ErrorCode::HttpUpgradeFailed,
973 p2_types::ErrorCode::HttpProtocolError => p3_types::ErrorCode::HttpProtocolError,
974 p2_types::ErrorCode::LoopDetected => p3_types::ErrorCode::LoopDetected,
975 p2_types::ErrorCode::ConfigurationError => p3_types::ErrorCode::ConfigurationError,
976 p2_types::ErrorCode::InternalError(payload) => p3_types::ErrorCode::InternalError(payload),
977 }
978}
979
980pub fn p3_to_p2_error_code(code: p3_types::ErrorCode) -> p2_types::ErrorCode {
983 match code {
984 p3_types::ErrorCode::DnsTimeout => p2_types::ErrorCode::DnsTimeout,
985 p3_types::ErrorCode::DnsError(payload) => {
986 p2_types::ErrorCode::DnsError(p2_types::DnsErrorPayload {
987 rcode: payload.rcode,
988 info_code: payload.info_code,
989 })
990 }
991 p3_types::ErrorCode::DestinationNotFound => p2_types::ErrorCode::DestinationNotFound,
992 p3_types::ErrorCode::DestinationUnavailable => p2_types::ErrorCode::DestinationUnavailable,
993 p3_types::ErrorCode::DestinationIpProhibited => {
994 p2_types::ErrorCode::DestinationIpProhibited
995 }
996 p3_types::ErrorCode::DestinationIpUnroutable => {
997 p2_types::ErrorCode::DestinationIpUnroutable
998 }
999 p3_types::ErrorCode::ConnectionRefused => p2_types::ErrorCode::ConnectionRefused,
1000 p3_types::ErrorCode::ConnectionTerminated => p2_types::ErrorCode::ConnectionTerminated,
1001 p3_types::ErrorCode::ConnectionTimeout => p2_types::ErrorCode::ConnectionTimeout,
1002 p3_types::ErrorCode::ConnectionReadTimeout => p2_types::ErrorCode::ConnectionReadTimeout,
1003 p3_types::ErrorCode::ConnectionWriteTimeout => p2_types::ErrorCode::ConnectionWriteTimeout,
1004 p3_types::ErrorCode::ConnectionLimitReached => p2_types::ErrorCode::ConnectionLimitReached,
1005 p3_types::ErrorCode::TlsProtocolError => p2_types::ErrorCode::TlsProtocolError,
1006 p3_types::ErrorCode::TlsCertificateError => p2_types::ErrorCode::TlsCertificateError,
1007 p3_types::ErrorCode::TlsAlertReceived(payload) => {
1008 p2_types::ErrorCode::TlsAlertReceived(p2_types::TlsAlertReceivedPayload {
1009 alert_id: payload.alert_id,
1010 alert_message: payload.alert_message,
1011 })
1012 }
1013 p3_types::ErrorCode::HttpRequestDenied => p2_types::ErrorCode::HttpRequestDenied,
1014 p3_types::ErrorCode::HttpRequestLengthRequired => {
1015 p2_types::ErrorCode::HttpRequestLengthRequired
1016 }
1017 p3_types::ErrorCode::HttpRequestBodySize(payload) => {
1018 p2_types::ErrorCode::HttpRequestBodySize(payload)
1019 }
1020 p3_types::ErrorCode::HttpRequestMethodInvalid => {
1021 p2_types::ErrorCode::HttpRequestMethodInvalid
1022 }
1023 p3_types::ErrorCode::HttpRequestUriInvalid => p2_types::ErrorCode::HttpRequestUriInvalid,
1024 p3_types::ErrorCode::HttpRequestUriTooLong => p2_types::ErrorCode::HttpRequestUriTooLong,
1025 p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => {
1026 p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload)
1027 }
1028 p3_types::ErrorCode::HttpRequestHeaderSize(payload) => {
1029 p2_types::ErrorCode::HttpRequestHeaderSize(payload.map(|payload| {
1030 p2_types::FieldSizePayload {
1031 field_name: payload.field_name,
1032 field_size: payload.field_size,
1033 }
1034 }))
1035 }
1036 p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => {
1037 p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload)
1038 }
1039 p3_types::ErrorCode::HttpRequestTrailerSize(payload) => {
1040 p2_types::ErrorCode::HttpRequestTrailerSize(p2_types::FieldSizePayload {
1041 field_name: payload.field_name,
1042 field_size: payload.field_size,
1043 })
1044 }
1045 p3_types::ErrorCode::HttpResponseIncomplete => p2_types::ErrorCode::HttpResponseIncomplete,
1046 p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => {
1047 p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload)
1048 }
1049 p3_types::ErrorCode::HttpResponseHeaderSize(payload) => {
1050 p2_types::ErrorCode::HttpResponseHeaderSize(p2_types::FieldSizePayload {
1051 field_name: payload.field_name,
1052 field_size: payload.field_size,
1053 })
1054 }
1055 p3_types::ErrorCode::HttpResponseBodySize(payload) => {
1056 p2_types::ErrorCode::HttpResponseBodySize(payload)
1057 }
1058 p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => {
1059 p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload)
1060 }
1061 p3_types::ErrorCode::HttpResponseTrailerSize(payload) => {
1062 p2_types::ErrorCode::HttpResponseTrailerSize(p2_types::FieldSizePayload {
1063 field_name: payload.field_name,
1064 field_size: payload.field_size,
1065 })
1066 }
1067 p3_types::ErrorCode::HttpResponseTransferCoding(payload) => {
1068 p2_types::ErrorCode::HttpResponseTransferCoding(payload)
1069 }
1070 p3_types::ErrorCode::HttpResponseContentCoding(payload) => {
1071 p2_types::ErrorCode::HttpResponseContentCoding(payload)
1072 }
1073 p3_types::ErrorCode::HttpResponseTimeout => p2_types::ErrorCode::HttpResponseTimeout,
1074 p3_types::ErrorCode::HttpUpgradeFailed => p2_types::ErrorCode::HttpUpgradeFailed,
1075 p3_types::ErrorCode::HttpProtocolError => p2_types::ErrorCode::HttpProtocolError,
1076 p3_types::ErrorCode::LoopDetected => p2_types::ErrorCode::LoopDetected,
1077 p3_types::ErrorCode::ConfigurationError => p2_types::ErrorCode::ConfigurationError,
1078 p3_types::ErrorCode::InternalError(payload) => p2_types::ErrorCode::InternalError(payload),
1079 }
1080}