1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use azure_data_cosmos::{
4 prelude::{
5 AuthorizationToken, CollectionClient, CosmosClient, CosmosClientBuilder, Operation, Query,
6 },
7 CosmosEntity,
8};
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use spin_factor_key_value::{
12 log_cas_error, log_error, log_error_v3, v3, Cas, Error, Store, StoreManager, SwapError,
13};
14use std::sync::{Arc, Mutex};
15
16pub struct KeyValueAzureCosmos {
17 client: CollectionClient,
18 app_id: Option<String>,
24}
25
26#[derive(Clone, Debug)]
28pub struct KeyValueAzureCosmosRuntimeConfigOptions {
29 key: String,
30}
31
32impl KeyValueAzureCosmosRuntimeConfigOptions {
33 pub fn new(key: String) -> Self {
34 Self { key }
35 }
36}
37
38#[derive(Clone, Debug)]
40pub enum KeyValueAzureCosmosAuthOptions {
41 RuntimeConfigValues(KeyValueAzureCosmosRuntimeConfigOptions),
43 Environmental,
74}
75
76impl KeyValueAzureCosmos {
77 pub fn new(
78 account: String,
79 database: String,
80 container: String,
81 auth_options: KeyValueAzureCosmosAuthOptions,
82 app_id: Option<String>,
83 ) -> Result<Self> {
84 let token = match auth_options {
85 KeyValueAzureCosmosAuthOptions::RuntimeConfigValues(config) => {
86 AuthorizationToken::primary_key(config.key).map_err(log_error)?
87 }
88 KeyValueAzureCosmosAuthOptions::Environmental => {
89 AuthorizationToken::from_token_credential(
90 azure_identity::create_default_credential()?,
91 )
92 }
93 };
94 let cosmos_client = cosmos_client(account, token)?;
95 let database_client = cosmos_client.database_client(database);
96 let client = database_client.collection_client(container);
97
98 Ok(Self { client, app_id })
99 }
100}
101
102fn cosmos_client(account: impl Into<String>, token: AuthorizationToken) -> Result<CosmosClient> {
103 if cfg!(feature = "connection-pooling") {
104 let client = reqwest::ClientBuilder::new()
105 .build()
106 .context("failed to build reqwest client")?;
107 let transport_options = azure_core::TransportOptions::new(std::sync::Arc::new(client));
108 Ok(CosmosClientBuilder::new(account, token)
109 .transport(transport_options)
110 .build())
111 } else {
112 Ok(CosmosClient::new(account, token))
113 }
114}
115
116#[async_trait]
117impl StoreManager for KeyValueAzureCosmos {
118 async fn get(&self, name: &str) -> Result<Arc<dyn Store>, Error> {
119 Ok(Arc::new(AzureCosmosStore {
120 client: self.client.clone(),
121 store_id: self.app_id.as_ref().map(|i| format!("{i}/{name}")),
122 }))
123 }
124
125 fn is_defined(&self, _store_name: &str) -> bool {
126 true
127 }
128
129 fn summary(&self, _store_name: &str) -> Option<String> {
130 let database = self.client.database_client().database_name();
131 let collection = self.client.collection_name();
132 Some(format!(
133 "Azure CosmosDB database: {database}, collection: {collection}"
134 ))
135 }
136}
137
138#[derive(Clone)]
139struct AzureCosmosStore {
140 client: CollectionClient,
141 store_id: Option<String>,
149}
150
151#[async_trait]
152impl Store for AzureCosmosStore {
153 async fn get(&self, key: &str, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
154 let pair = self.get_entity::<Pair>(key).await?;
155 let value = pair.map(|p| p.value);
156
157 if std::mem::size_of::<Option<Vec<u8>>>() + value.as_ref().map(|v| v.len()).unwrap_or(0)
162 > max_result_bytes
163 {
164 Err(Error::Other(format!(
165 "query result exceeds limit of {max_result_bytes} bytes"
166 )))
167 } else {
168 Ok(value)
169 }
170 }
171
172 async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
173 let illegal_chars = ['/', '\\', '?', '#'];
174
175 if key.contains(|c| illegal_chars.contains(&c)) {
176 return Err(Error::Other(format!(
177 "Key contains an illegal character. Keys must not include any of: {}",
178 illegal_chars.iter().collect::<String>()
179 )));
180 }
181
182 let pair = Pair {
183 id: key.to_string(),
184 value: value.to_vec(),
185 store_id: self.store_id.clone(),
186 };
187 self.client
188 .create_document(pair)
189 .is_upsert(true)
190 .await
191 .map_err(log_error)?;
192 Ok(())
193 }
194
195 async fn delete(&self, key: &str) -> Result<(), Error> {
196 let document_client = self
197 .client
198 .document_client(key, &self.store_id.clone().unwrap_or(key.to_string()))
199 .map_err(log_error)?;
200 if let Err(e) = document_client.delete_document().await {
201 if e.as_http_error().map(|e| e.status() != 404).unwrap_or(true) {
202 return Err(log_error(e));
203 }
204 }
205 Ok(())
206 }
207
208 async fn exists(&self, key: &str) -> Result<bool, Error> {
209 let mut stream = self
210 .client
211 .query_documents(Query::new(self.get_id_query(key)))
212 .query_cross_partition(true)
213 .max_item_count(1)
214 .into_stream::<Key>();
215
216 match stream.next().await {
217 Some(Ok(res)) => Ok(!res.results.is_empty()),
218 Some(Err(e)) => Err(log_error(e)),
219 None => Ok(false),
220 }
221 }
222
223 async fn get_keys(&self, max_result_bytes: usize) -> Result<Vec<String>, Error> {
224 self.get_keys(max_result_bytes).await
225 }
226
227 async fn get_keys_async(
228 &self,
229 max_result_bytes: usize,
230 ) -> (
231 tokio::sync::mpsc::Receiver<String>,
232 tokio::sync::oneshot::Receiver<Result<(), v3::Error>>,
233 ) {
234 let (keys_tx, keys_rx) = tokio::sync::mpsc::channel(4);
235 let (err_tx, err_rx) = tokio::sync::oneshot::channel();
236
237 let query = self
238 .client
239 .query_documents(Query::new(self.get_keys_query()))
240 .query_cross_partition(true);
241
242 let the_work = async move {
243 let mut stream = query.into_stream::<Key>();
244 while let Some(resp) = stream.next().await {
245 let resp = resp.map_err(log_error_v3)?;
246
247 if resp.results.iter().map(|(k, _)| k.id.len()).sum::<usize>() > max_result_bytes {
248 return Err(v3::Error::Other(format!(
249 "query exceeds limit of {max_result_bytes} bytes"
250 )));
251 }
252
253 for (key, _) in resp.results {
254 keys_tx.send(key.id).await.map_err(log_error_v3)?;
255 }
256 }
257 Ok(())
258 };
259 tokio::spawn(async move {
260 let res = the_work.await;
261 _ = err_tx.send(res);
262 });
263
264 (keys_rx, err_rx)
265 }
266
267 async fn get_many(
268 &self,
269 keys: Vec<String>,
270 max_result_bytes: usize,
271 ) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
272 let stmt = Query::new(self.get_in_query(keys));
273 let query = self
274 .client
275 .query_documents(stmt)
276 .query_cross_partition(true);
277
278 let mut res = Vec::new();
279 let mut stream = query.into_stream::<Pair>();
280 let mut byte_count = std::mem::size_of::<Vec<(String, Option<Vec<u8>>)>>();
281 while let Some(resp) = stream.next().await {
282 let resp = resp.map_err(log_error)?.results;
283 byte_count += resp
284 .iter()
285 .map(|(pair, _)| {
286 std::mem::size_of::<(String, Option<Vec<u8>>)>()
287 + pair.id.len()
288 + pair.value.len()
289 })
290 .sum::<usize>();
291 if byte_count > max_result_bytes {
292 return Err(Error::Other(format!(
293 "query result exceeds limit of {max_result_bytes} bytes"
294 )));
295 }
296 res.extend(
297 resp.into_iter()
298 .map(|(pair, _)| (pair.id, Some(pair.value))),
299 );
300 }
301 Ok(res)
302 }
303
304 async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
305 for (key, value) in key_values {
306 self.set(key.as_ref(), &value).await?
307 }
308 Ok(())
309 }
310
311 async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
312 for key in keys {
313 self.delete(key.as_ref()).await?
314 }
315 Ok(())
316 }
317
318 async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
327 let operations = vec![Operation::incr("/value", delta).map_err(log_error)?];
328 match self
329 .client
330 .document_client(&key, &self.store_id.clone().unwrap_or(key.to_string()))
331 .map_err(log_error)?
332 .patch_document(operations)
333 .await
334 {
335 Err(e) => {
336 if e.as_http_error()
337 .map(|e| e.status() == 404)
338 .unwrap_or(false)
339 {
340 let counter = Counter {
341 id: key.clone(),
342 value: delta,
343 store_id: self.store_id.clone(),
344 };
345 if let Err(e) = self.client.create_document(counter).is_upsert(false).await {
346 if e.as_http_error()
347 .map(|e| e.status())
348 .unwrap_or(azure_core::StatusCode::Continue)
349 == 409
350 {
351 self.increment(key, delta).await?;
353 } else {
354 return Err(log_error(e));
355 }
356 }
357 Ok(delta)
358 } else {
359 Err(log_error(e))
360 }
361 }
362 Ok(_) => self
363 .get_entity::<Counter>(key.as_ref())
364 .await?
365 .map(|c| c.value)
366 .ok_or(Error::Other(
367 "increment returned an empty value after patching, which indicates a bug"
368 .to_string(),
369 )),
370 }
371 }
372
373 async fn new_compare_and_swap(
374 &self,
375 bucket_rep: u32,
376 key: &str,
377 ) -> Result<Arc<dyn spin_factor_key_value::Cas>, Error> {
378 Ok(Arc::new(CompareAndSwap {
379 key: key.to_string(),
380 client: self.client.clone(),
381 etag: Mutex::new(None),
382 bucket_rep,
383 store_id: self.store_id.clone(),
384 }))
385 }
386}
387
388struct CompareAndSwap {
389 key: String,
390 client: CollectionClient,
391 bucket_rep: u32,
392 etag: Mutex<Option<String>>,
393 store_id: Option<String>,
394}
395
396impl CompareAndSwap {
397 fn get_query(&self) -> String {
398 let mut query = format!("SELECT * FROM c WHERE c.id='{}'", self.key);
399 self.append_store_id(&mut query, true);
400 query
401 }
402
403 fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
404 append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
405 }
406}
407
408#[async_trait]
409impl Cas for CompareAndSwap {
410 async fn current(&self, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
413 let mut stream = self
414 .client
415 .query_documents(Query::new(self.get_query()))
416 .query_cross_partition(true)
417 .max_item_count(1)
418 .into_stream::<Pair>();
419
420 let current_value: Option<(Vec<u8>, Option<String>)> = match stream.next().await {
421 Some(r) => {
422 let r = r.map_err(log_error)?;
423 match r.results.first() {
424 Some((item, Some(attr))) => {
425 Some((item.clone().value, Some(attr.etag().to_string())))
426 }
427 Some((item, None)) => Some((item.clone().value, None)),
428 _ => None,
429 }
430 }
431 None => None,
432 };
433
434 let value = match current_value {
435 Some((value, etag)) => {
436 self.etag.lock().unwrap().clone_from(&etag);
437 Some(value)
438 }
439 None => None,
440 };
441
442 if std::mem::size_of::<Option<Vec<u8>>>() + value.as_ref().map(|v| v.len()).unwrap_or(0)
447 > max_result_bytes
448 {
449 Err(Error::Other(format!(
450 "query result exceeds limit of {max_result_bytes} bytes"
451 )))
452 } else {
453 Ok(value)
454 }
455 }
456
457 async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
460 let pair = Pair {
461 id: self.key.clone(),
462 value,
463 store_id: self.store_id.clone(),
464 };
465
466 let doc_client = self
467 .client
468 .document_client(&self.key, &pair.partition_key())
469 .map_err(log_cas_error)?;
470
471 let etag_value = self.etag.lock().unwrap().clone();
472 match etag_value {
473 Some(etag) => {
474 doc_client
476 .replace_document(pair)
477 .if_match_condition(azure_core::request_options::IfMatchCondition::Match(etag))
478 .await
479 .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
480 .map(drop)
481 }
482 None => {
483 self.client
485 .create_document(pair)
486 .await
487 .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
488 .map(drop)
489 }
490 }
491 }
492
493 async fn bucket_rep(&self) -> u32 {
494 self.bucket_rep
495 }
496
497 async fn key(&self) -> String {
498 self.key.clone()
499 }
500}
501
502impl AzureCosmosStore {
503 async fn get_entity<F>(&self, key: &str) -> Result<Option<F>, Error>
504 where
505 F: CosmosEntity + Send + Sync + serde::de::DeserializeOwned + Clone,
506 {
507 let query = self
508 .client
509 .query_documents(Query::new(self.get_query(key)))
510 .query_cross_partition(true)
511 .max_item_count(1);
512
513 let mut stream = query.into_stream::<F>();
515 let Some(res) = stream.next().await else {
516 return Ok(None);
517 };
518 Ok(res
519 .map_err(log_error)?
520 .results
521 .first()
522 .map(|(p, _)| p.clone()))
523 }
524
525 async fn get_keys(&self, max_result_bytes: usize) -> Result<Vec<String>, Error> {
526 let query = self
527 .client
528 .query_documents(Query::new(self.get_keys_query()))
529 .query_cross_partition(true);
530 let mut res = Vec::new();
531
532 let mut stream = query.into_stream::<Key>();
533 let mut byte_count = std::mem::size_of::<Vec<String>>();
534 while let Some(resp) = stream.next().await {
535 let resp = resp.map_err(log_error)?.results;
536 byte_count += resp
537 .iter()
538 .map(|(key, _)| std::mem::size_of::<String>() + key.id.len())
539 .sum::<usize>();
540 if byte_count > max_result_bytes {
541 return Err(Error::Other(format!(
542 "query result exceeds limit of {max_result_bytes} bytes"
543 )));
544 }
545 res.extend(resp.into_iter().map(|(key, _)| key.id));
546 }
547
548 Ok(res)
549 }
550
551 fn get_query(&self, key: &str) -> String {
552 let mut query = format!("SELECT * FROM c WHERE c.id='{key}'");
553 self.append_store_id(&mut query, true);
554 query
555 }
556
557 fn get_id_query(&self, key: &str) -> String {
558 let mut query = format!("SELECT c.id, c.store_id FROM c WHERE c.id='{key}'");
559 self.append_store_id(&mut query, true);
560 query
561 }
562
563 fn get_keys_query(&self) -> String {
564 let mut query = "SELECT c.id, c.store_id FROM c".to_owned();
565 self.append_store_id(&mut query, false);
566 query
567 }
568
569 fn get_in_query(&self, keys: Vec<String>) -> String {
570 let in_clause: String = keys
571 .into_iter()
572 .map(|k| format!("'{k}'"))
573 .collect::<Vec<String>>()
574 .join(", ");
575
576 let mut query = format!("SELECT * FROM c WHERE c.id IN ({in_clause})");
577 self.append_store_id(&mut query, true);
578 query
579 }
580
581 fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
582 append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
583 }
584}
585
586fn append_store_id_condition(
588 query: &mut String,
589 store_id: Option<&str>,
590 condition_already_exists: bool,
591) {
592 if let Some(s) = store_id {
593 if condition_already_exists {
594 query.push_str(" AND");
595 } else {
596 query.push_str(" WHERE");
597 }
598 query.push_str(" c.store_id='");
599 query.push_str(s);
600 query.push('\'')
601 }
602}
603
604#[derive(Serialize, Deserialize, Clone, Debug)]
606pub struct Pair {
607 pub id: String,
608 pub value: Vec<u8>,
609 #[serde(skip_serializing_if = "Option::is_none")]
610 pub store_id: Option<String>,
611}
612
613impl CosmosEntity for Pair {
614 type Entity = String;
615
616 fn partition_key(&self) -> Self::Entity {
617 self.store_id.clone().unwrap_or_else(|| self.id.clone())
618 }
619}
620
621#[derive(Serialize, Deserialize, Clone, Debug)]
623pub struct Counter {
624 pub id: String,
625 pub value: i64,
626 #[serde(skip_serializing_if = "Option::is_none")]
627 pub store_id: Option<String>,
628}
629
630impl CosmosEntity for Counter {
631 type Entity = String;
632
633 fn partition_key(&self) -> Self::Entity {
634 self.store_id.clone().unwrap_or_else(|| self.id.clone())
635 }
636}
637
638#[derive(Serialize, Deserialize, Clone, Debug)]
640pub struct Key {
641 pub id: String,
642 #[serde(skip_serializing_if = "Option::is_none")]
643 pub store_id: Option<String>,
644}
645
646impl CosmosEntity for Key {
647 type Entity = String;
648
649 fn partition_key(&self) -> Self::Entity {
650 self.store_id.clone().unwrap_or_else(|| self.id.clone())
651 }
652}