Skip to main content

spin_factor_outbound_http/
spin.rs

1use std::sync::Arc;
2
3use futures::stream::TryStreamExt as _;
4use http_body_util::BodyExt;
5use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks;
6use spin_world::MAX_HOST_BUFFERED_BYTES;
7use spin_world::v1::{
8    http as spin_http,
9    http_types::{self, HttpError, Method, Request, Response},
10};
11use tracing::{Span, field::Empty, instrument};
12
13use crate::intercept::InterceptOutcome;
14
15impl spin_http::Host for crate::InstanceState {
16    #[instrument(name = "spin_outbound_http.send_request", skip_all,
17        fields(otel.kind = "client", url.full = Empty, http.request.method = Empty,
18        http.response.status_code = Empty, otel.name = Empty, server.address = Empty, server.port = Empty))]
19    async fn send_request(&mut self, req: Request) -> Result<Response, HttpError> {
20        self.hooks.otel.reparent_tracing_span();
21
22        let span = Span::current();
23        record_request_fields(&span, &req);
24
25        let uri = req.uri;
26        tracing::trace!("Sending outbound HTTP to {uri:?}");
27
28        if !req.params.is_empty() {
29            tracing::warn!("HTTP params field is deprecated");
30        }
31        let req_url = if !uri.starts_with('/') {
32            // Absolute URI
33            let is_allowed = self
34                .hooks
35                .allowed_hosts
36                .check_url(&uri, "https")
37                .await
38                .unwrap_or(false);
39            if !is_allowed {
40                return Err(HttpError::DestinationNotAllowed);
41            }
42            uri.parse().map_err(|_| HttpError::InvalidUrl)?
43        } else {
44            // Relative URI ("self" request)
45            let is_allowed = self
46                .hooks
47                .allowed_hosts
48                .check_relative_url(&["http", "https"])
49                .await
50                .unwrap_or(false);
51            if !is_allowed {
52                return Err(HttpError::DestinationNotAllowed);
53            }
54
55            let Some(origin) = &self.hooks.self_request_origin else {
56                tracing::error!(
57                    "Couldn't handle outbound HTTP request to relative URI; no origin set"
58                );
59                return Err(HttpError::InvalidUrl);
60            };
61            let path_and_query = uri.parse().map_err(|_| HttpError::InvalidUrl)?;
62            origin.clone().into_uri(Some(path_and_query))
63        };
64
65        // Build an http::Request for OutboundHttpInterceptor
66        let mut req = {
67            let mut builder = http::Request::builder()
68                .method(hyper_method(req.method))
69                .uri(&req_url);
70            for (key, val) in req.headers {
71                builder = builder.header(key, val);
72            }
73            builder.body(req.body.unwrap_or_default())
74        }
75        .map_err(|err| {
76            tracing::error!("Error building outbound request: {err}");
77            HttpError::RuntimeError
78        })?;
79
80        spin_telemetry::inject_trace_context(req.headers_mut());
81
82        if let Some(interceptor) = &self.hooks.request_interceptor {
83            let intercepted_request = std::mem::take(&mut req).into();
84            match interceptor.intercept(intercepted_request).await {
85                Ok(InterceptOutcome::Continue(intercepted_request)) => {
86                    req = intercepted_request.into_vec_request().unwrap();
87                }
88                Ok(InterceptOutcome::Complete(resp)) => return response_from_hyper(resp).await,
89                Err(err) => {
90                    tracing::error!("Error in outbound HTTP interceptor: {err}");
91                    return Err(HttpError::RuntimeError);
92                }
93            }
94        }
95
96        // Convert http::Request to reqwest::Request
97        let req = reqwest::Request::try_from(req).map_err(|_| HttpError::InvalidUrl)?;
98
99        // Allow reuse of Client's internal connection pool for multiple requests
100        // in a single component execution
101        let client = self.hooks.spin_http_client.get_or_insert_with(|| {
102            let mut builder = reqwest::Client::builder().dns_resolver(Arc::new(SpinDnsResolver(
103                self.hooks.blocked_networks.clone(),
104            )));
105            if !self.hooks.connection_pooling_enabled {
106                builder = builder.pool_max_idle_per_host(0);
107            }
108            builder.build().unwrap()
109        });
110
111        // If we're limiting concurrent outbound requests, acquire a permit
112        // Note: since we don't have access to the underlying connection, we can only
113        // limit the number of concurrent requests, not connections.
114        let permit = crate::concurrent_outbound_connections::acquire_semaphore(
115            "spin",
116            &self.hooks.concurrent_outbound_connections_semaphore,
117        )
118        .await;
119        let resp = client.execute(req).await.map_err(log_reqwest_error)?;
120        drop(permit);
121
122        tracing::trace!("Returning response from outbound request to {req_url}");
123        span.record("http.response.status_code", resp.status().as_u16());
124        response_from_reqwest(resp).await
125    }
126}
127
128/// Resolves DNS using Tokio's resolver, filtering out blocked IPs.
129struct SpinDnsResolver(BlockedNetworks);
130
131impl reqwest::dns::Resolve for SpinDnsResolver {
132    fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
133        let blocked_networks = self.0.clone();
134        Box::pin(async move {
135            let mut addrs = tokio::net::lookup_host((name.as_str(), 0))
136                .await
137                .map_err(Box::new)?
138                .collect::<Vec<_>>();
139            // Remove blocked IPs
140            crate::remove_blocked_addrs(&blocked_networks, &mut addrs).map_err(Box::new)?;
141            Ok(Box::new(addrs.into_iter()) as reqwest::dns::Addrs)
142        })
143    }
144}
145
146impl http_types::Host for crate::InstanceState {
147    fn convert_http_error(&mut self, err: HttpError) -> anyhow::Result<HttpError> {
148        Ok(err)
149    }
150}
151
152fn record_request_fields(span: &Span, req: &Request) {
153    let method = match req.method {
154        Method::Get => "GET",
155        Method::Post => "POST",
156        Method::Put => "PUT",
157        Method::Delete => "DELETE",
158        Method::Patch => "PATCH",
159        Method::Head => "HEAD",
160        Method::Options => "OPTIONS",
161    };
162    // Set otel.name to just the method name to fit with OpenTelemetry conventions
163    // <https://opentelemetry.io/docs/specs/semconv/http/http-spans/#name>
164    span.record("otel.name", method)
165        .record("http.request.method", method)
166        .record("url.full", req.uri.clone());
167    if let Ok(uri) = req.uri.parse::<http::Uri>()
168        && let Some(authority) = uri.authority()
169    {
170        span.record("server.address", authority.host());
171        if let Some(port) = authority.port() {
172            span.record("server.port", port.as_u16());
173        }
174    }
175}
176
177fn hyper_method(m: Method) -> http::Method {
178    match m {
179        Method::Get => http::Method::GET,
180        Method::Post => http::Method::POST,
181        Method::Put => http::Method::PUT,
182        Method::Delete => http::Method::DELETE,
183        Method::Patch => http::Method::PATCH,
184        Method::Head => http::Method::HEAD,
185        Method::Options => http::Method::OPTIONS,
186    }
187}
188
189async fn response_from_hyper(resp: crate::Response) -> Result<Response, HttpError> {
190    let status = resp.status().as_u16();
191
192    let headers = headers_from_map(resp.headers());
193
194    let header_bytes = std::mem::size_of::<Vec<(String, String)>>()
195        + headers
196            .iter()
197            .map(|(k, v)| std::mem::size_of::<(String, String)>() + k.len() + v.len())
198            .sum::<usize>();
199
200    let mut stream = resp.into_body().into_data_stream();
201    let mut body = Vec::new();
202    while let Some(chunk) = stream
203        .try_next()
204        .await
205        .map_err(|_| HttpError::RuntimeError)?
206    {
207        body.extend(chunk);
208        check_byte_count(header_bytes + body.len())?;
209    }
210
211    // One more check in case the body was empty:
212    check_byte_count(header_bytes + body.len())?;
213
214    Ok(Response {
215        status,
216        headers: Some(headers),
217        body: Some(body),
218    })
219}
220
221fn log_reqwest_error(err: reqwest::Error) -> HttpError {
222    let error_desc = if err.is_timeout() {
223        "timeout error"
224    } else if err.is_connect() {
225        "connection error"
226    } else if err.is_body() || err.is_decode() {
227        "message body error"
228    } else if err.is_request() {
229        "request error"
230    } else {
231        "error"
232    };
233    tracing::warn!(
234        "Outbound HTTP {}: URL {}, error detail {:?}",
235        error_desc,
236        err.url()
237            .map(|u| u.to_string())
238            .unwrap_or_else(|| "<unknown>".to_owned()),
239        err
240    );
241    HttpError::RuntimeError
242}
243
244async fn response_from_reqwest(res: reqwest::Response) -> Result<Response, HttpError> {
245    let status = res.status().as_u16();
246
247    let headers = headers_from_map(res.headers());
248
249    let header_bytes = std::mem::size_of::<Vec<(String, String)>>()
250        + headers
251            .iter()
252            .map(|(k, v)| std::mem::size_of::<(String, String)>() + k.len() + v.len())
253            .sum::<usize>();
254
255    let mut stream = res.bytes_stream();
256    let mut body = Vec::new();
257    while let Some(chunk) = stream
258        .try_next()
259        .await
260        .map_err(|_| HttpError::RuntimeError)?
261    {
262        body.extend(chunk);
263        check_byte_count(header_bytes + body.len())?;
264    }
265
266    // One more check in case the body was empty:
267    check_byte_count(header_bytes + body.len())?;
268
269    Ok(Response {
270        status,
271        headers: Some(headers),
272        body: Some(body),
273    })
274}
275
276fn check_byte_count(count: usize) -> Result<(), HttpError> {
277    if count > MAX_HOST_BUFFERED_BYTES {
278        tracing::warn!("query result exceeds limit of {MAX_HOST_BUFFERED_BYTES} bytes");
279        Err(HttpError::RuntimeError)
280    } else {
281        Ok(())
282    }
283}
284
285fn headers_from_map(map: &http::HeaderMap) -> Vec<(String, String)> {
286    map.iter()
287        .filter_map(|(key, val)| {
288            Some((
289                key.to_string(),
290                val.to_str()
291                    .ok()
292                    .or_else(|| {
293                        tracing::warn!("Non-ascii response header value for {key}");
294                        None
295                    })?
296                    .to_string(),
297            ))
298        })
299        .collect()
300}