spin_connection_semaphore/
lib.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use anyhow::anyhow;
5use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
6use tokio::time;
7
8#[derive(Clone)]
10pub struct ConnectionSemaphore {
11 global: Option<Arc<Semaphore>>,
12 factor_specific: Option<Arc<Semaphore>>,
13 factor: &'static str,
14 wait_timeout: Option<Duration>,
15}
16
17impl ConnectionSemaphore {
18 pub fn new(
25 global: Option<Arc<Semaphore>>,
26 factor_specific_limit: Option<usize>,
27 factor: &'static str,
28 wait_timeout: Option<Duration>,
29 ) -> Self {
30 Self {
31 global,
32 factor_specific: factor_specific_limit.map(|n| Arc::new(Semaphore::new(n))),
33 factor,
34 wait_timeout,
35 }
36 }
37
38 #[cfg(test)]
39 pub(crate) fn from_raw(
40 global: Option<Arc<Semaphore>>,
41 factor_specific: Option<Arc<Semaphore>>,
42 factor: &'static str,
43 wait_timeout: Option<Duration>,
44 ) -> Self {
45 Self {
46 global,
47 factor_specific,
48 factor,
49 wait_timeout,
50 }
51 }
52
53 pub async fn acquire(&self) -> anyhow::Result<ConnectionPermit> {
63 if let Ok(permit) = self.try_acquire_permits() {
65 spin_telemetry::monotonic_counter!(
66 outbound_connection_permits_acquired = 1,
67 kind = self.factor,
68 waited = false
69 );
70 return Ok(permit);
71 }
72
73 match self.wait_timeout {
74 Some(timeout) => time::timeout(timeout, self.acquire_inner())
75 .await
76 .map_err(|_| anyhow!("connection semaphore timed out after {timeout:?}"))?,
77 None => self.acquire_inner().await,
78 }
79 }
80
81 async fn acquire_inner(&self) -> anyhow::Result<ConnectionPermit> {
83 async fn acquire_one(
87 sem: &Arc<Semaphore>,
88 waited: &mut bool,
89 label: &str,
90 ) -> anyhow::Result<OwnedSemaphorePermit> {
91 match sem.clone().try_acquire_owned() {
92 Ok(p) => Ok(p),
93 Err(TryAcquireError::NoPermits) => {
94 *waited = true;
95 sem.clone()
96 .acquire_owned()
97 .await
98 .map_err(|_| anyhow!("{label} connection semaphore closed"))
99 }
100 Err(_) => Err(anyhow!("{label} connection semaphore closed")),
101 }
102 }
103 let mut waited = false;
104 let start = std::time::Instant::now();
105
106 let factor_specific = match &self.factor_specific {
109 Some(f) => Some(acquire_one(f, &mut waited, "factor").await?),
110 None => None,
111 };
112 let global = match &self.global {
115 Some(g) => Some(acquire_one(g, &mut waited, "global").await?),
116 None => None,
117 };
118
119 let factor = self.factor;
120 if waited {
121 spin_telemetry::histogram!(
122 outbound_connection_permit_wait_duration_ms = start.elapsed().as_millis() as f64,
123 kind = factor
124 );
125 }
126 spin_telemetry::monotonic_counter!(
127 outbound_connection_permits_acquired = 1,
128 kind = factor,
129 waited = waited
130 );
131
132 Ok(ConnectionPermit {
133 _global: global,
134 _factor_specific: factor_specific,
135 })
136 }
137
138 pub fn try_acquire(&self) -> Option<ConnectionPermit> {
144 match self.try_acquire_permits() {
145 Ok(permit) => {
146 spin_telemetry::monotonic_counter!(
147 outbound_connection_permits_acquired = 1,
148 kind = self.factor,
149 waited = false
150 );
151 Some(permit)
152 }
153 Err(limit) => {
154 spin_telemetry::monotonic_counter!(
155 outbound_connection_permits_rejected = 1,
156 kind = self.factor,
157 limit = limit
158 );
159 None
160 }
161 }
162 }
163
164 fn try_acquire_permits(&self) -> Result<ConnectionPermit, &'static str> {
170 let global = match &self.global {
172 Some(s) => match s.clone().try_acquire_owned() {
173 Ok(p) => Some(p),
174 Err(_) => return Err("global"),
175 },
176 None => None,
177 };
178 let factor_specific = match &self.factor_specific {
181 Some(s) => match s.clone().try_acquire_owned() {
182 Ok(p) => Some(p),
183 Err(_) => return Err("factor"),
184 },
185 None => None,
186 };
187 Ok(ConnectionPermit {
188 _global: global,
189 _factor_specific: factor_specific,
190 })
191 }
192}
193
194#[derive(Debug)]
200pub struct ConnectionPermit {
201 _global: Option<OwnedSemaphorePermit>,
202 _factor_specific: Option<OwnedSemaphorePermit>,
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[tokio::test]
210 async fn no_limits_acquire_always_succeeds() {
211 let sem = ConnectionSemaphore::new(None, None, "test", None);
212 let permit = sem.acquire().await.expect("should succeed");
213 drop(permit);
214 let _permit2 = sem.acquire().await.expect("should succeed again");
215 }
216
217 #[test]
218 fn no_limits_try_acquire_always_succeeds() {
219 let sem = ConnectionSemaphore::new(None, None, "test", None);
220 let permit = sem.try_acquire().expect("should succeed");
221 drop(permit);
222 let _permit2 = sem.try_acquire().expect("should succeed again");
223 }
224
225 #[test]
226 fn global_limit_only_exhausted() {
227 let global = Arc::new(Semaphore::new(1));
228 let sem = ConnectionSemaphore::new(Some(global.clone()), None, "test", None);
229 let permit1 = sem.try_acquire().expect("first should succeed");
230 assert!(
231 sem.try_acquire().is_none(),
232 "second should fail: global exhausted"
233 );
234 drop(permit1);
235 assert_eq!(global.available_permits(), 1);
236 let _permit3 = sem.try_acquire().expect("after release should succeed");
237 }
238
239 #[test]
240 fn factor_limit_only_exhausted() {
241 let sem = ConnectionSemaphore::new(None, Some(1), "test", None);
242 let permit1 = sem.try_acquire().expect("first should succeed");
243 assert!(
244 sem.try_acquire().is_none(),
245 "second should fail: factor exhausted"
246 );
247 drop(permit1);
248 let _permit3 = sem.try_acquire().expect("after release should succeed");
249 }
250
251 #[test]
252 fn both_limits_global_exhausted_first() {
253 let global = Arc::new(Semaphore::new(1));
254 let factor = Arc::new(Semaphore::new(2));
255 let sem =
256 ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test", None);
257
258 let permit1 = sem.try_acquire().expect("first should succeed");
259 let factor_before = factor.available_permits();
261
262 assert!(sem.try_acquire().is_none(), "should fail: global exhausted");
264 assert_eq!(
266 factor.available_permits(),
267 factor_before,
268 "factor permits should not be consumed when global is exhausted"
269 );
270 drop(permit1);
271 }
272
273 #[test]
274 fn both_limits_factor_exhausted_global_released() {
275 let global = Arc::new(Semaphore::new(2));
276 let factor = Arc::new(Semaphore::new(1));
277 let sem =
278 ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test", None);
279
280 let permit1 = sem.try_acquire().expect("first should succeed");
281 let result = sem.try_acquire();
283 assert!(result.is_none(), "should fail: factor exhausted");
284 assert_eq!(global.available_permits(), 1);
286 drop(permit1);
287 assert_eq!(global.available_permits(), 2);
288 }
289
290 #[tokio::test]
291 async fn acquire_waits_for_release() {
292 let global = Arc::new(Semaphore::new(1));
293 let sem = ConnectionSemaphore::new(Some(global.clone()), None, "test", None);
294
295 let permit = sem.try_acquire().expect("first should succeed");
296
297 let sem2 = sem.clone();
298 let handle = tokio::spawn(async move {
299 let _p = sem2.acquire().await.expect("should eventually acquire");
300 });
301
302 drop(permit); handle.await.expect("task should complete");
304 }
305
306 #[tokio::test]
309 async fn acquire_releases_global_while_waiting_for_factor() {
310 let global = Arc::new(Semaphore::new(1));
311 let factor = Arc::new(Semaphore::new(1));
312 let sem =
313 ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test", None);
314
315 let _factor_hold = factor.clone().acquire_owned().await.unwrap();
317
318 let global_clone = global.clone();
319 let sem_clone = sem.clone();
320 let handle = tokio::spawn(async move {
321 sem_clone
322 .acquire()
323 .await
324 .expect("should succeed after factor is released")
325 });
326
327 tokio::task::yield_now().await;
330 tokio::task::yield_now().await;
331
332 assert_eq!(
333 global_clone.available_permits(),
334 1,
335 "global should be free while acquire() waits for factor-specific"
336 );
337
338 drop(_factor_hold);
339 handle.await.expect("task should complete");
340 }
341
342 #[tokio::test]
343 async fn acquire_times_out_when_semaphore_exhausted() {
344 let global = Arc::new(Semaphore::new(1));
345 let sem = ConnectionSemaphore::new(
346 Some(global.clone()),
347 None,
348 "test",
349 Some(Duration::from_millis(10)),
350 );
351
352 let _permit = sem.try_acquire().expect("first should succeed");
353
354 let err = sem.acquire().await.expect_err("should time out");
355 assert!(
356 err.to_string().contains("timed out"),
357 "error message should mention timed out: {err}"
358 );
359 }
360}