1use std::{
9 collections::HashMap,
10 sync::{Arc, Mutex},
11};
12
13use spin_connection_semaphore::{ConnectionPermit, ConnectionSemaphore};
14use wasmtime::component::{HasData, Resource};
15use wasmtime_wasi::p2::bindings::sockets::network::{
16 ErrorCode as SocketErrorCode, Host as NetworkHost, Network,
17};
18use wasmtime_wasi::p2::bindings::sockets::tcp::{self as p2_tcp, IpSocketAddress, ShutdownType};
19use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create;
20use wasmtime_wasi::p2::bindings::sockets::udp as p2_udp;
21use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create;
22use wasmtime_wasi::p2::{DynInputStream, DynOutputStream, DynPollable};
23use wasmtime_wasi::sockets::{TcpSocket, UdpSocket, WasiSocketsCtxView};
24
25pub struct SocketPermitState {
29 semaphore: ConnectionSemaphore,
30 active: Mutex<HashMap<u32, ConnectionPermit>>,
32}
33
34impl SocketPermitState {
35 pub fn new(semaphore: ConnectionSemaphore) -> Arc<Self> {
36 Arc::new(Self {
37 semaphore,
38 active: Mutex::new(HashMap::new()),
39 })
40 }
41}
42
43pub struct SpinSocketsView<'a> {
46 pub(crate) inner: WasiSocketsCtxView<'a>,
47 pub(crate) permit_state: Option<Arc<SocketPermitState>>,
48}
49
50impl<'a> std::ops::Deref for SpinSocketsView<'a> {
51 type Target = WasiSocketsCtxView<'a>;
52 fn deref(&self) -> &Self::Target {
53 &self.inner
54 }
55}
56
57impl std::ops::DerefMut for SpinSocketsView<'_> {
58 fn deref_mut(&mut self) -> &mut Self::Target {
59 &mut self.inner
60 }
61}
62
63pub struct SpinSockets;
67
68impl HasData for SpinSockets {
69 type Data<'a> = SpinSocketsView<'a>;
70}
71
72impl SpinSocketsView<'_> {
73 pub(crate) fn try_acquire(&self) -> Result<Option<ConnectionPermit>, ()> {
81 let Some(state) = &self.permit_state else {
82 return Ok(None);
83 };
84 state.semaphore.try_acquire().map(Some).ok_or(())
85 }
86
87 pub(crate) fn register_permit(&self, socket_rep: u32, permit: Option<ConnectionPermit>) {
90 let (Some(state), Some(permit)) = (&self.permit_state, permit) else {
91 return;
92 };
93 state
94 .active
95 .lock()
96 .unwrap_or_else(|e| e.into_inner())
97 .insert(socket_rep, permit);
98 }
99
100 pub(crate) fn release_permit(&self, socket_rep: u32) {
102 if let Some(state) = &self.permit_state {
103 state
104 .active
105 .lock()
106 .unwrap_or_else(|e| e.into_inner())
107 .remove(&socket_rep);
108 }
109 }
110}
111
112impl p2_tcp::Host for SpinSocketsView<'_> {}
113
114impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> {
115 async fn start_bind(
116 &mut self,
117 this: Resource<TcpSocket>,
118 network: Resource<Network>,
119 local_address: IpSocketAddress,
120 ) -> wasmtime_wasi::p2::SocketResult<()> {
121 p2_tcp::HostTcpSocket::start_bind(&mut self.inner, this, network, local_address).await
122 }
123
124 fn finish_bind(&mut self, this: Resource<TcpSocket>) -> wasmtime_wasi::p2::SocketResult<()> {
125 p2_tcp::HostTcpSocket::finish_bind(&mut self.inner, this)
126 }
127
128 async fn start_connect(
129 &mut self,
130 this: Resource<TcpSocket>,
131 network: Resource<Network>,
132 remote_address: IpSocketAddress,
133 ) -> wasmtime_wasi::p2::SocketResult<()> {
134 let socket_rep = this.rep();
135 let Ok(permit) = self.try_acquire() else {
140 tracing::warn!("TCP socket connection refused: connection quota exhausted");
141 return Err(SocketErrorCode::NewSocketLimit.into());
142 };
143 let result =
144 p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address)
145 .await;
146 if result.is_ok() {
147 self.register_permit(socket_rep, permit);
148 }
149 result
151 }
152
153 fn finish_connect(
154 &mut self,
155 this: Resource<TcpSocket>,
156 ) -> wasmtime_wasi::p2::SocketResult<(Resource<DynInputStream>, Resource<DynOutputStream>)>
157 {
158 p2_tcp::HostTcpSocket::finish_connect(&mut self.inner, this)
159 }
160
161 fn start_listen(&mut self, this: Resource<TcpSocket>) -> wasmtime_wasi::p2::SocketResult<()> {
162 p2_tcp::HostTcpSocket::start_listen(&mut self.inner, this)
163 }
164
165 fn finish_listen(&mut self, this: Resource<TcpSocket>) -> wasmtime_wasi::p2::SocketResult<()> {
166 p2_tcp::HostTcpSocket::finish_listen(&mut self.inner, this)
167 }
168
169 fn accept(
170 &mut self,
171 this: Resource<TcpSocket>,
172 ) -> wasmtime_wasi::p2::SocketResult<(
173 Resource<TcpSocket>,
174 Resource<DynInputStream>,
175 Resource<DynOutputStream>,
176 )> {
177 p2_tcp::HostTcpSocket::accept(&mut self.inner, this)
178 }
179
180 fn local_address(
181 &mut self,
182 this: Resource<TcpSocket>,
183 ) -> wasmtime_wasi::p2::SocketResult<IpSocketAddress> {
184 p2_tcp::HostTcpSocket::local_address(&mut self.inner, this)
185 }
186
187 fn remote_address(
188 &mut self,
189 this: Resource<TcpSocket>,
190 ) -> wasmtime_wasi::p2::SocketResult<IpSocketAddress> {
191 p2_tcp::HostTcpSocket::remote_address(&mut self.inner, this)
192 }
193
194 fn is_listening(&mut self, this: Resource<TcpSocket>) -> wasmtime::Result<bool> {
195 p2_tcp::HostTcpSocket::is_listening(&mut self.inner, this)
196 }
197
198 fn address_family(
199 &mut self,
200 this: Resource<TcpSocket>,
201 ) -> wasmtime::Result<wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily> {
202 p2_tcp::HostTcpSocket::address_family(&mut self.inner, this)
203 }
204
205 fn set_listen_backlog_size(
206 &mut self,
207 this: Resource<TcpSocket>,
208 value: u64,
209 ) -> wasmtime_wasi::p2::SocketResult<()> {
210 p2_tcp::HostTcpSocket::set_listen_backlog_size(&mut self.inner, this, value)
211 }
212
213 fn keep_alive_enabled(
214 &mut self,
215 this: Resource<TcpSocket>,
216 ) -> wasmtime_wasi::p2::SocketResult<bool> {
217 p2_tcp::HostTcpSocket::keep_alive_enabled(&mut self.inner, this)
218 }
219
220 fn set_keep_alive_enabled(
221 &mut self,
222 this: Resource<TcpSocket>,
223 value: bool,
224 ) -> wasmtime_wasi::p2::SocketResult<()> {
225 p2_tcp::HostTcpSocket::set_keep_alive_enabled(&mut self.inner, this, value)
226 }
227
228 fn keep_alive_idle_time(
229 &mut self,
230 this: Resource<TcpSocket>,
231 ) -> wasmtime_wasi::p2::SocketResult<u64> {
232 p2_tcp::HostTcpSocket::keep_alive_idle_time(&mut self.inner, this)
233 }
234
235 fn set_keep_alive_idle_time(
236 &mut self,
237 this: Resource<TcpSocket>,
238 value: u64,
239 ) -> wasmtime_wasi::p2::SocketResult<()> {
240 p2_tcp::HostTcpSocket::set_keep_alive_idle_time(&mut self.inner, this, value)
241 }
242
243 fn keep_alive_interval(
244 &mut self,
245 this: Resource<TcpSocket>,
246 ) -> wasmtime_wasi::p2::SocketResult<u64> {
247 p2_tcp::HostTcpSocket::keep_alive_interval(&mut self.inner, this)
248 }
249
250 fn set_keep_alive_interval(
251 &mut self,
252 this: Resource<TcpSocket>,
253 value: u64,
254 ) -> wasmtime_wasi::p2::SocketResult<()> {
255 p2_tcp::HostTcpSocket::set_keep_alive_interval(&mut self.inner, this, value)
256 }
257
258 fn keep_alive_count(
259 &mut self,
260 this: Resource<TcpSocket>,
261 ) -> wasmtime_wasi::p2::SocketResult<u32> {
262 p2_tcp::HostTcpSocket::keep_alive_count(&mut self.inner, this)
263 }
264
265 fn set_keep_alive_count(
266 &mut self,
267 this: Resource<TcpSocket>,
268 value: u32,
269 ) -> wasmtime_wasi::p2::SocketResult<()> {
270 p2_tcp::HostTcpSocket::set_keep_alive_count(&mut self.inner, this, value)
271 }
272
273 fn hop_limit(&mut self, this: Resource<TcpSocket>) -> wasmtime_wasi::p2::SocketResult<u8> {
274 p2_tcp::HostTcpSocket::hop_limit(&mut self.inner, this)
275 }
276
277 fn set_hop_limit(
278 &mut self,
279 this: Resource<TcpSocket>,
280 value: u8,
281 ) -> wasmtime_wasi::p2::SocketResult<()> {
282 p2_tcp::HostTcpSocket::set_hop_limit(&mut self.inner, this, value)
283 }
284
285 fn receive_buffer_size(
286 &mut self,
287 this: Resource<TcpSocket>,
288 ) -> wasmtime_wasi::p2::SocketResult<u64> {
289 p2_tcp::HostTcpSocket::receive_buffer_size(&mut self.inner, this)
290 }
291
292 fn set_receive_buffer_size(
293 &mut self,
294 this: Resource<TcpSocket>,
295 value: u64,
296 ) -> wasmtime_wasi::p2::SocketResult<()> {
297 p2_tcp::HostTcpSocket::set_receive_buffer_size(&mut self.inner, this, value)
298 }
299
300 fn send_buffer_size(
301 &mut self,
302 this: Resource<TcpSocket>,
303 ) -> wasmtime_wasi::p2::SocketResult<u64> {
304 p2_tcp::HostTcpSocket::send_buffer_size(&mut self.inner, this)
305 }
306
307 fn set_send_buffer_size(
308 &mut self,
309 this: Resource<TcpSocket>,
310 value: u64,
311 ) -> wasmtime_wasi::p2::SocketResult<()> {
312 p2_tcp::HostTcpSocket::set_send_buffer_size(&mut self.inner, this, value)
313 }
314
315 fn subscribe(&mut self, this: Resource<TcpSocket>) -> wasmtime::Result<Resource<DynPollable>> {
316 p2_tcp::HostTcpSocket::subscribe(&mut self.inner, this)
317 }
318
319 fn shutdown(
320 &mut self,
321 this: Resource<TcpSocket>,
322 shutdown_type: ShutdownType,
323 ) -> wasmtime_wasi::p2::SocketResult<()> {
324 p2_tcp::HostTcpSocket::shutdown(&mut self.inner, this, shutdown_type)
325 }
326
327 fn drop(&mut self, this: Resource<TcpSocket>) -> wasmtime::Result<()> {
328 self.release_permit(this.rep());
329 p2_tcp::HostTcpSocket::drop(&mut self.inner, this)
330 }
331}
332
333impl NetworkHost for SpinSocketsView<'_> {
334 fn convert_error_code(
335 &mut self,
336 error: wasmtime_wasi::p2::SocketError,
337 ) -> wasmtime::Result<wasmtime_wasi::p2::bindings::sockets::network::ErrorCode> {
338 NetworkHost::convert_error_code(&mut self.inner, error)
339 }
340
341 fn network_error_code(
342 &mut self,
343 err: Resource<wasmtime::Error>,
344 ) -> wasmtime::Result<Option<wasmtime_wasi::p2::bindings::sockets::network::ErrorCode>> {
345 NetworkHost::network_error_code(&mut self.inner, err)
346 }
347}
348
349impl wasmtime_wasi::p2::bindings::sockets::network::HostNetwork for SpinSocketsView<'_> {
350 fn drop(&mut self, this: Resource<Network>) -> wasmtime::Result<()> {
351 wasmtime_wasi::p2::bindings::sockets::network::HostNetwork::drop(&mut self.inner, this)
352 }
353}
354
355impl p2_tcp_create::Host for SpinSocketsView<'_> {
356 fn create_tcp_socket(
357 &mut self,
358 address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily,
359 ) -> wasmtime_wasi::p2::SocketResult<Resource<TcpSocket>> {
360 p2_tcp_create::Host::create_tcp_socket(&mut self.inner, address_family)
361 }
362}
363
364impl p2_udp::Host for SpinSocketsView<'_> {}
365
366impl p2_udp::HostUdpSocket for SpinSocketsView<'_> {
367 async fn start_bind(
368 &mut self,
369 this: Resource<p2_udp::UdpSocket>,
370 network: Resource<p2_udp::Network>,
371 local_address: p2_udp::IpSocketAddress,
372 ) -> wasmtime_wasi::p2::SocketResult<()> {
373 p2_udp::HostUdpSocket::start_bind(&mut self.inner, this, network, local_address).await
374 }
375
376 fn finish_bind(
377 &mut self,
378 this: Resource<p2_udp::UdpSocket>,
379 ) -> wasmtime_wasi::p2::SocketResult<()> {
380 p2_udp::HostUdpSocket::finish_bind(&mut self.inner, this)
381 }
382
383 async fn stream(
384 &mut self,
385 this: Resource<p2_udp::UdpSocket>,
386 remote_address: Option<p2_udp::IpSocketAddress>,
387 ) -> wasmtime_wasi::p2::SocketResult<(
388 Resource<p2_udp::IncomingDatagramStream>,
389 Resource<p2_udp::OutgoingDatagramStream>,
390 )> {
391 p2_udp::HostUdpSocket::stream(&mut self.inner, this, remote_address).await
392 }
393
394 fn local_address(
395 &mut self,
396 this: Resource<p2_udp::UdpSocket>,
397 ) -> wasmtime_wasi::p2::SocketResult<p2_udp::IpSocketAddress> {
398 p2_udp::HostUdpSocket::local_address(&mut self.inner, this)
399 }
400
401 fn remote_address(
402 &mut self,
403 this: Resource<p2_udp::UdpSocket>,
404 ) -> wasmtime_wasi::p2::SocketResult<p2_udp::IpSocketAddress> {
405 p2_udp::HostUdpSocket::remote_address(&mut self.inner, this)
406 }
407
408 fn address_family(
409 &mut self,
410 this: Resource<p2_udp::UdpSocket>,
411 ) -> wasmtime::Result<p2_udp::IpAddressFamily> {
412 p2_udp::HostUdpSocket::address_family(&mut self.inner, this)
413 }
414
415 fn unicast_hop_limit(
416 &mut self,
417 this: Resource<p2_udp::UdpSocket>,
418 ) -> wasmtime_wasi::p2::SocketResult<u8> {
419 p2_udp::HostUdpSocket::unicast_hop_limit(&mut self.inner, this)
420 }
421
422 fn set_unicast_hop_limit(
423 &mut self,
424 this: Resource<p2_udp::UdpSocket>,
425 value: u8,
426 ) -> wasmtime_wasi::p2::SocketResult<()> {
427 p2_udp::HostUdpSocket::set_unicast_hop_limit(&mut self.inner, this, value)
428 }
429
430 fn receive_buffer_size(
431 &mut self,
432 this: Resource<p2_udp::UdpSocket>,
433 ) -> wasmtime_wasi::p2::SocketResult<u64> {
434 p2_udp::HostUdpSocket::receive_buffer_size(&mut self.inner, this)
435 }
436
437 fn set_receive_buffer_size(
438 &mut self,
439 this: Resource<p2_udp::UdpSocket>,
440 value: u64,
441 ) -> wasmtime_wasi::p2::SocketResult<()> {
442 p2_udp::HostUdpSocket::set_receive_buffer_size(&mut self.inner, this, value)
443 }
444
445 fn send_buffer_size(
446 &mut self,
447 this: Resource<p2_udp::UdpSocket>,
448 ) -> wasmtime_wasi::p2::SocketResult<u64> {
449 p2_udp::HostUdpSocket::send_buffer_size(&mut self.inner, this)
450 }
451
452 fn set_send_buffer_size(
453 &mut self,
454 this: Resource<p2_udp::UdpSocket>,
455 value: u64,
456 ) -> wasmtime_wasi::p2::SocketResult<()> {
457 p2_udp::HostUdpSocket::set_send_buffer_size(&mut self.inner, this, value)
458 }
459
460 fn subscribe(
461 &mut self,
462 this: Resource<p2_udp::UdpSocket>,
463 ) -> wasmtime::Result<Resource<DynPollable>> {
464 p2_udp::HostUdpSocket::subscribe(&mut self.inner, this)
465 }
466
467 fn drop(&mut self, this: Resource<p2_udp::UdpSocket>) -> wasmtime::Result<()> {
468 self.release_permit(this.rep());
469 p2_udp::HostUdpSocket::drop(&mut self.inner, this)
470 }
471}
472
473impl p2_udp::HostIncomingDatagramStream for SpinSocketsView<'_> {
474 fn receive(
475 &mut self,
476 this: Resource<p2_udp::IncomingDatagramStream>,
477 max_results: u64,
478 ) -> wasmtime_wasi::p2::SocketResult<Vec<p2_udp::IncomingDatagram>> {
479 p2_udp::HostIncomingDatagramStream::receive(&mut self.inner, this, max_results)
480 }
481
482 fn subscribe(
483 &mut self,
484 this: Resource<p2_udp::IncomingDatagramStream>,
485 ) -> wasmtime::Result<Resource<DynPollable>> {
486 p2_udp::HostIncomingDatagramStream::subscribe(&mut self.inner, this)
487 }
488
489 fn drop(&mut self, this: Resource<p2_udp::IncomingDatagramStream>) -> wasmtime::Result<()> {
490 p2_udp::HostIncomingDatagramStream::drop(&mut self.inner, this)
491 }
492}
493
494impl p2_udp::HostOutgoingDatagramStream for SpinSocketsView<'_> {
495 fn check_send(
496 &mut self,
497 this: Resource<p2_udp::OutgoingDatagramStream>,
498 ) -> wasmtime_wasi::p2::SocketResult<u64> {
499 p2_udp::HostOutgoingDatagramStream::check_send(&mut self.inner, this)
500 }
501
502 async fn send(
503 &mut self,
504 this: Resource<p2_udp::OutgoingDatagramStream>,
505 datagrams: Vec<p2_udp::OutgoingDatagram>,
506 ) -> wasmtime_wasi::p2::SocketResult<u64> {
507 p2_udp::HostOutgoingDatagramStream::send(&mut self.inner, this, datagrams).await
508 }
509
510 fn subscribe(
511 &mut self,
512 this: Resource<p2_udp::OutgoingDatagramStream>,
513 ) -> wasmtime::Result<Resource<DynPollable>> {
514 p2_udp::HostOutgoingDatagramStream::subscribe(&mut self.inner, this)
515 }
516
517 fn drop(&mut self, this: Resource<p2_udp::OutgoingDatagramStream>) -> wasmtime::Result<()> {
518 p2_udp::HostOutgoingDatagramStream::drop(&mut self.inner, this)
519 }
520}
521
522impl p2_udp_create::Host for SpinSocketsView<'_> {
523 fn create_udp_socket(
524 &mut self,
525 address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily,
526 ) -> wasmtime_wasi::p2::SocketResult<Resource<UdpSocket>> {
527 let Ok(permit) = self.try_acquire() else {
531 tracing::warn!("UDP socket creation refused: connection quota exhausted");
532 return Err(SocketErrorCode::NewSocketLimit.into());
533 };
534 let sock = p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family)?;
535 self.register_permit(sock.rep(), permit);
536 Ok(sock)
537 }
538}