1use std::{
2 borrow::Cow,
3 net::SocketAddr,
4 str::{self, FromStr},
5};
6
7use anyhow::Result;
8use http::Uri;
9use hyper::Request;
10use spin_factor_outbound_networking::config::allowed_hosts::is_service_chaining_host;
11use spin_http::routes::RouteMatch;
12
13use crate::Body;
14
15pub const FULL_URL: [&str; 2] = ["SPIN_FULL_URL", "X_FULL_URL"];
21pub const PATH_INFO: [&str; 2] = ["SPIN_PATH_INFO", "PATH_INFO"];
22pub const MATCHED_ROUTE: [&str; 2] = ["SPIN_MATCHED_ROUTE", "X_MATCHED_ROUTE"];
23pub const COMPONENT_ROUTE: [&str; 2] = ["SPIN_COMPONENT_ROUTE", "X_COMPONENT_ROUTE"];
24pub const RAW_COMPONENT_ROUTE: [&str; 2] = ["SPIN_RAW_COMPONENT_ROUTE", "X_RAW_COMPONENT_ROUTE"];
25pub const BASE_PATH: [&str; 2] = ["SPIN_BASE_PATH", "X_BASE_PATH"];
26pub const CLIENT_ADDR: [&str; 2] = ["SPIN_CLIENT_ADDR", "X_CLIENT_ADDR"];
27
28pub type HeaderPair<'a> = ([Cow<'static, str>; 2], Cow<'a, str>);
30
31pub fn compute_default_headers<'a>(
33 uri: &Uri,
34 host: &str,
35 route_match: &'a RouteMatch,
36 client_addr: SocketAddr,
37) -> anyhow::Result<Vec<HeaderPair<'a>>> {
38 fn owned(strs: &[&'static str; 2]) -> [Cow<'static, str>; 2] {
39 [strs[0].into(), strs[1].into()]
40 }
41
42 let owned_full_url = owned(&FULL_URL);
43 let owned_path_info = owned(&PATH_INFO);
44 let owned_matched_route = owned(&MATCHED_ROUTE);
45 let owned_component_route = owned(&COMPONENT_ROUTE);
46 let owned_raw_component_route = owned(&RAW_COMPONENT_ROUTE);
47 let owned_base_path = owned(&BASE_PATH);
48 let owned_client_addr = owned(&CLIENT_ADDR);
49
50 let mut res = vec![];
51 let abs_path = uri
52 .path_and_query()
53 .expect("cannot get path and query")
54 .as_str();
55
56 let path_info = route_match.trailing_wildcard();
57
58 let scheme = uri.scheme_str().unwrap_or("http");
59
60 let full_url = format!("{scheme}://{host}{abs_path}");
61
62 res.push((owned_path_info, path_info));
63 res.push((owned_full_url, full_url.into()));
64 res.push((owned_matched_route, route_match.based_route().into()));
65
66 res.push((owned_base_path, "/".into()));
67 res.push((owned_raw_component_route, route_match.raw_route().into()));
68 res.push((
69 owned_component_route,
70 route_match.raw_route_or_prefix().into(),
71 ));
72 res.push((owned_client_addr, client_addr.to_string().into()));
73
74 for (wild_name, wild_value) in route_match.named_wildcards() {
75 let wild_header = format!("SPIN_PATH_MATCH_{}", wild_name.to_ascii_uppercase()).into();
76 let wild_wagi_header = format!("X_PATH_MATCH_{}", wild_name.to_ascii_uppercase()).into();
77 res.push(([wild_header, wild_wagi_header], wild_value.into()));
78 }
79
80 Ok(res)
81}
82
83pub fn strip_forbidden_headers(req: &mut Request<Body>) {
84 let headers = req.headers_mut();
85 if let Some(host_header) = headers.get("Host")
86 && let Ok(host) = host_header.to_str()
87 && is_service_chaining_host(host)
88 {
89 headers.remove("Host");
90 }
91}
92
93pub fn prepare_request_headers(
94 req: &Request<Body>,
95 route_match: &RouteMatch,
96 client_addr: SocketAddr,
97) -> Result<Vec<(String, String)>> {
98 let mut res = Vec::new();
99 for (name, value) in req
100 .headers()
101 .iter()
102 .map(|(name, value)| (name.to_string(), std::str::from_utf8(value.as_bytes())))
103 {
104 let value = value?.to_string();
105 res.push((name, value));
106 }
107
108 let default_host = http::HeaderValue::from_str("localhost")?;
109 let host = std::str::from_utf8(
110 req.headers()
111 .get("host")
112 .unwrap_or(&default_host)
113 .as_bytes(),
114 )?;
115
116 for (keys, val) in compute_default_headers(req.uri(), host, route_match, client_addr)? {
120 res.push((prepare_header_key(&keys[0]), val.into_owned()));
121 }
122
123 Ok(res)
124}
125
126pub fn append_headers(
127 map: &mut http::HeaderMap,
128 headers: Option<Vec<(String, String)>>,
129) -> Result<()> {
130 if let Some(src) = headers {
131 for (k, v) in src.iter() {
132 map.insert(
133 http::header::HeaderName::from_str(k)?,
134 http::header::HeaderValue::from_str(v)?,
135 );
136 }
137 };
138
139 Ok(())
140}
141
142fn prepare_header_key(key: &str) -> String {
143 key.replace('_', "-").to_ascii_lowercase()
144}
145
146#[cfg(test)]
147mod tests {
148 use std::borrow::Cow;
149
150 use super::*;
151 use anyhow::Result;
152 use spin_http::routes::Router;
153
154 #[test]
155 fn test_spin_header_keys() {
156 assert_eq!(
157 prepare_header_key("SPIN_FULL_URL"),
158 "spin-full-url".to_string()
159 );
160 assert_eq!(
161 prepare_header_key("SPIN_PATH_INFO"),
162 "spin-path-info".to_string()
163 );
164 assert_eq!(
165 prepare_header_key("SPIN_RAW_COMPONENT_ROUTE"),
166 "spin-raw-component-route".to_string()
167 );
168 }
169
170 #[test]
171 fn test_default_headers() -> Result<()> {
172 let scheme = "https";
173 let host = "fermyon.dev";
174 let trigger_route = "/foo/...";
175 let component_path = "/foo";
176 let path_info = "/bar";
177 let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap();
178
179 let req_uri = format!(
180 "{}://{}{}{}?key1=value1&key2=value2",
181 scheme, host, component_path, path_info
182 );
183
184 let req = http::Request::builder()
185 .method("POST")
186 .uri(req_uri)
187 .body("")?;
188
189 let router = Router::build(
190 "/",
191 [(
192 &spin_http::routes::TriggerLookupKey::Component("DUMMY".into()),
193 &trigger_route.into(),
194 )],
195 None,
196 )?;
197 let route_match = router.route("/foo/bar")?;
198
199 let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?;
200
201 assert_eq!(
202 search(&FULL_URL, &default_headers).unwrap(),
203 "https://fermyon.dev/foo/bar?key1=value1&key2=value2".to_string()
204 );
205 assert_eq!(
206 search(&PATH_INFO, &default_headers).unwrap(),
207 "/bar".to_string()
208 );
209 assert_eq!(
210 search(&MATCHED_ROUTE, &default_headers).unwrap(),
211 "/foo/...".to_string()
212 );
213 assert_eq!(
214 search(&BASE_PATH, &default_headers).unwrap(),
215 "/".to_string()
216 );
217 assert_eq!(
218 search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(),
219 "/foo/...".to_string()
220 );
221 assert_eq!(
222 search(&COMPONENT_ROUTE, &default_headers).unwrap(),
223 "/foo".to_string()
224 );
225 assert_eq!(
226 search(&CLIENT_ADDR, &default_headers).unwrap(),
227 "127.0.0.1:8777".to_string()
228 );
229
230 Ok(())
231 }
232
233 #[test]
234 fn test_default_headers_with_named_wildcards() -> Result<()> {
235 let scheme = "https";
236 let host = "fermyon.dev";
237 let trigger_route = "/foo/:userid/...";
238 let component_path = "/foo";
239 let path_info = "/bar";
240 let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap();
241
242 let req_uri = format!(
243 "{}://{}{}/42{}?key1=value1&key2=value2",
244 scheme, host, component_path, path_info
245 );
246
247 let req = http::Request::builder()
248 .method("POST")
249 .uri(req_uri)
250 .body("")?;
251
252 let router = Router::build(
253 "/",
254 [(
255 &spin_http::routes::TriggerLookupKey::Component("DUMMY".into()),
256 &trigger_route.into(),
257 )],
258 None,
259 )?;
260 let route_match = router.route("/foo/42/bar")?;
261
262 let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?;
263
264 assert_eq!(
265 search(&FULL_URL, &default_headers).unwrap(),
266 "https://fermyon.dev/foo/42/bar?key1=value1&key2=value2".to_string()
267 );
268 assert_eq!(
269 search(&PATH_INFO, &default_headers).unwrap(),
270 "/bar".to_string()
271 );
272 assert_eq!(
273 search(&MATCHED_ROUTE, &default_headers).unwrap(),
274 "/foo/:userid/...".to_string()
275 );
276 assert_eq!(
277 search(&BASE_PATH, &default_headers).unwrap(),
278 "/".to_string()
279 );
280 assert_eq!(
281 search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(),
282 "/foo/:userid/...".to_string()
283 );
284 assert_eq!(
285 search(&COMPONENT_ROUTE, &default_headers).unwrap(),
286 "/foo/:userid".to_string()
287 );
288 assert_eq!(
289 search(&CLIENT_ADDR, &default_headers).unwrap(),
290 "127.0.0.1:8777".to_string()
291 );
292
293 assert_eq!(
294 search(
295 &["SPIN_PATH_MATCH_USERID", "X_PATH_MATCH_USERID"],
296 &default_headers
297 )
298 .unwrap(),
299 "42".to_string()
300 );
301
302 Ok(())
303 }
304
305 #[test]
306 fn forbidden_headers_are_removed() {
307 let mut req = Request::get("http://test.spin.internal")
308 .header("Host", "test.spin.internal")
309 .header("accept", "text/plain")
310 .body(Default::default())
311 .unwrap();
312
313 strip_forbidden_headers(&mut req);
314
315 assert_eq!(1, req.headers().len());
316 assert!(req.headers().get("Host").is_none());
317
318 let mut req = Request::get("http://test.spin.internal")
319 .header("Host", "test.spin.internal:1234")
320 .header("accept", "text/plain")
321 .body(Default::default())
322 .unwrap();
323
324 strip_forbidden_headers(&mut req);
325
326 assert_eq!(1, req.headers().len());
327 assert!(req.headers().get("Host").is_none());
328 }
329
330 #[test]
331 fn non_forbidden_headers_are_not_removed() {
332 let mut req = Request::get("http://test.example.com")
333 .header("Host", "test.example.org")
334 .header("accept", "text/plain")
335 .body(Default::default())
336 .unwrap();
337
338 strip_forbidden_headers(&mut req);
339
340 assert_eq!(2, req.headers().len());
341 assert!(req.headers().get("Host").is_some());
342 }
343
344 fn search(
345 keys: &[&str; 2],
346 headers: &[([Cow<'static, str>; 2], Cow<'_, str>)],
347 ) -> Option<String> {
348 let mut res: Option<String> = None;
349 for (k, v) in headers {
350 if k[0] == keys[0] && k[1] == keys[1] {
351 res = Some(v.as_ref().to_owned());
352 }
353 }
354
355 res
356 }
357}