spin_trigger_http/
lib.rs

1//! Implementation for the Spin HTTP engine.
2
3mod headers;
4mod instrument;
5mod outbound_http;
6mod server;
7mod spin;
8mod tls;
9mod wagi;
10mod wasi;
11
12use std::{
13    error::Error,
14    net::{Ipv4Addr, SocketAddr, ToSocketAddrs},
15    path::PathBuf,
16    sync::Arc,
17};
18
19use anyhow::{bail, Context};
20use clap::Args;
21use serde::Deserialize;
22use spin_app::App;
23use spin_factors::RuntimeFactors;
24use spin_trigger::Trigger;
25use wasmtime_wasi_http::bindings::http::types::ErrorCode;
26
27pub use server::HttpServer;
28
29pub use tls::TlsConfig;
30
31pub(crate) use wasmtime_wasi_http::body::HyperIncomingBody as Body;
32
33/// A [`spin_trigger::TriggerApp`] for the HTTP trigger.
34pub(crate) type TriggerApp<F> = spin_trigger::TriggerApp<HttpTrigger, F>;
35
36/// A [`spin_trigger::TriggerInstanceBuilder`] for the HTTP trigger.
37pub(crate) type TriggerInstanceBuilder<'a, F> =
38    spin_trigger::TriggerInstanceBuilder<'a, HttpTrigger, F>;
39
40#[derive(Args)]
41pub struct CliArgs {
42    /// IP address and port to listen on
43    #[clap(long = "listen", env = "SPIN_HTTP_LISTEN_ADDR", default_value = "127.0.0.1:3000", value_parser = parse_listen_addr)]
44    pub address: SocketAddr,
45
46    /// The path to the certificate to use for https, if this is not set, normal http will be used. The cert should be in PEM format
47    #[clap(long, env = "SPIN_TLS_CERT", requires = "tls-key")]
48    pub tls_cert: Option<PathBuf>,
49
50    /// The path to the certificate key to use for https, if this is not set, normal http will be used. The key should be in PKCS#8 format
51    #[clap(long, env = "SPIN_TLS_KEY", requires = "tls-cert")]
52    pub tls_key: Option<PathBuf>,
53
54    #[clap(long = "find-free-port")]
55    pub find_free_port: bool,
56}
57
58impl CliArgs {
59    fn into_tls_config(self) -> Option<TlsConfig> {
60        match (self.tls_cert, self.tls_key) {
61            (Some(cert_path), Some(key_path)) => Some(TlsConfig {
62                cert_path,
63                key_path,
64            }),
65            (None, None) => None,
66            _ => unreachable!(),
67        }
68    }
69}
70
71/// The Spin HTTP trigger.
72pub struct HttpTrigger {
73    /// The address the server should listen on.
74    ///
75    /// Note that this might not be the actual socket address that ends up being bound to.
76    /// If the port is set to 0, the actual address will be determined by the OS.
77    listen_addr: SocketAddr,
78    tls_config: Option<TlsConfig>,
79    find_free_port: bool,
80}
81
82impl<F: RuntimeFactors> Trigger<F> for HttpTrigger {
83    const TYPE: &'static str = "http";
84
85    type CliArgs = CliArgs;
86    type InstanceState = ();
87
88    fn new(cli_args: Self::CliArgs, app: &spin_app::App) -> anyhow::Result<Self> {
89        let find_free_port = cli_args.find_free_port;
90
91        Self::new(
92            app,
93            cli_args.address,
94            cli_args.into_tls_config(),
95            find_free_port,
96        )
97    }
98
99    async fn run(self, trigger_app: TriggerApp<F>) -> anyhow::Result<()> {
100        let server = self.into_server(trigger_app)?;
101
102        server.serve().await?;
103
104        Ok(())
105    }
106
107    fn supported_host_requirements() -> Vec<&'static str> {
108        vec![spin_app::locked::SERVICE_CHAINING_KEY]
109    }
110}
111
112impl HttpTrigger {
113    /// Create a new `HttpTrigger`.
114    pub fn new(
115        app: &spin_app::App,
116        listen_addr: SocketAddr,
117        tls_config: Option<TlsConfig>,
118        find_free_port: bool,
119    ) -> anyhow::Result<Self> {
120        Self::validate_app(app)?;
121
122        Ok(Self {
123            listen_addr,
124            tls_config,
125            find_free_port,
126        })
127    }
128
129    /// Turn this [`HttpTrigger`] into an [`HttpServer`].
130    pub fn into_server<F: RuntimeFactors>(
131        self,
132        trigger_app: TriggerApp<F>,
133    ) -> anyhow::Result<Arc<HttpServer<F>>> {
134        let Self {
135            listen_addr,
136            tls_config,
137            find_free_port,
138        } = self;
139        let server = Arc::new(HttpServer::new(
140            listen_addr,
141            tls_config,
142            find_free_port,
143            trigger_app,
144        )?);
145        Ok(server)
146    }
147
148    fn validate_app(app: &App) -> anyhow::Result<()> {
149        #[derive(Deserialize)]
150        #[serde(deny_unknown_fields)]
151        struct TriggerMetadata {
152            base: Option<String>,
153        }
154        if let Some(TriggerMetadata { base: Some(base) }) = app.get_trigger_metadata("http")? {
155            if base == "/" {
156                tracing::warn!("This application has the deprecated trigger 'base' set to the default value '/'. This may be an error in the future!");
157            } else {
158                bail!("This application is using the deprecated trigger 'base' field. The base must be prepended to each [[trigger.http]]'s 'route'.")
159            }
160        }
161        Ok(())
162    }
163}
164
165fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {
166    let addrs: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
167    // Prefer 127.0.0.1 over e.g. [::1] because CHANGE IS HARD
168    if let Some(addr) = addrs
169        .iter()
170        .find(|addr| addr.is_ipv4() && addr.ip() == Ipv4Addr::LOCALHOST)
171    {
172        return Ok(*addr);
173    }
174    // Otherwise, take the first addr (OS preference)
175    addrs.into_iter().next().context("couldn't resolve address")
176}
177
178#[derive(Debug, PartialEq)]
179enum NotFoundRouteKind {
180    Normal(String),
181    WellKnown,
182}
183
184/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
185pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
186    // If there's a source, we might be able to extract a wasi-http error from it.
187    if let Some(cause) = err.source() {
188        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
189            return err.clone();
190        }
191    }
192
193    tracing::warn!("hyper request error: {err:?}");
194
195    ErrorCode::HttpProtocolError
196}
197
198pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
199    ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
200        rcode: Some(rcode),
201        info_code: Some(info_code),
202    })
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn parse_listen_addr_prefers_ipv4() {
211        let addr = parse_listen_addr("localhost:12345").unwrap();
212        assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
213        assert_eq!(addr.port(), 12345);
214    }
215}