spin_trigger_http/
lib.rs

1//! Implementation for the Spin HTTP engine.
2
3mod headers;
4mod instrument;
5mod outbound_http;
6mod server;
7mod spin;
8mod tls;
9mod wagi;
10mod wasi;
11mod wasip3;
12
13use std::{
14    error::Error,
15    fmt::Display,
16    net::{Ipv4Addr, SocketAddr, ToSocketAddrs},
17    path::PathBuf,
18    str::FromStr,
19    sync::Arc,
20    time::Duration,
21};
22
23use anyhow::{bail, Context};
24use clap::Args;
25use rand::{
26    distr::uniform::{SampleRange, SampleUniform},
27    RngCore,
28};
29use serde::Deserialize;
30use spin_app::App;
31use spin_factors::RuntimeFactors;
32use spin_trigger::Trigger;
33use wasmtime_wasi_http::bindings::http::types::ErrorCode;
34
35pub use server::HttpServer;
36
37pub use tls::TlsConfig;
38
39pub(crate) use wasmtime_wasi_http::body::HyperIncomingBody as Body;
40
41const DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT: usize = 128;
42const DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT: usize = 16;
43const DEFAULT_REQUEST_TIMEOUT: Option<Range<Duration>> = None;
44const DEFAULT_IDLE_INSTANCE_TIMEOUT: Range<Duration> = Range::Value(Duration::from_secs(1));
45
46/// A [`spin_trigger::TriggerApp`] for the HTTP trigger.
47pub(crate) type TriggerApp<F> = spin_trigger::TriggerApp<HttpTrigger, F>;
48
49/// A [`spin_trigger::TriggerInstanceBuilder`] for the HTTP trigger.
50pub(crate) type TriggerInstanceBuilder<'a, F> =
51    spin_trigger::TriggerInstanceBuilder<'a, HttpTrigger, F>;
52
53#[derive(Args)]
54pub struct CliArgs {
55    /// IP address and port to listen on
56    #[clap(long = "listen", env = "SPIN_HTTP_LISTEN_ADDR", default_value = "127.0.0.1:3000", value_parser = parse_listen_addr)]
57    pub address: SocketAddr,
58
59    /// The path to the certificate to use for https, if this is not set, normal http will be used. The cert should be in PEM format
60    #[clap(long, env = "SPIN_TLS_CERT", requires = "tls-key")]
61    pub tls_cert: Option<PathBuf>,
62
63    /// The path to the certificate key to use for https, if this is not set, normal http will be used. The key should be in PKCS#8 format
64    #[clap(long, env = "SPIN_TLS_KEY", requires = "tls-cert")]
65    pub tls_key: Option<PathBuf>,
66
67    /// Sets the maximum buffer size (in bytes) for the HTTP connection. The minimum value allowed is 8192.
68    #[clap(long, env = "SPIN_HTTP1_MAX_BUF_SIZE")]
69    pub http1_max_buf_size: Option<usize>,
70
71    #[clap(long = "find-free-port")]
72    pub find_free_port: bool,
73
74    /// Maximum number of requests to send to a single component instance before
75    /// dropping it.
76    ///
77    /// This defaults to 1 for WASIp2 components and 128 for WASIp3 components.
78    /// As of this writing, setting it to more than 1 will have no effect for
79    /// WASIp2 components, but that may change in the future.
80    ///
81    /// This may be specified either as an integer value or as a range,
82    /// e.g. 1..8.  If it's a range, a number will be selected from that range
83    /// at random for each new instance.
84    #[clap(long, value_parser = parse_usize_range)]
85    max_instance_reuse_count: Option<Range<usize>>,
86
87    /// Maximum number of concurrent requests to send to a single component
88    /// instance.
89    ///
90    /// This defaults to 1 for WASIp2 components and 16 for WASIp3 components.
91    /// Note that setting it to more than 1 will have no effect for WASIp2
92    /// components since they cannot be called concurrently.
93    ///
94    /// This may be specified either as an integer value or as a range,
95    /// e.g. 1..8.  If it's a range, a number will be selected from that range
96    /// at random for each new instance.
97    #[clap(long, value_parser = parse_usize_range)]
98    max_instance_concurrent_reuse_count: Option<Range<usize>>,
99
100    /// Request timeout to enforce.
101    ///
102    /// As of this writing, this only affects WASIp3 components.
103    ///
104    /// A number with no suffix or with an `s` suffix is interpreted as seconds;
105    /// other accepted suffixes include `ms` (milliseconds), `us` or `μs`
106    /// (microseconds), and `ns` (nanoseconds).
107    ///
108    /// This may be specified either as a single time value or as a range,
109    /// e.g. 1..8s.  If it's a range, a value will be selected from that range
110    /// at random for each new instance.
111    #[clap(long, value_parser = parse_duration_range)]
112    request_timeout: Option<Range<Duration>>,
113
114    /// Time to hold an idle component instance for possible reuse before
115    /// dropping it.
116    ///
117    /// A number with no suffix or with an `s` suffix is interpreted as seconds;
118    /// other accepted suffixes include `ms` (milliseconds), `us` or `μs`
119    /// (microseconds), and `ns` (nanoseconds).
120    ///
121    /// This may be specified either as a single time value or as a range,
122    /// e.g. 1..8s.  If it's a range, a value will be selected from that range
123    /// at random for each new instance.
124    #[clap(long, default_value = "1s", value_parser = parse_duration_range)]
125    idle_instance_timeout: Range<Duration>,
126}
127
128impl CliArgs {
129    fn into_tls_config(self) -> Option<TlsConfig> {
130        match (self.tls_cert, self.tls_key) {
131            (Some(cert_path), Some(key_path)) => Some(TlsConfig {
132                cert_path,
133                key_path,
134            }),
135            (None, None) => None,
136            _ => unreachable!(),
137        }
138    }
139}
140
141#[derive(Copy, Clone)]
142enum Range<T> {
143    Value(T),
144    Bounds(T, T),
145}
146
147impl<T> Range<T> {
148    fn map<V>(self, fun: impl Fn(T) -> V) -> Range<V> {
149        match self {
150            Self::Value(v) => Range::Value(fun(v)),
151            Self::Bounds(a, b) => Range::Bounds(fun(a), fun(b)),
152        }
153    }
154}
155
156impl<T: SampleUniform + PartialOrd> SampleRange<T> for Range<T> {
157    fn sample_single<R: RngCore + ?Sized>(
158        self,
159        rng: &mut R,
160    ) -> Result<T, rand::distr::uniform::Error> {
161        match self {
162            Self::Value(v) => Ok(v),
163            Self::Bounds(a, b) => (a..b).sample_single(rng),
164        }
165    }
166
167    fn is_empty(&self) -> bool {
168        match self {
169            Self::Value(_) => false,
170            Self::Bounds(a, b) => (a..b).is_empty(),
171        }
172    }
173}
174
175fn parse_range<T: FromStr>(s: &str) -> Result<Range<T>, String>
176where
177    T::Err: Display,
178{
179    let error = |e| format!("expected integer or range; got {s:?}; {e}");
180    if let Some((start, end)) = s.split_once("..") {
181        Ok(Range::Bounds(
182            start.parse().map_err(error)?,
183            end.parse().map_err(error)?,
184        ))
185    } else {
186        Ok(Range::Value(s.parse().map_err(error)?))
187    }
188}
189
190fn parse_usize_range(s: &str) -> Result<Range<usize>, String> {
191    parse_range(s)
192}
193
194struct ParsedDuration(Duration);
195
196impl FromStr for ParsedDuration {
197    type Err = String;
198
199    fn from_str(s: &str) -> Result<Self, Self::Err> {
200        let error = |e| {
201            format!("expected integer suffixed by `s`, `ms`, `us`, `μs`, or `ns`; got {s:?}; {e}")
202        };
203        Ok(Self(match s.parse() {
204            Ok(val) => Duration::from_secs(val),
205            Err(err) => {
206                if let Some(num) = s.strip_suffix("s") {
207                    Duration::from_secs(num.parse().map_err(error)?)
208                } else if let Some(num) = s.strip_suffix("ms") {
209                    Duration::from_millis(num.parse().map_err(error)?)
210                } else if let Some(num) = s.strip_suffix("us").or(s.strip_suffix("μs")) {
211                    Duration::from_micros(num.parse().map_err(error)?)
212                } else if let Some(num) = s.strip_suffix("ns") {
213                    Duration::from_nanos(num.parse().map_err(error)?)
214                } else {
215                    return Err(error(err));
216                }
217            }
218        }))
219    }
220}
221
222fn parse_duration_range(s: &str) -> Result<Range<Duration>, String> {
223    parse_range::<ParsedDuration>(s).map(|v| v.map(|v| v.0))
224}
225
226#[derive(Clone, Copy)]
227pub struct InstanceReuseConfig {
228    max_instance_reuse_count: Range<usize>,
229    max_instance_concurrent_reuse_count: Range<usize>,
230    request_timeout: Option<Range<Duration>>,
231    idle_instance_timeout: Range<Duration>,
232}
233
234impl Default for InstanceReuseConfig {
235    fn default() -> Self {
236        Self {
237            max_instance_reuse_count: Range::Value(DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT),
238            max_instance_concurrent_reuse_count: Range::Value(
239                DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT,
240            ),
241            request_timeout: DEFAULT_REQUEST_TIMEOUT,
242            idle_instance_timeout: DEFAULT_IDLE_INSTANCE_TIMEOUT,
243        }
244    }
245}
246
247/// The Spin HTTP trigger.
248pub struct HttpTrigger {
249    /// The address the server should listen on.
250    ///
251    /// Note that this might not be the actual socket address that ends up being bound to.
252    /// If the port is set to 0, the actual address will be determined by the OS.
253    listen_addr: SocketAddr,
254    tls_config: Option<TlsConfig>,
255    find_free_port: bool,
256    http1_max_buf_size: Option<usize>,
257    reuse_config: InstanceReuseConfig,
258}
259
260impl<F: RuntimeFactors> Trigger<F> for HttpTrigger {
261    const TYPE: &'static str = "http";
262
263    type CliArgs = CliArgs;
264    type InstanceState = ();
265
266    fn new(cli_args: Self::CliArgs, app: &spin_app::App) -> anyhow::Result<Self> {
267        let find_free_port = cli_args.find_free_port;
268        let http1_max_buf_size = cli_args.http1_max_buf_size;
269        let reuse_config = InstanceReuseConfig {
270            max_instance_reuse_count: cli_args
271                .max_instance_reuse_count
272                .unwrap_or(Range::Value(DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT)),
273            max_instance_concurrent_reuse_count: cli_args
274                .max_instance_concurrent_reuse_count
275                .unwrap_or(Range::Value(
276                    DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT,
277                )),
278            request_timeout: cli_args.request_timeout,
279            idle_instance_timeout: cli_args.idle_instance_timeout,
280        };
281
282        Self::new(
283            app,
284            cli_args.address,
285            cli_args.into_tls_config(),
286            find_free_port,
287            http1_max_buf_size,
288            reuse_config,
289        )
290    }
291
292    async fn run(self, trigger_app: TriggerApp<F>) -> anyhow::Result<()> {
293        let server = self.into_server(trigger_app)?;
294
295        server.serve().await?;
296
297        Ok(())
298    }
299
300    fn supported_host_requirements() -> Vec<&'static str> {
301        vec![spin_app::locked::SERVICE_CHAINING_KEY]
302    }
303}
304
305impl HttpTrigger {
306    /// Create a new `HttpTrigger`.
307    pub fn new(
308        app: &spin_app::App,
309        listen_addr: SocketAddr,
310        tls_config: Option<TlsConfig>,
311        find_free_port: bool,
312        http1_max_buf_size: Option<usize>,
313        reuse_config: InstanceReuseConfig,
314    ) -> anyhow::Result<Self> {
315        Self::validate_app(app)?;
316
317        Ok(Self {
318            listen_addr,
319            tls_config,
320            find_free_port,
321            http1_max_buf_size,
322            reuse_config,
323        })
324    }
325
326    /// Turn this [`HttpTrigger`] into an [`HttpServer`].
327    pub fn into_server<F: RuntimeFactors>(
328        self,
329        trigger_app: TriggerApp<F>,
330    ) -> anyhow::Result<Arc<HttpServer<F>>> {
331        let Self {
332            listen_addr,
333            tls_config,
334            find_free_port,
335            http1_max_buf_size,
336            reuse_config,
337        } = self;
338        let server = Arc::new(HttpServer::new(
339            listen_addr,
340            tls_config,
341            find_free_port,
342            trigger_app,
343            http1_max_buf_size,
344            reuse_config,
345        )?);
346        Ok(server)
347    }
348
349    fn validate_app(app: &App) -> anyhow::Result<()> {
350        #[derive(Deserialize)]
351        #[serde(deny_unknown_fields)]
352        struct TriggerMetadata {
353            base: Option<String>,
354        }
355        if let Some(TriggerMetadata { base: Some(base) }) = app.get_trigger_metadata("http")? {
356            if base == "/" {
357                tracing::warn!(
358                    "This application has the deprecated trigger 'base' set to the default value '/'. This may be an error in the future!"
359                );
360            } else {
361                bail!(
362                    "This application is using the deprecated trigger 'base' field. The base must be prepended to each [[trigger.http]]'s 'route'."
363                )
364            }
365        }
366        Ok(())
367    }
368}
369
370fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {
371    let addrs: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
372    // Prefer 127.0.0.1 over e.g. [::1] because CHANGE IS HARD
373    if let Some(addr) = addrs
374        .iter()
375        .find(|addr| addr.is_ipv4() && addr.ip() == Ipv4Addr::LOCALHOST)
376    {
377        return Ok(*addr);
378    }
379    // Otherwise, take the first addr (OS preference)
380    addrs.into_iter().next().context("couldn't resolve address")
381}
382
383#[derive(Debug, PartialEq)]
384enum NotFoundRouteKind {
385    Normal(String),
386    WellKnown,
387}
388
389/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
390pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
391    // If there's a source, we might be able to extract a wasi-http error from it.
392    if let Some(cause) = err.source() {
393        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
394            return err.clone();
395        }
396    }
397
398    tracing::warn!("hyper request error: {err:?}");
399
400    ErrorCode::HttpProtocolError
401}
402
403pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
404    ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
405        rcode: Some(rcode),
406        info_code: Some(info_code),
407    })
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn parse_listen_addr_prefers_ipv4() {
416        let addr = parse_listen_addr("localhost:12345").unwrap();
417        assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
418        assert_eq!(addr.port(), 12345);
419    }
420}