1use std::io::{self, Read, Write};
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll};
5
6use async_trait::async_trait;
7use tokio::io::{AsyncRead, AsyncWrite};
8use wasmtime_wasi::cli::{IsTerminal, StdinStream, StdoutStream};
9use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
10
11pub struct PipedWriteStream<T>(Arc<Mutex<T>>);
23
24impl<T> PipedWriteStream<T> {
25 pub fn new(inner: T) -> Self {
26 Self(Arc::new(Mutex::new(inner)))
27 }
28}
29
30impl<T> Clone for PipedWriteStream<T> {
31 fn clone(&self) -> Self {
32 Self(self.0.clone())
33 }
34}
35
36impl<T: Write + Send + Sync + 'static> OutputStream for PipedWriteStream<T> {
37 fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
38 self.0
39 .lock()
40 .unwrap()
41 .write_all(&bytes)
42 .map_err(|e| StreamError::LastOperationFailed(e.into()))
43 }
44
45 fn flush(&mut self) -> Result<(), StreamError> {
46 self.0
47 .lock()
48 .unwrap()
49 .flush()
50 .map_err(|e| StreamError::LastOperationFailed(e.into()))
51 }
52
53 fn check_write(&mut self) -> Result<usize, StreamError> {
54 Ok(1024 * 1024)
55 }
56}
57
58impl<T: Write + Send + Sync + 'static> AsyncWrite for PipedWriteStream<T> {
59 fn poll_write(
60 self: Pin<&mut Self>,
61 _cx: &mut Context<'_>,
62 buf: &[u8],
63 ) -> Poll<io::Result<usize>> {
64 Poll::Ready(self.0.lock().unwrap().write(buf))
65 }
66 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67 Poll::Ready(self.0.lock().unwrap().flush())
68 }
69 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
70 Poll::Ready(Ok(()))
71 }
72}
73
74impl<T> IsTerminal for PipedWriteStream<T> {
75 fn is_terminal(&self) -> bool {
76 false
77 }
78}
79
80impl<T: Write + Send + Sync + 'static> StdoutStream for PipedWriteStream<T> {
81 fn p2_stream(&self) -> Box<dyn OutputStream> {
82 Box::new(self.clone())
83 }
84 fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
85 Box::new(self.clone())
86 }
87}
88
89#[async_trait]
90impl<T: Write + Send + Sync + 'static> Pollable for PipedWriteStream<T> {
91 async fn ready(&mut self) {}
92}
93
94pub struct PipeReadStream<T> {
98 buffer: Vec<u8>,
99 inner: Arc<Mutex<T>>,
100}
101
102impl<T> PipeReadStream<T> {
103 pub fn new(inner: T) -> Self {
104 Self {
105 buffer: vec![0_u8; 64 * 1024],
106 inner: Arc::new(Mutex::new(inner)),
107 }
108 }
109}
110
111impl<T> Clone for PipeReadStream<T> {
112 fn clone(&self) -> Self {
113 Self {
114 buffer: vec![0_u8; 64 * 1024],
115 inner: self.inner.clone(),
116 }
117 }
118}
119
120impl<T> IsTerminal for PipeReadStream<T> {
121 fn is_terminal(&self) -> bool {
122 false
123 }
124}
125
126impl<T: Read + Send + Sync + 'static> InputStream for PipeReadStream<T> {
127 fn read(&mut self, size: usize) -> wasmtime_wasi::p2::StreamResult<bytes::Bytes> {
128 let size = size.min(self.buffer.len());
129
130 let count = self
131 .inner
132 .lock()
133 .unwrap()
134 .read(&mut self.buffer[..size])
135 .map_err(|e| StreamError::LastOperationFailed(e.into()))?;
136 if count == 0 {
137 return Err(wasmtime_wasi::p2::StreamError::Closed);
138 }
139
140 Ok(bytes::Bytes::copy_from_slice(&self.buffer[..count]))
141 }
142}
143
144impl<T: Read + Send + Sync + 'static> AsyncRead for PipeReadStream<T> {
145 fn poll_read(
146 self: Pin<&mut Self>,
147 _cx: &mut Context<'_>,
148 buf: &mut tokio::io::ReadBuf<'_>,
149 ) -> Poll<io::Result<()>> {
150 let result = self
151 .inner
152 .lock()
153 .unwrap()
154 .read(buf.initialize_unfilled())
155 .map(|n| buf.advance(n));
156 Poll::Ready(result)
157 }
158}
159
160#[async_trait]
161impl<T: Read + Send + Sync + 'static> Pollable for PipeReadStream<T> {
162 async fn ready(&mut self) {}
163}
164
165impl<T: Read + Send + Sync + 'static> StdinStream for PipeReadStream<T> {
166 fn p2_stream(&self) -> Box<dyn InputStream> {
167 Box::new(self.clone())
168 }
169
170 fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
171 Box::new(self.clone())
172 }
173}