spin_factor_outbound_http/
spin.rs1use std::sync::Arc;
2
3use futures::stream::TryStreamExt as _;
4use http_body_util::BodyExt;
5use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks;
6use spin_world::MAX_HOST_BUFFERED_BYTES;
7use spin_world::v1::{
8 http as spin_http,
9 http_types::{self, HttpError, Method, Request, Response},
10};
11use tracing::{Span, field::Empty, instrument};
12
13use crate::intercept::InterceptOutcome;
14
15impl spin_http::Host for crate::InstanceState {
16 #[instrument(name = "spin_outbound_http.send_request", skip_all,
17 fields(otel.kind = "client", url.full = Empty, http.request.method = Empty,
18 http.response.status_code = Empty, otel.name = Empty, server.address = Empty, server.port = Empty))]
19 async fn send_request(&mut self, req: Request) -> Result<Response, HttpError> {
20 self.hooks.otel.reparent_tracing_span();
21
22 let span = Span::current();
23 record_request_fields(&span, &req);
24
25 let uri = req.uri;
26 tracing::trace!("Sending outbound HTTP to {uri:?}");
27
28 if !req.params.is_empty() {
29 tracing::warn!("HTTP params field is deprecated");
30 }
31 let req_url = if !uri.starts_with('/') {
32 let is_allowed = self
34 .hooks
35 .allowed_hosts
36 .check_url(&uri, "https")
37 .await
38 .unwrap_or(false);
39 if !is_allowed {
40 return Err(HttpError::DestinationNotAllowed);
41 }
42 uri.parse().map_err(|_| HttpError::InvalidUrl)?
43 } else {
44 let is_allowed = self
46 .hooks
47 .allowed_hosts
48 .check_relative_url(&["http", "https"])
49 .await
50 .unwrap_or(false);
51 if !is_allowed {
52 return Err(HttpError::DestinationNotAllowed);
53 }
54
55 let Some(origin) = &self.hooks.self_request_origin else {
56 tracing::error!(
57 "Couldn't handle outbound HTTP request to relative URI; no origin set"
58 );
59 return Err(HttpError::InvalidUrl);
60 };
61 let path_and_query = uri.parse().map_err(|_| HttpError::InvalidUrl)?;
62 origin.clone().into_uri(Some(path_and_query))
63 };
64
65 let mut req = {
67 let mut builder = http::Request::builder()
68 .method(hyper_method(req.method))
69 .uri(&req_url);
70 for (key, val) in req.headers {
71 builder = builder.header(key, val);
72 }
73 builder.body(req.body.unwrap_or_default())
74 }
75 .map_err(|err| {
76 tracing::error!("Error building outbound request: {err}");
77 HttpError::RuntimeError
78 })?;
79
80 spin_telemetry::inject_trace_context(req.headers_mut());
81
82 if let Some(interceptor) = &self.hooks.request_interceptor {
83 let intercepted_request = std::mem::take(&mut req).into();
84 match interceptor.intercept(intercepted_request).await {
85 Ok(InterceptOutcome::Continue(intercepted_request)) => {
86 req = intercepted_request.into_vec_request().unwrap();
87 }
88 Ok(InterceptOutcome::Complete(resp)) => return response_from_hyper(resp).await,
89 Err(err) => {
90 tracing::error!("Error in outbound HTTP interceptor: {err}");
91 return Err(HttpError::RuntimeError);
92 }
93 }
94 }
95
96 let req = reqwest::Request::try_from(req).map_err(|_| HttpError::InvalidUrl)?;
98
99 let client = self.hooks.spin_http_client.get_or_insert_with(|| {
102 let mut builder = reqwest::Client::builder().dns_resolver(Arc::new(SpinDnsResolver(
103 self.hooks.blocked_networks.clone(),
104 )));
105 if !self.hooks.connection_pooling_enabled {
106 builder = builder.pool_max_idle_per_host(0);
107 }
108 builder.build().unwrap()
109 });
110
111 let permit = crate::concurrent_outbound_connections::acquire_semaphore(
115 "spin",
116 &self.hooks.concurrent_outbound_connections_semaphore,
117 )
118 .await;
119 let resp = client.execute(req).await.map_err(log_reqwest_error)?;
120 drop(permit);
121
122 tracing::trace!("Returning response from outbound request to {req_url}");
123 span.record("http.response.status_code", resp.status().as_u16());
124 response_from_reqwest(resp).await
125 }
126}
127
128struct SpinDnsResolver(BlockedNetworks);
130
131impl reqwest::dns::Resolve for SpinDnsResolver {
132 fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
133 let blocked_networks = self.0.clone();
134 Box::pin(async move {
135 let mut addrs = tokio::net::lookup_host((name.as_str(), 0))
136 .await
137 .map_err(Box::new)?
138 .collect::<Vec<_>>();
139 crate::remove_blocked_addrs(&blocked_networks, &mut addrs).map_err(Box::new)?;
141 Ok(Box::new(addrs.into_iter()) as reqwest::dns::Addrs)
142 })
143 }
144}
145
146impl http_types::Host for crate::InstanceState {
147 fn convert_http_error(&mut self, err: HttpError) -> anyhow::Result<HttpError> {
148 Ok(err)
149 }
150}
151
152fn record_request_fields(span: &Span, req: &Request) {
153 let method = match req.method {
154 Method::Get => "GET",
155 Method::Post => "POST",
156 Method::Put => "PUT",
157 Method::Delete => "DELETE",
158 Method::Patch => "PATCH",
159 Method::Head => "HEAD",
160 Method::Options => "OPTIONS",
161 };
162 span.record("otel.name", method)
165 .record("http.request.method", method)
166 .record("url.full", req.uri.clone());
167 if let Ok(uri) = req.uri.parse::<http::Uri>()
168 && let Some(authority) = uri.authority()
169 {
170 span.record("server.address", authority.host());
171 if let Some(port) = authority.port() {
172 span.record("server.port", port.as_u16());
173 }
174 }
175}
176
177fn hyper_method(m: Method) -> http::Method {
178 match m {
179 Method::Get => http::Method::GET,
180 Method::Post => http::Method::POST,
181 Method::Put => http::Method::PUT,
182 Method::Delete => http::Method::DELETE,
183 Method::Patch => http::Method::PATCH,
184 Method::Head => http::Method::HEAD,
185 Method::Options => http::Method::OPTIONS,
186 }
187}
188
189async fn response_from_hyper(resp: crate::Response) -> Result<Response, HttpError> {
190 let status = resp.status().as_u16();
191
192 let headers = headers_from_map(resp.headers());
193
194 let header_bytes = std::mem::size_of::<Vec<(String, String)>>()
195 + headers
196 .iter()
197 .map(|(k, v)| std::mem::size_of::<(String, String)>() + k.len() + v.len())
198 .sum::<usize>();
199
200 let mut stream = resp.into_body().into_data_stream();
201 let mut body = Vec::new();
202 while let Some(chunk) = stream
203 .try_next()
204 .await
205 .map_err(|_| HttpError::RuntimeError)?
206 {
207 body.extend(chunk);
208 check_byte_count(header_bytes + body.len())?;
209 }
210
211 check_byte_count(header_bytes + body.len())?;
213
214 Ok(Response {
215 status,
216 headers: Some(headers),
217 body: Some(body),
218 })
219}
220
221fn log_reqwest_error(err: reqwest::Error) -> HttpError {
222 let error_desc = if err.is_timeout() {
223 "timeout error"
224 } else if err.is_connect() {
225 "connection error"
226 } else if err.is_body() || err.is_decode() {
227 "message body error"
228 } else if err.is_request() {
229 "request error"
230 } else {
231 "error"
232 };
233 tracing::warn!(
234 "Outbound HTTP {}: URL {}, error detail {:?}",
235 error_desc,
236 err.url()
237 .map(|u| u.to_string())
238 .unwrap_or_else(|| "<unknown>".to_owned()),
239 err
240 );
241 HttpError::RuntimeError
242}
243
244async fn response_from_reqwest(res: reqwest::Response) -> Result<Response, HttpError> {
245 let status = res.status().as_u16();
246
247 let headers = headers_from_map(res.headers());
248
249 let header_bytes = std::mem::size_of::<Vec<(String, String)>>()
250 + headers
251 .iter()
252 .map(|(k, v)| std::mem::size_of::<(String, String)>() + k.len() + v.len())
253 .sum::<usize>();
254
255 let mut stream = res.bytes_stream();
256 let mut body = Vec::new();
257 while let Some(chunk) = stream
258 .try_next()
259 .await
260 .map_err(|_| HttpError::RuntimeError)?
261 {
262 body.extend(chunk);
263 check_byte_count(header_bytes + body.len())?;
264 }
265
266 check_byte_count(header_bytes + body.len())?;
268
269 Ok(Response {
270 status,
271 headers: Some(headers),
272 body: Some(body),
273 })
274}
275
276fn check_byte_count(count: usize) -> Result<(), HttpError> {
277 if count > MAX_HOST_BUFFERED_BYTES {
278 tracing::warn!("query result exceeds limit of {MAX_HOST_BUFFERED_BYTES} bytes");
279 Err(HttpError::RuntimeError)
280 } else {
281 Ok(())
282 }
283}
284
285fn headers_from_map(map: &http::HeaderMap) -> Vec<(String, String)> {
286 map.iter()
287 .filter_map(|(key, val)| {
288 Some((
289 key.to_string(),
290 val.to_str()
291 .ok()
292 .or_else(|| {
293 tracing::warn!("Non-ascii response header value for {key}");
294 None
295 })?
296 .to_string(),
297 ))
298 })
299 .collect()
300}