Skip to main content

spin_trigger_http/
lib.rs

1//! Implementation for the Spin HTTP engine.
2
3mod headers;
4mod instrument;
5mod outbound_http;
6mod server;
7mod spin;
8mod tls;
9mod wagi;
10mod wasi;
11mod wasip3;
12
13use std::{
14    error::Error,
15    fmt::Display,
16    net::{Ipv4Addr, SocketAddr, ToSocketAddrs},
17    path::PathBuf,
18    str::FromStr,
19    sync::Arc,
20    time::Duration,
21};
22
23use anyhow::{Context, bail};
24use clap::Args;
25use rand::{
26    RngCore,
27    distr::uniform::{SampleRange, SampleUniform},
28};
29use serde::Deserialize;
30use spin_app::App;
31use spin_factors::RuntimeFactors;
32use spin_trigger::Trigger;
33use wasmtime_wasi_http::p2::bindings::http::types::ErrorCode;
34
35pub use server::HttpServer;
36
37pub use tls::TlsConfig;
38
39pub(crate) use wasmtime_wasi_http::p2::body::HyperIncomingBody as Body;
40
41const DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT: usize = 128;
42const DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT: usize = 16;
43const DEFAULT_REQUEST_TIMEOUT: Option<Range<Duration>> = None;
44const DEFAULT_IDLE_INSTANCE_TIMEOUT: Range<Duration> = Range::Value(Duration::from_secs(1));
45
46/// The format in which to print startup route information.
47#[derive(clap::ValueEnum, Clone, Copy, Debug, Default)]
48pub enum OutputFormat {
49    /// Human-readable plain text output (the default).
50    #[default]
51    Plain,
52    /// Machine-readable JSON output.
53    Json,
54}
55
56/// A [`spin_trigger::TriggerApp`] for the HTTP trigger.
57pub(crate) type TriggerApp<F> = spin_trigger::TriggerApp<HttpTrigger, F>;
58
59/// A [`spin_trigger::TriggerInstanceBuilder`] for the HTTP trigger.
60pub(crate) type TriggerInstanceBuilder<'a, F> =
61    spin_trigger::TriggerInstanceBuilder<'a, HttpTrigger, F>;
62
63#[derive(Args)]
64pub struct CliArgs {
65    /// IP address and port to listen on
66    #[clap(long = "listen", env = "SPIN_HTTP_LISTEN_ADDR", default_value = "127.0.0.1:3000", value_parser = parse_listen_addr)]
67    pub address: SocketAddr,
68
69    /// The path to the certificate to use for https, if this is not set, normal http will be used. The cert should be in PEM format
70    #[clap(long, env = "SPIN_TLS_CERT", requires = "tls_key")]
71    pub tls_cert: Option<PathBuf>,
72
73    /// The path to the certificate key to use for https, if this is not set, normal http will be used. The key should be in PKCS#8 format
74    #[clap(long, env = "SPIN_TLS_KEY", requires = "tls_cert")]
75    pub tls_key: Option<PathBuf>,
76
77    /// Sets the maximum buffer size (in bytes) for the HTTP connection. The minimum value allowed is 8192.
78    #[clap(long, env = "SPIN_HTTP1_MAX_BUF_SIZE")]
79    pub http1_max_buf_size: Option<usize>,
80
81    #[clap(long = "find-free-port")]
82    pub find_free_port: bool,
83
84    #[clap(value_enum, long = "format", default_value_t = OutputFormat::default())]
85    pub format: OutputFormat,
86
87    /// Maximum number of requests to send to a single component instance before
88    /// dropping it.
89    ///
90    /// This defaults to 1 for WASIp2 components and 128 for WASIp3 components.
91    /// As of this writing, setting it to more than 1 will have no effect for
92    /// WASIp2 components, but that may change in the future.
93    ///
94    /// This may be specified either as an integer value or as a range,
95    /// e.g. 1..8.  If it's a range, a number will be selected from that range
96    /// at random for each new instance.
97    #[clap(long, value_parser = parse_usize_range)]
98    pub max_instance_reuse_count: Option<Range<usize>>,
99
100    /// Maximum number of concurrent requests to send to a single component
101    /// instance.
102    ///
103    /// This defaults to 1 for WASIp2 components and 16 for WASIp3 components.
104    /// Note that setting it to more than 1 will have no effect for WASIp2
105    /// components since they cannot be called concurrently.
106    ///
107    /// This may be specified either as an integer value or as a range,
108    /// e.g. 1..8.  If it's a range, a number will be selected from that range
109    /// at random for each new instance.
110    #[clap(long, value_parser = parse_usize_range)]
111    pub max_instance_concurrent_reuse_count: Option<Range<usize>>,
112
113    /// Request timeout to enforce.
114    ///
115    /// As of this writing, this only affects WASIp3 components.
116    ///
117    /// A number with no suffix or with an `s` suffix is interpreted as seconds;
118    /// other accepted suffixes include `ms` (milliseconds), `us` or `μs`
119    /// (microseconds), and `ns` (nanoseconds).
120    ///
121    /// This may be specified either as a single time value or as a range,
122    /// e.g. 1..8s.  If it's a range, a value will be selected from that range
123    /// at random for each new instance.
124    #[clap(long, value_parser = parse_duration_range)]
125    pub request_timeout: Option<Range<Duration>>,
126
127    /// Time to hold an idle component instance for possible reuse before
128    /// dropping it.
129    ///
130    /// A number with no suffix or with an `s` suffix is interpreted as seconds;
131    /// other accepted suffixes include `ms` (milliseconds), `us` or `μs`
132    /// (microseconds), and `ns` (nanoseconds).
133    ///
134    /// This may be specified either as a single time value or as a range,
135    /// e.g. 1..8s.  If it's a range, a value will be selected from that range
136    /// at random for each new instance.
137    #[clap(long, default_value = "1s", value_parser = parse_duration_range)]
138    pub idle_instance_timeout: Range<Duration>,
139}
140
141impl CliArgs {
142    fn into_tls_config(self) -> Option<TlsConfig> {
143        match (self.tls_cert, self.tls_key) {
144            (Some(cert_path), Some(key_path)) => Some(TlsConfig {
145                cert_path,
146                key_path,
147            }),
148            (None, None) => None,
149            _ => unreachable!(),
150        }
151    }
152}
153
154#[derive(Copy, Clone)]
155pub enum Range<T> {
156    Value(T),
157    Bounds(T, T),
158}
159
160impl<T> Range<T> {
161    fn map<V>(self, fun: impl Fn(T) -> V) -> Range<V> {
162        match self {
163            Self::Value(v) => Range::Value(fun(v)),
164            Self::Bounds(a, b) => Range::Bounds(fun(a), fun(b)),
165        }
166    }
167}
168
169impl<T: SampleUniform + PartialOrd> SampleRange<T> for Range<T> {
170    fn sample_single<R: RngCore + ?Sized>(
171        self,
172        rng: &mut R,
173    ) -> Result<T, rand::distr::uniform::Error> {
174        match self {
175            Self::Value(v) => Ok(v),
176            Self::Bounds(a, b) => (a..b).sample_single(rng),
177        }
178    }
179
180    fn is_empty(&self) -> bool {
181        match self {
182            Self::Value(_) => false,
183            Self::Bounds(a, b) => (a..b).is_empty(),
184        }
185    }
186}
187
188fn parse_range<T: FromStr>(s: &str) -> Result<Range<T>, String>
189where
190    T::Err: Display,
191{
192    let error = |e| format!("expected integer or range; got {s:?}; {e}");
193    if let Some((start, end)) = s.split_once("..") {
194        Ok(Range::Bounds(
195            start.parse().map_err(error)?,
196            end.parse().map_err(error)?,
197        ))
198    } else {
199        Ok(Range::Value(s.parse().map_err(error)?))
200    }
201}
202
203fn parse_usize_range(s: &str) -> Result<Range<usize>, String> {
204    parse_range(s)
205}
206
207struct ParsedDuration(Duration);
208
209impl FromStr for ParsedDuration {
210    type Err = String;
211
212    fn from_str(s: &str) -> Result<Self, Self::Err> {
213        let error = |e| {
214            format!("expected integer suffixed by `s`, `ms`, `us`, `μs`, or `ns`; got {s:?}; {e}")
215        };
216        Ok(Self(match s.parse() {
217            Ok(val) => Duration::from_secs(val),
218            Err(err) => {
219                if let Some(num) = s.strip_suffix("s") {
220                    Duration::from_secs(num.parse().map_err(error)?)
221                } else if let Some(num) = s.strip_suffix("ms") {
222                    Duration::from_millis(num.parse().map_err(error)?)
223                } else if let Some(num) = s.strip_suffix("us").or(s.strip_suffix("μs")) {
224                    Duration::from_micros(num.parse().map_err(error)?)
225                } else if let Some(num) = s.strip_suffix("ns") {
226                    Duration::from_nanos(num.parse().map_err(error)?)
227                } else {
228                    return Err(error(err));
229                }
230            }
231        }))
232    }
233}
234
235fn parse_duration_range(s: &str) -> Result<Range<Duration>, String> {
236    parse_range::<ParsedDuration>(s).map(|v| v.map(|v| v.0))
237}
238
239#[derive(Clone, Copy)]
240pub struct InstanceReuseConfig {
241    max_instance_reuse_count: Range<usize>,
242    max_instance_concurrent_reuse_count: Range<usize>,
243    request_timeout: Option<Range<Duration>>,
244    idle_instance_timeout: Range<Duration>,
245}
246
247impl Default for InstanceReuseConfig {
248    fn default() -> Self {
249        Self {
250            max_instance_reuse_count: Range::Value(DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT),
251            max_instance_concurrent_reuse_count: Range::Value(
252                DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT,
253            ),
254            request_timeout: DEFAULT_REQUEST_TIMEOUT,
255            idle_instance_timeout: DEFAULT_IDLE_INSTANCE_TIMEOUT,
256        }
257    }
258}
259
260/// The Spin HTTP trigger.
261pub struct HttpTrigger {
262    /// The address the server should listen on.
263    ///
264    /// Note that this might not be the actual socket address that ends up being bound to.
265    /// If the port is set to 0, the actual address will be determined by the OS.
266    listen_addr: SocketAddr,
267    tls_config: Option<TlsConfig>,
268    find_free_port: bool,
269    http1_max_buf_size: Option<usize>,
270    reuse_config: InstanceReuseConfig,
271    output_format: OutputFormat,
272}
273
274impl<F: RuntimeFactors> Trigger<F> for HttpTrigger {
275    const TYPE: &'static str = "http";
276
277    type CliArgs = CliArgs;
278    type InstanceState = ();
279
280    fn new(cli_args: Self::CliArgs, app: &spin_app::App) -> anyhow::Result<Self> {
281        let find_free_port = cli_args.find_free_port;
282        let http1_max_buf_size = cli_args.http1_max_buf_size;
283        let output_format = cli_args.format;
284        let reuse_config = InstanceReuseConfig {
285            max_instance_reuse_count: cli_args
286                .max_instance_reuse_count
287                .unwrap_or(Range::Value(DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT)),
288            max_instance_concurrent_reuse_count: cli_args
289                .max_instance_concurrent_reuse_count
290                .unwrap_or(Range::Value(
291                    DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT,
292                )),
293            request_timeout: cli_args.request_timeout,
294            idle_instance_timeout: cli_args.idle_instance_timeout,
295        };
296
297        Self::new(
298            app,
299            cli_args.address,
300            cli_args.into_tls_config(),
301            find_free_port,
302            http1_max_buf_size,
303            reuse_config,
304            output_format,
305        )
306    }
307
308    async fn run(self, trigger_app: TriggerApp<F>) -> anyhow::Result<()> {
309        let server = self.into_server(trigger_app)?;
310
311        server.serve().await?;
312
313        Ok(())
314    }
315
316    fn supported_host_requirements() -> Vec<&'static str> {
317        vec![spin_app::locked::SERVICE_CHAINING_KEY]
318    }
319
320    fn display_name() -> String {
321        "HTTP".to_string()
322    }
323}
324
325impl HttpTrigger {
326    /// Create a new `HttpTrigger`.
327    pub fn new(
328        app: &spin_app::App,
329        listen_addr: SocketAddr,
330        tls_config: Option<TlsConfig>,
331        find_free_port: bool,
332        http1_max_buf_size: Option<usize>,
333        reuse_config: InstanceReuseConfig,
334        output_format: OutputFormat,
335    ) -> anyhow::Result<Self> {
336        Self::validate_app(app)?;
337
338        Ok(Self {
339            listen_addr,
340            tls_config,
341            find_free_port,
342            http1_max_buf_size,
343            reuse_config,
344            output_format,
345        })
346    }
347
348    /// Turn this [`HttpTrigger`] into an [`HttpServer`].
349    pub fn into_server<F: RuntimeFactors>(
350        self,
351        trigger_app: TriggerApp<F>,
352    ) -> anyhow::Result<Arc<HttpServer<F>>> {
353        let Self {
354            listen_addr,
355            tls_config,
356            find_free_port,
357            http1_max_buf_size,
358            reuse_config,
359            output_format,
360        } = self;
361        let server = Arc::new(HttpServer::new(
362            listen_addr,
363            tls_config,
364            find_free_port,
365            trigger_app,
366            http1_max_buf_size,
367            reuse_config,
368            output_format,
369        )?);
370        Ok(server)
371    }
372
373    fn validate_app(app: &App) -> anyhow::Result<()> {
374        #[derive(Deserialize)]
375        #[serde(deny_unknown_fields)]
376        struct TriggerMetadata {
377            base: Option<String>,
378        }
379        if let Some(TriggerMetadata { base: Some(base) }) = app.get_trigger_metadata("http")? {
380            if base == "/" {
381                tracing::warn!(
382                    "This application has the deprecated trigger 'base' set to the default value '/'. This may be an error in the future!"
383                );
384            } else {
385                bail!(
386                    "This application is using the deprecated trigger 'base' field. The base must be prepended to each [[trigger.http]]'s 'route'."
387                )
388            }
389        }
390        Ok(())
391    }
392}
393
394fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {
395    let addrs: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
396    // Prefer 127.0.0.1 over e.g. [::1] because CHANGE IS HARD
397    if let Some(addr) = addrs
398        .iter()
399        .find(|addr| addr.is_ipv4() && addr.ip() == Ipv4Addr::LOCALHOST)
400    {
401        return Ok(*addr);
402    }
403    // Otherwise, take the first addr (OS preference)
404    addrs.into_iter().next().context("couldn't resolve address")
405}
406
407#[derive(Debug, PartialEq)]
408enum NotFoundRouteKind {
409    Normal(String),
410    WellKnown,
411}
412
413/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
414pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
415    // If there's a source, we might be able to extract a wasi-http error from it.
416    if let Some(cause) = err.source()
417        && let Some(err) = cause.downcast_ref::<ErrorCode>()
418    {
419        return err.clone();
420    }
421
422    tracing::warn!("hyper request error: {err:?}");
423
424    ErrorCode::HttpProtocolError
425}
426
427pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
428    ErrorCode::DnsError(
429        wasmtime_wasi_http::p2::bindings::http::types::DnsErrorPayload {
430            rcode: Some(rcode),
431            info_code: Some(info_code),
432        },
433    )
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn parse_listen_addr_prefers_ipv4() {
442        let addr = parse_listen_addr("localhost:12345").unwrap();
443        assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
444        assert_eq!(addr.port(), 12345);
445    }
446}