Skip to main content

spin_connection_semaphore/
lib.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use anyhow::anyhow;
5use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
6use tokio::time;
7
8/// Wraps an optional global and an optional factor-specific semaphore.
9#[derive(Clone)]
10pub struct ConnectionSemaphore {
11    global: Option<Arc<Semaphore>>,
12    factor_specific: Option<Arc<Semaphore>>,
13    factor: &'static str,
14    wait_timeout: Option<Duration>,
15}
16
17impl ConnectionSemaphore {
18    /// Creates a new `ConnectionSemaphore`.
19    ///
20    /// `global` is an optional semaphore shared across factors; `factor_specific_limit`
21    /// is an optional permit limit for this specific factor. If either is `None`, that level of
22    /// limiting is disabled. `factor` is a label used in emitted telemetry, and `wait_timeout` is
23    /// an optional duration to wait for a permit before giving up and returning an error.
24    pub fn new(
25        global: Option<Arc<Semaphore>>,
26        factor_specific_limit: Option<usize>,
27        factor: &'static str,
28        wait_timeout: Option<Duration>,
29    ) -> Self {
30        Self {
31            global,
32            factor_specific: factor_specific_limit.map(|n| Arc::new(Semaphore::new(n))),
33            factor,
34            wait_timeout,
35        }
36    }
37
38    #[cfg(test)]
39    pub(crate) fn from_raw(
40        global: Option<Arc<Semaphore>>,
41        factor_specific: Option<Arc<Semaphore>>,
42        factor: &'static str,
43        wait_timeout: Option<Duration>,
44    ) -> Self {
45        Self {
46            global,
47            factor_specific,
48            factor,
49            wait_timeout,
50        }
51    }
52
53    /// Acquire both configured semaphore slots, returning a permit that holds
54    /// them until dropped.
55    ///
56    /// When both a global and a factor-specific semaphore are configured, this
57    /// method acquires factor-specific first, then global, ensuring the global
58    /// permit is never held while blocking on a factor-specific backlog.
59    ///
60    /// If `wait_timeout` is configured and the permits cannot be acquired within
61    /// that duration, an error is returned.
62    pub async fn acquire(&self) -> anyhow::Result<ConnectionPermit> {
63        // Fast path: all required permits are already available
64        if let Ok(permit) = self.try_acquire_permits() {
65            spin_telemetry::monotonic_counter!(
66                outbound_connection_permits_acquired = 1,
67                kind = self.factor,
68                waited = false
69            );
70            return Ok(permit);
71        }
72
73        match self.wait_timeout {
74            Some(timeout) => time::timeout(timeout, self.acquire_inner())
75                .await
76                .map_err(|_| anyhow!("connection semaphore timed out after {timeout:?}"))?,
77            None => self.acquire_inner().await,
78        }
79    }
80
81    /// Inner logic for [`Self::acquire`], separated so the caller can apply a timeout.
82    async fn acquire_inner(&self) -> anyhow::Result<ConnectionPermit> {
83        /// Acquires a single permit from `sem`, trying non-blocking first.
84        ///
85        /// Sets `*waited = true` if a blocking wait was required.
86        async fn acquire_one(
87            sem: &Arc<Semaphore>,
88            waited: &mut bool,
89            label: &str,
90        ) -> anyhow::Result<OwnedSemaphorePermit> {
91            match sem.clone().try_acquire_owned() {
92                Ok(p) => Ok(p),
93                Err(TryAcquireError::NoPermits) => {
94                    *waited = true;
95                    sem.clone()
96                        .acquire_owned()
97                        .await
98                        .map_err(|_| anyhow!("{label} connection semaphore closed"))
99                }
100                Err(_) => Err(anyhow!("{label} connection semaphore closed")),
101            }
102        }
103        let mut waited = false;
104        let start = std::time::Instant::now();
105
106        // Acquire factor-specific first, then global. This ensures we never hold
107        // the global permit while blocking on factor-specific backlog.
108        let factor_specific = match &self.factor_specific {
109            Some(f) => Some(acquire_one(f, &mut waited, "factor").await?),
110            None => None,
111        };
112        // It's fine to hold the factor-specific permit while waiting for the global slot, since
113        // other consumers of the factor-specific would also end up waiting for the same global slot.
114        let global = match &self.global {
115            Some(g) => Some(acquire_one(g, &mut waited, "global").await?),
116            None => None,
117        };
118
119        let factor = self.factor;
120        if waited {
121            spin_telemetry::histogram!(
122                outbound_connection_permit_wait_duration_ms = start.elapsed().as_millis() as f64,
123                kind = factor
124            );
125        }
126        spin_telemetry::monotonic_counter!(
127            outbound_connection_permits_acquired = 1,
128            kind = factor,
129            waited = waited
130        );
131
132        Ok(ConnectionPermit {
133            _global: global,
134            _factor_specific: factor_specific,
135        })
136    }
137
138    /// Attempt to acquire both configured slots without waiting.
139    /// Returns `None` if either semaphore is exhausted.
140    ///
141    /// If the global permit is acquired but the factor-specific permit is not
142    /// available, the global permit is released before returning `None`.
143    pub fn try_acquire(&self) -> Option<ConnectionPermit> {
144        match self.try_acquire_permits() {
145            Ok(permit) => {
146                spin_telemetry::monotonic_counter!(
147                    outbound_connection_permits_acquired = 1,
148                    kind = self.factor,
149                    waited = false
150                );
151                Some(permit)
152            }
153            Err(limit) => {
154                spin_telemetry::monotonic_counter!(
155                    outbound_connection_permits_rejected = 1,
156                    kind = self.factor,
157                    limit = limit
158                );
159                None
160            }
161        }
162    }
163
164    /// Inner logic for [`Self::try_acquire`], separated so the caller can emit
165    /// telemetry based on whether a permit was obtained.
166    ///
167    /// Returns `Err("global")` or `Err("factor")` to indicate which limit was
168    /// exhausted, so the caller can tag the rejection metric accordingly.
169    fn try_acquire_permits(&self) -> Result<ConnectionPermit, &'static str> {
170        // Acquire global first. If it fails, nothing is consumed.
171        let global = match &self.global {
172            Some(s) => match s.clone().try_acquire_owned() {
173                Ok(p) => Some(p),
174                Err(_) => return Err("global"),
175            },
176            None => None,
177        };
178        // Now attempt the factor-specific permit.
179        // On failure, `global` is dropped here, releasing the global slot.
180        let factor_specific = match &self.factor_specific {
181            Some(s) => match s.clone().try_acquire_owned() {
182                Ok(p) => Some(p),
183                Err(_) => return Err("factor"),
184            },
185            None => None,
186        };
187        Ok(ConnectionPermit {
188            _global: global,
189            _factor_specific: factor_specific,
190        })
191    }
192}
193
194/// Holds up to two semaphore permits (global + factor-specific).
195/// Both permits are released when this value is dropped.
196/// All-`None` fields are valid and represent the no-limits case.
197///
198/// Fields are intentionally prefixed with `_` — they exist solely to be dropped.
199#[derive(Debug)]
200pub struct ConnectionPermit {
201    _global: Option<OwnedSemaphorePermit>,
202    _factor_specific: Option<OwnedSemaphorePermit>,
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[tokio::test]
210    async fn no_limits_acquire_always_succeeds() {
211        let sem = ConnectionSemaphore::new(None, None, "test", None);
212        let permit = sem.acquire().await.expect("should succeed");
213        drop(permit);
214        let _permit2 = sem.acquire().await.expect("should succeed again");
215    }
216
217    #[test]
218    fn no_limits_try_acquire_always_succeeds() {
219        let sem = ConnectionSemaphore::new(None, None, "test", None);
220        let permit = sem.try_acquire().expect("should succeed");
221        drop(permit);
222        let _permit2 = sem.try_acquire().expect("should succeed again");
223    }
224
225    #[test]
226    fn global_limit_only_exhausted() {
227        let global = Arc::new(Semaphore::new(1));
228        let sem = ConnectionSemaphore::new(Some(global.clone()), None, "test", None);
229        let permit1 = sem.try_acquire().expect("first should succeed");
230        assert!(
231            sem.try_acquire().is_none(),
232            "second should fail: global exhausted"
233        );
234        drop(permit1);
235        assert_eq!(global.available_permits(), 1);
236        let _permit3 = sem.try_acquire().expect("after release should succeed");
237    }
238
239    #[test]
240    fn factor_limit_only_exhausted() {
241        let sem = ConnectionSemaphore::new(None, Some(1), "test", None);
242        let permit1 = sem.try_acquire().expect("first should succeed");
243        assert!(
244            sem.try_acquire().is_none(),
245            "second should fail: factor exhausted"
246        );
247        drop(permit1);
248        let _permit3 = sem.try_acquire().expect("after release should succeed");
249    }
250
251    #[test]
252    fn both_limits_global_exhausted_first() {
253        let global = Arc::new(Semaphore::new(1));
254        let factor = Arc::new(Semaphore::new(2));
255        let sem =
256            ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test", None);
257
258        let permit1 = sem.try_acquire().expect("first should succeed");
259        // After permit1: global=0, factor=1
260        let factor_before = factor.available_permits();
261
262        // Second try_acquire should fail because global is exhausted.
263        assert!(sem.try_acquire().is_none(), "should fail: global exhausted");
264        // Factor must NOT have been consumed by the failed attempt.
265        assert_eq!(
266            factor.available_permits(),
267            factor_before,
268            "factor permits should not be consumed when global is exhausted"
269        );
270        drop(permit1);
271    }
272
273    #[test]
274    fn both_limits_factor_exhausted_global_released() {
275        let global = Arc::new(Semaphore::new(2));
276        let factor = Arc::new(Semaphore::new(1));
277        let sem =
278            ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test", None);
279
280        let permit1 = sem.try_acquire().expect("first should succeed");
281        // Global still has 1, factor exhausted
282        let result = sem.try_acquire();
283        assert!(result.is_none(), "should fail: factor exhausted");
284        // Global slot must have been released (back to 1)
285        assert_eq!(global.available_permits(), 1);
286        drop(permit1);
287        assert_eq!(global.available_permits(), 2);
288    }
289
290    #[tokio::test]
291    async fn acquire_waits_for_release() {
292        let global = Arc::new(Semaphore::new(1));
293        let sem = ConnectionSemaphore::new(Some(global.clone()), None, "test", None);
294
295        let permit = sem.try_acquire().expect("first should succeed");
296
297        let sem2 = sem.clone();
298        let handle = tokio::spawn(async move {
299            let _p = sem2.acquire().await.expect("should eventually acquire");
300        });
301
302        drop(permit); // release so the spawned task can proceed
303        handle.await.expect("task should complete");
304    }
305
306    /// Verifies that when factor-specific is exhausted, acquire() doesn't hold
307    /// a global permit while waiting — so other connection types aren't blocked.
308    #[tokio::test]
309    async fn acquire_releases_global_while_waiting_for_factor() {
310        let global = Arc::new(Semaphore::new(1));
311        let factor = Arc::new(Semaphore::new(1));
312        let sem =
313            ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test", None);
314
315        // Exhaust factor-specific from outside.
316        let _factor_hold = factor.clone().acquire_owned().await.unwrap();
317
318        let global_clone = global.clone();
319        let sem_clone = sem.clone();
320        let handle = tokio::spawn(async move {
321            sem_clone
322                .acquire()
323                .await
324                .expect("should succeed after factor is released")
325        });
326
327        // Yield twice: first to let the spawned task run until it blocks waiting
328        // for factor-specific; second to confirm it has released the global permit.
329        tokio::task::yield_now().await;
330        tokio::task::yield_now().await;
331
332        assert_eq!(
333            global_clone.available_permits(),
334            1,
335            "global should be free while acquire() waits for factor-specific"
336        );
337
338        drop(_factor_hold);
339        handle.await.expect("task should complete");
340    }
341
342    #[tokio::test]
343    async fn acquire_times_out_when_semaphore_exhausted() {
344        let global = Arc::new(Semaphore::new(1));
345        let sem = ConnectionSemaphore::new(
346            Some(global.clone()),
347            None,
348            "test",
349            Some(Duration::from_millis(10)),
350        );
351
352        let _permit = sem.try_acquire().expect("first should succeed");
353
354        let err = sem.acquire().await.expect_err("should time out");
355        assert!(
356            err.to_string().contains("timed out"),
357            "error message should mention timed out: {err}"
358        );
359    }
360}