spin_factor_wasi/
io.rs

1use std::io::{self, Read, Write};
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll};
5
6use async_trait::async_trait;
7use spin_factors::anyhow;
8use tokio::io::{AsyncRead, AsyncWrite};
9use wasmtime_wasi::cli::{IsTerminal, StdinStream, StdoutStream};
10use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
11
12/// A [`OutputStream`] that writes to a `Write` type.
13///
14/// `StdinStream::stream` and `StdoutStream::new` can be called more than once in components
15/// which are composed of multiple subcomponents, since each subcomponent will potentially want
16/// its own handle. This means the streams need to be shareable. The easiest way to do that is
17/// provide cloneable implementations of streams which operate synchronously.
18///
19/// Note that this amounts to doing synchronous I/O in an asynchronous context, which we'd normally
20/// prefer to avoid, but the properly asynchronous implementations Host{In|Out}putStream based on
21/// `AsyncRead`/`AsyncWrite`` are quite hairy and probably not worth it for "normal" stdio streams in
22/// Spin. If this does prove to be a performance bottleneck, though, we can certainly revisit it.
23pub struct PipedWriteStream<T>(Arc<Mutex<T>>);
24
25impl<T> PipedWriteStream<T> {
26    pub fn new(inner: T) -> Self {
27        Self(Arc::new(Mutex::new(inner)))
28    }
29}
30
31impl<T> Clone for PipedWriteStream<T> {
32    fn clone(&self) -> Self {
33        Self(self.0.clone())
34    }
35}
36
37impl<T: Write + Send + Sync + 'static> OutputStream for PipedWriteStream<T> {
38    fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
39        self.0
40            .lock()
41            .unwrap()
42            .write_all(&bytes)
43            .map_err(|e| StreamError::LastOperationFailed(anyhow::anyhow!(e)))
44    }
45
46    fn flush(&mut self) -> Result<(), StreamError> {
47        self.0
48            .lock()
49            .unwrap()
50            .flush()
51            .map_err(|e| StreamError::LastOperationFailed(anyhow::anyhow!(e)))
52    }
53
54    fn check_write(&mut self) -> Result<usize, StreamError> {
55        Ok(1024 * 1024)
56    }
57}
58
59impl<T: Write + Send + Sync + 'static> AsyncWrite for PipedWriteStream<T> {
60    fn poll_write(
61        self: Pin<&mut Self>,
62        _cx: &mut Context<'_>,
63        buf: &[u8],
64    ) -> Poll<io::Result<usize>> {
65        Poll::Ready(self.0.lock().unwrap().write(buf))
66    }
67    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68        Poll::Ready(self.0.lock().unwrap().flush())
69    }
70    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71        Poll::Ready(Ok(()))
72    }
73}
74
75impl<T> IsTerminal for PipedWriteStream<T> {
76    fn is_terminal(&self) -> bool {
77        false
78    }
79}
80
81impl<T: Write + Send + Sync + 'static> StdoutStream for PipedWriteStream<T> {
82    fn p2_stream(&self) -> Box<dyn OutputStream> {
83        Box::new(self.clone())
84    }
85    fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
86        Box::new(self.clone())
87    }
88}
89
90#[async_trait]
91impl<T: Write + Send + Sync + 'static> Pollable for PipedWriteStream<T> {
92    async fn ready(&mut self) {}
93}
94
95/// A [`InputStream`] that reads to a `Read` type.
96///
97/// See [`PipedWriteStream`] for more information on why this is synchronous.
98pub struct PipeReadStream<T> {
99    buffer: Vec<u8>,
100    inner: Arc<Mutex<T>>,
101}
102
103impl<T> PipeReadStream<T> {
104    pub fn new(inner: T) -> Self {
105        Self {
106            buffer: vec![0_u8; 64 * 1024],
107            inner: Arc::new(Mutex::new(inner)),
108        }
109    }
110}
111
112impl<T> Clone for PipeReadStream<T> {
113    fn clone(&self) -> Self {
114        Self {
115            buffer: vec![0_u8; 64 * 1024],
116            inner: self.inner.clone(),
117        }
118    }
119}
120
121impl<T> IsTerminal for PipeReadStream<T> {
122    fn is_terminal(&self) -> bool {
123        false
124    }
125}
126
127impl<T: Read + Send + Sync + 'static> InputStream for PipeReadStream<T> {
128    fn read(&mut self, size: usize) -> wasmtime_wasi::p2::StreamResult<bytes::Bytes> {
129        let size = size.min(self.buffer.len());
130
131        let count = self
132            .inner
133            .lock()
134            .unwrap()
135            .read(&mut self.buffer[..size])
136            .map_err(|e| StreamError::LastOperationFailed(anyhow::anyhow!(e)))?;
137        if count == 0 {
138            return Err(wasmtime_wasi::p2::StreamError::Closed);
139        }
140
141        Ok(bytes::Bytes::copy_from_slice(&self.buffer[..count]))
142    }
143}
144
145impl<T: Read + Send + Sync + 'static> AsyncRead for PipeReadStream<T> {
146    fn poll_read(
147        self: Pin<&mut Self>,
148        _cx: &mut Context<'_>,
149        buf: &mut tokio::io::ReadBuf<'_>,
150    ) -> Poll<io::Result<()>> {
151        let result = self
152            .inner
153            .lock()
154            .unwrap()
155            .read(buf.initialize_unfilled())
156            .map(|n| buf.advance(n));
157        Poll::Ready(result)
158    }
159}
160
161#[async_trait]
162impl<T: Read + Send + Sync + 'static> Pollable for PipeReadStream<T> {
163    async fn ready(&mut self) {}
164}
165
166impl<T: Read + Send + Sync + 'static> StdinStream for PipeReadStream<T> {
167    fn p2_stream(&self) -> Box<dyn InputStream> {
168        Box::new(self.clone())
169    }
170
171    fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
172        Box::new(self.clone())
173    }
174}