1mod headers;
4mod instrument;
5mod outbound_http;
6mod server;
7mod spin;
8mod tls;
9mod wagi;
10mod wasi;
11
12use std::{
13 error::Error,
14 net::{Ipv4Addr, SocketAddr, ToSocketAddrs},
15 path::PathBuf,
16 sync::Arc,
17};
18
19use anyhow::{bail, Context};
20use clap::Args;
21use serde::Deserialize;
22use spin_app::App;
23use spin_factors::RuntimeFactors;
24use spin_trigger::Trigger;
25use wasmtime_wasi_http::bindings::http::types::ErrorCode;
26
27pub use server::HttpServer;
28
29pub use tls::TlsConfig;
30
31pub(crate) use wasmtime_wasi_http::body::HyperIncomingBody as Body;
32
33pub(crate) type TriggerApp<F> = spin_trigger::TriggerApp<HttpTrigger, F>;
35
36pub(crate) type TriggerInstanceBuilder<'a, F> =
38 spin_trigger::TriggerInstanceBuilder<'a, HttpTrigger, F>;
39
40#[derive(Args)]
41pub struct CliArgs {
42 #[clap(long = "listen", env = "SPIN_HTTP_LISTEN_ADDR", default_value = "127.0.0.1:3000", value_parser = parse_listen_addr)]
44 pub address: SocketAddr,
45
46 #[clap(long, env = "SPIN_TLS_CERT", requires = "tls-key")]
48 pub tls_cert: Option<PathBuf>,
49
50 #[clap(long, env = "SPIN_TLS_KEY", requires = "tls-cert")]
52 pub tls_key: Option<PathBuf>,
53
54 #[clap(long = "find-free-port")]
55 pub find_free_port: bool,
56}
57
58impl CliArgs {
59 fn into_tls_config(self) -> Option<TlsConfig> {
60 match (self.tls_cert, self.tls_key) {
61 (Some(cert_path), Some(key_path)) => Some(TlsConfig {
62 cert_path,
63 key_path,
64 }),
65 (None, None) => None,
66 _ => unreachable!(),
67 }
68 }
69}
70
71pub struct HttpTrigger {
73 listen_addr: SocketAddr,
78 tls_config: Option<TlsConfig>,
79 find_free_port: bool,
80}
81
82impl<F: RuntimeFactors> Trigger<F> for HttpTrigger {
83 const TYPE: &'static str = "http";
84
85 type CliArgs = CliArgs;
86 type InstanceState = ();
87
88 fn new(cli_args: Self::CliArgs, app: &spin_app::App) -> anyhow::Result<Self> {
89 let find_free_port = cli_args.find_free_port;
90
91 Self::new(
92 app,
93 cli_args.address,
94 cli_args.into_tls_config(),
95 find_free_port,
96 )
97 }
98
99 async fn run(self, trigger_app: TriggerApp<F>) -> anyhow::Result<()> {
100 let server = self.into_server(trigger_app)?;
101
102 server.serve().await?;
103
104 Ok(())
105 }
106
107 fn supported_host_requirements() -> Vec<&'static str> {
108 vec![spin_app::locked::SERVICE_CHAINING_KEY]
109 }
110}
111
112impl HttpTrigger {
113 pub fn new(
115 app: &spin_app::App,
116 listen_addr: SocketAddr,
117 tls_config: Option<TlsConfig>,
118 find_free_port: bool,
119 ) -> anyhow::Result<Self> {
120 Self::validate_app(app)?;
121
122 Ok(Self {
123 listen_addr,
124 tls_config,
125 find_free_port,
126 })
127 }
128
129 pub fn into_server<F: RuntimeFactors>(
131 self,
132 trigger_app: TriggerApp<F>,
133 ) -> anyhow::Result<Arc<HttpServer<F>>> {
134 let Self {
135 listen_addr,
136 tls_config,
137 find_free_port,
138 } = self;
139 let server = Arc::new(HttpServer::new(
140 listen_addr,
141 tls_config,
142 find_free_port,
143 trigger_app,
144 )?);
145 Ok(server)
146 }
147
148 fn validate_app(app: &App) -> anyhow::Result<()> {
149 #[derive(Deserialize)]
150 #[serde(deny_unknown_fields)]
151 struct TriggerMetadata {
152 base: Option<String>,
153 }
154 if let Some(TriggerMetadata { base: Some(base) }) = app.get_trigger_metadata("http")? {
155 if base == "/" {
156 tracing::warn!("This application has the deprecated trigger 'base' set to the default value '/'. This may be an error in the future!");
157 } else {
158 bail!("This application is using the deprecated trigger 'base' field. The base must be prepended to each [[trigger.http]]'s 'route'.")
159 }
160 }
161 Ok(())
162 }
163}
164
165fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {
166 let addrs: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
167 if let Some(addr) = addrs
169 .iter()
170 .find(|addr| addr.is_ipv4() && addr.ip() == Ipv4Addr::LOCALHOST)
171 {
172 return Ok(*addr);
173 }
174 addrs.into_iter().next().context("couldn't resolve address")
176}
177
178#[derive(Debug, PartialEq)]
179enum NotFoundRouteKind {
180 Normal(String),
181 WellKnown,
182}
183
184pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
186 if let Some(cause) = err.source() {
188 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
189 return err.clone();
190 }
191 }
192
193 tracing::warn!("hyper request error: {err:?}");
194
195 ErrorCode::HttpProtocolError
196}
197
198pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
199 ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
200 rcode: Some(rcode),
201 info_code: Some(info_code),
202 })
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn parse_listen_addr_prefers_ipv4() {
211 let addr = parse_listen_addr("localhost:12345").unwrap();
212 assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
213 assert_eq!(addr.port(), 12345);
214 }
215}