1mod headers;
4mod instrument;
5mod outbound_http;
6mod server;
7mod spin;
8mod tls;
9mod wagi;
10mod wasi;
11mod wasip3;
12
13use std::{
14 error::Error,
15 net::{Ipv4Addr, SocketAddr, ToSocketAddrs},
16 path::PathBuf,
17 sync::Arc,
18};
19
20use anyhow::{bail, Context};
21use clap::Args;
22use serde::Deserialize;
23use spin_app::App;
24use spin_factors::RuntimeFactors;
25use spin_trigger::Trigger;
26use wasmtime_wasi_http::bindings::http::types::ErrorCode;
27
28pub use server::HttpServer;
29
30pub use tls::TlsConfig;
31
32pub(crate) use wasmtime_wasi_http::body::HyperIncomingBody as Body;
33
34pub(crate) type TriggerApp<F> = spin_trigger::TriggerApp<HttpTrigger, F>;
36
37pub(crate) type TriggerInstanceBuilder<'a, F> =
39 spin_trigger::TriggerInstanceBuilder<'a, HttpTrigger, F>;
40
41#[derive(Args)]
42pub struct CliArgs {
43 #[clap(long = "listen", env = "SPIN_HTTP_LISTEN_ADDR", default_value = "127.0.0.1:3000", value_parser = parse_listen_addr)]
45 pub address: SocketAddr,
46
47 #[clap(long, env = "SPIN_TLS_CERT", requires = "tls-key")]
49 pub tls_cert: Option<PathBuf>,
50
51 #[clap(long, env = "SPIN_TLS_KEY", requires = "tls-cert")]
53 pub tls_key: Option<PathBuf>,
54
55 #[clap(long, env = "SPIN_HTTP1_MAX_BUF_SIZE")]
57 pub http1_max_buf_size: Option<usize>,
58
59 #[clap(long = "find-free-port")]
60 pub find_free_port: bool,
61}
62
63impl CliArgs {
64 fn into_tls_config(self) -> Option<TlsConfig> {
65 match (self.tls_cert, self.tls_key) {
66 (Some(cert_path), Some(key_path)) => Some(TlsConfig {
67 cert_path,
68 key_path,
69 }),
70 (None, None) => None,
71 _ => unreachable!(),
72 }
73 }
74}
75
76pub struct HttpTrigger {
78 listen_addr: SocketAddr,
83 tls_config: Option<TlsConfig>,
84 find_free_port: bool,
85 http1_max_buf_size: Option<usize>,
86}
87
88impl<F: RuntimeFactors> Trigger<F> for HttpTrigger {
89 const TYPE: &'static str = "http";
90
91 type CliArgs = CliArgs;
92 type InstanceState = ();
93
94 fn new(cli_args: Self::CliArgs, app: &spin_app::App) -> anyhow::Result<Self> {
95 let find_free_port = cli_args.find_free_port;
96 let http1_max_buf_size = cli_args.http1_max_buf_size;
97
98 Self::new(
99 app,
100 cli_args.address,
101 cli_args.into_tls_config(),
102 find_free_port,
103 http1_max_buf_size,
104 )
105 }
106
107 async fn run(self, trigger_app: TriggerApp<F>) -> anyhow::Result<()> {
108 let server = self.into_server(trigger_app)?;
109
110 server.serve().await?;
111
112 Ok(())
113 }
114
115 fn supported_host_requirements() -> Vec<&'static str> {
116 vec![spin_app::locked::SERVICE_CHAINING_KEY]
117 }
118}
119
120impl HttpTrigger {
121 pub fn new(
123 app: &spin_app::App,
124 listen_addr: SocketAddr,
125 tls_config: Option<TlsConfig>,
126 find_free_port: bool,
127 http1_max_buf_size: Option<usize>,
128 ) -> anyhow::Result<Self> {
129 Self::validate_app(app)?;
130
131 Ok(Self {
132 listen_addr,
133 tls_config,
134 find_free_port,
135 http1_max_buf_size,
136 })
137 }
138
139 pub fn into_server<F: RuntimeFactors>(
141 self,
142 trigger_app: TriggerApp<F>,
143 ) -> anyhow::Result<Arc<HttpServer<F>>> {
144 let Self {
145 listen_addr,
146 tls_config,
147 find_free_port,
148 http1_max_buf_size,
149 } = self;
150 let server = Arc::new(HttpServer::new(
151 listen_addr,
152 tls_config,
153 find_free_port,
154 trigger_app,
155 http1_max_buf_size,
156 )?);
157 Ok(server)
158 }
159
160 fn validate_app(app: &App) -> anyhow::Result<()> {
161 #[derive(Deserialize)]
162 #[serde(deny_unknown_fields)]
163 struct TriggerMetadata {
164 base: Option<String>,
165 }
166 if let Some(TriggerMetadata { base: Some(base) }) = app.get_trigger_metadata("http")? {
167 if base == "/" {
168 tracing::warn!(
169 "This application has the deprecated trigger 'base' set to the default value '/'. This may be an error in the future!"
170 );
171 } else {
172 bail!(
173 "This application is using the deprecated trigger 'base' field. The base must be prepended to each [[trigger.http]]'s 'route'."
174 )
175 }
176 }
177 Ok(())
178 }
179}
180
181fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {
182 let addrs: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
183 if let Some(addr) = addrs
185 .iter()
186 .find(|addr| addr.is_ipv4() && addr.ip() == Ipv4Addr::LOCALHOST)
187 {
188 return Ok(*addr);
189 }
190 addrs.into_iter().next().context("couldn't resolve address")
192}
193
194#[derive(Debug, PartialEq)]
195enum NotFoundRouteKind {
196 Normal(String),
197 WellKnown,
198}
199
200pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
202 if let Some(cause) = err.source() {
204 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
205 return err.clone();
206 }
207 }
208
209 tracing::warn!("hyper request error: {err:?}");
210
211 ErrorCode::HttpProtocolError
212}
213
214pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
215 ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
216 rcode: Some(rcode),
217 info_code: Some(info_code),
218 })
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn parse_listen_addr_prefers_ipv4() {
227 let addr = parse_listen_addr("localhost:12345").unwrap();
228 assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
229 assert_eq!(addr.port(), 12345);
230 }
231}