spin_trigger_http/
lib.rs

1//! Implementation for the Spin HTTP engine.
2
3mod headers;
4mod instrument;
5mod outbound_http;
6mod server;
7mod spin;
8mod tls;
9mod wagi;
10mod wasi;
11mod wasip3;
12
13use std::{
14    error::Error,
15    net::{Ipv4Addr, SocketAddr, ToSocketAddrs},
16    path::PathBuf,
17    sync::Arc,
18};
19
20use anyhow::{bail, Context};
21use clap::Args;
22use serde::Deserialize;
23use spin_app::App;
24use spin_factors::RuntimeFactors;
25use spin_trigger::Trigger;
26use wasmtime_wasi_http::bindings::http::types::ErrorCode;
27
28pub use server::HttpServer;
29
30pub use tls::TlsConfig;
31
32pub(crate) use wasmtime_wasi_http::body::HyperIncomingBody as Body;
33
34/// A [`spin_trigger::TriggerApp`] for the HTTP trigger.
35pub(crate) type TriggerApp<F> = spin_trigger::TriggerApp<HttpTrigger, F>;
36
37/// A [`spin_trigger::TriggerInstanceBuilder`] for the HTTP trigger.
38pub(crate) type TriggerInstanceBuilder<'a, F> =
39    spin_trigger::TriggerInstanceBuilder<'a, HttpTrigger, F>;
40
41#[derive(Args)]
42pub struct CliArgs {
43    /// IP address and port to listen on
44    #[clap(long = "listen", env = "SPIN_HTTP_LISTEN_ADDR", default_value = "127.0.0.1:3000", value_parser = parse_listen_addr)]
45    pub address: SocketAddr,
46
47    /// The path to the certificate to use for https, if this is not set, normal http will be used. The cert should be in PEM format
48    #[clap(long, env = "SPIN_TLS_CERT", requires = "tls-key")]
49    pub tls_cert: Option<PathBuf>,
50
51    /// The path to the certificate key to use for https, if this is not set, normal http will be used. The key should be in PKCS#8 format
52    #[clap(long, env = "SPIN_TLS_KEY", requires = "tls-cert")]
53    pub tls_key: Option<PathBuf>,
54
55    /// Sets the maximum buffer size (in bytes) for the HTTP connection. The minimum value allowed is 8192.
56    #[clap(long, env = "SPIN_HTTP1_MAX_BUF_SIZE")]
57    pub http1_max_buf_size: Option<usize>,
58
59    #[clap(long = "find-free-port")]
60    pub find_free_port: bool,
61}
62
63impl CliArgs {
64    fn into_tls_config(self) -> Option<TlsConfig> {
65        match (self.tls_cert, self.tls_key) {
66            (Some(cert_path), Some(key_path)) => Some(TlsConfig {
67                cert_path,
68                key_path,
69            }),
70            (None, None) => None,
71            _ => unreachable!(),
72        }
73    }
74}
75
76/// The Spin HTTP trigger.
77pub struct HttpTrigger {
78    /// The address the server should listen on.
79    ///
80    /// Note that this might not be the actual socket address that ends up being bound to.
81    /// If the port is set to 0, the actual address will be determined by the OS.
82    listen_addr: SocketAddr,
83    tls_config: Option<TlsConfig>,
84    find_free_port: bool,
85    http1_max_buf_size: Option<usize>,
86}
87
88impl<F: RuntimeFactors> Trigger<F> for HttpTrigger {
89    const TYPE: &'static str = "http";
90
91    type CliArgs = CliArgs;
92    type InstanceState = ();
93
94    fn new(cli_args: Self::CliArgs, app: &spin_app::App) -> anyhow::Result<Self> {
95        let find_free_port = cli_args.find_free_port;
96        let http1_max_buf_size = cli_args.http1_max_buf_size;
97
98        Self::new(
99            app,
100            cli_args.address,
101            cli_args.into_tls_config(),
102            find_free_port,
103            http1_max_buf_size,
104        )
105    }
106
107    async fn run(self, trigger_app: TriggerApp<F>) -> anyhow::Result<()> {
108        let server = self.into_server(trigger_app)?;
109
110        server.serve().await?;
111
112        Ok(())
113    }
114
115    fn supported_host_requirements() -> Vec<&'static str> {
116        vec![spin_app::locked::SERVICE_CHAINING_KEY]
117    }
118}
119
120impl HttpTrigger {
121    /// Create a new `HttpTrigger`.
122    pub fn new(
123        app: &spin_app::App,
124        listen_addr: SocketAddr,
125        tls_config: Option<TlsConfig>,
126        find_free_port: bool,
127        http1_max_buf_size: Option<usize>,
128    ) -> anyhow::Result<Self> {
129        Self::validate_app(app)?;
130
131        Ok(Self {
132            listen_addr,
133            tls_config,
134            find_free_port,
135            http1_max_buf_size,
136        })
137    }
138
139    /// Turn this [`HttpTrigger`] into an [`HttpServer`].
140    pub fn into_server<F: RuntimeFactors>(
141        self,
142        trigger_app: TriggerApp<F>,
143    ) -> anyhow::Result<Arc<HttpServer<F>>> {
144        let Self {
145            listen_addr,
146            tls_config,
147            find_free_port,
148            http1_max_buf_size,
149        } = self;
150        let server = Arc::new(HttpServer::new(
151            listen_addr,
152            tls_config,
153            find_free_port,
154            trigger_app,
155            http1_max_buf_size,
156        )?);
157        Ok(server)
158    }
159
160    fn validate_app(app: &App) -> anyhow::Result<()> {
161        #[derive(Deserialize)]
162        #[serde(deny_unknown_fields)]
163        struct TriggerMetadata {
164            base: Option<String>,
165        }
166        if let Some(TriggerMetadata { base: Some(base) }) = app.get_trigger_metadata("http")? {
167            if base == "/" {
168                tracing::warn!(
169                    "This application has the deprecated trigger 'base' set to the default value '/'. This may be an error in the future!"
170                );
171            } else {
172                bail!(
173                    "This application is using the deprecated trigger 'base' field. The base must be prepended to each [[trigger.http]]'s 'route'."
174                )
175            }
176        }
177        Ok(())
178    }
179}
180
181fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {
182    let addrs: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
183    // Prefer 127.0.0.1 over e.g. [::1] because CHANGE IS HARD
184    if let Some(addr) = addrs
185        .iter()
186        .find(|addr| addr.is_ipv4() && addr.ip() == Ipv4Addr::LOCALHOST)
187    {
188        return Ok(*addr);
189    }
190    // Otherwise, take the first addr (OS preference)
191    addrs.into_iter().next().context("couldn't resolve address")
192}
193
194#[derive(Debug, PartialEq)]
195enum NotFoundRouteKind {
196    Normal(String),
197    WellKnown,
198}
199
200/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
201pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
202    // If there's a source, we might be able to extract a wasi-http error from it.
203    if let Some(cause) = err.source() {
204        if let Some(err) = cause.downcast_ref::<ErrorCode>() {
205            return err.clone();
206        }
207    }
208
209    tracing::warn!("hyper request error: {err:?}");
210
211    ErrorCode::HttpProtocolError
212}
213
214pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
215    ErrorCode::DnsError(wasmtime_wasi_http::bindings::http::types::DnsErrorPayload {
216        rcode: Some(rcode),
217        info_code: Some(info_code),
218    })
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn parse_listen_addr_prefers_ipv4() {
227        let addr = parse_listen_addr("localhost:12345").unwrap();
228        assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
229        assert_eq!(addr.port(), 12345);
230    }
231}