1use std::{
2 borrow::Cow,
3 net::SocketAddr,
4 str::{self, FromStr},
5};
6
7use anyhow::Result;
8use http::Uri;
9use hyper::Request;
10use spin_factor_outbound_networking::is_service_chaining_host;
11use spin_http::routes::RouteMatch;
12
13use crate::Body;
14
15pub const FULL_URL: [&str; 2] = ["SPIN_FULL_URL", "X_FULL_URL"];
21pub const PATH_INFO: [&str; 2] = ["SPIN_PATH_INFO", "PATH_INFO"];
22pub const MATCHED_ROUTE: [&str; 2] = ["SPIN_MATCHED_ROUTE", "X_MATCHED_ROUTE"];
23pub const COMPONENT_ROUTE: [&str; 2] = ["SPIN_COMPONENT_ROUTE", "X_COMPONENT_ROUTE"];
24pub const RAW_COMPONENT_ROUTE: [&str; 2] = ["SPIN_RAW_COMPONENT_ROUTE", "X_RAW_COMPONENT_ROUTE"];
25pub const BASE_PATH: [&str; 2] = ["SPIN_BASE_PATH", "X_BASE_PATH"];
26pub const CLIENT_ADDR: [&str; 2] = ["SPIN_CLIENT_ADDR", "X_CLIENT_ADDR"];
27
28pub type HeaderPair<'a> = ([Cow<'static, str>; 2], Cow<'a, str>);
30
31pub fn compute_default_headers<'a>(
33 uri: &Uri,
34 host: &str,
35 route_match: &'a RouteMatch,
36 client_addr: SocketAddr,
37) -> anyhow::Result<Vec<HeaderPair<'a>>> {
38 fn owned(strs: &[&'static str; 2]) -> [Cow<'static, str>; 2] {
39 [strs[0].into(), strs[1].into()]
40 }
41
42 let owned_full_url = owned(&FULL_URL);
43 let owned_path_info = owned(&PATH_INFO);
44 let owned_matched_route = owned(&MATCHED_ROUTE);
45 let owned_component_route = owned(&COMPONENT_ROUTE);
46 let owned_raw_component_route = owned(&RAW_COMPONENT_ROUTE);
47 let owned_base_path = owned(&BASE_PATH);
48 let owned_client_addr = owned(&CLIENT_ADDR);
49
50 let mut res = vec![];
51 let abs_path = uri
52 .path_and_query()
53 .expect("cannot get path and query")
54 .as_str();
55
56 let path_info = route_match.trailing_wildcard();
57
58 let scheme = uri.scheme_str().unwrap_or("http");
59
60 let full_url = format!("{}://{}{}", scheme, host, abs_path);
61
62 res.push((owned_path_info, path_info));
63 res.push((owned_full_url, full_url.into()));
64 res.push((owned_matched_route, route_match.based_route().into()));
65
66 res.push((owned_base_path, "/".into()));
67 res.push((owned_raw_component_route, route_match.raw_route().into()));
68 res.push((
69 owned_component_route,
70 route_match.raw_route_or_prefix().into(),
71 ));
72 res.push((owned_client_addr, client_addr.to_string().into()));
73
74 for (wild_name, wild_value) in route_match.named_wildcards() {
75 let wild_header = format!("SPIN_PATH_MATCH_{}", wild_name.to_ascii_uppercase()).into();
76 let wild_wagi_header = format!("X_PATH_MATCH_{}", wild_name.to_ascii_uppercase()).into();
77 res.push(([wild_header, wild_wagi_header], wild_value.into()));
78 }
79
80 Ok(res)
81}
82
83pub fn strip_forbidden_headers(req: &mut Request<Body>) {
84 let headers = req.headers_mut();
85 if let Some(host_header) = headers.get("Host") {
86 if let Ok(host) = host_header.to_str() {
87 if is_service_chaining_host(host) {
88 headers.remove("Host");
89 }
90 }
91 }
92}
93
94pub fn prepare_request_headers(
95 req: &Request<Body>,
96 route_match: &RouteMatch,
97 client_addr: SocketAddr,
98) -> Result<Vec<(String, String)>> {
99 let mut res = Vec::new();
100 for (name, value) in req
101 .headers()
102 .iter()
103 .map(|(name, value)| (name.to_string(), std::str::from_utf8(value.as_bytes())))
104 {
105 let value = value?.to_string();
106 res.push((name, value));
107 }
108
109 let default_host = http::HeaderValue::from_str("localhost")?;
110 let host = std::str::from_utf8(
111 req.headers()
112 .get("host")
113 .unwrap_or(&default_host)
114 .as_bytes(),
115 )?;
116
117 for (keys, val) in compute_default_headers(req.uri(), host, route_match, client_addr)? {
121 res.push((prepare_header_key(&keys[0]), val.into_owned()));
122 }
123
124 Ok(res)
125}
126
127pub fn append_headers(
128 map: &mut http::HeaderMap,
129 headers: Option<Vec<(String, String)>>,
130) -> Result<()> {
131 if let Some(src) = headers {
132 for (k, v) in src.iter() {
133 map.insert(
134 http::header::HeaderName::from_str(k)?,
135 http::header::HeaderValue::from_str(v)?,
136 );
137 }
138 };
139
140 Ok(())
141}
142
143fn prepare_header_key(key: &str) -> String {
144 key.replace('_', "-").to_ascii_lowercase()
145}
146
147#[cfg(test)]
148mod tests {
149 use std::borrow::Cow;
150
151 use super::*;
152 use anyhow::Result;
153 use spin_http::routes::Router;
154
155 #[test]
156 fn test_spin_header_keys() {
157 assert_eq!(
158 prepare_header_key("SPIN_FULL_URL"),
159 "spin-full-url".to_string()
160 );
161 assert_eq!(
162 prepare_header_key("SPIN_PATH_INFO"),
163 "spin-path-info".to_string()
164 );
165 assert_eq!(
166 prepare_header_key("SPIN_RAW_COMPONENT_ROUTE"),
167 "spin-raw-component-route".to_string()
168 );
169 }
170
171 #[test]
172 fn test_default_headers() -> Result<()> {
173 let scheme = "https";
174 let host = "fermyon.dev";
175 let trigger_route = "/foo/...";
176 let component_path = "/foo";
177 let path_info = "/bar";
178 let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap();
179
180 let req_uri = format!(
181 "{}://{}{}{}?key1=value1&key2=value2",
182 scheme, host, component_path, path_info
183 );
184
185 let req = http::Request::builder()
186 .method("POST")
187 .uri(req_uri)
188 .body("")?;
189
190 let router = Router::build("/", [("DUMMY", &trigger_route.into())], None)?;
191 let route_match = router.route("/foo/bar")?;
192
193 let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?;
194
195 assert_eq!(
196 search(&FULL_URL, &default_headers).unwrap(),
197 "https://fermyon.dev/foo/bar?key1=value1&key2=value2".to_string()
198 );
199 assert_eq!(
200 search(&PATH_INFO, &default_headers).unwrap(),
201 "/bar".to_string()
202 );
203 assert_eq!(
204 search(&MATCHED_ROUTE, &default_headers).unwrap(),
205 "/foo/...".to_string()
206 );
207 assert_eq!(
208 search(&BASE_PATH, &default_headers).unwrap(),
209 "/".to_string()
210 );
211 assert_eq!(
212 search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(),
213 "/foo/...".to_string()
214 );
215 assert_eq!(
216 search(&COMPONENT_ROUTE, &default_headers).unwrap(),
217 "/foo".to_string()
218 );
219 assert_eq!(
220 search(&CLIENT_ADDR, &default_headers).unwrap(),
221 "127.0.0.1:8777".to_string()
222 );
223
224 Ok(())
225 }
226
227 #[test]
228 fn test_default_headers_with_named_wildcards() -> Result<()> {
229 let scheme = "https";
230 let host = "fermyon.dev";
231 let trigger_route = "/foo/:userid/...";
232 let component_path = "/foo";
233 let path_info = "/bar";
234 let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap();
235
236 let req_uri = format!(
237 "{}://{}{}/42{}?key1=value1&key2=value2",
238 scheme, host, component_path, path_info
239 );
240
241 let req = http::Request::builder()
242 .method("POST")
243 .uri(req_uri)
244 .body("")?;
245
246 let router = Router::build("/", [("DUMMY", &trigger_route.into())], None)?;
247 let route_match = router.route("/foo/42/bar")?;
248
249 let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?;
250
251 assert_eq!(
252 search(&FULL_URL, &default_headers).unwrap(),
253 "https://fermyon.dev/foo/42/bar?key1=value1&key2=value2".to_string()
254 );
255 assert_eq!(
256 search(&PATH_INFO, &default_headers).unwrap(),
257 "/bar".to_string()
258 );
259 assert_eq!(
260 search(&MATCHED_ROUTE, &default_headers).unwrap(),
261 "/foo/:userid/...".to_string()
262 );
263 assert_eq!(
264 search(&BASE_PATH, &default_headers).unwrap(),
265 "/".to_string()
266 );
267 assert_eq!(
268 search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(),
269 "/foo/:userid/...".to_string()
270 );
271 assert_eq!(
272 search(&COMPONENT_ROUTE, &default_headers).unwrap(),
273 "/foo/:userid".to_string()
274 );
275 assert_eq!(
276 search(&CLIENT_ADDR, &default_headers).unwrap(),
277 "127.0.0.1:8777".to_string()
278 );
279
280 assert_eq!(
281 search(
282 &["SPIN_PATH_MATCH_USERID", "X_PATH_MATCH_USERID"],
283 &default_headers
284 )
285 .unwrap(),
286 "42".to_string()
287 );
288
289 Ok(())
290 }
291
292 #[test]
293 fn forbidden_headers_are_removed() {
294 let mut req = Request::get("http://test.spin.internal")
295 .header("Host", "test.spin.internal")
296 .header("accept", "text/plain")
297 .body(Default::default())
298 .unwrap();
299
300 strip_forbidden_headers(&mut req);
301
302 assert_eq!(1, req.headers().len());
303 assert!(req.headers().get("Host").is_none());
304
305 let mut req = Request::get("http://test.spin.internal")
306 .header("Host", "test.spin.internal:1234")
307 .header("accept", "text/plain")
308 .body(Default::default())
309 .unwrap();
310
311 strip_forbidden_headers(&mut req);
312
313 assert_eq!(1, req.headers().len());
314 assert!(req.headers().get("Host").is_none());
315 }
316
317 #[test]
318 fn non_forbidden_headers_are_not_removed() {
319 let mut req = Request::get("http://test.example.com")
320 .header("Host", "test.example.org")
321 .header("accept", "text/plain")
322 .body(Default::default())
323 .unwrap();
324
325 strip_forbidden_headers(&mut req);
326
327 assert_eq!(2, req.headers().len());
328 assert!(req.headers().get("Host").is_some());
329 }
330
331 fn search(
332 keys: &[&str; 2],
333 headers: &[([Cow<'static, str>; 2], Cow<'_, str>)],
334 ) -> Option<String> {
335 let mut res: Option<String> = None;
336 for (k, v) in headers {
337 if k[0] == keys[0] && k[1] == keys[1] {
338 res = Some(v.as_ref().to_owned());
339 }
340 }
341
342 res
343 }
344}