spin_trigger_http/
lib.rs

1//! Implementation for the Spin HTTP engine.
2
3mod headers;
4mod instrument;
5mod outbound_http;
6mod server;
7mod spin;
8mod tls;
9mod wagi;
10mod wasi;
11
12use std::{
13    error::Error,
14    net::{Ipv4Addr, SocketAddr, ToSocketAddrs},
15    path::PathBuf,
16    sync::Arc,
17};
18
19use anyhow::{bail, Context};
20use clap::Args;
21use serde::Deserialize;
22use spin_app::App;
23use spin_factors::RuntimeFactors;
24use spin_trigger::Trigger;
25use wasmtime_wasi_http::bindings::http::types::ErrorCode;
26
27pub use server::HttpServer;
28
29pub use tls::TlsConfig;
30
31pub(crate) use wasmtime_wasi_http::body::HyperIncomingBody as Body;
32
33/// A [`spin_trigger::TriggerApp`] for the HTTP trigger.
34pub(crate) type TriggerApp<F> = spin_trigger::TriggerApp<HttpTrigger, F>;
35
36/// A [`spin_trigger::TriggerInstanceBuilder`] for the HTTP trigger.
37pub(crate) type TriggerInstanceBuilder<'a, F> =
38    spin_trigger::TriggerInstanceBuilder<'a, HttpTrigger, F>;
39
40#[derive(Args)]
41pub struct CliArgs {
42    /// IP address and port to listen on
43    #[clap(long = "listen", env = "SPIN_HTTP_LISTEN_ADDR", default_value = "127.0.0.1:3000", value_parser = parse_listen_addr)]
44    pub address: SocketAddr,
45
46    /// The path to the certificate to use for https, if this is not set, normal http will be used. The cert should be in PEM format
47    #[clap(long, env = "SPIN_TLS_CERT", requires = "tls-key")]
48    pub tls_cert: Option<PathBuf>,
49
50    /// The path to the certificate key to use for https, if this is not set, normal http will be used. The key should be in PKCS#8 format
51    #[clap(long, env = "SPIN_TLS_KEY", requires = "tls-cert")]
52    pub tls_key: Option<PathBuf>,
53}
54
55impl CliArgs {
56    fn into_tls_config(self) -> Option<TlsConfig> {
57        match (self.tls_cert, self.tls_key) {
58            (Some(cert_path), Some(key_path)) => Some(TlsConfig {
59                cert_path,
60                key_path,
61            }),
62            (None, None) => None,
63            _ => unreachable!(),
64        }
65    }
66}
67
68/// The Spin HTTP trigger.
69pub struct HttpTrigger {
70    /// The address the server should listen on.
71    ///
72    /// Note that this might not be the actual socket address that ends up being bound to.
73    /// If the port is set to 0, the actual address will be determined by the OS.
74    listen_addr: SocketAddr,
75    tls_config: Option<TlsConfig>,
76}
77
78impl<F: RuntimeFactors> Trigger<F> for HttpTrigger {
79    const TYPE: &'static str = "http";
80
81    type CliArgs = CliArgs;
82    type InstanceState = ();
83
84    fn new(cli_args: Self::CliArgs, app: &spin_app::App) -> anyhow::Result<Self> {
85        Self::new(app, cli_args.address, cli_args.into_tls_config())
86    }
87
88    async fn run(self, trigger_app: TriggerApp<F>) -> anyhow::Result<()> {
89        let server = self.into_server(trigger_app)?;
90
91        server.serve().await?;
92
93        Ok(())
94    }
95
96    fn supported_host_requirements() -> Vec<&'static str> {
97        vec![spin_app::locked::SERVICE_CHAINING_KEY]
98    }
99}
100
101impl HttpTrigger {
102    /// Create a new `HttpTrigger`.
103    pub fn new(
104        app: &spin_app::App,
105        listen_addr: SocketAddr,
106        tls_config: Option<TlsConfig>,
107    ) -> anyhow::Result<Self> {
108        Self::validate_app(app)?;
109
110        Ok(Self {
111            listen_addr,
112            tls_config,
113        })
114    }
115
116    /// Turn this [`HttpTrigger`] into an [`HttpServer`].
117    pub fn into_server<F: RuntimeFactors>(
118        self,
119        trigger_app: TriggerApp<F>,
120    ) -> anyhow::Result<Arc<HttpServer<F>>> {
121        let Self {
122            listen_addr,
123            tls_config,
124        } = self;
125        let server = Arc::new(HttpServer::new(listen_addr, tls_config, trigger_app)?);
126        Ok(server)
127    }
128
129    fn validate_app(app: &App) -> anyhow::Result<()> {
130        #[derive(Deserialize)]
131        #[serde(deny_unknown_fields)]
132        struct TriggerMetadata {
133            base: Option<String>,
134        }
135        if let Some(TriggerMetadata { base: Some(base) }) = app.get_trigger_metadata("http")? {
136            if base == "/" {
137                tracing::warn!("This application has the deprecated trigger 'base' set to the default value '/'. This may be an error in the future!");
138            } else {
139                bail!("This application is using the deprecated trigger 'base' field. The base must be prepended to each [[trigger.http]]'s 'route'.")
140            }
141        }
142        Ok(())
143    }
144}
145
146fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {
147    let addrs: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
148    // Prefer 127.0.0.1 over e.g. [::1] because CHANGE IS HARD
149    if let Some(addr) = addrs
150        .iter()
151        .find(|addr| addr.is_ipv4() && addr.ip() == Ipv4Addr::LOCALHOST)
152    {
153        return Ok(*addr);
154    }
155    // Otherwise, take the first addr (OS preference)
156    addrs.into_iter().next().context("couldn't resolve address")
157}
158
159#[derive(Debug, PartialEq)]
160enum NotFoundRouteKind {
161    Normal(String),
162    WellKnown,
163}
164
165/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
166pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
167    // If there's a source, we might be able to extract a wasi-http error from it.
168    if let Some(cause) = err.source() {
169        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
170            return err.clone();
171        }
172    }
173
174    tracing::warn!("hyper request error: {err:?}");
175
176    ErrorCode::HttpProtocolError
177}
178
179pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
180    ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
181        rcode: Some(rcode),
182        info_code: Some(info_code),
183    })
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn parse_listen_addr_prefers_ipv4() {
192        let addr = parse_listen_addr("localhost:12345").unwrap();
193        assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
194        assert_eq!(addr.port(), 12345);
195    }
196}