1mod headers;
4mod instrument;
5mod outbound_http;
6mod server;
7mod spin;
8mod tls;
9mod wagi;
10mod wasi;
11mod wasip3;
12
13use std::{
14 error::Error,
15 net::{Ipv4Addr, SocketAddr, ToSocketAddrs},
16 path::PathBuf,
17 sync::Arc,
18};
19
20use anyhow::{bail, Context};
21use clap::Args;
22use serde::Deserialize;
23use spin_app::App;
24use spin_factors::RuntimeFactors;
25use spin_trigger::Trigger;
26use wasmtime_wasi_http::bindings::http::types::ErrorCode;
27
28pub use server::HttpServer;
29
30pub use tls::TlsConfig;
31
32pub(crate) use wasmtime_wasi_http::body::HyperIncomingBody as Body;
33
34pub(crate) type TriggerApp<F> = spin_trigger::TriggerApp<HttpTrigger, F>;
36
37pub(crate) type TriggerInstanceBuilder<'a, F> =
39 spin_trigger::TriggerInstanceBuilder<'a, HttpTrigger, F>;
40
41#[derive(Args)]
42pub struct CliArgs {
43 #[clap(long = "listen", env = "SPIN_HTTP_LISTEN_ADDR", default_value = "127.0.0.1:3000", value_parser = parse_listen_addr)]
45 pub address: SocketAddr,
46
47 #[clap(long, env = "SPIN_TLS_CERT", requires = "tls-key")]
49 pub tls_cert: Option<PathBuf>,
50
51 #[clap(long, env = "SPIN_TLS_KEY", requires = "tls-cert")]
53 pub tls_key: Option<PathBuf>,
54
55 #[clap(long = "find-free-port")]
56 pub find_free_port: bool,
57}
58
59impl CliArgs {
60 fn into_tls_config(self) -> Option<TlsConfig> {
61 match (self.tls_cert, self.tls_key) {
62 (Some(cert_path), Some(key_path)) => Some(TlsConfig {
63 cert_path,
64 key_path,
65 }),
66 (None, None) => None,
67 _ => unreachable!(),
68 }
69 }
70}
71
72pub struct HttpTrigger {
74 listen_addr: SocketAddr,
79 tls_config: Option<TlsConfig>,
80 find_free_port: bool,
81}
82
83impl<F: RuntimeFactors> Trigger<F> for HttpTrigger {
84 const TYPE: &'static str = "http";
85
86 type CliArgs = CliArgs;
87 type InstanceState = ();
88
89 fn new(cli_args: Self::CliArgs, app: &spin_app::App) -> anyhow::Result<Self> {
90 let find_free_port = cli_args.find_free_port;
91
92 Self::new(
93 app,
94 cli_args.address,
95 cli_args.into_tls_config(),
96 find_free_port,
97 )
98 }
99
100 async fn run(self, trigger_app: TriggerApp<F>) -> anyhow::Result<()> {
101 let server = self.into_server(trigger_app)?;
102
103 server.serve().await?;
104
105 Ok(())
106 }
107
108 fn supported_host_requirements() -> Vec<&'static str> {
109 vec![spin_app::locked::SERVICE_CHAINING_KEY]
110 }
111}
112
113impl HttpTrigger {
114 pub fn new(
116 app: &spin_app::App,
117 listen_addr: SocketAddr,
118 tls_config: Option<TlsConfig>,
119 find_free_port: bool,
120 ) -> anyhow::Result<Self> {
121 Self::validate_app(app)?;
122
123 Ok(Self {
124 listen_addr,
125 tls_config,
126 find_free_port,
127 })
128 }
129
130 pub fn into_server<F: RuntimeFactors>(
132 self,
133 trigger_app: TriggerApp<F>,
134 ) -> anyhow::Result<Arc<HttpServer<F>>> {
135 let Self {
136 listen_addr,
137 tls_config,
138 find_free_port,
139 } = self;
140 let server = Arc::new(HttpServer::new(
141 listen_addr,
142 tls_config,
143 find_free_port,
144 trigger_app,
145 )?);
146 Ok(server)
147 }
148
149 fn validate_app(app: &App) -> anyhow::Result<()> {
150 #[derive(Deserialize)]
151 #[serde(deny_unknown_fields)]
152 struct TriggerMetadata {
153 base: Option<String>,
154 }
155 if let Some(TriggerMetadata { base: Some(base) }) = app.get_trigger_metadata("http")? {
156 if base == "/" {
157 tracing::warn!(
158 "This application has the deprecated trigger 'base' set to the default value '/'. This may be an error in the future!"
159 );
160 } else {
161 bail!(
162 "This application is using the deprecated trigger 'base' field. The base must be prepended to each [[trigger.http]]'s 'route'."
163 )
164 }
165 }
166 Ok(())
167 }
168}
169
170fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {
171 let addrs: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
172 if let Some(addr) = addrs
174 .iter()
175 .find(|addr| addr.is_ipv4() && addr.ip() == Ipv4Addr::LOCALHOST)
176 {
177 return Ok(*addr);
178 }
179 addrs.into_iter().next().context("couldn't resolve address")
181}
182
183#[derive(Debug, PartialEq)]
184enum NotFoundRouteKind {
185 Normal(String),
186 WellKnown,
187}
188
189pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
191 if let Some(cause) = err.source() {
193 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
194 return err.clone();
195 }
196 }
197
198 tracing::warn!("hyper request error: {err:?}");
199
200 ErrorCode::HttpProtocolError
201}
202
203pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
204 ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
205 rcode: Some(rcode),
206 info_code: Some(info_code),
207 })
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 #[test]
215 fn parse_listen_addr_prefers_ipv4() {
216 let addr = parse_listen_addr("localhost:12345").unwrap();
217 assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
218 assert_eq!(addr.port(), 12345);
219 }
220}