spin_key_value_azure/
store.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use azure_data_cosmos::{
4    prelude::{
5        AuthorizationToken, CollectionClient, CosmosClient, CosmosClientBuilder, Operation, Query,
6    },
7    CosmosEntity,
8};
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use spin_factor_key_value::{log_cas_error, log_error, Cas, Error, Store, StoreManager, SwapError};
12use std::sync::{Arc, Mutex};
13
14pub struct KeyValueAzureCosmos {
15    client: CollectionClient,
16    /// An optional app id
17    ///
18    /// If provided, the store will handle multiple stores per container using a
19    /// partition key of `/$app_id/$store_name`, otherwise there will be one container
20    /// per store, and the partition key will be `/id`.
21    app_id: Option<String>,
22}
23
24/// Azure Cosmos Key / Value runtime config literal options for authentication
25#[derive(Clone, Debug)]
26pub struct KeyValueAzureCosmosRuntimeConfigOptions {
27    key: String,
28}
29
30impl KeyValueAzureCosmosRuntimeConfigOptions {
31    pub fn new(key: String) -> Self {
32        Self { key }
33    }
34}
35
36/// Azure Cosmos Key / Value enumeration for the possible authentication options
37#[derive(Clone, Debug)]
38pub enum KeyValueAzureCosmosAuthOptions {
39    /// Runtime Config values indicates the account and key have been specified directly
40    RuntimeConfigValues(KeyValueAzureCosmosRuntimeConfigOptions),
41    /// Environmental indicates that the environment variables of the process should be used to
42    /// create the TokenCredential for the Cosmos client. This will use the Azure Rust SDK's
43    /// DefaultCredentialChain to derive the TokenCredential based on what environment variables
44    /// have been set.
45    ///
46    /// Service Principal with client secret:
47    /// - `AZURE_TENANT_ID`: ID of the service principal's Azure tenant.
48    /// - `AZURE_CLIENT_ID`: the service principal's client ID.
49    /// - `AZURE_CLIENT_SECRET`: one of the service principal's secrets.
50    ///
51    /// Service Principal with certificate:
52    /// - `AZURE_TENANT_ID`: ID of the service principal's Azure tenant.
53    /// - `AZURE_CLIENT_ID`: the service principal's client ID.
54    /// - `AZURE_CLIENT_CERTIFICATE_PATH`: path to a PEM or PKCS12 certificate file including the private key.
55    /// - `AZURE_CLIENT_CERTIFICATE_PASSWORD`: (optional) password for the certificate file.
56    ///
57    /// Workload Identity (Kubernetes, injected by the Workload Identity mutating webhook):
58    /// - `AZURE_TENANT_ID`: ID of the service principal's Azure tenant.
59    /// - `AZURE_CLIENT_ID`: the service principal's client ID.
60    /// - `AZURE_FEDERATED_TOKEN_FILE`: TokenFilePath is the path of a file containing a Kubernetes service account token.
61    ///
62    /// Managed Identity (User Assigned or System Assigned identities):
63    /// - `AZURE_CLIENT_ID`: (optional) if using a user assigned identity, this will be the client ID of the identity.
64    ///
65    /// Azure CLI:
66    /// - `AZURE_TENANT_ID`: (optional) use a specific tenant via the Azure CLI.
67    ///
68    /// Common across each:
69    /// - `AZURE_AUTHORITY_HOST`: (optional) the host for the identity provider. For example, for Azure public cloud the host defaults to "https://login.microsoftonline.com".
70    ///   See also: https://github.com/Azure/azure-sdk-for-rust/blob/main/sdk/identity/README.md
71    Environmental,
72}
73
74impl KeyValueAzureCosmos {
75    pub fn new(
76        account: String,
77        database: String,
78        container: String,
79        auth_options: KeyValueAzureCosmosAuthOptions,
80        app_id: Option<String>,
81    ) -> Result<Self> {
82        let token = match auth_options {
83            KeyValueAzureCosmosAuthOptions::RuntimeConfigValues(config) => {
84                AuthorizationToken::primary_key(config.key).map_err(log_error)?
85            }
86            KeyValueAzureCosmosAuthOptions::Environmental => {
87                AuthorizationToken::from_token_credential(
88                    azure_identity::create_default_credential()?,
89                )
90            }
91        };
92        let cosmos_client = cosmos_client(account, token)?;
93        let database_client = cosmos_client.database_client(database);
94        let client = database_client.collection_client(container);
95
96        Ok(Self { client, app_id })
97    }
98}
99
100fn cosmos_client(account: impl Into<String>, token: AuthorizationToken) -> Result<CosmosClient> {
101    if cfg!(feature = "connection-pooling") {
102        let client = reqwest::ClientBuilder::new()
103            .build()
104            .context("failed to build reqwest client")?;
105        let transport_options = azure_core::TransportOptions::new(std::sync::Arc::new(client));
106        Ok(CosmosClientBuilder::new(account, token)
107            .transport(transport_options)
108            .build())
109    } else {
110        Ok(CosmosClient::new(account, token))
111    }
112}
113
114#[async_trait]
115impl StoreManager for KeyValueAzureCosmos {
116    async fn get(&self, name: &str) -> Result<Arc<dyn Store>, Error> {
117        Ok(Arc::new(AzureCosmosStore {
118            client: self.client.clone(),
119            store_id: self.app_id.as_ref().map(|i| format!("{i}/{name}")),
120        }))
121    }
122
123    fn is_defined(&self, _store_name: &str) -> bool {
124        true
125    }
126
127    fn summary(&self, _store_name: &str) -> Option<String> {
128        let database = self.client.database_client().database_name();
129        let collection = self.client.collection_name();
130        Some(format!(
131            "Azure CosmosDB database: {database}, collection: {collection}"
132        ))
133    }
134}
135
136#[derive(Clone)]
137struct AzureCosmosStore {
138    client: CollectionClient,
139    /// An optional store id to use as a partition key for all operations.
140    ///
141    /// If the store ID is not set, the store will use `/id` (the row key) as
142    /// the partition key. For example, if `store.set("my_key", "my_value")` is
143    /// called, the partition key will be `my_key` if the store ID is set to
144    /// `None`. If the store ID is set to `Some("myappid/default"), the
145    /// partition key will be `myappid/default`.
146    store_id: Option<String>,
147}
148
149#[async_trait]
150impl Store for AzureCosmosStore {
151    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
152        let pair = self.get_entity::<Pair>(key).await?;
153        Ok(pair.map(|p| p.value))
154    }
155
156    async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
157        let illegal_chars = ['/', '\\', '?', '#'];
158
159        if key.contains(|c| illegal_chars.contains(&c)) {
160            return Err(Error::Other(format!(
161                "Key contains an illegal character. Keys must not include any of: {}",
162                illegal_chars.iter().collect::<String>()
163            )));
164        }
165
166        let pair = Pair {
167            id: key.to_string(),
168            value: value.to_vec(),
169            store_id: self.store_id.clone(),
170        };
171        self.client
172            .create_document(pair)
173            .is_upsert(true)
174            .await
175            .map_err(log_error)?;
176        Ok(())
177    }
178
179    async fn delete(&self, key: &str) -> Result<(), Error> {
180        let document_client = self
181            .client
182            .document_client(key, &self.store_id.clone().unwrap_or(key.to_string()))
183            .map_err(log_error)?;
184        if let Err(e) = document_client.delete_document().await {
185            if e.as_http_error().map(|e| e.status() != 404).unwrap_or(true) {
186                return Err(log_error(e));
187            }
188        }
189        Ok(())
190    }
191
192    async fn exists(&self, key: &str) -> Result<bool, Error> {
193        Ok(self.get_entity::<Key>(key).await?.is_some())
194    }
195
196    async fn get_keys(&self) -> Result<Vec<String>, Error> {
197        self.get_keys().await
198    }
199
200    async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
201        let stmt = Query::new(self.get_in_query(keys));
202        let query = self
203            .client
204            .query_documents(stmt)
205            .query_cross_partition(true);
206
207        let mut res = Vec::new();
208        let mut stream = query.into_stream::<Pair>();
209        while let Some(resp) = stream.next().await {
210            let resp = resp.map_err(log_error)?;
211            res.extend(
212                resp.results
213                    .into_iter()
214                    .map(|(pair, _)| (pair.id, Some(pair.value))),
215            );
216        }
217        Ok(res)
218    }
219
220    async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
221        for (key, value) in key_values {
222            self.set(key.as_ref(), &value).await?
223        }
224        Ok(())
225    }
226
227    async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
228        for key in keys {
229            self.delete(key.as_ref()).await?
230        }
231        Ok(())
232    }
233
234    /// Increments a numerical value.
235    ///
236    /// The initial value for the item must be set through this interface, as this sets the
237    /// number value if it does not exist. If the value was previously set using
238    /// the `set` interface, this will fail due to a type mismatch.
239    // TODO: The function should parse the new value from the return response
240    // rather than sending an additional new request. However, the current SDK
241    // version does not support this.
242    async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
243        let operations = vec![Operation::incr("/value", delta).map_err(log_error)?];
244        match self
245            .client
246            .document_client(&key, &self.store_id.clone().unwrap_or(key.to_string()))
247            .map_err(log_error)?
248            .patch_document(operations)
249            .await
250        {
251            Err(e) => {
252                if e.as_http_error()
253                    .map(|e| e.status() == 404)
254                    .unwrap_or(false)
255                {
256                    let counter = Counter {
257                        id: key.clone(),
258                        value: delta,
259                        store_id: self.store_id.clone(),
260                    };
261                    if let Err(e) = self.client.create_document(counter).is_upsert(false).await {
262                        if e.as_http_error()
263                            .map(|e| e.status())
264                            .unwrap_or(azure_core::StatusCode::Continue)
265                            == 409
266                        {
267                            // Conflict trying to create counter, retry increment
268                            self.increment(key, delta).await?;
269                        } else {
270                            return Err(log_error(e));
271                        }
272                    }
273                    Ok(delta)
274                } else {
275                    Err(log_error(e))
276                }
277            }
278            Ok(_) => self
279                .get_entity::<Counter>(key.as_ref())
280                .await?
281                .map(|c| c.value)
282                .ok_or(Error::Other(
283                    "increment returned an empty value after patching, which indicates a bug"
284                        .to_string(),
285                )),
286        }
287    }
288
289    async fn new_compare_and_swap(
290        &self,
291        bucket_rep: u32,
292        key: &str,
293    ) -> Result<Arc<dyn spin_factor_key_value::Cas>, Error> {
294        Ok(Arc::new(CompareAndSwap {
295            key: key.to_string(),
296            client: self.client.clone(),
297            etag: Mutex::new(None),
298            bucket_rep,
299            store_id: self.store_id.clone(),
300        }))
301    }
302}
303
304struct CompareAndSwap {
305    key: String,
306    client: CollectionClient,
307    bucket_rep: u32,
308    etag: Mutex<Option<String>>,
309    store_id: Option<String>,
310}
311
312impl CompareAndSwap {
313    fn get_query(&self) -> String {
314        let mut query = format!("SELECT * FROM c WHERE c.id='{}'", self.key);
315        self.append_store_id(&mut query, true);
316        query
317    }
318
319    fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
320        append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
321    }
322}
323
324#[async_trait]
325impl Cas for CompareAndSwap {
326    /// `current` will fetch the current value for the key and store the etag for the record. The
327    /// etag will be used to perform and optimistic concurrency update using the `if-match` header.
328    async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
329        let mut stream = self
330            .client
331            .query_documents(Query::new(self.get_query()))
332            .query_cross_partition(true)
333            .max_item_count(1)
334            .into_stream::<Pair>();
335
336        let current_value: Option<(Vec<u8>, Option<String>)> = match stream.next().await {
337            Some(r) => {
338                let r = r.map_err(log_error)?;
339                match r.results.first() {
340                    Some((item, Some(attr))) => {
341                        Some((item.clone().value, Some(attr.etag().to_string())))
342                    }
343                    Some((item, None)) => Some((item.clone().value, None)),
344                    _ => None,
345                }
346            }
347            None => None,
348        };
349
350        match current_value {
351            Some((value, etag)) => {
352                self.etag.lock().unwrap().clone_from(&etag);
353                Ok(Some(value))
354            }
355            None => Ok(None),
356        }
357    }
358
359    /// `swap` updates the value for the key using the etag saved in the `current` function for
360    /// optimistic concurrency.
361    async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
362        let pair = Pair {
363            id: self.key.clone(),
364            value,
365            store_id: self.store_id.clone(),
366        };
367
368        let doc_client = self
369            .client
370            .document_client(&self.key, &pair.partition_key())
371            .map_err(log_cas_error)?;
372
373        let etag_value = self.etag.lock().unwrap().clone();
374        match etag_value {
375            Some(etag) => {
376                // attempt to replace the document if the etag matches
377                doc_client
378                    .replace_document(pair)
379                    .if_match_condition(azure_core::request_options::IfMatchCondition::Match(etag))
380                    .await
381                    .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
382                    .map(drop)
383            }
384            None => {
385                // if we have no etag, then we assume the document does not yet exist and must insert; no upserts.
386                self.client
387                    .create_document(pair)
388                    .await
389                    .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
390                    .map(drop)
391            }
392        }
393    }
394
395    async fn bucket_rep(&self) -> u32 {
396        self.bucket_rep
397    }
398
399    async fn key(&self) -> String {
400        self.key.clone()
401    }
402}
403
404impl AzureCosmosStore {
405    async fn get_entity<F>(&self, key: &str) -> Result<Option<F>, Error>
406    where
407        F: CosmosEntity + Send + Sync + serde::de::DeserializeOwned + Clone,
408    {
409        let query = self
410            .client
411            .query_documents(Query::new(self.get_query(key)))
412            .query_cross_partition(true)
413            .max_item_count(1);
414
415        // There can be no duplicated keys, so we create the stream and only take the first result.
416        let mut stream = query.into_stream::<F>();
417        let Some(res) = stream.next().await else {
418            return Ok(None);
419        };
420        Ok(res
421            .map_err(log_error)?
422            .results
423            .first()
424            .map(|(p, _)| p.clone()))
425    }
426
427    async fn get_keys(&self) -> Result<Vec<String>, Error> {
428        let query = self
429            .client
430            .query_documents(Query::new(self.get_keys_query()))
431            .query_cross_partition(true);
432        let mut res = Vec::new();
433
434        let mut stream = query.into_stream::<Key>();
435        while let Some(resp) = stream.next().await {
436            let resp = resp.map_err(log_error)?;
437            res.extend(resp.results.into_iter().map(|(key, _)| key.id));
438        }
439
440        Ok(res)
441    }
442
443    fn get_query(&self, key: &str) -> String {
444        let mut query = format!("SELECT * FROM c WHERE c.id='{}'", key);
445        self.append_store_id(&mut query, true);
446        query
447    }
448
449    fn get_keys_query(&self) -> String {
450        let mut query = "SELECT * FROM c".to_owned();
451        self.append_store_id(&mut query, false);
452        query
453    }
454
455    fn get_in_query(&self, keys: Vec<String>) -> String {
456        let in_clause: String = keys
457            .into_iter()
458            .map(|k| format!("'{k}'"))
459            .collect::<Vec<String>>()
460            .join(", ");
461
462        let mut query = format!("SELECT * FROM c WHERE c.id IN ({})", in_clause);
463        self.append_store_id(&mut query, true);
464        query
465    }
466
467    fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
468        append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
469    }
470}
471
472/// Appends an option store id condition to the query.
473fn append_store_id_condition(
474    query: &mut String,
475    store_id: Option<&str>,
476    condition_already_exists: bool,
477) {
478    if let Some(s) = store_id {
479        if condition_already_exists {
480            query.push_str(" AND");
481        } else {
482            query.push_str(" WHERE");
483        }
484        query.push_str(" c.store_id='");
485        query.push_str(s);
486        query.push('\'')
487    }
488}
489
490// Pair structure for key value operations
491#[derive(Serialize, Deserialize, Clone, Debug)]
492pub struct Pair {
493    pub id: String,
494    pub value: Vec<u8>,
495    #[serde(skip_serializing_if = "Option::is_none")]
496    pub store_id: Option<String>,
497}
498
499impl CosmosEntity for Pair {
500    type Entity = String;
501
502    fn partition_key(&self) -> Self::Entity {
503        self.store_id.clone().unwrap_or_else(|| self.id.clone())
504    }
505}
506
507// Counter structure for increment operations
508#[derive(Serialize, Deserialize, Clone, Debug)]
509pub struct Counter {
510    pub id: String,
511    pub value: i64,
512    #[serde(skip_serializing_if = "Option::is_none")]
513    pub store_id: Option<String>,
514}
515
516impl CosmosEntity for Counter {
517    type Entity = String;
518
519    fn partition_key(&self) -> Self::Entity {
520        self.store_id.clone().unwrap_or_else(|| self.id.clone())
521    }
522}
523
524// Key structure for operations with generic value types
525#[derive(Serialize, Deserialize, Clone, Debug)]
526pub struct Key {
527    pub id: String,
528    #[serde(skip_serializing_if = "Option::is_none")]
529    pub store_id: Option<String>,
530}
531
532impl CosmosEntity for Key {
533    type Entity = String;
534
535    fn partition_key(&self) -> Self::Entity {
536        self.store_id.clone().unwrap_or_else(|| self.id.clone())
537    }
538}