spin_trigger_http/
lib.rs

1//! Implementation for the Spin HTTP engine.
2
3mod headers;
4mod instrument;
5mod outbound_http;
6mod server;
7mod spin;
8mod tls;
9mod wagi;
10mod wasi;
11mod wasip3;
12
13use std::{
14    error::Error,
15    net::{Ipv4Addr, SocketAddr, ToSocketAddrs},
16    path::PathBuf,
17    sync::Arc,
18};
19
20use anyhow::{bail, Context};
21use clap::Args;
22use serde::Deserialize;
23use spin_app::App;
24use spin_factors::RuntimeFactors;
25use spin_trigger::Trigger;
26use wasmtime_wasi_http::bindings::http::types::ErrorCode;
27
28pub use server::HttpServer;
29
30pub use tls::TlsConfig;
31
32pub(crate) use wasmtime_wasi_http::body::HyperIncomingBody as Body;
33
34/// A [`spin_trigger::TriggerApp`] for the HTTP trigger.
35pub(crate) type TriggerApp<F> = spin_trigger::TriggerApp<HttpTrigger, F>;
36
37/// A [`spin_trigger::TriggerInstanceBuilder`] for the HTTP trigger.
38pub(crate) type TriggerInstanceBuilder<'a, F> =
39    spin_trigger::TriggerInstanceBuilder<'a, HttpTrigger, F>;
40
41#[derive(Args)]
42pub struct CliArgs {
43    /// IP address and port to listen on
44    #[clap(long = "listen", env = "SPIN_HTTP_LISTEN_ADDR", default_value = "127.0.0.1:3000", value_parser = parse_listen_addr)]
45    pub address: SocketAddr,
46
47    /// The path to the certificate to use for https, if this is not set, normal http will be used. The cert should be in PEM format
48    #[clap(long, env = "SPIN_TLS_CERT", requires = "tls-key")]
49    pub tls_cert: Option<PathBuf>,
50
51    /// The path to the certificate key to use for https, if this is not set, normal http will be used. The key should be in PKCS#8 format
52    #[clap(long, env = "SPIN_TLS_KEY", requires = "tls-cert")]
53    pub tls_key: Option<PathBuf>,
54
55    #[clap(long = "find-free-port")]
56    pub find_free_port: bool,
57}
58
59impl CliArgs {
60    fn into_tls_config(self) -> Option<TlsConfig> {
61        match (self.tls_cert, self.tls_key) {
62            (Some(cert_path), Some(key_path)) => Some(TlsConfig {
63                cert_path,
64                key_path,
65            }),
66            (None, None) => None,
67            _ => unreachable!(),
68        }
69    }
70}
71
72/// The Spin HTTP trigger.
73pub struct HttpTrigger {
74    /// The address the server should listen on.
75    ///
76    /// Note that this might not be the actual socket address that ends up being bound to.
77    /// If the port is set to 0, the actual address will be determined by the OS.
78    listen_addr: SocketAddr,
79    tls_config: Option<TlsConfig>,
80    find_free_port: bool,
81}
82
83impl<F: RuntimeFactors> Trigger<F> for HttpTrigger {
84    const TYPE: &'static str = "http";
85
86    type CliArgs = CliArgs;
87    type InstanceState = ();
88
89    fn new(cli_args: Self::CliArgs, app: &spin_app::App) -> anyhow::Result<Self> {
90        let find_free_port = cli_args.find_free_port;
91
92        Self::new(
93            app,
94            cli_args.address,
95            cli_args.into_tls_config(),
96            find_free_port,
97        )
98    }
99
100    async fn run(self, trigger_app: TriggerApp<F>) -> anyhow::Result<()> {
101        let server = self.into_server(trigger_app)?;
102
103        server.serve().await?;
104
105        Ok(())
106    }
107
108    fn supported_host_requirements() -> Vec<&'static str> {
109        vec![spin_app::locked::SERVICE_CHAINING_KEY]
110    }
111}
112
113impl HttpTrigger {
114    /// Create a new `HttpTrigger`.
115    pub fn new(
116        app: &spin_app::App,
117        listen_addr: SocketAddr,
118        tls_config: Option<TlsConfig>,
119        find_free_port: bool,
120    ) -> anyhow::Result<Self> {
121        Self::validate_app(app)?;
122
123        Ok(Self {
124            listen_addr,
125            tls_config,
126            find_free_port,
127        })
128    }
129
130    /// Turn this [`HttpTrigger`] into an [`HttpServer`].
131    pub fn into_server<F: RuntimeFactors>(
132        self,
133        trigger_app: TriggerApp<F>,
134    ) -> anyhow::Result<Arc<HttpServer<F>>> {
135        let Self {
136            listen_addr,
137            tls_config,
138            find_free_port,
139        } = self;
140        let server = Arc::new(HttpServer::new(
141            listen_addr,
142            tls_config,
143            find_free_port,
144            trigger_app,
145        )?);
146        Ok(server)
147    }
148
149    fn validate_app(app: &App) -> anyhow::Result<()> {
150        #[derive(Deserialize)]
151        #[serde(deny_unknown_fields)]
152        struct TriggerMetadata {
153            base: Option<String>,
154        }
155        if let Some(TriggerMetadata { base: Some(base) }) = app.get_trigger_metadata("http")? {
156            if base == "/" {
157                tracing::warn!(
158                    "This application has the deprecated trigger 'base' set to the default value '/'. This may be an error in the future!"
159                );
160            } else {
161                bail!(
162                    "This application is using the deprecated trigger 'base' field. The base must be prepended to each [[trigger.http]]'s 'route'."
163                )
164            }
165        }
166        Ok(())
167    }
168}
169
170fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {
171    let addrs: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
172    // Prefer 127.0.0.1 over e.g. [::1] because CHANGE IS HARD
173    if let Some(addr) = addrs
174        .iter()
175        .find(|addr| addr.is_ipv4() && addr.ip() == Ipv4Addr::LOCALHOST)
176    {
177        return Ok(*addr);
178    }
179    // Otherwise, take the first addr (OS preference)
180    addrs.into_iter().next().context("couldn't resolve address")
181}
182
183#[derive(Debug, PartialEq)]
184enum NotFoundRouteKind {
185    Normal(String),
186    WellKnown,
187}
188
189/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
190pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
191    // If there's a source, we might be able to extract a wasi-http error from it.
192    if let Some(cause) = err.source() {
193        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
194            return err.clone();
195        }
196    }
197
198    tracing::warn!("hyper request error: {err:?}");
199
200    ErrorCode::HttpProtocolError
201}
202
203pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
204    ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
205        rcode: Some(rcode),
206        info_code: Some(info_code),
207    })
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn parse_listen_addr_prefers_ipv4() {
216        let addr = parse_listen_addr("localhost:12345").unwrap();
217        assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
218        assert_eq!(addr.port(), 12345);
219    }
220}