Skip to main content

spin_factor_wasi/
sockets.rs

1//! Socket quota tracking and WASI socket host implementations.
2//!
3//! This module provides [`SocketPermitState`], [`SpinSocketsView`], and
4//! [`SpinSockets`] — the types needed to intercept WASI TCP/UDP socket
5//! creation and enforce a per-app cap on the number of concurrently open
6//! sockets.
7
8use std::{
9    collections::HashMap,
10    sync::{Arc, Mutex},
11};
12
13use spin_connection_semaphore::{ConnectionPermit, ConnectionSemaphore};
14use wasmtime::component::{HasData, Resource};
15use wasmtime_wasi::p2::bindings::sockets::network::{
16    ErrorCode as SocketErrorCode, Host as NetworkHost, Network,
17};
18use wasmtime_wasi::p2::bindings::sockets::tcp::{self as p2_tcp, IpSocketAddress, ShutdownType};
19use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create;
20use wasmtime_wasi::p2::bindings::sockets::udp as p2_udp;
21use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create;
22use wasmtime_wasi::p2::{DynInputStream, DynOutputStream, DynPollable};
23use wasmtime_wasi::sockets::{TcpSocket, UdpSocket, WasiSocketsCtxView};
24
25/// Shared state for tracking per-socket semaphore permits. Permits are
26/// acquired when a socket is allocated (at `start_connect` for TCP, at
27/// `create_udp_socket` for UDP) and released when the socket resource is dropped.
28pub struct SocketPermitState {
29    semaphore: ConnectionSemaphore,
30    /// Active permits keyed by socket resource rep, released when the resource is dropped.
31    active: Mutex<HashMap<u32, ConnectionPermit>>,
32}
33
34impl SocketPermitState {
35    pub fn new(semaphore: ConnectionSemaphore) -> Arc<Self> {
36        Arc::new(Self {
37            semaphore,
38            active: Mutex::new(HashMap::new()),
39        })
40    }
41}
42
43/// A view over WASI socket state that carries an optional per-instance socket
44/// permit store, enabling per-connection quota tracking.
45pub struct SpinSocketsView<'a> {
46    pub(crate) inner: WasiSocketsCtxView<'a>,
47    pub(crate) permit_state: Option<Arc<SocketPermitState>>,
48}
49
50impl<'a> std::ops::Deref for SpinSocketsView<'a> {
51    type Target = WasiSocketsCtxView<'a>;
52    fn deref(&self) -> &Self::Target {
53        &self.inner
54    }
55}
56
57impl std::ops::DerefMut for SpinSocketsView<'_> {
58    fn deref_mut(&mut self) -> &mut Self::Target {
59        &mut self.inner
60    }
61}
62
63/// [`HasData`] accessor for [`SpinSocketsView`], used in place of [`WasiSockets`]
64/// when registering TCP socket bindings so that `start_connect` and `drop` can
65/// participate in socket quota tracking.
66pub struct SpinSockets;
67
68impl HasData for SpinSockets {
69    type Data<'a> = SpinSocketsView<'a>;
70}
71
72impl SpinSocketsView<'_> {
73    /// Attempts to acquire a connection permit from the semaphore.
74    ///
75    /// Returns `Ok(None)` when no quota is configured, `Ok(Some(permit))` on
76    /// success, or `Err(())` when the quota is exhausted.
77    ///
78    /// The returned permit is unregistered — call [`Self::register_permit`] once
79    /// the socket resource rep is known to tie its lifetime to the socket.
80    pub(crate) fn try_acquire(&self) -> Result<Option<ConnectionPermit>, ()> {
81        let Some(state) = &self.permit_state else {
82            return Ok(None);
83        };
84        state.semaphore.try_acquire().map(Some).ok_or(())
85    }
86
87    /// Registers `permit` under `socket_rep` so it is held until the socket is
88    /// dropped. No-op when `permit` is `None` (no quota configured).
89    pub(crate) fn register_permit(&self, socket_rep: u32, permit: Option<ConnectionPermit>) {
90        let (Some(state), Some(permit)) = (&self.permit_state, permit) else {
91            return;
92        };
93        state
94            .active
95            .lock()
96            .unwrap_or_else(|e| e.into_inner())
97            .insert(socket_rep, permit);
98    }
99
100    /// Releases the connection permit for `socket_rep`, if any.
101    pub(crate) fn release_permit(&self, socket_rep: u32) {
102        if let Some(state) = &self.permit_state {
103            state
104                .active
105                .lock()
106                .unwrap_or_else(|e| e.into_inner())
107                .remove(&socket_rep);
108        }
109    }
110}
111
112impl p2_tcp::Host for SpinSocketsView<'_> {}
113
114impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> {
115    async fn start_bind(
116        &mut self,
117        this: Resource<TcpSocket>,
118        network: Resource<Network>,
119        local_address: IpSocketAddress,
120    ) -> wasmtime_wasi::p2::SocketResult<()> {
121        p2_tcp::HostTcpSocket::start_bind(&mut self.inner, this, network, local_address).await
122    }
123
124    fn finish_bind(&mut self, this: Resource<TcpSocket>) -> wasmtime_wasi::p2::SocketResult<()> {
125        p2_tcp::HostTcpSocket::finish_bind(&mut self.inner, this)
126    }
127
128    async fn start_connect(
129        &mut self,
130        this: Resource<TcpSocket>,
131        network: Resource<Network>,
132        remote_address: IpSocketAddress,
133    ) -> wasmtime_wasi::p2::SocketResult<()> {
134        let socket_rep = this.rep();
135        // Unlike outbound HTTP (which queues when its permit pool is exhausted),
136        // sockets fail immediately. Waiting would risk deadlock if a component
137        // holds sockets open across async yield points, and raw-socket callers
138        // are better positioned to implement their own retry logic.
139        let Ok(permit) = self.try_acquire() else {
140            tracing::warn!("TCP socket connection refused: connection quota exhausted");
141            return Err(SocketErrorCode::NewSocketLimit.into());
142        };
143        let result =
144            p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address)
145                .await;
146        if result.is_ok() {
147            self.register_permit(socket_rep, permit);
148        }
149        // On error, `permit` is dropped here, automatically releasing the semaphore slot.
150        result
151    }
152
153    fn finish_connect(
154        &mut self,
155        this: Resource<TcpSocket>,
156    ) -> wasmtime_wasi::p2::SocketResult<(Resource<DynInputStream>, Resource<DynOutputStream>)>
157    {
158        p2_tcp::HostTcpSocket::finish_connect(&mut self.inner, this)
159    }
160
161    fn start_listen(&mut self, this: Resource<TcpSocket>) -> wasmtime_wasi::p2::SocketResult<()> {
162        p2_tcp::HostTcpSocket::start_listen(&mut self.inner, this)
163    }
164
165    fn finish_listen(&mut self, this: Resource<TcpSocket>) -> wasmtime_wasi::p2::SocketResult<()> {
166        p2_tcp::HostTcpSocket::finish_listen(&mut self.inner, this)
167    }
168
169    fn accept(
170        &mut self,
171        this: Resource<TcpSocket>,
172    ) -> wasmtime_wasi::p2::SocketResult<(
173        Resource<TcpSocket>,
174        Resource<DynInputStream>,
175        Resource<DynOutputStream>,
176    )> {
177        p2_tcp::HostTcpSocket::accept(&mut self.inner, this)
178    }
179
180    fn local_address(
181        &mut self,
182        this: Resource<TcpSocket>,
183    ) -> wasmtime_wasi::p2::SocketResult<IpSocketAddress> {
184        p2_tcp::HostTcpSocket::local_address(&mut self.inner, this)
185    }
186
187    fn remote_address(
188        &mut self,
189        this: Resource<TcpSocket>,
190    ) -> wasmtime_wasi::p2::SocketResult<IpSocketAddress> {
191        p2_tcp::HostTcpSocket::remote_address(&mut self.inner, this)
192    }
193
194    fn is_listening(&mut self, this: Resource<TcpSocket>) -> wasmtime::Result<bool> {
195        p2_tcp::HostTcpSocket::is_listening(&mut self.inner, this)
196    }
197
198    fn address_family(
199        &mut self,
200        this: Resource<TcpSocket>,
201    ) -> wasmtime::Result<wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily> {
202        p2_tcp::HostTcpSocket::address_family(&mut self.inner, this)
203    }
204
205    fn set_listen_backlog_size(
206        &mut self,
207        this: Resource<TcpSocket>,
208        value: u64,
209    ) -> wasmtime_wasi::p2::SocketResult<()> {
210        p2_tcp::HostTcpSocket::set_listen_backlog_size(&mut self.inner, this, value)
211    }
212
213    fn keep_alive_enabled(
214        &mut self,
215        this: Resource<TcpSocket>,
216    ) -> wasmtime_wasi::p2::SocketResult<bool> {
217        p2_tcp::HostTcpSocket::keep_alive_enabled(&mut self.inner, this)
218    }
219
220    fn set_keep_alive_enabled(
221        &mut self,
222        this: Resource<TcpSocket>,
223        value: bool,
224    ) -> wasmtime_wasi::p2::SocketResult<()> {
225        p2_tcp::HostTcpSocket::set_keep_alive_enabled(&mut self.inner, this, value)
226    }
227
228    fn keep_alive_idle_time(
229        &mut self,
230        this: Resource<TcpSocket>,
231    ) -> wasmtime_wasi::p2::SocketResult<u64> {
232        p2_tcp::HostTcpSocket::keep_alive_idle_time(&mut self.inner, this)
233    }
234
235    fn set_keep_alive_idle_time(
236        &mut self,
237        this: Resource<TcpSocket>,
238        value: u64,
239    ) -> wasmtime_wasi::p2::SocketResult<()> {
240        p2_tcp::HostTcpSocket::set_keep_alive_idle_time(&mut self.inner, this, value)
241    }
242
243    fn keep_alive_interval(
244        &mut self,
245        this: Resource<TcpSocket>,
246    ) -> wasmtime_wasi::p2::SocketResult<u64> {
247        p2_tcp::HostTcpSocket::keep_alive_interval(&mut self.inner, this)
248    }
249
250    fn set_keep_alive_interval(
251        &mut self,
252        this: Resource<TcpSocket>,
253        value: u64,
254    ) -> wasmtime_wasi::p2::SocketResult<()> {
255        p2_tcp::HostTcpSocket::set_keep_alive_interval(&mut self.inner, this, value)
256    }
257
258    fn keep_alive_count(
259        &mut self,
260        this: Resource<TcpSocket>,
261    ) -> wasmtime_wasi::p2::SocketResult<u32> {
262        p2_tcp::HostTcpSocket::keep_alive_count(&mut self.inner, this)
263    }
264
265    fn set_keep_alive_count(
266        &mut self,
267        this: Resource<TcpSocket>,
268        value: u32,
269    ) -> wasmtime_wasi::p2::SocketResult<()> {
270        p2_tcp::HostTcpSocket::set_keep_alive_count(&mut self.inner, this, value)
271    }
272
273    fn hop_limit(&mut self, this: Resource<TcpSocket>) -> wasmtime_wasi::p2::SocketResult<u8> {
274        p2_tcp::HostTcpSocket::hop_limit(&mut self.inner, this)
275    }
276
277    fn set_hop_limit(
278        &mut self,
279        this: Resource<TcpSocket>,
280        value: u8,
281    ) -> wasmtime_wasi::p2::SocketResult<()> {
282        p2_tcp::HostTcpSocket::set_hop_limit(&mut self.inner, this, value)
283    }
284
285    fn receive_buffer_size(
286        &mut self,
287        this: Resource<TcpSocket>,
288    ) -> wasmtime_wasi::p2::SocketResult<u64> {
289        p2_tcp::HostTcpSocket::receive_buffer_size(&mut self.inner, this)
290    }
291
292    fn set_receive_buffer_size(
293        &mut self,
294        this: Resource<TcpSocket>,
295        value: u64,
296    ) -> wasmtime_wasi::p2::SocketResult<()> {
297        p2_tcp::HostTcpSocket::set_receive_buffer_size(&mut self.inner, this, value)
298    }
299
300    fn send_buffer_size(
301        &mut self,
302        this: Resource<TcpSocket>,
303    ) -> wasmtime_wasi::p2::SocketResult<u64> {
304        p2_tcp::HostTcpSocket::send_buffer_size(&mut self.inner, this)
305    }
306
307    fn set_send_buffer_size(
308        &mut self,
309        this: Resource<TcpSocket>,
310        value: u64,
311    ) -> wasmtime_wasi::p2::SocketResult<()> {
312        p2_tcp::HostTcpSocket::set_send_buffer_size(&mut self.inner, this, value)
313    }
314
315    fn subscribe(&mut self, this: Resource<TcpSocket>) -> wasmtime::Result<Resource<DynPollable>> {
316        p2_tcp::HostTcpSocket::subscribe(&mut self.inner, this)
317    }
318
319    fn shutdown(
320        &mut self,
321        this: Resource<TcpSocket>,
322        shutdown_type: ShutdownType,
323    ) -> wasmtime_wasi::p2::SocketResult<()> {
324        p2_tcp::HostTcpSocket::shutdown(&mut self.inner, this, shutdown_type)
325    }
326
327    fn drop(&mut self, this: Resource<TcpSocket>) -> wasmtime::Result<()> {
328        self.release_permit(this.rep());
329        p2_tcp::HostTcpSocket::drop(&mut self.inner, this)
330    }
331}
332
333impl NetworkHost for SpinSocketsView<'_> {
334    fn convert_error_code(
335        &mut self,
336        error: wasmtime_wasi::p2::SocketError,
337    ) -> wasmtime::Result<wasmtime_wasi::p2::bindings::sockets::network::ErrorCode> {
338        NetworkHost::convert_error_code(&mut self.inner, error)
339    }
340
341    fn network_error_code(
342        &mut self,
343        err: Resource<wasmtime::Error>,
344    ) -> wasmtime::Result<Option<wasmtime_wasi::p2::bindings::sockets::network::ErrorCode>> {
345        NetworkHost::network_error_code(&mut self.inner, err)
346    }
347}
348
349impl wasmtime_wasi::p2::bindings::sockets::network::HostNetwork for SpinSocketsView<'_> {
350    fn drop(&mut self, this: Resource<Network>) -> wasmtime::Result<()> {
351        wasmtime_wasi::p2::bindings::sockets::network::HostNetwork::drop(&mut self.inner, this)
352    }
353}
354
355impl p2_tcp_create::Host for SpinSocketsView<'_> {
356    fn create_tcp_socket(
357        &mut self,
358        address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily,
359    ) -> wasmtime_wasi::p2::SocketResult<Resource<TcpSocket>> {
360        p2_tcp_create::Host::create_tcp_socket(&mut self.inner, address_family)
361    }
362}
363
364impl p2_udp::Host for SpinSocketsView<'_> {}
365
366impl p2_udp::HostUdpSocket for SpinSocketsView<'_> {
367    async fn start_bind(
368        &mut self,
369        this: Resource<p2_udp::UdpSocket>,
370        network: Resource<p2_udp::Network>,
371        local_address: p2_udp::IpSocketAddress,
372    ) -> wasmtime_wasi::p2::SocketResult<()> {
373        p2_udp::HostUdpSocket::start_bind(&mut self.inner, this, network, local_address).await
374    }
375
376    fn finish_bind(
377        &mut self,
378        this: Resource<p2_udp::UdpSocket>,
379    ) -> wasmtime_wasi::p2::SocketResult<()> {
380        p2_udp::HostUdpSocket::finish_bind(&mut self.inner, this)
381    }
382
383    async fn stream(
384        &mut self,
385        this: Resource<p2_udp::UdpSocket>,
386        remote_address: Option<p2_udp::IpSocketAddress>,
387    ) -> wasmtime_wasi::p2::SocketResult<(
388        Resource<p2_udp::IncomingDatagramStream>,
389        Resource<p2_udp::OutgoingDatagramStream>,
390    )> {
391        p2_udp::HostUdpSocket::stream(&mut self.inner, this, remote_address).await
392    }
393
394    fn local_address(
395        &mut self,
396        this: Resource<p2_udp::UdpSocket>,
397    ) -> wasmtime_wasi::p2::SocketResult<p2_udp::IpSocketAddress> {
398        p2_udp::HostUdpSocket::local_address(&mut self.inner, this)
399    }
400
401    fn remote_address(
402        &mut self,
403        this: Resource<p2_udp::UdpSocket>,
404    ) -> wasmtime_wasi::p2::SocketResult<p2_udp::IpSocketAddress> {
405        p2_udp::HostUdpSocket::remote_address(&mut self.inner, this)
406    }
407
408    fn address_family(
409        &mut self,
410        this: Resource<p2_udp::UdpSocket>,
411    ) -> wasmtime::Result<p2_udp::IpAddressFamily> {
412        p2_udp::HostUdpSocket::address_family(&mut self.inner, this)
413    }
414
415    fn unicast_hop_limit(
416        &mut self,
417        this: Resource<p2_udp::UdpSocket>,
418    ) -> wasmtime_wasi::p2::SocketResult<u8> {
419        p2_udp::HostUdpSocket::unicast_hop_limit(&mut self.inner, this)
420    }
421
422    fn set_unicast_hop_limit(
423        &mut self,
424        this: Resource<p2_udp::UdpSocket>,
425        value: u8,
426    ) -> wasmtime_wasi::p2::SocketResult<()> {
427        p2_udp::HostUdpSocket::set_unicast_hop_limit(&mut self.inner, this, value)
428    }
429
430    fn receive_buffer_size(
431        &mut self,
432        this: Resource<p2_udp::UdpSocket>,
433    ) -> wasmtime_wasi::p2::SocketResult<u64> {
434        p2_udp::HostUdpSocket::receive_buffer_size(&mut self.inner, this)
435    }
436
437    fn set_receive_buffer_size(
438        &mut self,
439        this: Resource<p2_udp::UdpSocket>,
440        value: u64,
441    ) -> wasmtime_wasi::p2::SocketResult<()> {
442        p2_udp::HostUdpSocket::set_receive_buffer_size(&mut self.inner, this, value)
443    }
444
445    fn send_buffer_size(
446        &mut self,
447        this: Resource<p2_udp::UdpSocket>,
448    ) -> wasmtime_wasi::p2::SocketResult<u64> {
449        p2_udp::HostUdpSocket::send_buffer_size(&mut self.inner, this)
450    }
451
452    fn set_send_buffer_size(
453        &mut self,
454        this: Resource<p2_udp::UdpSocket>,
455        value: u64,
456    ) -> wasmtime_wasi::p2::SocketResult<()> {
457        p2_udp::HostUdpSocket::set_send_buffer_size(&mut self.inner, this, value)
458    }
459
460    fn subscribe(
461        &mut self,
462        this: Resource<p2_udp::UdpSocket>,
463    ) -> wasmtime::Result<Resource<DynPollable>> {
464        p2_udp::HostUdpSocket::subscribe(&mut self.inner, this)
465    }
466
467    fn drop(&mut self, this: Resource<p2_udp::UdpSocket>) -> wasmtime::Result<()> {
468        self.release_permit(this.rep());
469        p2_udp::HostUdpSocket::drop(&mut self.inner, this)
470    }
471}
472
473impl p2_udp::HostIncomingDatagramStream for SpinSocketsView<'_> {
474    fn receive(
475        &mut self,
476        this: Resource<p2_udp::IncomingDatagramStream>,
477        max_results: u64,
478    ) -> wasmtime_wasi::p2::SocketResult<Vec<p2_udp::IncomingDatagram>> {
479        p2_udp::HostIncomingDatagramStream::receive(&mut self.inner, this, max_results)
480    }
481
482    fn subscribe(
483        &mut self,
484        this: Resource<p2_udp::IncomingDatagramStream>,
485    ) -> wasmtime::Result<Resource<DynPollable>> {
486        p2_udp::HostIncomingDatagramStream::subscribe(&mut self.inner, this)
487    }
488
489    fn drop(&mut self, this: Resource<p2_udp::IncomingDatagramStream>) -> wasmtime::Result<()> {
490        p2_udp::HostIncomingDatagramStream::drop(&mut self.inner, this)
491    }
492}
493
494impl p2_udp::HostOutgoingDatagramStream for SpinSocketsView<'_> {
495    fn check_send(
496        &mut self,
497        this: Resource<p2_udp::OutgoingDatagramStream>,
498    ) -> wasmtime_wasi::p2::SocketResult<u64> {
499        p2_udp::HostOutgoingDatagramStream::check_send(&mut self.inner, this)
500    }
501
502    async fn send(
503        &mut self,
504        this: Resource<p2_udp::OutgoingDatagramStream>,
505        datagrams: Vec<p2_udp::OutgoingDatagram>,
506    ) -> wasmtime_wasi::p2::SocketResult<u64> {
507        p2_udp::HostOutgoingDatagramStream::send(&mut self.inner, this, datagrams).await
508    }
509
510    fn subscribe(
511        &mut self,
512        this: Resource<p2_udp::OutgoingDatagramStream>,
513    ) -> wasmtime::Result<Resource<DynPollable>> {
514        p2_udp::HostOutgoingDatagramStream::subscribe(&mut self.inner, this)
515    }
516
517    fn drop(&mut self, this: Resource<p2_udp::OutgoingDatagramStream>) -> wasmtime::Result<()> {
518        p2_udp::HostOutgoingDatagramStream::drop(&mut self.inner, this)
519    }
520}
521
522impl p2_udp_create::Host for SpinSocketsView<'_> {
523    fn create_udp_socket(
524        &mut self,
525        address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily,
526    ) -> wasmtime_wasi::p2::SocketResult<Resource<UdpSocket>> {
527        // Check quota before allocating the socket resource.
528        // See the analogous comment in `start_connect` for why we fail
529        // immediately rather than waiting (as outbound HTTP does).
530        let Ok(permit) = self.try_acquire() else {
531            tracing::warn!("UDP socket creation refused: connection quota exhausted");
532            return Err(SocketErrorCode::NewSocketLimit.into());
533        };
534        let sock = p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family)?;
535        self.register_permit(sock.rep(), permit);
536        Ok(sock)
537    }
538}