spin_trigger_http/
headers.rs

1use std::{
2    borrow::Cow,
3    net::SocketAddr,
4    str::{self, FromStr},
5};
6
7use anyhow::Result;
8use http::Uri;
9use hyper::Request;
10use spin_factor_outbound_networking::is_service_chaining_host;
11use spin_http::routes::RouteMatch;
12
13use crate::Body;
14
15// We need to make the following pieces of information available to both executors.
16// While the values we set are identical, the way they are passed to the
17// modules is going to be different, so each executor must must use the info
18// in its standardized way (environment variables for the Wagi executor, and custom headers
19// for the Spin HTTP executor).
20pub const FULL_URL: [&str; 2] = ["SPIN_FULL_URL", "X_FULL_URL"];
21pub const PATH_INFO: [&str; 2] = ["SPIN_PATH_INFO", "PATH_INFO"];
22pub const MATCHED_ROUTE: [&str; 2] = ["SPIN_MATCHED_ROUTE", "X_MATCHED_ROUTE"];
23pub const COMPONENT_ROUTE: [&str; 2] = ["SPIN_COMPONENT_ROUTE", "X_COMPONENT_ROUTE"];
24pub const RAW_COMPONENT_ROUTE: [&str; 2] = ["SPIN_RAW_COMPONENT_ROUTE", "X_RAW_COMPONENT_ROUTE"];
25pub const BASE_PATH: [&str; 2] = ["SPIN_BASE_PATH", "X_BASE_PATH"];
26pub const CLIENT_ADDR: [&str; 2] = ["SPIN_CLIENT_ADDR", "X_CLIENT_ADDR"];
27
28// Header key/value pairs that use copy on write to avoid allocation
29pub type HeaderPair<'a> = ([Cow<'static, str>; 2], Cow<'a, str>);
30
31/// Compute the default headers to be passed to the component.
32pub fn compute_default_headers<'a>(
33    uri: &Uri,
34    host: &str,
35    route_match: &'a RouteMatch,
36    client_addr: SocketAddr,
37) -> anyhow::Result<Vec<HeaderPair<'a>>> {
38    fn owned(strs: &[&'static str; 2]) -> [Cow<'static, str>; 2] {
39        [strs[0].into(), strs[1].into()]
40    }
41
42    let owned_full_url = owned(&FULL_URL);
43    let owned_path_info = owned(&PATH_INFO);
44    let owned_matched_route = owned(&MATCHED_ROUTE);
45    let owned_component_route = owned(&COMPONENT_ROUTE);
46    let owned_raw_component_route = owned(&RAW_COMPONENT_ROUTE);
47    let owned_base_path = owned(&BASE_PATH);
48    let owned_client_addr = owned(&CLIENT_ADDR);
49
50    let mut res = vec![];
51    let abs_path = uri
52        .path_and_query()
53        .expect("cannot get path and query")
54        .as_str();
55
56    let path_info = route_match.trailing_wildcard();
57
58    let scheme = uri.scheme_str().unwrap_or("http");
59
60    let full_url = format!("{}://{}{}", scheme, host, abs_path);
61
62    res.push((owned_path_info, path_info));
63    res.push((owned_full_url, full_url.into()));
64    res.push((owned_matched_route, route_match.based_route().into()));
65
66    res.push((owned_base_path, "/".into()));
67    res.push((owned_raw_component_route, route_match.raw_route().into()));
68    res.push((
69        owned_component_route,
70        route_match.raw_route_or_prefix().into(),
71    ));
72    res.push((owned_client_addr, client_addr.to_string().into()));
73
74    for (wild_name, wild_value) in route_match.named_wildcards() {
75        let wild_header = format!("SPIN_PATH_MATCH_{}", wild_name.to_ascii_uppercase()).into();
76        let wild_wagi_header = format!("X_PATH_MATCH_{}", wild_name.to_ascii_uppercase()).into();
77        res.push(([wild_header, wild_wagi_header], wild_value.into()));
78    }
79
80    Ok(res)
81}
82
83pub fn strip_forbidden_headers(req: &mut Request<Body>) {
84    let headers = req.headers_mut();
85    if let Some(host_header) = headers.get("Host") {
86        if let Ok(host) = host_header.to_str() {
87            if is_service_chaining_host(host) {
88                headers.remove("Host");
89            }
90        }
91    }
92}
93
94pub fn prepare_request_headers(
95    req: &Request<Body>,
96    route_match: &RouteMatch,
97    client_addr: SocketAddr,
98) -> Result<Vec<(String, String)>> {
99    let mut res = Vec::new();
100    for (name, value) in req
101        .headers()
102        .iter()
103        .map(|(name, value)| (name.to_string(), std::str::from_utf8(value.as_bytes())))
104    {
105        let value = value?.to_string();
106        res.push((name, value));
107    }
108
109    let default_host = http::HeaderValue::from_str("localhost")?;
110    let host = std::str::from_utf8(
111        req.headers()
112            .get("host")
113            .unwrap_or(&default_host)
114            .as_bytes(),
115    )?;
116
117    // Set the environment information (path info, base path, etc) as headers.
118    // In the future, we might want to have this information in a context
119    // object as opposed to headers.
120    for (keys, val) in compute_default_headers(req.uri(), host, route_match, client_addr)? {
121        res.push((prepare_header_key(&keys[0]), val.into_owned()));
122    }
123
124    Ok(res)
125}
126
127pub fn append_headers(
128    map: &mut http::HeaderMap,
129    headers: Option<Vec<(String, String)>>,
130) -> Result<()> {
131    if let Some(src) = headers {
132        for (k, v) in src.iter() {
133            map.insert(
134                http::header::HeaderName::from_str(k)?,
135                http::header::HeaderValue::from_str(v)?,
136            );
137        }
138    };
139
140    Ok(())
141}
142
143fn prepare_header_key(key: &str) -> String {
144    key.replace('_', "-").to_ascii_lowercase()
145}
146
147#[cfg(test)]
148mod tests {
149    use std::borrow::Cow;
150
151    use super::*;
152    use anyhow::Result;
153    use spin_http::routes::Router;
154
155    #[test]
156    fn test_spin_header_keys() {
157        assert_eq!(
158            prepare_header_key("SPIN_FULL_URL"),
159            "spin-full-url".to_string()
160        );
161        assert_eq!(
162            prepare_header_key("SPIN_PATH_INFO"),
163            "spin-path-info".to_string()
164        );
165        assert_eq!(
166            prepare_header_key("SPIN_RAW_COMPONENT_ROUTE"),
167            "spin-raw-component-route".to_string()
168        );
169    }
170
171    #[test]
172    fn test_default_headers() -> Result<()> {
173        let scheme = "https";
174        let host = "fermyon.dev";
175        let trigger_route = "/foo/...";
176        let component_path = "/foo";
177        let path_info = "/bar";
178        let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap();
179
180        let req_uri = format!(
181            "{}://{}{}{}?key1=value1&key2=value2",
182            scheme, host, component_path, path_info
183        );
184
185        let req = http::Request::builder()
186            .method("POST")
187            .uri(req_uri)
188            .body("")?;
189
190        let router = Router::build("/", [("DUMMY", &trigger_route.into())], None)?;
191        let route_match = router.route("/foo/bar")?;
192
193        let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?;
194
195        assert_eq!(
196            search(&FULL_URL, &default_headers).unwrap(),
197            "https://fermyon.dev/foo/bar?key1=value1&key2=value2".to_string()
198        );
199        assert_eq!(
200            search(&PATH_INFO, &default_headers).unwrap(),
201            "/bar".to_string()
202        );
203        assert_eq!(
204            search(&MATCHED_ROUTE, &default_headers).unwrap(),
205            "/foo/...".to_string()
206        );
207        assert_eq!(
208            search(&BASE_PATH, &default_headers).unwrap(),
209            "/".to_string()
210        );
211        assert_eq!(
212            search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(),
213            "/foo/...".to_string()
214        );
215        assert_eq!(
216            search(&COMPONENT_ROUTE, &default_headers).unwrap(),
217            "/foo".to_string()
218        );
219        assert_eq!(
220            search(&CLIENT_ADDR, &default_headers).unwrap(),
221            "127.0.0.1:8777".to_string()
222        );
223
224        Ok(())
225    }
226
227    #[test]
228    fn test_default_headers_with_named_wildcards() -> Result<()> {
229        let scheme = "https";
230        let host = "fermyon.dev";
231        let trigger_route = "/foo/:userid/...";
232        let component_path = "/foo";
233        let path_info = "/bar";
234        let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap();
235
236        let req_uri = format!(
237            "{}://{}{}/42{}?key1=value1&key2=value2",
238            scheme, host, component_path, path_info
239        );
240
241        let req = http::Request::builder()
242            .method("POST")
243            .uri(req_uri)
244            .body("")?;
245
246        let router = Router::build("/", [("DUMMY", &trigger_route.into())], None)?;
247        let route_match = router.route("/foo/42/bar")?;
248
249        let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?;
250
251        assert_eq!(
252            search(&FULL_URL, &default_headers).unwrap(),
253            "https://fermyon.dev/foo/42/bar?key1=value1&key2=value2".to_string()
254        );
255        assert_eq!(
256            search(&PATH_INFO, &default_headers).unwrap(),
257            "/bar".to_string()
258        );
259        assert_eq!(
260            search(&MATCHED_ROUTE, &default_headers).unwrap(),
261            "/foo/:userid/...".to_string()
262        );
263        assert_eq!(
264            search(&BASE_PATH, &default_headers).unwrap(),
265            "/".to_string()
266        );
267        assert_eq!(
268            search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(),
269            "/foo/:userid/...".to_string()
270        );
271        assert_eq!(
272            search(&COMPONENT_ROUTE, &default_headers).unwrap(),
273            "/foo/:userid".to_string()
274        );
275        assert_eq!(
276            search(&CLIENT_ADDR, &default_headers).unwrap(),
277            "127.0.0.1:8777".to_string()
278        );
279
280        assert_eq!(
281            search(
282                &["SPIN_PATH_MATCH_USERID", "X_PATH_MATCH_USERID"],
283                &default_headers
284            )
285            .unwrap(),
286            "42".to_string()
287        );
288
289        Ok(())
290    }
291
292    #[test]
293    fn forbidden_headers_are_removed() {
294        let mut req = Request::get("http://test.spin.internal")
295            .header("Host", "test.spin.internal")
296            .header("accept", "text/plain")
297            .body(Default::default())
298            .unwrap();
299
300        strip_forbidden_headers(&mut req);
301
302        assert_eq!(1, req.headers().len());
303        assert!(req.headers().get("Host").is_none());
304
305        let mut req = Request::get("http://test.spin.internal")
306            .header("Host", "test.spin.internal:1234")
307            .header("accept", "text/plain")
308            .body(Default::default())
309            .unwrap();
310
311        strip_forbidden_headers(&mut req);
312
313        assert_eq!(1, req.headers().len());
314        assert!(req.headers().get("Host").is_none());
315    }
316
317    #[test]
318    fn non_forbidden_headers_are_not_removed() {
319        let mut req = Request::get("http://test.example.com")
320            .header("Host", "test.example.org")
321            .header("accept", "text/plain")
322            .body(Default::default())
323            .unwrap();
324
325        strip_forbidden_headers(&mut req);
326
327        assert_eq!(2, req.headers().len());
328        assert!(req.headers().get("Host").is_some());
329    }
330
331    fn search(
332        keys: &[&str; 2],
333        headers: &[([Cow<'static, str>; 2], Cow<'_, str>)],
334    ) -> Option<String> {
335        let mut res: Option<String> = None;
336        for (k, v) in headers {
337            if k[0] == keys[0] && k[1] == keys[1] {
338                res = Some(v.as_ref().to_owned());
339            }
340        }
341
342        res
343    }
344}