1mod headers;
4mod instrument;
5mod outbound_http;
6mod server;
7mod spin;
8mod tls;
9mod wagi;
10mod wasi;
11mod wasip3;
12
13use std::{
14 error::Error,
15 fmt::Display,
16 net::{Ipv4Addr, SocketAddr, ToSocketAddrs},
17 path::PathBuf,
18 str::FromStr,
19 sync::Arc,
20 time::Duration,
21};
22
23use anyhow::{bail, Context};
24use clap::Args;
25use rand::{
26 distr::uniform::{SampleRange, SampleUniform},
27 RngCore,
28};
29use serde::Deserialize;
30use spin_app::App;
31use spin_factors::RuntimeFactors;
32use spin_trigger::Trigger;
33use wasmtime_wasi_http::bindings::http::types::ErrorCode;
34
35pub use server::HttpServer;
36
37pub use tls::TlsConfig;
38
39pub(crate) use wasmtime_wasi_http::body::HyperIncomingBody as Body;
40
41const DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT: usize = 128;
42const DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT: usize = 16;
43const DEFAULT_REQUEST_TIMEOUT: Option<Range<Duration>> = None;
44const DEFAULT_IDLE_INSTANCE_TIMEOUT: Range<Duration> = Range::Value(Duration::from_secs(1));
45
46pub(crate) type TriggerApp<F> = spin_trigger::TriggerApp<HttpTrigger, F>;
48
49pub(crate) type TriggerInstanceBuilder<'a, F> =
51 spin_trigger::TriggerInstanceBuilder<'a, HttpTrigger, F>;
52
53#[derive(Args)]
54pub struct CliArgs {
55 #[clap(long = "listen", env = "SPIN_HTTP_LISTEN_ADDR", default_value = "127.0.0.1:3000", value_parser = parse_listen_addr)]
57 pub address: SocketAddr,
58
59 #[clap(long, env = "SPIN_TLS_CERT", requires = "tls-key")]
61 pub tls_cert: Option<PathBuf>,
62
63 #[clap(long, env = "SPIN_TLS_KEY", requires = "tls-cert")]
65 pub tls_key: Option<PathBuf>,
66
67 #[clap(long, env = "SPIN_HTTP1_MAX_BUF_SIZE")]
69 pub http1_max_buf_size: Option<usize>,
70
71 #[clap(long = "find-free-port")]
72 pub find_free_port: bool,
73
74 #[clap(long, value_parser = parse_usize_range)]
85 max_instance_reuse_count: Option<Range<usize>>,
86
87 #[clap(long, value_parser = parse_usize_range)]
98 max_instance_concurrent_reuse_count: Option<Range<usize>>,
99
100 #[clap(long, value_parser = parse_duration_range)]
112 request_timeout: Option<Range<Duration>>,
113
114 #[clap(long, default_value = "1s", value_parser = parse_duration_range)]
125 idle_instance_timeout: Range<Duration>,
126}
127
128impl CliArgs {
129 fn into_tls_config(self) -> Option<TlsConfig> {
130 match (self.tls_cert, self.tls_key) {
131 (Some(cert_path), Some(key_path)) => Some(TlsConfig {
132 cert_path,
133 key_path,
134 }),
135 (None, None) => None,
136 _ => unreachable!(),
137 }
138 }
139}
140
141#[derive(Copy, Clone)]
142enum Range<T> {
143 Value(T),
144 Bounds(T, T),
145}
146
147impl<T> Range<T> {
148 fn map<V>(self, fun: impl Fn(T) -> V) -> Range<V> {
149 match self {
150 Self::Value(v) => Range::Value(fun(v)),
151 Self::Bounds(a, b) => Range::Bounds(fun(a), fun(b)),
152 }
153 }
154}
155
156impl<T: SampleUniform + PartialOrd> SampleRange<T> for Range<T> {
157 fn sample_single<R: RngCore + ?Sized>(
158 self,
159 rng: &mut R,
160 ) -> Result<T, rand::distr::uniform::Error> {
161 match self {
162 Self::Value(v) => Ok(v),
163 Self::Bounds(a, b) => (a..b).sample_single(rng),
164 }
165 }
166
167 fn is_empty(&self) -> bool {
168 match self {
169 Self::Value(_) => false,
170 Self::Bounds(a, b) => (a..b).is_empty(),
171 }
172 }
173}
174
175fn parse_range<T: FromStr>(s: &str) -> Result<Range<T>, String>
176where
177 T::Err: Display,
178{
179 let error = |e| format!("expected integer or range; got {s:?}; {e}");
180 if let Some((start, end)) = s.split_once("..") {
181 Ok(Range::Bounds(
182 start.parse().map_err(error)?,
183 end.parse().map_err(error)?,
184 ))
185 } else {
186 Ok(Range::Value(s.parse().map_err(error)?))
187 }
188}
189
190fn parse_usize_range(s: &str) -> Result<Range<usize>, String> {
191 parse_range(s)
192}
193
194struct ParsedDuration(Duration);
195
196impl FromStr for ParsedDuration {
197 type Err = String;
198
199 fn from_str(s: &str) -> Result<Self, Self::Err> {
200 let error = |e| {
201 format!("expected integer suffixed by `s`, `ms`, `us`, `μs`, or `ns`; got {s:?}; {e}")
202 };
203 Ok(Self(match s.parse() {
204 Ok(val) => Duration::from_secs(val),
205 Err(err) => {
206 if let Some(num) = s.strip_suffix("s") {
207 Duration::from_secs(num.parse().map_err(error)?)
208 } else if let Some(num) = s.strip_suffix("ms") {
209 Duration::from_millis(num.parse().map_err(error)?)
210 } else if let Some(num) = s.strip_suffix("us").or(s.strip_suffix("μs")) {
211 Duration::from_micros(num.parse().map_err(error)?)
212 } else if let Some(num) = s.strip_suffix("ns") {
213 Duration::from_nanos(num.parse().map_err(error)?)
214 } else {
215 return Err(error(err));
216 }
217 }
218 }))
219 }
220}
221
222fn parse_duration_range(s: &str) -> Result<Range<Duration>, String> {
223 parse_range::<ParsedDuration>(s).map(|v| v.map(|v| v.0))
224}
225
226#[derive(Clone, Copy)]
227pub struct InstanceReuseConfig {
228 max_instance_reuse_count: Range<usize>,
229 max_instance_concurrent_reuse_count: Range<usize>,
230 request_timeout: Option<Range<Duration>>,
231 idle_instance_timeout: Range<Duration>,
232}
233
234impl Default for InstanceReuseConfig {
235 fn default() -> Self {
236 Self {
237 max_instance_reuse_count: Range::Value(DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT),
238 max_instance_concurrent_reuse_count: Range::Value(
239 DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT,
240 ),
241 request_timeout: DEFAULT_REQUEST_TIMEOUT,
242 idle_instance_timeout: DEFAULT_IDLE_INSTANCE_TIMEOUT,
243 }
244 }
245}
246
247pub struct HttpTrigger {
249 listen_addr: SocketAddr,
254 tls_config: Option<TlsConfig>,
255 find_free_port: bool,
256 http1_max_buf_size: Option<usize>,
257 reuse_config: InstanceReuseConfig,
258}
259
260impl<F: RuntimeFactors> Trigger<F> for HttpTrigger {
261 const TYPE: &'static str = "http";
262
263 type CliArgs = CliArgs;
264 type InstanceState = ();
265
266 fn new(cli_args: Self::CliArgs, app: &spin_app::App) -> anyhow::Result<Self> {
267 let find_free_port = cli_args.find_free_port;
268 let http1_max_buf_size = cli_args.http1_max_buf_size;
269 let reuse_config = InstanceReuseConfig {
270 max_instance_reuse_count: cli_args
271 .max_instance_reuse_count
272 .unwrap_or(Range::Value(DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT)),
273 max_instance_concurrent_reuse_count: cli_args
274 .max_instance_concurrent_reuse_count
275 .unwrap_or(Range::Value(
276 DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT,
277 )),
278 request_timeout: cli_args.request_timeout,
279 idle_instance_timeout: cli_args.idle_instance_timeout,
280 };
281
282 Self::new(
283 app,
284 cli_args.address,
285 cli_args.into_tls_config(),
286 find_free_port,
287 http1_max_buf_size,
288 reuse_config,
289 )
290 }
291
292 async fn run(self, trigger_app: TriggerApp<F>) -> anyhow::Result<()> {
293 let server = self.into_server(trigger_app)?;
294
295 server.serve().await?;
296
297 Ok(())
298 }
299
300 fn supported_host_requirements() -> Vec<&'static str> {
301 vec![spin_app::locked::SERVICE_CHAINING_KEY]
302 }
303}
304
305impl HttpTrigger {
306 pub fn new(
308 app: &spin_app::App,
309 listen_addr: SocketAddr,
310 tls_config: Option<TlsConfig>,
311 find_free_port: bool,
312 http1_max_buf_size: Option<usize>,
313 reuse_config: InstanceReuseConfig,
314 ) -> anyhow::Result<Self> {
315 Self::validate_app(app)?;
316
317 Ok(Self {
318 listen_addr,
319 tls_config,
320 find_free_port,
321 http1_max_buf_size,
322 reuse_config,
323 })
324 }
325
326 pub fn into_server<F: RuntimeFactors>(
328 self,
329 trigger_app: TriggerApp<F>,
330 ) -> anyhow::Result<Arc<HttpServer<F>>> {
331 let Self {
332 listen_addr,
333 tls_config,
334 find_free_port,
335 http1_max_buf_size,
336 reuse_config,
337 } = self;
338 let server = Arc::new(HttpServer::new(
339 listen_addr,
340 tls_config,
341 find_free_port,
342 trigger_app,
343 http1_max_buf_size,
344 reuse_config,
345 )?);
346 Ok(server)
347 }
348
349 fn validate_app(app: &App) -> anyhow::Result<()> {
350 #[derive(Deserialize)]
351 #[serde(deny_unknown_fields)]
352 struct TriggerMetadata {
353 base: Option<String>,
354 }
355 if let Some(TriggerMetadata { base: Some(base) }) = app.get_trigger_metadata("http")? {
356 if base == "/" {
357 tracing::warn!(
358 "This application has the deprecated trigger 'base' set to the default value '/'. This may be an error in the future!"
359 );
360 } else {
361 bail!(
362 "This application is using the deprecated trigger 'base' field. The base must be prepended to each [[trigger.http]]'s 'route'."
363 )
364 }
365 }
366 Ok(())
367 }
368}
369
370fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {
371 let addrs: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
372 if let Some(addr) = addrs
374 .iter()
375 .find(|addr| addr.is_ipv4() && addr.ip() == Ipv4Addr::LOCALHOST)
376 {
377 return Ok(*addr);
378 }
379 addrs.into_iter().next().context("couldn't resolve address")
381}
382
383#[derive(Debug, PartialEq)]
384enum NotFoundRouteKind {
385 Normal(String),
386 WellKnown,
387}
388
389pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
391 if let Some(cause) = err.source() {
393 if let Some(err) = cause.downcast_ref::<ErrorCode>() {
394 return err.clone();
395 }
396 }
397
398 tracing::warn!("hyper request error: {err:?}");
399
400 ErrorCode::HttpProtocolError
401}
402
403pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
404 ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
405 rcode: Some(rcode),
406 info_code: Some(info_code),
407 })
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn parse_listen_addr_prefers_ipv4() {
416 let addr = parse_listen_addr("localhost:12345").unwrap();
417 assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
418 assert_eq!(addr.port(), 12345);
419 }
420}