1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use azure_data_cosmos::{
4 prelude::{
5 AuthorizationToken, CollectionClient, CosmosClient, CosmosClientBuilder, Operation, Query,
6 },
7 CosmosEntity,
8};
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use spin_factor_key_value::{log_cas_error, log_error, Cas, Error, Store, StoreManager, SwapError};
12use std::sync::{Arc, Mutex};
13
14pub struct KeyValueAzureCosmos {
15 client: CollectionClient,
16 app_id: Option<String>,
22}
23
24#[derive(Clone, Debug)]
26pub struct KeyValueAzureCosmosRuntimeConfigOptions {
27 key: String,
28}
29
30impl KeyValueAzureCosmosRuntimeConfigOptions {
31 pub fn new(key: String) -> Self {
32 Self { key }
33 }
34}
35
36#[derive(Clone, Debug)]
38pub enum KeyValueAzureCosmosAuthOptions {
39 RuntimeConfigValues(KeyValueAzureCosmosRuntimeConfigOptions),
41 Environmental,
72}
73
74impl KeyValueAzureCosmos {
75 pub fn new(
76 account: String,
77 database: String,
78 container: String,
79 auth_options: KeyValueAzureCosmosAuthOptions,
80 app_id: Option<String>,
81 ) -> Result<Self> {
82 let token = match auth_options {
83 KeyValueAzureCosmosAuthOptions::RuntimeConfigValues(config) => {
84 AuthorizationToken::primary_key(config.key).map_err(log_error)?
85 }
86 KeyValueAzureCosmosAuthOptions::Environmental => {
87 AuthorizationToken::from_token_credential(
88 azure_identity::create_default_credential()?,
89 )
90 }
91 };
92 let cosmos_client = cosmos_client(account, token)?;
93 let database_client = cosmos_client.database_client(database);
94 let client = database_client.collection_client(container);
95
96 Ok(Self { client, app_id })
97 }
98}
99
100fn cosmos_client(account: impl Into<String>, token: AuthorizationToken) -> Result<CosmosClient> {
101 if cfg!(feature = "connection-pooling") {
102 let client = reqwest::ClientBuilder::new()
103 .build()
104 .context("failed to build reqwest client")?;
105 let transport_options = azure_core::TransportOptions::new(std::sync::Arc::new(client));
106 Ok(CosmosClientBuilder::new(account, token)
107 .transport(transport_options)
108 .build())
109 } else {
110 Ok(CosmosClient::new(account, token))
111 }
112}
113
114#[async_trait]
115impl StoreManager for KeyValueAzureCosmos {
116 async fn get(&self, name: &str) -> Result<Arc<dyn Store>, Error> {
117 Ok(Arc::new(AzureCosmosStore {
118 client: self.client.clone(),
119 store_id: self.app_id.as_ref().map(|i| format!("{i}/{name}")),
120 }))
121 }
122
123 fn is_defined(&self, _store_name: &str) -> bool {
124 true
125 }
126
127 fn summary(&self, _store_name: &str) -> Option<String> {
128 let database = self.client.database_client().database_name();
129 let collection = self.client.collection_name();
130 Some(format!(
131 "Azure CosmosDB database: {database}, collection: {collection}"
132 ))
133 }
134}
135
136#[derive(Clone)]
137struct AzureCosmosStore {
138 client: CollectionClient,
139 store_id: Option<String>,
147}
148
149#[async_trait]
150impl Store for AzureCosmosStore {
151 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
152 let pair = self.get_entity::<Pair>(key).await?;
153 Ok(pair.map(|p| p.value))
154 }
155
156 async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
157 let illegal_chars = ['/', '\\', '?', '#'];
158
159 if key.contains(|c| illegal_chars.contains(&c)) {
160 return Err(Error::Other(format!(
161 "Key contains an illegal character. Keys must not include any of: {}",
162 illegal_chars.iter().collect::<String>()
163 )));
164 }
165
166 let pair = Pair {
167 id: key.to_string(),
168 value: value.to_vec(),
169 store_id: self.store_id.clone(),
170 };
171 self.client
172 .create_document(pair)
173 .is_upsert(true)
174 .await
175 .map_err(log_error)?;
176 Ok(())
177 }
178
179 async fn delete(&self, key: &str) -> Result<(), Error> {
180 let document_client = self
181 .client
182 .document_client(key, &self.store_id.clone().unwrap_or(key.to_string()))
183 .map_err(log_error)?;
184 if let Err(e) = document_client.delete_document().await {
185 if e.as_http_error().map(|e| e.status() != 404).unwrap_or(true) {
186 return Err(log_error(e));
187 }
188 }
189 Ok(())
190 }
191
192 async fn exists(&self, key: &str) -> Result<bool, Error> {
193 Ok(self.get_entity::<Key>(key).await?.is_some())
194 }
195
196 async fn get_keys(&self) -> Result<Vec<String>, Error> {
197 self.get_keys().await
198 }
199
200 async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
201 let stmt = Query::new(self.get_in_query(keys));
202 let query = self
203 .client
204 .query_documents(stmt)
205 .query_cross_partition(true);
206
207 let mut res = Vec::new();
208 let mut stream = query.into_stream::<Pair>();
209 while let Some(resp) = stream.next().await {
210 let resp = resp.map_err(log_error)?;
211 res.extend(
212 resp.results
213 .into_iter()
214 .map(|(pair, _)| (pair.id, Some(pair.value))),
215 );
216 }
217 Ok(res)
218 }
219
220 async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
221 for (key, value) in key_values {
222 self.set(key.as_ref(), &value).await?
223 }
224 Ok(())
225 }
226
227 async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
228 for key in keys {
229 self.delete(key.as_ref()).await?
230 }
231 Ok(())
232 }
233
234 async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
243 let operations = vec![Operation::incr("/value", delta).map_err(log_error)?];
244 match self
245 .client
246 .document_client(&key, &self.store_id.clone().unwrap_or(key.to_string()))
247 .map_err(log_error)?
248 .patch_document(operations)
249 .await
250 {
251 Err(e) => {
252 if e.as_http_error()
253 .map(|e| e.status() == 404)
254 .unwrap_or(false)
255 {
256 let counter = Counter {
257 id: key.clone(),
258 value: delta,
259 store_id: self.store_id.clone(),
260 };
261 if let Err(e) = self.client.create_document(counter).is_upsert(false).await {
262 if e.as_http_error()
263 .map(|e| e.status())
264 .unwrap_or(azure_core::StatusCode::Continue)
265 == 409
266 {
267 self.increment(key, delta).await?;
269 } else {
270 return Err(log_error(e));
271 }
272 }
273 Ok(delta)
274 } else {
275 Err(log_error(e))
276 }
277 }
278 Ok(_) => self
279 .get_entity::<Counter>(key.as_ref())
280 .await?
281 .map(|c| c.value)
282 .ok_or(Error::Other(
283 "increment returned an empty value after patching, which indicates a bug"
284 .to_string(),
285 )),
286 }
287 }
288
289 async fn new_compare_and_swap(
290 &self,
291 bucket_rep: u32,
292 key: &str,
293 ) -> Result<Arc<dyn spin_factor_key_value::Cas>, Error> {
294 Ok(Arc::new(CompareAndSwap {
295 key: key.to_string(),
296 client: self.client.clone(),
297 etag: Mutex::new(None),
298 bucket_rep,
299 store_id: self.store_id.clone(),
300 }))
301 }
302}
303
304struct CompareAndSwap {
305 key: String,
306 client: CollectionClient,
307 bucket_rep: u32,
308 etag: Mutex<Option<String>>,
309 store_id: Option<String>,
310}
311
312impl CompareAndSwap {
313 fn get_query(&self) -> String {
314 let mut query = format!("SELECT * FROM c WHERE c.id='{}'", self.key);
315 self.append_store_id(&mut query, true);
316 query
317 }
318
319 fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
320 append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
321 }
322}
323
324#[async_trait]
325impl Cas for CompareAndSwap {
326 async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
329 let mut stream = self
330 .client
331 .query_documents(Query::new(self.get_query()))
332 .query_cross_partition(true)
333 .max_item_count(1)
334 .into_stream::<Pair>();
335
336 let current_value: Option<(Vec<u8>, Option<String>)> = match stream.next().await {
337 Some(r) => {
338 let r = r.map_err(log_error)?;
339 match r.results.first() {
340 Some((item, Some(attr))) => {
341 Some((item.clone().value, Some(attr.etag().to_string())))
342 }
343 Some((item, None)) => Some((item.clone().value, None)),
344 _ => None,
345 }
346 }
347 None => None,
348 };
349
350 match current_value {
351 Some((value, etag)) => {
352 self.etag.lock().unwrap().clone_from(&etag);
353 Ok(Some(value))
354 }
355 None => Ok(None),
356 }
357 }
358
359 async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
362 let pair = Pair {
363 id: self.key.clone(),
364 value,
365 store_id: self.store_id.clone(),
366 };
367
368 let doc_client = self
369 .client
370 .document_client(&self.key, &pair.partition_key())
371 .map_err(log_cas_error)?;
372
373 let etag_value = self.etag.lock().unwrap().clone();
374 match etag_value {
375 Some(etag) => {
376 doc_client
378 .replace_document(pair)
379 .if_match_condition(azure_core::request_options::IfMatchCondition::Match(etag))
380 .await
381 .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
382 .map(drop)
383 }
384 None => {
385 self.client
387 .create_document(pair)
388 .await
389 .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
390 .map(drop)
391 }
392 }
393 }
394
395 async fn bucket_rep(&self) -> u32 {
396 self.bucket_rep
397 }
398
399 async fn key(&self) -> String {
400 self.key.clone()
401 }
402}
403
404impl AzureCosmosStore {
405 async fn get_entity<F>(&self, key: &str) -> Result<Option<F>, Error>
406 where
407 F: CosmosEntity + Send + Sync + serde::de::DeserializeOwned + Clone,
408 {
409 let query = self
410 .client
411 .query_documents(Query::new(self.get_query(key)))
412 .query_cross_partition(true)
413 .max_item_count(1);
414
415 let mut stream = query.into_stream::<F>();
417 let Some(res) = stream.next().await else {
418 return Ok(None);
419 };
420 Ok(res
421 .map_err(log_error)?
422 .results
423 .first()
424 .map(|(p, _)| p.clone()))
425 }
426
427 async fn get_keys(&self) -> Result<Vec<String>, Error> {
428 let query = self
429 .client
430 .query_documents(Query::new(self.get_keys_query()))
431 .query_cross_partition(true);
432 let mut res = Vec::new();
433
434 let mut stream = query.into_stream::<Key>();
435 while let Some(resp) = stream.next().await {
436 let resp = resp.map_err(log_error)?;
437 res.extend(resp.results.into_iter().map(|(key, _)| key.id));
438 }
439
440 Ok(res)
441 }
442
443 fn get_query(&self, key: &str) -> String {
444 let mut query = format!("SELECT * FROM c WHERE c.id='{}'", key);
445 self.append_store_id(&mut query, true);
446 query
447 }
448
449 fn get_keys_query(&self) -> String {
450 let mut query = "SELECT * FROM c".to_owned();
451 self.append_store_id(&mut query, false);
452 query
453 }
454
455 fn get_in_query(&self, keys: Vec<String>) -> String {
456 let in_clause: String = keys
457 .into_iter()
458 .map(|k| format!("'{k}'"))
459 .collect::<Vec<String>>()
460 .join(", ");
461
462 let mut query = format!("SELECT * FROM c WHERE c.id IN ({})", in_clause);
463 self.append_store_id(&mut query, true);
464 query
465 }
466
467 fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
468 append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
469 }
470}
471
472fn append_store_id_condition(
474 query: &mut String,
475 store_id: Option<&str>,
476 condition_already_exists: bool,
477) {
478 if let Some(s) = store_id {
479 if condition_already_exists {
480 query.push_str(" AND");
481 } else {
482 query.push_str(" WHERE");
483 }
484 query.push_str(" c.store_id='");
485 query.push_str(s);
486 query.push('\'')
487 }
488}
489
490#[derive(Serialize, Deserialize, Clone, Debug)]
492pub struct Pair {
493 pub id: String,
494 pub value: Vec<u8>,
495 #[serde(skip_serializing_if = "Option::is_none")]
496 pub store_id: Option<String>,
497}
498
499impl CosmosEntity for Pair {
500 type Entity = String;
501
502 fn partition_key(&self) -> Self::Entity {
503 self.store_id.clone().unwrap_or_else(|| self.id.clone())
504 }
505}
506
507#[derive(Serialize, Deserialize, Clone, Debug)]
509pub struct Counter {
510 pub id: String,
511 pub value: i64,
512 #[serde(skip_serializing_if = "Option::is_none")]
513 pub store_id: Option<String>,
514}
515
516impl CosmosEntity for Counter {
517 type Entity = String;
518
519 fn partition_key(&self) -> Self::Entity {
520 self.store_id.clone().unwrap_or_else(|| self.id.clone())
521 }
522}
523
524#[derive(Serialize, Deserialize, Clone, Debug)]
526pub struct Key {
527 pub id: String,
528 #[serde(skip_serializing_if = "Option::is_none")]
529 pub store_id: Option<String>,
530}
531
532impl CosmosEntity for Key {
533 type Entity = String;
534
535 fn partition_key(&self) -> Self::Entity {
536 self.store_id.clone().unwrap_or_else(|| self.id.clone())
537 }
538}