Skip to main content

spin_factor_outbound_http/
spin.rs

1use std::sync::Arc;
2
3use futures::stream::TryStreamExt as _;
4use http_body_util::BodyExt;
5use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks;
6use spin_world::v1::{
7    http as spin_http,
8    http_types::{self, HttpError, Method, Request, Response},
9};
10use spin_world::MAX_HOST_BUFFERED_BYTES;
11use tracing::{field::Empty, instrument, Span};
12
13use crate::intercept::InterceptOutcome;
14
15impl spin_http::Host for crate::InstanceState {
16    #[instrument(name = "spin_outbound_http.send_request", skip_all,
17        fields(otel.kind = "client", url.full = Empty, http.request.method = Empty,
18        http.response.status_code = Empty, otel.name = Empty, server.address = Empty, server.port = Empty))]
19    async fn send_request(&mut self, req: Request) -> Result<Response, HttpError> {
20        self.otel.reparent_tracing_span();
21
22        let span = Span::current();
23        record_request_fields(&span, &req);
24
25        let uri = req.uri;
26        tracing::trace!("Sending outbound HTTP to {uri:?}");
27
28        if !req.params.is_empty() {
29            tracing::warn!("HTTP params field is deprecated");
30        }
31        let req_url = if !uri.starts_with('/') {
32            // Absolute URI
33            let is_allowed = self
34                .allowed_hosts
35                .check_url(&uri, "https")
36                .await
37                .unwrap_or(false);
38            if !is_allowed {
39                return Err(HttpError::DestinationNotAllowed);
40            }
41            uri.parse().map_err(|_| HttpError::InvalidUrl)?
42        } else {
43            // Relative URI ("self" request)
44            let is_allowed = self
45                .allowed_hosts
46                .check_relative_url(&["http", "https"])
47                .await
48                .unwrap_or(false);
49            if !is_allowed {
50                return Err(HttpError::DestinationNotAllowed);
51            }
52
53            let Some(origin) = &self.self_request_origin else {
54                tracing::error!(
55                    "Couldn't handle outbound HTTP request to relative URI; no origin set"
56                );
57                return Err(HttpError::InvalidUrl);
58            };
59            let path_and_query = uri.parse().map_err(|_| HttpError::InvalidUrl)?;
60            origin.clone().into_uri(Some(path_and_query))
61        };
62
63        // Build an http::Request for OutboundHttpInterceptor
64        let mut req = {
65            let mut builder = http::Request::builder()
66                .method(hyper_method(req.method))
67                .uri(&req_url);
68            for (key, val) in req.headers {
69                builder = builder.header(key, val);
70            }
71            builder.body(req.body.unwrap_or_default())
72        }
73        .map_err(|err| {
74            tracing::error!("Error building outbound request: {err}");
75            HttpError::RuntimeError
76        })?;
77
78        spin_telemetry::inject_trace_context(req.headers_mut());
79
80        if let Some(interceptor) = &self.request_interceptor {
81            let intercepted_request = std::mem::take(&mut req).into();
82            match interceptor.intercept(intercepted_request).await {
83                Ok(InterceptOutcome::Continue(intercepted_request)) => {
84                    req = intercepted_request.into_vec_request().unwrap();
85                }
86                Ok(InterceptOutcome::Complete(resp)) => return response_from_hyper(resp).await,
87                Err(err) => {
88                    tracing::error!("Error in outbound HTTP interceptor: {err}");
89                    return Err(HttpError::RuntimeError);
90                }
91            }
92        }
93
94        // Convert http::Request to reqwest::Request
95        let req = reqwest::Request::try_from(req).map_err(|_| HttpError::InvalidUrl)?;
96
97        // Allow reuse of Client's internal connection pool for multiple requests
98        // in a single component execution
99        let client = self.spin_http_client.get_or_insert_with(|| {
100            let mut builder = reqwest::Client::builder()
101                .dns_resolver(Arc::new(SpinDnsResolver(self.blocked_networks.clone())));
102            if !self.connection_pooling_enabled {
103                builder = builder.pool_max_idle_per_host(0);
104            }
105            builder.build().unwrap()
106        });
107
108        // If we're limiting concurrent outbound requests, acquire a permit
109        // Note: since we don't have access to the underlying connection, we can only
110        // limit the number of concurrent requests, not connections.
111        let permit = crate::concurrent_outbound_connections::acquire_semaphore(
112            "spin",
113            &self.concurrent_outbound_connections_semaphore,
114        )
115        .await;
116        let resp = client.execute(req).await.map_err(log_reqwest_error)?;
117        drop(permit);
118
119        tracing::trace!("Returning response from outbound request to {req_url}");
120        span.record("http.response.status_code", resp.status().as_u16());
121        response_from_reqwest(resp).await
122    }
123}
124
125/// Resolves DNS using Tokio's resolver, filtering out blocked IPs.
126struct SpinDnsResolver(BlockedNetworks);
127
128impl reqwest::dns::Resolve for SpinDnsResolver {
129    fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
130        let blocked_networks = self.0.clone();
131        Box::pin(async move {
132            let mut addrs = tokio::net::lookup_host((name.as_str(), 0))
133                .await
134                .map_err(Box::new)?
135                .collect::<Vec<_>>();
136            // Remove blocked IPs
137            crate::remove_blocked_addrs(&blocked_networks, &mut addrs).map_err(Box::new)?;
138            Ok(Box::new(addrs.into_iter()) as reqwest::dns::Addrs)
139        })
140    }
141}
142
143impl http_types::Host for crate::InstanceState {
144    fn convert_http_error(&mut self, err: HttpError) -> anyhow::Result<HttpError> {
145        Ok(err)
146    }
147}
148
149fn record_request_fields(span: &Span, req: &Request) {
150    let method = match req.method {
151        Method::Get => "GET",
152        Method::Post => "POST",
153        Method::Put => "PUT",
154        Method::Delete => "DELETE",
155        Method::Patch => "PATCH",
156        Method::Head => "HEAD",
157        Method::Options => "OPTIONS",
158    };
159    // Set otel.name to just the method name to fit with OpenTelemetry conventions
160    // <https://opentelemetry.io/docs/specs/semconv/http/http-spans/#name>
161    span.record("otel.name", method)
162        .record("http.request.method", method)
163        .record("url.full", req.uri.clone());
164    if let Ok(uri) = req.uri.parse::<http::Uri>() {
165        if let Some(authority) = uri.authority() {
166            span.record("server.address", authority.host());
167            if let Some(port) = authority.port() {
168                span.record("server.port", port.as_u16());
169            }
170        }
171    }
172}
173
174fn hyper_method(m: Method) -> http::Method {
175    match m {
176        Method::Get => http::Method::GET,
177        Method::Post => http::Method::POST,
178        Method::Put => http::Method::PUT,
179        Method::Delete => http::Method::DELETE,
180        Method::Patch => http::Method::PATCH,
181        Method::Head => http::Method::HEAD,
182        Method::Options => http::Method::OPTIONS,
183    }
184}
185
186async fn response_from_hyper(resp: crate::Response) -> Result<Response, HttpError> {
187    let status = resp.status().as_u16();
188
189    let headers = headers_from_map(resp.headers());
190
191    let header_bytes = std::mem::size_of::<Vec<(String, String)>>()
192        + headers
193            .iter()
194            .map(|(k, v)| std::mem::size_of::<(String, String)>() + k.len() + v.len())
195            .sum::<usize>();
196
197    let mut stream = resp.into_body().into_data_stream();
198    let mut body = Vec::new();
199    while let Some(chunk) = stream
200        .try_next()
201        .await
202        .map_err(|_| HttpError::RuntimeError)?
203    {
204        body.extend(chunk);
205        check_byte_count(header_bytes + body.len())?;
206    }
207
208    // One more check in case the body was empty:
209    check_byte_count(header_bytes + body.len())?;
210
211    Ok(Response {
212        status,
213        headers: Some(headers),
214        body: Some(body),
215    })
216}
217
218fn log_reqwest_error(err: reqwest::Error) -> HttpError {
219    let error_desc = if err.is_timeout() {
220        "timeout error"
221    } else if err.is_connect() {
222        "connection error"
223    } else if err.is_body() || err.is_decode() {
224        "message body error"
225    } else if err.is_request() {
226        "request error"
227    } else {
228        "error"
229    };
230    tracing::warn!(
231        "Outbound HTTP {}: URL {}, error detail {:?}",
232        error_desc,
233        err.url()
234            .map(|u| u.to_string())
235            .unwrap_or_else(|| "<unknown>".to_owned()),
236        err
237    );
238    HttpError::RuntimeError
239}
240
241async fn response_from_reqwest(res: reqwest::Response) -> Result<Response, HttpError> {
242    let status = res.status().as_u16();
243
244    let headers = headers_from_map(res.headers());
245
246    let header_bytes = std::mem::size_of::<Vec<(String, String)>>()
247        + headers
248            .iter()
249            .map(|(k, v)| std::mem::size_of::<(String, String)>() + k.len() + v.len())
250            .sum::<usize>();
251
252    let mut stream = res.bytes_stream();
253    let mut body = Vec::new();
254    while let Some(chunk) = stream
255        .try_next()
256        .await
257        .map_err(|_| HttpError::RuntimeError)?
258    {
259        body.extend(chunk);
260        check_byte_count(header_bytes + body.len())?;
261    }
262
263    // One more check in case the body was empty:
264    check_byte_count(header_bytes + body.len())?;
265
266    Ok(Response {
267        status,
268        headers: Some(headers),
269        body: Some(body),
270    })
271}
272
273fn check_byte_count(count: usize) -> Result<(), HttpError> {
274    if count > MAX_HOST_BUFFERED_BYTES {
275        tracing::warn!("query result exceeds limit of {MAX_HOST_BUFFERED_BYTES} bytes");
276        Err(HttpError::RuntimeError)
277    } else {
278        Ok(())
279    }
280}
281
282fn headers_from_map(map: &http::HeaderMap) -> Vec<(String, String)> {
283    map.iter()
284        .filter_map(|(key, val)| {
285            Some((
286                key.to_string(),
287                val.to_str()
288                    .ok()
289                    .or_else(|| {
290                        tracing::warn!("Non-ascii response header value for {key}");
291                        None
292                    })?
293                    .to_string(),
294            ))
295        })
296        .collect()
297}