1use std::io::{self, Read, Write};
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll};
5
6use async_trait::async_trait;
7use spin_factors::anyhow;
8use tokio::io::{AsyncRead, AsyncWrite};
9use wasmtime_wasi::cli::{IsTerminal, StdinStream, StdoutStream};
10use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
11
12pub struct PipedWriteStream<T>(Arc<Mutex<T>>);
24
25impl<T> PipedWriteStream<T> {
26 pub fn new(inner: T) -> Self {
27 Self(Arc::new(Mutex::new(inner)))
28 }
29}
30
31impl<T> Clone for PipedWriteStream<T> {
32 fn clone(&self) -> Self {
33 Self(self.0.clone())
34 }
35}
36
37impl<T: Write + Send + Sync + 'static> OutputStream for PipedWriteStream<T> {
38 fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
39 self.0
40 .lock()
41 .unwrap()
42 .write_all(&bytes)
43 .map_err(|e| StreamError::LastOperationFailed(anyhow::anyhow!(e)))
44 }
45
46 fn flush(&mut self) -> Result<(), StreamError> {
47 self.0
48 .lock()
49 .unwrap()
50 .flush()
51 .map_err(|e| StreamError::LastOperationFailed(anyhow::anyhow!(e)))
52 }
53
54 fn check_write(&mut self) -> Result<usize, StreamError> {
55 Ok(1024 * 1024)
56 }
57}
58
59impl<T: Write + Send + Sync + 'static> AsyncWrite for PipedWriteStream<T> {
60 fn poll_write(
61 self: Pin<&mut Self>,
62 _cx: &mut Context<'_>,
63 buf: &[u8],
64 ) -> Poll<io::Result<usize>> {
65 Poll::Ready(self.0.lock().unwrap().write(buf))
66 }
67 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68 Poll::Ready(self.0.lock().unwrap().flush())
69 }
70 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71 Poll::Ready(Ok(()))
72 }
73}
74
75impl<T> IsTerminal for PipedWriteStream<T> {
76 fn is_terminal(&self) -> bool {
77 false
78 }
79}
80
81impl<T: Write + Send + Sync + 'static> StdoutStream for PipedWriteStream<T> {
82 fn p2_stream(&self) -> Box<dyn OutputStream> {
83 Box::new(self.clone())
84 }
85 fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
86 Box::new(self.clone())
87 }
88}
89
90#[async_trait]
91impl<T: Write + Send + Sync + 'static> Pollable for PipedWriteStream<T> {
92 async fn ready(&mut self) {}
93}
94
95pub struct PipeReadStream<T> {
99 buffer: Vec<u8>,
100 inner: Arc<Mutex<T>>,
101}
102
103impl<T> PipeReadStream<T> {
104 pub fn new(inner: T) -> Self {
105 Self {
106 buffer: vec![0_u8; 64 * 1024],
107 inner: Arc::new(Mutex::new(inner)),
108 }
109 }
110}
111
112impl<T> Clone for PipeReadStream<T> {
113 fn clone(&self) -> Self {
114 Self {
115 buffer: vec![0_u8; 64 * 1024],
116 inner: self.inner.clone(),
117 }
118 }
119}
120
121impl<T> IsTerminal for PipeReadStream<T> {
122 fn is_terminal(&self) -> bool {
123 false
124 }
125}
126
127impl<T: Read + Send + Sync + 'static> InputStream for PipeReadStream<T> {
128 fn read(&mut self, size: usize) -> wasmtime_wasi::p2::StreamResult<bytes::Bytes> {
129 let size = size.min(self.buffer.len());
130
131 let count = self
132 .inner
133 .lock()
134 .unwrap()
135 .read(&mut self.buffer[..size])
136 .map_err(|e| StreamError::LastOperationFailed(anyhow::anyhow!(e)))?;
137 if count == 0 {
138 return Err(wasmtime_wasi::p2::StreamError::Closed);
139 }
140
141 Ok(bytes::Bytes::copy_from_slice(&self.buffer[..count]))
142 }
143}
144
145impl<T: Read + Send + Sync + 'static> AsyncRead for PipeReadStream<T> {
146 fn poll_read(
147 self: Pin<&mut Self>,
148 _cx: &mut Context<'_>,
149 buf: &mut tokio::io::ReadBuf<'_>,
150 ) -> Poll<io::Result<()>> {
151 let result = self
152 .inner
153 .lock()
154 .unwrap()
155 .read(buf.initialize_unfilled())
156 .map(|n| buf.advance(n));
157 Poll::Ready(result)
158 }
159}
160
161#[async_trait]
162impl<T: Read + Send + Sync + 'static> Pollable for PipeReadStream<T> {
163 async fn ready(&mut self) {}
164}
165
166impl<T: Read + Send + Sync + 'static> StdinStream for PipeReadStream<T> {
167 fn p2_stream(&self) -> Box<dyn InputStream> {
168 Box::new(self.clone())
169 }
170
171 fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
172 Box::new(self.clone())
173 }
174}