spin_key_value_azure/
store.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use azure_data_cosmos::{
4    prelude::{
5        AuthorizationToken, CollectionClient, CosmosClient, CosmosClientBuilder, Operation, Query,
6    },
7    CosmosEntity,
8};
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use spin_factor_key_value::{log_cas_error, log_error, Cas, Error, Store, StoreManager, SwapError};
12use std::sync::{Arc, Mutex};
13
14pub struct KeyValueAzureCosmos {
15    client: CollectionClient,
16    /// An optional app id
17    ///
18    /// If provided, the store will handle multiple stores per container using a
19    /// partition key of `/$app_id/$store_name`, otherwise there will be one container
20    /// per store, and the partition key will be `/id`.
21    app_id: Option<String>,
22}
23
24/// Azure Cosmos Key / Value runtime config literal options for authentication
25#[derive(Clone, Debug)]
26pub struct KeyValueAzureCosmosRuntimeConfigOptions {
27    key: String,
28}
29
30impl KeyValueAzureCosmosRuntimeConfigOptions {
31    pub fn new(key: String) -> Self {
32        Self { key }
33    }
34}
35
36/// Azure Cosmos Key / Value enumeration for the possible authentication options
37#[derive(Clone, Debug)]
38pub enum KeyValueAzureCosmosAuthOptions {
39    /// Runtime Config values indicates the account and key have been specified directly
40    RuntimeConfigValues(KeyValueAzureCosmosRuntimeConfigOptions),
41    /// Environmental indicates that the environment variables of the process should be used to
42    /// create the TokenCredential for the Cosmos client. This will use the Azure Rust SDK's
43    /// DefaultCredentialChain to derive the TokenCredential based on what environment variables
44    /// have been set.
45    ///
46    /// Service Principal with client secret:
47    /// - `AZURE_TENANT_ID`: ID of the service principal's Azure tenant.
48    /// - `AZURE_CLIENT_ID`: the service principal's client ID.
49    /// - `AZURE_CLIENT_SECRET`: one of the service principal's secrets.
50    ///
51    /// Service Principal with certificate:
52    /// - `AZURE_TENANT_ID`: ID of the service principal's Azure tenant.
53    /// - `AZURE_CLIENT_ID`: the service principal's client ID.
54    /// - `AZURE_CLIENT_CERTIFICATE_PATH`: path to a PEM or PKCS12 certificate file including the private key.
55    /// - `AZURE_CLIENT_CERTIFICATE_PASSWORD`: (optional) password for the certificate file.
56    ///
57    /// Workload Identity (Kubernetes, injected by the Workload Identity mutating webhook):
58    /// - `AZURE_TENANT_ID`: ID of the service principal's Azure tenant.
59    /// - `AZURE_CLIENT_ID`: the service principal's client ID.
60    /// - `AZURE_FEDERATED_TOKEN_FILE`: TokenFilePath is the path of a file containing a Kubernetes service account token.
61    ///
62    /// Managed Identity (User Assigned or System Assigned identities):
63    /// - `AZURE_CLIENT_ID`: (optional) if using a user assigned identity, this will be the client ID of the identity.
64    ///
65    /// Azure CLI:
66    /// - `AZURE_TENANT_ID`: (optional) use a specific tenant via the Azure CLI.
67    ///
68    /// Common across each:
69    /// - `AZURE_AUTHORITY_HOST`: (optional) the host for the identity provider. For example, for Azure public cloud the host defaults to "https://login.microsoftonline.com".
70    ///   See also: https://github.com/Azure/azure-sdk-for-rust/blob/main/sdk/identity/README.md
71    Environmental,
72}
73
74impl KeyValueAzureCosmos {
75    pub fn new(
76        account: String,
77        database: String,
78        container: String,
79        auth_options: KeyValueAzureCosmosAuthOptions,
80        app_id: Option<String>,
81    ) -> Result<Self> {
82        let token = match auth_options {
83            KeyValueAzureCosmosAuthOptions::RuntimeConfigValues(config) => {
84                AuthorizationToken::primary_key(config.key).map_err(log_error)?
85            }
86            KeyValueAzureCosmosAuthOptions::Environmental => {
87                AuthorizationToken::from_token_credential(
88                    azure_identity::create_default_credential()?,
89                )
90            }
91        };
92        let cosmos_client = cosmos_client(account, token)?;
93        let database_client = cosmos_client.database_client(database);
94        let client = database_client.collection_client(container);
95
96        Ok(Self { client, app_id })
97    }
98}
99
100fn cosmos_client(account: impl Into<String>, token: AuthorizationToken) -> Result<CosmosClient> {
101    if cfg!(feature = "connection-pooling") {
102        let client = reqwest::ClientBuilder::new()
103            .build()
104            .context("failed to build reqwest client")?;
105        let transport_options = azure_core::TransportOptions::new(std::sync::Arc::new(client));
106        Ok(CosmosClientBuilder::new(account, token)
107            .transport(transport_options)
108            .build())
109    } else {
110        Ok(CosmosClient::new(account, token))
111    }
112}
113
114#[async_trait]
115impl StoreManager for KeyValueAzureCosmos {
116    async fn get(&self, name: &str) -> Result<Arc<dyn Store>, Error> {
117        Ok(Arc::new(AzureCosmosStore {
118            client: self.client.clone(),
119            store_id: self.app_id.as_ref().map(|i| format!("{i}/{name}")),
120        }))
121    }
122
123    fn is_defined(&self, _store_name: &str) -> bool {
124        true
125    }
126
127    fn summary(&self, _store_name: &str) -> Option<String> {
128        let database = self.client.database_client().database_name();
129        let collection = self.client.collection_name();
130        Some(format!(
131            "Azure CosmosDB database: {database}, collection: {collection}"
132        ))
133    }
134}
135
136#[derive(Clone)]
137struct AzureCosmosStore {
138    client: CollectionClient,
139    /// An optional store id to use as a partition key for all operations.
140    ///
141    /// If the store ID is not set, the store will use `/id` (the row key) as
142    /// the partition key. For example, if `store.set("my_key", "my_value")` is
143    /// called, the partition key will be `my_key` if the store ID is set to
144    /// `None`. If the store ID is set to `Some("myappid/default"), the
145    /// partition key will be `myappid/default`.
146    store_id: Option<String>,
147}
148
149#[async_trait]
150impl Store for AzureCosmosStore {
151    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
152        let pair = self.get_entity::<Pair>(key).await?;
153        Ok(pair.map(|p| p.value))
154    }
155
156    async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
157        let illegal_chars = ['/', '\\', '?', '#'];
158
159        if key.contains(|c| illegal_chars.contains(&c)) {
160            return Err(Error::Other(format!(
161                "Key contains an illegal character. Keys must not include any of: {}",
162                illegal_chars.iter().collect::<String>()
163            )));
164        }
165
166        let pair = Pair {
167            id: key.to_string(),
168            value: value.to_vec(),
169            store_id: self.store_id.clone(),
170        };
171        self.client
172            .create_document(pair)
173            .is_upsert(true)
174            .await
175            .map_err(log_error)?;
176        Ok(())
177    }
178
179    async fn delete(&self, key: &str) -> Result<(), Error> {
180        let document_client = self
181            .client
182            .document_client(key, &self.store_id.clone().unwrap_or(key.to_string()))
183            .map_err(log_error)?;
184        if let Err(e) = document_client.delete_document().await {
185            if e.as_http_error().map(|e| e.status() != 404).unwrap_or(true) {
186                return Err(log_error(e));
187            }
188        }
189        Ok(())
190    }
191
192    async fn exists(&self, key: &str) -> Result<bool, Error> {
193        let mut stream = self
194            .client
195            .query_documents(Query::new(self.get_id_query(key)))
196            .query_cross_partition(true)
197            .max_item_count(1)
198            .into_stream::<Key>();
199
200        match stream.next().await {
201            Some(Ok(res)) => Ok(!res.results.is_empty()),
202            Some(Err(e)) => Err(log_error(e)),
203            None => Ok(false),
204        }
205    }
206
207    async fn get_keys(&self) -> Result<Vec<String>, Error> {
208        self.get_keys().await
209    }
210
211    async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
212        let stmt = Query::new(self.get_in_query(keys));
213        let query = self
214            .client
215            .query_documents(stmt)
216            .query_cross_partition(true);
217
218        let mut res = Vec::new();
219        let mut stream = query.into_stream::<Pair>();
220        while let Some(resp) = stream.next().await {
221            let resp = resp.map_err(log_error)?;
222            res.extend(
223                resp.results
224                    .into_iter()
225                    .map(|(pair, _)| (pair.id, Some(pair.value))),
226            );
227        }
228        Ok(res)
229    }
230
231    async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
232        for (key, value) in key_values {
233            self.set(key.as_ref(), &value).await?
234        }
235        Ok(())
236    }
237
238    async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
239        for key in keys {
240            self.delete(key.as_ref()).await?
241        }
242        Ok(())
243    }
244
245    /// Increments a numerical value.
246    ///
247    /// The initial value for the item must be set through this interface, as this sets the
248    /// number value if it does not exist. If the value was previously set using
249    /// the `set` interface, this will fail due to a type mismatch.
250    // TODO: The function should parse the new value from the return response
251    // rather than sending an additional new request. However, the current SDK
252    // version does not support this.
253    async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
254        let operations = vec![Operation::incr("/value", delta).map_err(log_error)?];
255        match self
256            .client
257            .document_client(&key, &self.store_id.clone().unwrap_or(key.to_string()))
258            .map_err(log_error)?
259            .patch_document(operations)
260            .await
261        {
262            Err(e) => {
263                if e.as_http_error()
264                    .map(|e| e.status() == 404)
265                    .unwrap_or(false)
266                {
267                    let counter = Counter {
268                        id: key.clone(),
269                        value: delta,
270                        store_id: self.store_id.clone(),
271                    };
272                    if let Err(e) = self.client.create_document(counter).is_upsert(false).await {
273                        if e.as_http_error()
274                            .map(|e| e.status())
275                            .unwrap_or(azure_core::StatusCode::Continue)
276                            == 409
277                        {
278                            // Conflict trying to create counter, retry increment
279                            self.increment(key, delta).await?;
280                        } else {
281                            return Err(log_error(e));
282                        }
283                    }
284                    Ok(delta)
285                } else {
286                    Err(log_error(e))
287                }
288            }
289            Ok(_) => self
290                .get_entity::<Counter>(key.as_ref())
291                .await?
292                .map(|c| c.value)
293                .ok_or(Error::Other(
294                    "increment returned an empty value after patching, which indicates a bug"
295                        .to_string(),
296                )),
297        }
298    }
299
300    async fn new_compare_and_swap(
301        &self,
302        bucket_rep: u32,
303        key: &str,
304    ) -> Result<Arc<dyn spin_factor_key_value::Cas>, Error> {
305        Ok(Arc::new(CompareAndSwap {
306            key: key.to_string(),
307            client: self.client.clone(),
308            etag: Mutex::new(None),
309            bucket_rep,
310            store_id: self.store_id.clone(),
311        }))
312    }
313}
314
315struct CompareAndSwap {
316    key: String,
317    client: CollectionClient,
318    bucket_rep: u32,
319    etag: Mutex<Option<String>>,
320    store_id: Option<String>,
321}
322
323impl CompareAndSwap {
324    fn get_query(&self) -> String {
325        let mut query = format!("SELECT * FROM c WHERE c.id='{}'", self.key);
326        self.append_store_id(&mut query, true);
327        query
328    }
329
330    fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
331        append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
332    }
333}
334
335#[async_trait]
336impl Cas for CompareAndSwap {
337    /// `current` will fetch the current value for the key and store the etag for the record. The
338    /// etag will be used to perform and optimistic concurrency update using the `if-match` header.
339    async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
340        let mut stream = self
341            .client
342            .query_documents(Query::new(self.get_query()))
343            .query_cross_partition(true)
344            .max_item_count(1)
345            .into_stream::<Pair>();
346
347        let current_value: Option<(Vec<u8>, Option<String>)> = match stream.next().await {
348            Some(r) => {
349                let r = r.map_err(log_error)?;
350                match r.results.first() {
351                    Some((item, Some(attr))) => {
352                        Some((item.clone().value, Some(attr.etag().to_string())))
353                    }
354                    Some((item, None)) => Some((item.clone().value, None)),
355                    _ => None,
356                }
357            }
358            None => None,
359        };
360
361        match current_value {
362            Some((value, etag)) => {
363                self.etag.lock().unwrap().clone_from(&etag);
364                Ok(Some(value))
365            }
366            None => Ok(None),
367        }
368    }
369
370    /// `swap` updates the value for the key using the etag saved in the `current` function for
371    /// optimistic concurrency.
372    async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
373        let pair = Pair {
374            id: self.key.clone(),
375            value,
376            store_id: self.store_id.clone(),
377        };
378
379        let doc_client = self
380            .client
381            .document_client(&self.key, &pair.partition_key())
382            .map_err(log_cas_error)?;
383
384        let etag_value = self.etag.lock().unwrap().clone();
385        match etag_value {
386            Some(etag) => {
387                // attempt to replace the document if the etag matches
388                doc_client
389                    .replace_document(pair)
390                    .if_match_condition(azure_core::request_options::IfMatchCondition::Match(etag))
391                    .await
392                    .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
393                    .map(drop)
394            }
395            None => {
396                // if we have no etag, then we assume the document does not yet exist and must insert; no upserts.
397                self.client
398                    .create_document(pair)
399                    .await
400                    .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
401                    .map(drop)
402            }
403        }
404    }
405
406    async fn bucket_rep(&self) -> u32 {
407        self.bucket_rep
408    }
409
410    async fn key(&self) -> String {
411        self.key.clone()
412    }
413}
414
415impl AzureCosmosStore {
416    async fn get_entity<F>(&self, key: &str) -> Result<Option<F>, Error>
417    where
418        F: CosmosEntity + Send + Sync + serde::de::DeserializeOwned + Clone,
419    {
420        let query = self
421            .client
422            .query_documents(Query::new(self.get_query(key)))
423            .query_cross_partition(true)
424            .max_item_count(1);
425
426        // There can be no duplicated keys, so we create the stream and only take the first result.
427        let mut stream = query.into_stream::<F>();
428        let Some(res) = stream.next().await else {
429            return Ok(None);
430        };
431        Ok(res
432            .map_err(log_error)?
433            .results
434            .first()
435            .map(|(p, _)| p.clone()))
436    }
437
438    async fn get_keys(&self) -> Result<Vec<String>, Error> {
439        let query = self
440            .client
441            .query_documents(Query::new(self.get_keys_query()))
442            .query_cross_partition(true);
443        let mut res = Vec::new();
444
445        let mut stream = query.into_stream::<Key>();
446        while let Some(resp) = stream.next().await {
447            let resp = resp.map_err(log_error)?;
448            res.extend(resp.results.into_iter().map(|(key, _)| key.id));
449        }
450
451        Ok(res)
452    }
453
454    fn get_query(&self, key: &str) -> String {
455        let mut query = format!("SELECT * FROM c WHERE c.id='{key}'");
456        self.append_store_id(&mut query, true);
457        query
458    }
459
460    fn get_id_query(&self, key: &str) -> String {
461        let mut query = format!("SELECT c.id, c.store_id FROM c WHERE c.id='{key}'");
462        self.append_store_id(&mut query, true);
463        query
464    }
465
466    fn get_keys_query(&self) -> String {
467        let mut query = "SELECT c.id, c.store_id FROM c".to_owned();
468        self.append_store_id(&mut query, false);
469        query
470    }
471
472    fn get_in_query(&self, keys: Vec<String>) -> String {
473        let in_clause: String = keys
474            .into_iter()
475            .map(|k| format!("'{k}'"))
476            .collect::<Vec<String>>()
477            .join(", ");
478
479        let mut query = format!("SELECT * FROM c WHERE c.id IN ({in_clause})");
480        self.append_store_id(&mut query, true);
481        query
482    }
483
484    fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
485        append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
486    }
487}
488
489/// Appends an option store id condition to the query.
490fn append_store_id_condition(
491    query: &mut String,
492    store_id: Option<&str>,
493    condition_already_exists: bool,
494) {
495    if let Some(s) = store_id {
496        if condition_already_exists {
497            query.push_str(" AND");
498        } else {
499            query.push_str(" WHERE");
500        }
501        query.push_str(" c.store_id='");
502        query.push_str(s);
503        query.push('\'')
504    }
505}
506
507// Pair structure for key value operations
508#[derive(Serialize, Deserialize, Clone, Debug)]
509pub struct Pair {
510    pub id: String,
511    pub value: Vec<u8>,
512    #[serde(skip_serializing_if = "Option::is_none")]
513    pub store_id: Option<String>,
514}
515
516impl CosmosEntity for Pair {
517    type Entity = String;
518
519    fn partition_key(&self) -> Self::Entity {
520        self.store_id.clone().unwrap_or_else(|| self.id.clone())
521    }
522}
523
524// Counter structure for increment operations
525#[derive(Serialize, Deserialize, Clone, Debug)]
526pub struct Counter {
527    pub id: String,
528    pub value: i64,
529    #[serde(skip_serializing_if = "Option::is_none")]
530    pub store_id: Option<String>,
531}
532
533impl CosmosEntity for Counter {
534    type Entity = String;
535
536    fn partition_key(&self) -> Self::Entity {
537        self.store_id.clone().unwrap_or_else(|| self.id.clone())
538    }
539}
540
541// Key structure for operations with generic value types
542#[derive(Serialize, Deserialize, Clone, Debug)]
543pub struct Key {
544    pub id: String,
545    #[serde(skip_serializing_if = "Option::is_none")]
546    pub store_id: Option<String>,
547}
548
549impl CosmosEntity for Key {
550    type Entity = String;
551
552    fn partition_key(&self) -> Self::Entity {
553        self.store_id.clone().unwrap_or_else(|| self.id.clone())
554    }
555}