1use std::{
2 borrow::Cow,
3 net::SocketAddr,
4 str::{self, FromStr},
5};
6
7use anyhow::Result;
8use http::Uri;
9use hyper::Request;
10use spin_factor_outbound_networking::config::allowed_hosts::is_service_chaining_host;
11use spin_http::routes::RouteMatch;
12
13use crate::Body;
14
15pub const FULL_URL: [&str; 2] = ["SPIN_FULL_URL", "X_FULL_URL"];
21pub const PATH_INFO: [&str; 2] = ["SPIN_PATH_INFO", "PATH_INFO"];
22pub const MATCHED_ROUTE: [&str; 2] = ["SPIN_MATCHED_ROUTE", "X_MATCHED_ROUTE"];
23pub const COMPONENT_ROUTE: [&str; 2] = ["SPIN_COMPONENT_ROUTE", "X_COMPONENT_ROUTE"];
24pub const RAW_COMPONENT_ROUTE: [&str; 2] = ["SPIN_RAW_COMPONENT_ROUTE", "X_RAW_COMPONENT_ROUTE"];
25pub const BASE_PATH: [&str; 2] = ["SPIN_BASE_PATH", "X_BASE_PATH"];
26pub const CLIENT_ADDR: [&str; 2] = ["SPIN_CLIENT_ADDR", "X_CLIENT_ADDR"];
27
28pub type HeaderPair<'a> = ([Cow<'static, str>; 2], Cow<'a, str>);
30
31pub fn compute_default_headers<'a>(
33 uri: &Uri,
34 host: &str,
35 route_match: &'a RouteMatch,
36 client_addr: SocketAddr,
37) -> anyhow::Result<Vec<HeaderPair<'a>>> {
38 fn owned(strs: &[&'static str; 2]) -> [Cow<'static, str>; 2] {
39 [strs[0].into(), strs[1].into()]
40 }
41
42 let owned_full_url = owned(&FULL_URL);
43 let owned_path_info = owned(&PATH_INFO);
44 let owned_matched_route = owned(&MATCHED_ROUTE);
45 let owned_component_route = owned(&COMPONENT_ROUTE);
46 let owned_raw_component_route = owned(&RAW_COMPONENT_ROUTE);
47 let owned_base_path = owned(&BASE_PATH);
48 let owned_client_addr = owned(&CLIENT_ADDR);
49
50 let mut res = vec![];
51 let abs_path = uri
52 .path_and_query()
53 .expect("cannot get path and query")
54 .as_str();
55
56 let path_info = route_match.trailing_wildcard();
57
58 let scheme = uri.scheme_str().unwrap_or("http");
59
60 let full_url = format!("{scheme}://{host}{abs_path}");
61
62 res.push((owned_path_info, path_info));
63 res.push((owned_full_url, full_url.into()));
64 res.push((owned_matched_route, route_match.based_route().into()));
65
66 res.push((owned_base_path, "/".into()));
67 res.push((owned_raw_component_route, route_match.raw_route().into()));
68 res.push((
69 owned_component_route,
70 route_match.raw_route_or_prefix().into(),
71 ));
72 res.push((owned_client_addr, client_addr.to_string().into()));
73
74 for (wild_name, wild_value) in route_match.named_wildcards() {
75 let wild_header = format!("SPIN_PATH_MATCH_{}", wild_name.to_ascii_uppercase()).into();
76 let wild_wagi_header = format!("X_PATH_MATCH_{}", wild_name.to_ascii_uppercase()).into();
77 res.push(([wild_header, wild_wagi_header], wild_value.into()));
78 }
79
80 Ok(res)
81}
82
83pub fn strip_forbidden_headers(req: &mut Request<Body>) {
84 let headers = req.headers_mut();
85 if let Some(host_header) = headers.get("Host") {
86 if let Ok(host) = host_header.to_str() {
87 if is_service_chaining_host(host) {
88 headers.remove("Host");
89 }
90 }
91 }
92}
93
94pub fn prepare_request_headers(
95 req: &Request<Body>,
96 route_match: &RouteMatch,
97 client_addr: SocketAddr,
98) -> Result<Vec<(String, String)>> {
99 let mut res = Vec::new();
100 for (name, value) in req
101 .headers()
102 .iter()
103 .map(|(name, value)| (name.to_string(), std::str::from_utf8(value.as_bytes())))
104 {
105 let value = value?.to_string();
106 res.push((name, value));
107 }
108
109 let default_host = http::HeaderValue::from_str("localhost")?;
110 let host = std::str::from_utf8(
111 req.headers()
112 .get("host")
113 .unwrap_or(&default_host)
114 .as_bytes(),
115 )?;
116
117 for (keys, val) in compute_default_headers(req.uri(), host, route_match, client_addr)? {
121 res.push((prepare_header_key(&keys[0]), val.into_owned()));
122 }
123
124 Ok(res)
125}
126
127pub fn append_headers(
128 map: &mut http::HeaderMap,
129 headers: Option<Vec<(String, String)>>,
130) -> Result<()> {
131 if let Some(src) = headers {
132 for (k, v) in src.iter() {
133 map.insert(
134 http::header::HeaderName::from_str(k)?,
135 http::header::HeaderValue::from_str(v)?,
136 );
137 }
138 };
139
140 Ok(())
141}
142
143fn prepare_header_key(key: &str) -> String {
144 key.replace('_', "-").to_ascii_lowercase()
145}
146
147#[cfg(test)]
148mod tests {
149 use std::borrow::Cow;
150
151 use super::*;
152 use anyhow::Result;
153 use spin_http::routes::Router;
154
155 #[test]
156 fn test_spin_header_keys() {
157 assert_eq!(
158 prepare_header_key("SPIN_FULL_URL"),
159 "spin-full-url".to_string()
160 );
161 assert_eq!(
162 prepare_header_key("SPIN_PATH_INFO"),
163 "spin-path-info".to_string()
164 );
165 assert_eq!(
166 prepare_header_key("SPIN_RAW_COMPONENT_ROUTE"),
167 "spin-raw-component-route".to_string()
168 );
169 }
170
171 #[test]
172 fn test_default_headers() -> Result<()> {
173 let scheme = "https";
174 let host = "fermyon.dev";
175 let trigger_route = "/foo/...";
176 let component_path = "/foo";
177 let path_info = "/bar";
178 let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap();
179
180 let req_uri = format!(
181 "{}://{}{}{}?key1=value1&key2=value2",
182 scheme, host, component_path, path_info
183 );
184
185 let req = http::Request::builder()
186 .method("POST")
187 .uri(req_uri)
188 .body("")?;
189
190 let router = Router::build(
191 "/",
192 [(
193 &spin_http::routes::TriggerLookupKey::Component("DUMMY".into()),
194 &trigger_route.into(),
195 )],
196 None,
197 )?;
198 let route_match = router.route("/foo/bar")?;
199
200 let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?;
201
202 assert_eq!(
203 search(&FULL_URL, &default_headers).unwrap(),
204 "https://fermyon.dev/foo/bar?key1=value1&key2=value2".to_string()
205 );
206 assert_eq!(
207 search(&PATH_INFO, &default_headers).unwrap(),
208 "/bar".to_string()
209 );
210 assert_eq!(
211 search(&MATCHED_ROUTE, &default_headers).unwrap(),
212 "/foo/...".to_string()
213 );
214 assert_eq!(
215 search(&BASE_PATH, &default_headers).unwrap(),
216 "/".to_string()
217 );
218 assert_eq!(
219 search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(),
220 "/foo/...".to_string()
221 );
222 assert_eq!(
223 search(&COMPONENT_ROUTE, &default_headers).unwrap(),
224 "/foo".to_string()
225 );
226 assert_eq!(
227 search(&CLIENT_ADDR, &default_headers).unwrap(),
228 "127.0.0.1:8777".to_string()
229 );
230
231 Ok(())
232 }
233
234 #[test]
235 fn test_default_headers_with_named_wildcards() -> Result<()> {
236 let scheme = "https";
237 let host = "fermyon.dev";
238 let trigger_route = "/foo/:userid/...";
239 let component_path = "/foo";
240 let path_info = "/bar";
241 let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap();
242
243 let req_uri = format!(
244 "{}://{}{}/42{}?key1=value1&key2=value2",
245 scheme, host, component_path, path_info
246 );
247
248 let req = http::Request::builder()
249 .method("POST")
250 .uri(req_uri)
251 .body("")?;
252
253 let router = Router::build(
254 "/",
255 [(
256 &spin_http::routes::TriggerLookupKey::Component("DUMMY".into()),
257 &trigger_route.into(),
258 )],
259 None,
260 )?;
261 let route_match = router.route("/foo/42/bar")?;
262
263 let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?;
264
265 assert_eq!(
266 search(&FULL_URL, &default_headers).unwrap(),
267 "https://fermyon.dev/foo/42/bar?key1=value1&key2=value2".to_string()
268 );
269 assert_eq!(
270 search(&PATH_INFO, &default_headers).unwrap(),
271 "/bar".to_string()
272 );
273 assert_eq!(
274 search(&MATCHED_ROUTE, &default_headers).unwrap(),
275 "/foo/:userid/...".to_string()
276 );
277 assert_eq!(
278 search(&BASE_PATH, &default_headers).unwrap(),
279 "/".to_string()
280 );
281 assert_eq!(
282 search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(),
283 "/foo/:userid/...".to_string()
284 );
285 assert_eq!(
286 search(&COMPONENT_ROUTE, &default_headers).unwrap(),
287 "/foo/:userid".to_string()
288 );
289 assert_eq!(
290 search(&CLIENT_ADDR, &default_headers).unwrap(),
291 "127.0.0.1:8777".to_string()
292 );
293
294 assert_eq!(
295 search(
296 &["SPIN_PATH_MATCH_USERID", "X_PATH_MATCH_USERID"],
297 &default_headers
298 )
299 .unwrap(),
300 "42".to_string()
301 );
302
303 Ok(())
304 }
305
306 #[test]
307 fn forbidden_headers_are_removed() {
308 let mut req = Request::get("http://test.spin.internal")
309 .header("Host", "test.spin.internal")
310 .header("accept", "text/plain")
311 .body(Default::default())
312 .unwrap();
313
314 strip_forbidden_headers(&mut req);
315
316 assert_eq!(1, req.headers().len());
317 assert!(req.headers().get("Host").is_none());
318
319 let mut req = Request::get("http://test.spin.internal")
320 .header("Host", "test.spin.internal:1234")
321 .header("accept", "text/plain")
322 .body(Default::default())
323 .unwrap();
324
325 strip_forbidden_headers(&mut req);
326
327 assert_eq!(1, req.headers().len());
328 assert!(req.headers().get("Host").is_none());
329 }
330
331 #[test]
332 fn non_forbidden_headers_are_not_removed() {
333 let mut req = Request::get("http://test.example.com")
334 .header("Host", "test.example.org")
335 .header("accept", "text/plain")
336 .body(Default::default())
337 .unwrap();
338
339 strip_forbidden_headers(&mut req);
340
341 assert_eq!(2, req.headers().len());
342 assert!(req.headers().get("Host").is_some());
343 }
344
345 fn search(
346 keys: &[&str; 2],
347 headers: &[([Cow<'static, str>; 2], Cow<'_, str>)],
348 ) -> Option<String> {
349 let mut res: Option<String> = None;
350 for (k, v) in headers {
351 if k[0] == keys[0] && k[1] == keys[1] {
352 res = Some(v.as_ref().to_owned());
353 }
354 }
355
356 res
357 }
358}