1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use azure_data_cosmos::{
4 prelude::{
5 AuthorizationToken, CollectionClient, CosmosClient, CosmosClientBuilder, Operation, Query,
6 },
7 CosmosEntity,
8};
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use spin_factor_key_value::{log_cas_error, log_error, Cas, Error, Store, StoreManager, SwapError};
12use std::sync::{Arc, Mutex};
13
14pub struct KeyValueAzureCosmos {
15 client: CollectionClient,
16 app_id: Option<String>,
22}
23
24#[derive(Clone, Debug)]
26pub struct KeyValueAzureCosmosRuntimeConfigOptions {
27 key: String,
28}
29
30impl KeyValueAzureCosmosRuntimeConfigOptions {
31 pub fn new(key: String) -> Self {
32 Self { key }
33 }
34}
35
36#[derive(Clone, Debug)]
38pub enum KeyValueAzureCosmosAuthOptions {
39 RuntimeConfigValues(KeyValueAzureCosmosRuntimeConfigOptions),
41 Environmental,
72}
73
74impl KeyValueAzureCosmos {
75 pub fn new(
76 account: String,
77 database: String,
78 container: String,
79 auth_options: KeyValueAzureCosmosAuthOptions,
80 app_id: Option<String>,
81 ) -> Result<Self> {
82 let token = match auth_options {
83 KeyValueAzureCosmosAuthOptions::RuntimeConfigValues(config) => {
84 AuthorizationToken::primary_key(config.key).map_err(log_error)?
85 }
86 KeyValueAzureCosmosAuthOptions::Environmental => {
87 AuthorizationToken::from_token_credential(
88 azure_identity::create_default_credential()?,
89 )
90 }
91 };
92 let cosmos_client = cosmos_client(account, token)?;
93 let database_client = cosmos_client.database_client(database);
94 let client = database_client.collection_client(container);
95
96 Ok(Self { client, app_id })
97 }
98}
99
100fn cosmos_client(account: impl Into<String>, token: AuthorizationToken) -> Result<CosmosClient> {
101 if cfg!(feature = "connection-pooling") {
102 let client = reqwest::ClientBuilder::new()
103 .build()
104 .context("failed to build reqwest client")?;
105 let transport_options = azure_core::TransportOptions::new(std::sync::Arc::new(client));
106 Ok(CosmosClientBuilder::new(account, token)
107 .transport(transport_options)
108 .build())
109 } else {
110 Ok(CosmosClient::new(account, token))
111 }
112}
113
114#[async_trait]
115impl StoreManager for KeyValueAzureCosmos {
116 async fn get(&self, name: &str) -> Result<Arc<dyn Store>, Error> {
117 Ok(Arc::new(AzureCosmosStore {
118 client: self.client.clone(),
119 store_id: self.app_id.as_ref().map(|i| format!("{i}/{name}")),
120 }))
121 }
122
123 fn is_defined(&self, _store_name: &str) -> bool {
124 true
125 }
126
127 fn summary(&self, _store_name: &str) -> Option<String> {
128 let database = self.client.database_client().database_name();
129 let collection = self.client.collection_name();
130 Some(format!(
131 "Azure CosmosDB database: {database}, collection: {collection}"
132 ))
133 }
134}
135
136#[derive(Clone)]
137struct AzureCosmosStore {
138 client: CollectionClient,
139 store_id: Option<String>,
147}
148
149#[async_trait]
150impl Store for AzureCosmosStore {
151 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
152 let pair = self.get_entity::<Pair>(key).await?;
153 Ok(pair.map(|p| p.value))
154 }
155
156 async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
157 let illegal_chars = ['/', '\\', '?', '#'];
158
159 if key.contains(|c| illegal_chars.contains(&c)) {
160 return Err(Error::Other(format!(
161 "Key contains an illegal character. Keys must not include any of: {}",
162 illegal_chars.iter().collect::<String>()
163 )));
164 }
165
166 let pair = Pair {
167 id: key.to_string(),
168 value: value.to_vec(),
169 store_id: self.store_id.clone(),
170 };
171 self.client
172 .create_document(pair)
173 .is_upsert(true)
174 .await
175 .map_err(log_error)?;
176 Ok(())
177 }
178
179 async fn delete(&self, key: &str) -> Result<(), Error> {
180 let document_client = self
181 .client
182 .document_client(key, &self.store_id.clone().unwrap_or(key.to_string()))
183 .map_err(log_error)?;
184 if let Err(e) = document_client.delete_document().await {
185 if e.as_http_error().map(|e| e.status() != 404).unwrap_or(true) {
186 return Err(log_error(e));
187 }
188 }
189 Ok(())
190 }
191
192 async fn exists(&self, key: &str) -> Result<bool, Error> {
193 let mut stream = self
194 .client
195 .query_documents(Query::new(self.get_id_query(key)))
196 .query_cross_partition(true)
197 .max_item_count(1)
198 .into_stream::<Key>();
199
200 match stream.next().await {
201 Some(Ok(res)) => Ok(!res.results.is_empty()),
202 Some(Err(e)) => Err(log_error(e)),
203 None => Ok(false),
204 }
205 }
206
207 async fn get_keys(&self) -> Result<Vec<String>, Error> {
208 self.get_keys().await
209 }
210
211 async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
212 let stmt = Query::new(self.get_in_query(keys));
213 let query = self
214 .client
215 .query_documents(stmt)
216 .query_cross_partition(true);
217
218 let mut res = Vec::new();
219 let mut stream = query.into_stream::<Pair>();
220 while let Some(resp) = stream.next().await {
221 let resp = resp.map_err(log_error)?;
222 res.extend(
223 resp.results
224 .into_iter()
225 .map(|(pair, _)| (pair.id, Some(pair.value))),
226 );
227 }
228 Ok(res)
229 }
230
231 async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
232 for (key, value) in key_values {
233 self.set(key.as_ref(), &value).await?
234 }
235 Ok(())
236 }
237
238 async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
239 for key in keys {
240 self.delete(key.as_ref()).await?
241 }
242 Ok(())
243 }
244
245 async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
254 let operations = vec![Operation::incr("/value", delta).map_err(log_error)?];
255 match self
256 .client
257 .document_client(&key, &self.store_id.clone().unwrap_or(key.to_string()))
258 .map_err(log_error)?
259 .patch_document(operations)
260 .await
261 {
262 Err(e) => {
263 if e.as_http_error()
264 .map(|e| e.status() == 404)
265 .unwrap_or(false)
266 {
267 let counter = Counter {
268 id: key.clone(),
269 value: delta,
270 store_id: self.store_id.clone(),
271 };
272 if let Err(e) = self.client.create_document(counter).is_upsert(false).await {
273 if e.as_http_error()
274 .map(|e| e.status())
275 .unwrap_or(azure_core::StatusCode::Continue)
276 == 409
277 {
278 self.increment(key, delta).await?;
280 } else {
281 return Err(log_error(e));
282 }
283 }
284 Ok(delta)
285 } else {
286 Err(log_error(e))
287 }
288 }
289 Ok(_) => self
290 .get_entity::<Counter>(key.as_ref())
291 .await?
292 .map(|c| c.value)
293 .ok_or(Error::Other(
294 "increment returned an empty value after patching, which indicates a bug"
295 .to_string(),
296 )),
297 }
298 }
299
300 async fn new_compare_and_swap(
301 &self,
302 bucket_rep: u32,
303 key: &str,
304 ) -> Result<Arc<dyn spin_factor_key_value::Cas>, Error> {
305 Ok(Arc::new(CompareAndSwap {
306 key: key.to_string(),
307 client: self.client.clone(),
308 etag: Mutex::new(None),
309 bucket_rep,
310 store_id: self.store_id.clone(),
311 }))
312 }
313}
314
315struct CompareAndSwap {
316 key: String,
317 client: CollectionClient,
318 bucket_rep: u32,
319 etag: Mutex<Option<String>>,
320 store_id: Option<String>,
321}
322
323impl CompareAndSwap {
324 fn get_query(&self) -> String {
325 let mut query = format!("SELECT * FROM c WHERE c.id='{}'", self.key);
326 self.append_store_id(&mut query, true);
327 query
328 }
329
330 fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
331 append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
332 }
333}
334
335#[async_trait]
336impl Cas for CompareAndSwap {
337 async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
340 let mut stream = self
341 .client
342 .query_documents(Query::new(self.get_query()))
343 .query_cross_partition(true)
344 .max_item_count(1)
345 .into_stream::<Pair>();
346
347 let current_value: Option<(Vec<u8>, Option<String>)> = match stream.next().await {
348 Some(r) => {
349 let r = r.map_err(log_error)?;
350 match r.results.first() {
351 Some((item, Some(attr))) => {
352 Some((item.clone().value, Some(attr.etag().to_string())))
353 }
354 Some((item, None)) => Some((item.clone().value, None)),
355 _ => None,
356 }
357 }
358 None => None,
359 };
360
361 match current_value {
362 Some((value, etag)) => {
363 self.etag.lock().unwrap().clone_from(&etag);
364 Ok(Some(value))
365 }
366 None => Ok(None),
367 }
368 }
369
370 async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
373 let pair = Pair {
374 id: self.key.clone(),
375 value,
376 store_id: self.store_id.clone(),
377 };
378
379 let doc_client = self
380 .client
381 .document_client(&self.key, &pair.partition_key())
382 .map_err(log_cas_error)?;
383
384 let etag_value = self.etag.lock().unwrap().clone();
385 match etag_value {
386 Some(etag) => {
387 doc_client
389 .replace_document(pair)
390 .if_match_condition(azure_core::request_options::IfMatchCondition::Match(etag))
391 .await
392 .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
393 .map(drop)
394 }
395 None => {
396 self.client
398 .create_document(pair)
399 .await
400 .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
401 .map(drop)
402 }
403 }
404 }
405
406 async fn bucket_rep(&self) -> u32 {
407 self.bucket_rep
408 }
409
410 async fn key(&self) -> String {
411 self.key.clone()
412 }
413}
414
415impl AzureCosmosStore {
416 async fn get_entity<F>(&self, key: &str) -> Result<Option<F>, Error>
417 where
418 F: CosmosEntity + Send + Sync + serde::de::DeserializeOwned + Clone,
419 {
420 let query = self
421 .client
422 .query_documents(Query::new(self.get_query(key)))
423 .query_cross_partition(true)
424 .max_item_count(1);
425
426 let mut stream = query.into_stream::<F>();
428 let Some(res) = stream.next().await else {
429 return Ok(None);
430 };
431 Ok(res
432 .map_err(log_error)?
433 .results
434 .first()
435 .map(|(p, _)| p.clone()))
436 }
437
438 async fn get_keys(&self) -> Result<Vec<String>, Error> {
439 let query = self
440 .client
441 .query_documents(Query::new(self.get_keys_query()))
442 .query_cross_partition(true);
443 let mut res = Vec::new();
444
445 let mut stream = query.into_stream::<Key>();
446 while let Some(resp) = stream.next().await {
447 let resp = resp.map_err(log_error)?;
448 res.extend(resp.results.into_iter().map(|(key, _)| key.id));
449 }
450
451 Ok(res)
452 }
453
454 fn get_query(&self, key: &str) -> String {
455 let mut query = format!("SELECT * FROM c WHERE c.id='{key}'");
456 self.append_store_id(&mut query, true);
457 query
458 }
459
460 fn get_id_query(&self, key: &str) -> String {
461 let mut query = format!("SELECT c.id, c.store_id FROM c WHERE c.id='{key}'");
462 self.append_store_id(&mut query, true);
463 query
464 }
465
466 fn get_keys_query(&self) -> String {
467 let mut query = "SELECT c.id, c.store_id FROM c".to_owned();
468 self.append_store_id(&mut query, false);
469 query
470 }
471
472 fn get_in_query(&self, keys: Vec<String>) -> String {
473 let in_clause: String = keys
474 .into_iter()
475 .map(|k| format!("'{k}'"))
476 .collect::<Vec<String>>()
477 .join(", ");
478
479 let mut query = format!("SELECT * FROM c WHERE c.id IN ({in_clause})");
480 self.append_store_id(&mut query, true);
481 query
482 }
483
484 fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
485 append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
486 }
487}
488
489fn append_store_id_condition(
491 query: &mut String,
492 store_id: Option<&str>,
493 condition_already_exists: bool,
494) {
495 if let Some(s) = store_id {
496 if condition_already_exists {
497 query.push_str(" AND");
498 } else {
499 query.push_str(" WHERE");
500 }
501 query.push_str(" c.store_id='");
502 query.push_str(s);
503 query.push('\'')
504 }
505}
506
507#[derive(Serialize, Deserialize, Clone, Debug)]
509pub struct Pair {
510 pub id: String,
511 pub value: Vec<u8>,
512 #[serde(skip_serializing_if = "Option::is_none")]
513 pub store_id: Option<String>,
514}
515
516impl CosmosEntity for Pair {
517 type Entity = String;
518
519 fn partition_key(&self) -> Self::Entity {
520 self.store_id.clone().unwrap_or_else(|| self.id.clone())
521 }
522}
523
524#[derive(Serialize, Deserialize, Clone, Debug)]
526pub struct Counter {
527 pub id: String,
528 pub value: i64,
529 #[serde(skip_serializing_if = "Option::is_none")]
530 pub store_id: Option<String>,
531}
532
533impl CosmosEntity for Counter {
534 type Entity = String;
535
536 fn partition_key(&self) -> Self::Entity {
537 self.store_id.clone().unwrap_or_else(|| self.id.clone())
538 }
539}
540
541#[derive(Serialize, Deserialize, Clone, Debug)]
543pub struct Key {
544 pub id: String,
545 #[serde(skip_serializing_if = "Option::is_none")]
546 pub store_id: Option<String>,
547}
548
549impl CosmosEntity for Key {
550 type Entity = String;
551
552 fn partition_key(&self) -> Self::Entity {
553 self.store_id.clone().unwrap_or_else(|| self.id.clone())
554 }
555}