Skip to main content

spin_trigger_http/
headers.rs

1use std::{
2    borrow::Cow,
3    net::SocketAddr,
4    str::{self, FromStr},
5};
6
7use anyhow::Result;
8use http::Uri;
9use hyper::Request;
10use spin_factor_outbound_networking::config::allowed_hosts::is_service_chaining_host;
11use spin_http::routes::RouteMatch;
12
13use crate::Body;
14
15// We need to make the following pieces of information available to both executors.
16// While the values we set are identical, the way they are passed to the
17// modules is going to be different, so each executor must must use the info
18// in its standardized way (environment variables for the Wagi executor, and custom headers
19// for the Spin HTTP executor).
20pub const FULL_URL: [&str; 2] = ["SPIN_FULL_URL", "X_FULL_URL"];
21pub const PATH_INFO: [&str; 2] = ["SPIN_PATH_INFO", "PATH_INFO"];
22pub const MATCHED_ROUTE: [&str; 2] = ["SPIN_MATCHED_ROUTE", "X_MATCHED_ROUTE"];
23pub const COMPONENT_ROUTE: [&str; 2] = ["SPIN_COMPONENT_ROUTE", "X_COMPONENT_ROUTE"];
24pub const RAW_COMPONENT_ROUTE: [&str; 2] = ["SPIN_RAW_COMPONENT_ROUTE", "X_RAW_COMPONENT_ROUTE"];
25pub const BASE_PATH: [&str; 2] = ["SPIN_BASE_PATH", "X_BASE_PATH"];
26pub const CLIENT_ADDR: [&str; 2] = ["SPIN_CLIENT_ADDR", "X_CLIENT_ADDR"];
27
28// Header key/value pairs that use copy on write to avoid allocation
29pub type HeaderPair<'a> = ([Cow<'static, str>; 2], Cow<'a, str>);
30
31/// Compute the default headers to be passed to the component.
32pub fn compute_default_headers<'a>(
33    uri: &Uri,
34    host: &str,
35    route_match: &'a RouteMatch,
36    client_addr: SocketAddr,
37) -> anyhow::Result<Vec<HeaderPair<'a>>> {
38    fn owned(strs: &[&'static str; 2]) -> [Cow<'static, str>; 2] {
39        [strs[0].into(), strs[1].into()]
40    }
41
42    let owned_full_url = owned(&FULL_URL);
43    let owned_path_info = owned(&PATH_INFO);
44    let owned_matched_route = owned(&MATCHED_ROUTE);
45    let owned_component_route = owned(&COMPONENT_ROUTE);
46    let owned_raw_component_route = owned(&RAW_COMPONENT_ROUTE);
47    let owned_base_path = owned(&BASE_PATH);
48    let owned_client_addr = owned(&CLIENT_ADDR);
49
50    let mut res = vec![];
51    let abs_path = uri
52        .path_and_query()
53        .expect("cannot get path and query")
54        .as_str();
55
56    let path_info = route_match.trailing_wildcard();
57
58    let scheme = uri.scheme_str().unwrap_or("http");
59
60    let full_url = format!("{scheme}://{host}{abs_path}");
61
62    res.push((owned_path_info, path_info));
63    res.push((owned_full_url, full_url.into()));
64    res.push((owned_matched_route, route_match.based_route().into()));
65
66    res.push((owned_base_path, "/".into()));
67    res.push((owned_raw_component_route, route_match.raw_route().into()));
68    res.push((
69        owned_component_route,
70        route_match.raw_route_or_prefix().into(),
71    ));
72    res.push((owned_client_addr, client_addr.to_string().into()));
73
74    for (wild_name, wild_value) in route_match.named_wildcards() {
75        let wild_header = format!("SPIN_PATH_MATCH_{}", wild_name.to_ascii_uppercase()).into();
76        let wild_wagi_header = format!("X_PATH_MATCH_{}", wild_name.to_ascii_uppercase()).into();
77        res.push(([wild_header, wild_wagi_header], wild_value.into()));
78    }
79
80    Ok(res)
81}
82
83pub fn strip_forbidden_headers(req: &mut Request<Body>) {
84    let headers = req.headers_mut();
85    if let Some(host_header) = headers.get("Host")
86        && let Ok(host) = host_header.to_str()
87        && is_service_chaining_host(host)
88    {
89        headers.remove("Host");
90    }
91}
92
93pub fn prepare_request_headers(
94    req: &Request<Body>,
95    route_match: &RouteMatch,
96    client_addr: SocketAddr,
97) -> Result<Vec<(String, String)>> {
98    let mut res = Vec::new();
99    for (name, value) in req
100        .headers()
101        .iter()
102        .map(|(name, value)| (name.to_string(), std::str::from_utf8(value.as_bytes())))
103    {
104        let value = value?.to_string();
105        res.push((name, value));
106    }
107
108    let default_host = http::HeaderValue::from_str("localhost")?;
109    let host = std::str::from_utf8(
110        req.headers()
111            .get("host")
112            .unwrap_or(&default_host)
113            .as_bytes(),
114    )?;
115
116    // Set the environment information (path info, base path, etc) as headers.
117    // In the future, we might want to have this information in a context
118    // object as opposed to headers.
119    for (keys, val) in compute_default_headers(req.uri(), host, route_match, client_addr)? {
120        res.push((prepare_header_key(&keys[0]), val.into_owned()));
121    }
122
123    Ok(res)
124}
125
126pub fn append_headers(
127    map: &mut http::HeaderMap,
128    headers: Option<Vec<(String, String)>>,
129) -> Result<()> {
130    if let Some(src) = headers {
131        for (k, v) in src.iter() {
132            map.insert(
133                http::header::HeaderName::from_str(k)?,
134                http::header::HeaderValue::from_str(v)?,
135            );
136        }
137    };
138
139    Ok(())
140}
141
142fn prepare_header_key(key: &str) -> String {
143    key.replace('_', "-").to_ascii_lowercase()
144}
145
146#[cfg(test)]
147mod tests {
148    use std::borrow::Cow;
149
150    use super::*;
151    use anyhow::Result;
152    use spin_http::routes::Router;
153
154    #[test]
155    fn test_spin_header_keys() {
156        assert_eq!(
157            prepare_header_key("SPIN_FULL_URL"),
158            "spin-full-url".to_string()
159        );
160        assert_eq!(
161            prepare_header_key("SPIN_PATH_INFO"),
162            "spin-path-info".to_string()
163        );
164        assert_eq!(
165            prepare_header_key("SPIN_RAW_COMPONENT_ROUTE"),
166            "spin-raw-component-route".to_string()
167        );
168    }
169
170    #[test]
171    fn test_default_headers() -> Result<()> {
172        let scheme = "https";
173        let host = "fermyon.dev";
174        let trigger_route = "/foo/...";
175        let component_path = "/foo";
176        let path_info = "/bar";
177        let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap();
178
179        let req_uri = format!(
180            "{}://{}{}{}?key1=value1&key2=value2",
181            scheme, host, component_path, path_info
182        );
183
184        let req = http::Request::builder()
185            .method("POST")
186            .uri(req_uri)
187            .body("")?;
188
189        let router = Router::build(
190            "/",
191            [(
192                &spin_http::routes::TriggerLookupKey::Component("DUMMY".into()),
193                &trigger_route.into(),
194            )],
195            None,
196        )?;
197        let route_match = router.route("/foo/bar")?;
198
199        let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?;
200
201        assert_eq!(
202            search(&FULL_URL, &default_headers).unwrap(),
203            "https://fermyon.dev/foo/bar?key1=value1&key2=value2".to_string()
204        );
205        assert_eq!(
206            search(&PATH_INFO, &default_headers).unwrap(),
207            "/bar".to_string()
208        );
209        assert_eq!(
210            search(&MATCHED_ROUTE, &default_headers).unwrap(),
211            "/foo/...".to_string()
212        );
213        assert_eq!(
214            search(&BASE_PATH, &default_headers).unwrap(),
215            "/".to_string()
216        );
217        assert_eq!(
218            search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(),
219            "/foo/...".to_string()
220        );
221        assert_eq!(
222            search(&COMPONENT_ROUTE, &default_headers).unwrap(),
223            "/foo".to_string()
224        );
225        assert_eq!(
226            search(&CLIENT_ADDR, &default_headers).unwrap(),
227            "127.0.0.1:8777".to_string()
228        );
229
230        Ok(())
231    }
232
233    #[test]
234    fn test_default_headers_with_named_wildcards() -> Result<()> {
235        let scheme = "https";
236        let host = "fermyon.dev";
237        let trigger_route = "/foo/:userid/...";
238        let component_path = "/foo";
239        let path_info = "/bar";
240        let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap();
241
242        let req_uri = format!(
243            "{}://{}{}/42{}?key1=value1&key2=value2",
244            scheme, host, component_path, path_info
245        );
246
247        let req = http::Request::builder()
248            .method("POST")
249            .uri(req_uri)
250            .body("")?;
251
252        let router = Router::build(
253            "/",
254            [(
255                &spin_http::routes::TriggerLookupKey::Component("DUMMY".into()),
256                &trigger_route.into(),
257            )],
258            None,
259        )?;
260        let route_match = router.route("/foo/42/bar")?;
261
262        let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?;
263
264        assert_eq!(
265            search(&FULL_URL, &default_headers).unwrap(),
266            "https://fermyon.dev/foo/42/bar?key1=value1&key2=value2".to_string()
267        );
268        assert_eq!(
269            search(&PATH_INFO, &default_headers).unwrap(),
270            "/bar".to_string()
271        );
272        assert_eq!(
273            search(&MATCHED_ROUTE, &default_headers).unwrap(),
274            "/foo/:userid/...".to_string()
275        );
276        assert_eq!(
277            search(&BASE_PATH, &default_headers).unwrap(),
278            "/".to_string()
279        );
280        assert_eq!(
281            search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(),
282            "/foo/:userid/...".to_string()
283        );
284        assert_eq!(
285            search(&COMPONENT_ROUTE, &default_headers).unwrap(),
286            "/foo/:userid".to_string()
287        );
288        assert_eq!(
289            search(&CLIENT_ADDR, &default_headers).unwrap(),
290            "127.0.0.1:8777".to_string()
291        );
292
293        assert_eq!(
294            search(
295                &["SPIN_PATH_MATCH_USERID", "X_PATH_MATCH_USERID"],
296                &default_headers
297            )
298            .unwrap(),
299            "42".to_string()
300        );
301
302        Ok(())
303    }
304
305    #[test]
306    fn forbidden_headers_are_removed() {
307        let mut req = Request::get("http://test.spin.internal")
308            .header("Host", "test.spin.internal")
309            .header("accept", "text/plain")
310            .body(Default::default())
311            .unwrap();
312
313        strip_forbidden_headers(&mut req);
314
315        assert_eq!(1, req.headers().len());
316        assert!(req.headers().get("Host").is_none());
317
318        let mut req = Request::get("http://test.spin.internal")
319            .header("Host", "test.spin.internal:1234")
320            .header("accept", "text/plain")
321            .body(Default::default())
322            .unwrap();
323
324        strip_forbidden_headers(&mut req);
325
326        assert_eq!(1, req.headers().len());
327        assert!(req.headers().get("Host").is_none());
328    }
329
330    #[test]
331    fn non_forbidden_headers_are_not_removed() {
332        let mut req = Request::get("http://test.example.com")
333            .header("Host", "test.example.org")
334            .header("accept", "text/plain")
335            .body(Default::default())
336            .unwrap();
337
338        strip_forbidden_headers(&mut req);
339
340        assert_eq!(2, req.headers().len());
341        assert!(req.headers().get("Host").is_some());
342    }
343
344    fn search(
345        keys: &[&str; 2],
346        headers: &[([Cow<'static, str>; 2], Cow<'_, str>)],
347    ) -> Option<String> {
348        let mut res: Option<String> = None;
349        for (k, v) in headers {
350            if k[0] == keys[0] && k[1] == keys[1] {
351                res = Some(v.as_ref().to_owned());
352            }
353        }
354
355        res
356    }
357}