Skip to main content

spin_factor_outbound_http/
spin.rs

1use std::sync::Arc;
2
3use futures::stream::TryStreamExt as _;
4use http_body_util::BodyExt;
5use opentelemetry_semantic_conventions::attribute as otel_attribute;
6use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks;
7use spin_world::MAX_HOST_BUFFERED_BYTES;
8use spin_world::v1::{
9    http as spin_http,
10    http_types::{self, HttpError, Method, Request, Response},
11};
12use tracing::{Span, field::Empty, instrument};
13
14use crate::intercept::InterceptOutcome;
15
16impl spin_http::Host for crate::InstanceState {
17    #[instrument(name = "spin_outbound_http.send_request", skip_all,
18        fields(otel.kind = "client", {otel_attribute::URL_FULL} = Empty, {otel_attribute::HTTP_REQUEST_METHOD} = Empty,
19        {otel_attribute::HTTP_RESPONSE_STATUS_CODE} = Empty, otel.name = Empty, {otel_attribute::SERVER_ADDRESS} = Empty, {otel_attribute::SERVER_PORT} = Empty))]
20    async fn send_request(&mut self, req: Request) -> Result<Response, HttpError> {
21        self.hooks.otel.reparent_tracing_span();
22
23        let span = Span::current();
24        record_request_fields(&span, &req);
25
26        let uri = req.uri;
27        tracing::trace!("Sending outbound HTTP to {uri:?}");
28
29        if !req.params.is_empty() {
30            tracing::warn!("HTTP params field is deprecated");
31        }
32        let req_url = if !uri.starts_with('/') {
33            // Absolute URI
34            let is_allowed = self
35                .hooks
36                .allowed_hosts
37                .check_url(&uri, "https")
38                .await
39                .unwrap_or(false);
40            if !is_allowed {
41                return Err(HttpError::DestinationNotAllowed);
42            }
43            uri.parse().map_err(|_| HttpError::InvalidUrl)?
44        } else {
45            // Relative URI ("self" request)
46            let is_allowed = self
47                .hooks
48                .allowed_hosts
49                .check_relative_url(&["http", "https"])
50                .await
51                .unwrap_or(false);
52            if !is_allowed {
53                return Err(HttpError::DestinationNotAllowed);
54            }
55
56            let Some(origin) = &self.hooks.self_request_origin else {
57                tracing::error!(
58                    "Couldn't handle outbound HTTP request to relative URI; no origin set"
59                );
60                return Err(HttpError::InvalidUrl);
61            };
62            let path_and_query = uri.parse().map_err(|_| HttpError::InvalidUrl)?;
63            origin.clone().into_uri(Some(path_and_query))
64        };
65
66        // Build an http::Request for OutboundHttpInterceptor
67        let mut req = {
68            let mut builder = http::Request::builder()
69                .method(hyper_method(req.method))
70                .uri(&req_url);
71            for (key, val) in req.headers {
72                builder = builder.header(key, val);
73            }
74            builder.body(req.body.unwrap_or_default())
75        }
76        .map_err(|err| {
77            tracing::error!("Error building outbound request: {err}");
78            HttpError::RuntimeError
79        })?;
80
81        spin_telemetry::inject_trace_context(req.headers_mut());
82
83        if let Some(interceptor) = &self.hooks.request_interceptor {
84            let intercepted_request = std::mem::take(&mut req).into();
85            match interceptor.intercept(intercepted_request).await {
86                Ok(InterceptOutcome::Continue(intercepted_request)) => {
87                    req = intercepted_request.into_vec_request().unwrap();
88                }
89                Ok(InterceptOutcome::Complete(resp)) => return response_from_hyper(resp).await,
90                Err(err) => {
91                    tracing::error!("Error in outbound HTTP interceptor: {err}");
92                    return Err(HttpError::RuntimeError);
93                }
94            }
95        }
96
97        // Convert http::Request to reqwest::Request
98        let req = reqwest::Request::try_from(req).map_err(|_| HttpError::InvalidUrl)?;
99
100        // Allow reuse of Client's internal connection pool for multiple requests
101        // in a single component execution
102        let client = self.hooks.spin_http_client.get_or_insert_with(|| {
103            let mut builder = reqwest::Client::builder().dns_resolver(Arc::new(SpinDnsResolver(
104                self.hooks.blocked_networks.clone(),
105            )));
106            if !self.hooks.connection_pooling_enabled {
107                builder = builder.pool_max_idle_per_host(0);
108            }
109            builder.build().unwrap()
110        });
111
112        // If we're limiting concurrent outbound requests, acquire a permit
113        // Note: since we don't have access to the underlying connection, we can only
114        // limit the number of concurrent requests, not connections.
115        let permit = self
116            .hooks
117            .semaphore
118            .acquire()
119            .await
120            .map_err(|_| HttpError::TooManyRequests)?;
121        let resp = client.execute(req).await.map_err(log_reqwest_error)?;
122        drop(permit);
123
124        tracing::trace!("Returning response from outbound request to {req_url}");
125        span.record(
126            otel_attribute::HTTP_RESPONSE_STATUS_CODE,
127            resp.status().as_u16(),
128        );
129        response_from_reqwest(resp).await
130    }
131}
132
133/// Resolves DNS using Tokio's resolver, filtering out blocked IPs.
134struct SpinDnsResolver(BlockedNetworks);
135
136impl reqwest::dns::Resolve for SpinDnsResolver {
137    fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
138        let blocked_networks = self.0.clone();
139        Box::pin(async move {
140            let mut addrs = tokio::net::lookup_host((name.as_str(), 0))
141                .await
142                .map_err(Box::new)?
143                .collect::<Vec<_>>();
144            // Remove blocked IPs
145            crate::remove_blocked_addrs(&blocked_networks, &mut addrs).map_err(Box::new)?;
146            Ok(Box::new(addrs.into_iter()) as reqwest::dns::Addrs)
147        })
148    }
149}
150
151impl http_types::Host for crate::InstanceState {
152    fn convert_http_error(&mut self, err: HttpError) -> anyhow::Result<HttpError> {
153        Ok(err)
154    }
155}
156
157fn record_request_fields(span: &Span, req: &Request) {
158    let method = match req.method {
159        Method::Get => "GET",
160        Method::Post => "POST",
161        Method::Put => "PUT",
162        Method::Delete => "DELETE",
163        Method::Patch => "PATCH",
164        Method::Head => "HEAD",
165        Method::Options => "OPTIONS",
166    };
167    // Set otel.name to just the method name to fit with OpenTelemetry conventions
168    // <https://opentelemetry.io/docs/specs/semconv/http/http-spans/#name>
169    span.record("otel.name", method)
170        .record(otel_attribute::HTTP_REQUEST_METHOD, method)
171        .record(otel_attribute::URL_FULL, req.uri.clone());
172    if let Ok(uri) = req.uri.parse::<http::Uri>()
173        && let Some(authority) = uri.authority()
174    {
175        span.record(otel_attribute::SERVER_ADDRESS, authority.host());
176        if let Some(port) = authority.port() {
177            span.record(otel_attribute::SERVER_PORT, port.as_u16());
178        }
179    }
180}
181
182fn hyper_method(m: Method) -> http::Method {
183    match m {
184        Method::Get => http::Method::GET,
185        Method::Post => http::Method::POST,
186        Method::Put => http::Method::PUT,
187        Method::Delete => http::Method::DELETE,
188        Method::Patch => http::Method::PATCH,
189        Method::Head => http::Method::HEAD,
190        Method::Options => http::Method::OPTIONS,
191    }
192}
193
194async fn response_from_hyper(resp: crate::Response) -> Result<Response, HttpError> {
195    let status = resp.status().as_u16();
196
197    let headers = headers_from_map(resp.headers());
198
199    let header_bytes = std::mem::size_of::<Vec<(String, String)>>()
200        + headers
201            .iter()
202            .map(|(k, v)| std::mem::size_of::<(String, String)>() + k.len() + v.len())
203            .sum::<usize>();
204
205    let mut stream = resp.into_body().into_data_stream();
206    let mut body = Vec::new();
207    while let Some(chunk) = stream
208        .try_next()
209        .await
210        .map_err(|_| HttpError::RuntimeError)?
211    {
212        body.extend(chunk);
213        check_byte_count(header_bytes + body.len())?;
214    }
215
216    // One more check in case the body was empty:
217    check_byte_count(header_bytes + body.len())?;
218
219    Ok(Response {
220        status,
221        headers: Some(headers),
222        body: Some(body),
223    })
224}
225
226fn log_reqwest_error(err: reqwest::Error) -> HttpError {
227    let error_desc = if err.is_timeout() {
228        "timeout error"
229    } else if err.is_connect() {
230        "connection error"
231    } else if err.is_body() || err.is_decode() {
232        "message body error"
233    } else if err.is_request() {
234        "request error"
235    } else {
236        "error"
237    };
238    tracing::warn!(
239        "Outbound HTTP {}: URL {}, error detail {:?}",
240        error_desc,
241        err.url()
242            .map(|u| u.to_string())
243            .unwrap_or_else(|| "<unknown>".to_owned()),
244        err
245    );
246    HttpError::RuntimeError
247}
248
249async fn response_from_reqwest(res: reqwest::Response) -> Result<Response, HttpError> {
250    let status = res.status().as_u16();
251
252    let headers = headers_from_map(res.headers());
253
254    let header_bytes = std::mem::size_of::<Vec<(String, String)>>()
255        + headers
256            .iter()
257            .map(|(k, v)| std::mem::size_of::<(String, String)>() + k.len() + v.len())
258            .sum::<usize>();
259
260    let mut stream = res.bytes_stream();
261    let mut body = Vec::new();
262    while let Some(chunk) = stream
263        .try_next()
264        .await
265        .map_err(|_| HttpError::RuntimeError)?
266    {
267        body.extend(chunk);
268        check_byte_count(header_bytes + body.len())?;
269    }
270
271    // One more check in case the body was empty:
272    check_byte_count(header_bytes + body.len())?;
273
274    Ok(Response {
275        status,
276        headers: Some(headers),
277        body: Some(body),
278    })
279}
280
281fn check_byte_count(count: usize) -> Result<(), HttpError> {
282    if count > MAX_HOST_BUFFERED_BYTES {
283        tracing::warn!("query result exceeds limit of {MAX_HOST_BUFFERED_BYTES} bytes");
284        Err(HttpError::RuntimeError)
285    } else {
286        Ok(())
287    }
288}
289
290fn headers_from_map(map: &http::HeaderMap) -> Vec<(String, String)> {
291    map.iter()
292        .filter_map(|(key, val)| {
293            Some((
294                key.to_string(),
295                val.to_str()
296                    .ok()
297                    .or_else(|| {
298                        tracing::warn!("Non-ascii response header value for {key}");
299                        None
300                    })?
301                    .to_string(),
302            ))
303        })
304        .collect()
305}