Skip to main content

spin_factor_wasi/
io.rs

1use std::io::{self, Read, Write};
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll};
5
6use async_trait::async_trait;
7use tokio::io::{AsyncRead, AsyncWrite};
8use wasmtime_wasi::cli::{IsTerminal, StdinStream, StdoutStream};
9use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
10
11/// A [`OutputStream`] that writes to a `Write` type.
12///
13/// `StdinStream::stream` and `StdoutStream::new` can be called more than once in components
14/// which are composed of multiple subcomponents, since each subcomponent will potentially want
15/// its own handle. This means the streams need to be shareable. The easiest way to do that is
16/// provide cloneable implementations of streams which operate synchronously.
17///
18/// Note that this amounts to doing synchronous I/O in an asynchronous context, which we'd normally
19/// prefer to avoid, but the properly asynchronous implementations Host{In|Out}putStream based on
20/// `AsyncRead`/`AsyncWrite`` are quite hairy and probably not worth it for "normal" stdio streams in
21/// Spin. If this does prove to be a performance bottleneck, though, we can certainly revisit it.
22pub struct PipedWriteStream<T>(Arc<Mutex<T>>);
23
24impl<T> PipedWriteStream<T> {
25    pub fn new(inner: T) -> Self {
26        Self(Arc::new(Mutex::new(inner)))
27    }
28}
29
30impl<T> Clone for PipedWriteStream<T> {
31    fn clone(&self) -> Self {
32        Self(self.0.clone())
33    }
34}
35
36impl<T: Write + Send + Sync + 'static> OutputStream for PipedWriteStream<T> {
37    fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
38        self.0
39            .lock()
40            .unwrap()
41            .write_all(&bytes)
42            .map_err(|e| StreamError::LastOperationFailed(e.into()))
43    }
44
45    fn flush(&mut self) -> Result<(), StreamError> {
46        self.0
47            .lock()
48            .unwrap()
49            .flush()
50            .map_err(|e| StreamError::LastOperationFailed(e.into()))
51    }
52
53    fn check_write(&mut self) -> Result<usize, StreamError> {
54        Ok(1024 * 1024)
55    }
56}
57
58impl<T: Write + Send + Sync + 'static> AsyncWrite for PipedWriteStream<T> {
59    fn poll_write(
60        self: Pin<&mut Self>,
61        _cx: &mut Context<'_>,
62        buf: &[u8],
63    ) -> Poll<io::Result<usize>> {
64        Poll::Ready(self.0.lock().unwrap().write(buf))
65    }
66    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67        Poll::Ready(self.0.lock().unwrap().flush())
68    }
69    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
70        Poll::Ready(Ok(()))
71    }
72}
73
74impl<T> IsTerminal for PipedWriteStream<T> {
75    fn is_terminal(&self) -> bool {
76        false
77    }
78}
79
80impl<T: Write + Send + Sync + 'static> StdoutStream for PipedWriteStream<T> {
81    fn p2_stream(&self) -> Box<dyn OutputStream> {
82        Box::new(self.clone())
83    }
84    fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
85        Box::new(self.clone())
86    }
87}
88
89#[async_trait]
90impl<T: Write + Send + Sync + 'static> Pollable for PipedWriteStream<T> {
91    async fn ready(&mut self) {}
92}
93
94/// A [`InputStream`] that reads to a `Read` type.
95///
96/// See [`PipedWriteStream`] for more information on why this is synchronous.
97pub struct PipeReadStream<T> {
98    buffer: Vec<u8>,
99    inner: Arc<Mutex<T>>,
100}
101
102impl<T> PipeReadStream<T> {
103    pub fn new(inner: T) -> Self {
104        Self {
105            buffer: vec![0_u8; 64 * 1024],
106            inner: Arc::new(Mutex::new(inner)),
107        }
108    }
109}
110
111impl<T> Clone for PipeReadStream<T> {
112    fn clone(&self) -> Self {
113        Self {
114            buffer: vec![0_u8; 64 * 1024],
115            inner: self.inner.clone(),
116        }
117    }
118}
119
120impl<T> IsTerminal for PipeReadStream<T> {
121    fn is_terminal(&self) -> bool {
122        false
123    }
124}
125
126impl<T: Read + Send + Sync + 'static> InputStream for PipeReadStream<T> {
127    fn read(&mut self, size: usize) -> wasmtime_wasi::p2::StreamResult<bytes::Bytes> {
128        let size = size.min(self.buffer.len());
129
130        let count = self
131            .inner
132            .lock()
133            .unwrap()
134            .read(&mut self.buffer[..size])
135            .map_err(|e| StreamError::LastOperationFailed(e.into()))?;
136        if count == 0 {
137            return Err(wasmtime_wasi::p2::StreamError::Closed);
138        }
139
140        Ok(bytes::Bytes::copy_from_slice(&self.buffer[..count]))
141    }
142}
143
144impl<T: Read + Send + Sync + 'static> AsyncRead for PipeReadStream<T> {
145    fn poll_read(
146        self: Pin<&mut Self>,
147        _cx: &mut Context<'_>,
148        buf: &mut tokio::io::ReadBuf<'_>,
149    ) -> Poll<io::Result<()>> {
150        let result = self
151            .inner
152            .lock()
153            .unwrap()
154            .read(buf.initialize_unfilled())
155            .map(|n| buf.advance(n));
156        Poll::Ready(result)
157    }
158}
159
160#[async_trait]
161impl<T: Read + Send + Sync + 'static> Pollable for PipeReadStream<T> {
162    async fn ready(&mut self) {}
163}
164
165impl<T: Read + Send + Sync + 'static> StdinStream for PipeReadStream<T> {
166    fn p2_stream(&self) -> Box<dyn InputStream> {
167        Box::new(self.clone())
168    }
169
170    fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
171        Box::new(self.clone())
172    }
173}