Skip to main content

spin_key_value_azure/
store.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use azure_data_cosmos::{
4    prelude::{
5        AuthorizationToken, CollectionClient, CosmosClient, CosmosClientBuilder, Operation, Query,
6    },
7    CosmosEntity,
8};
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use spin_factor_key_value::{
12    log_cas_error, log_error, log_error_v3, v3, Cas, Error, Store, StoreManager, SwapError,
13};
14use std::sync::{Arc, Mutex};
15
16pub struct KeyValueAzureCosmos {
17    client: CollectionClient,
18    /// An optional app id
19    ///
20    /// If provided, the store will handle multiple stores per container using a
21    /// partition key of `/$app_id/$store_name`, otherwise there will be one container
22    /// per store, and the partition key will be `/id`.
23    app_id: Option<String>,
24}
25
26/// Azure Cosmos Key / Value runtime config literal options for authentication
27#[derive(Clone, Debug)]
28pub struct KeyValueAzureCosmosRuntimeConfigOptions {
29    key: String,
30}
31
32impl KeyValueAzureCosmosRuntimeConfigOptions {
33    pub fn new(key: String) -> Self {
34        Self { key }
35    }
36}
37
38/// Azure Cosmos Key / Value enumeration for the possible authentication options
39#[derive(Clone, Debug)]
40pub enum KeyValueAzureCosmosAuthOptions {
41    /// Runtime Config values indicates the account and key have been specified directly
42    RuntimeConfigValues(KeyValueAzureCosmosRuntimeConfigOptions),
43    /// Environmental indicates that the environment variables of the process should be used to
44    /// create the TokenCredential for the Cosmos client. This will use the Azure Rust SDK's
45    /// DefaultCredentialChain to derive the TokenCredential based on what environment variables
46    /// have been set.
47    ///
48    /// Service Principal with client secret:
49    /// - `AZURE_TENANT_ID`: ID of the service principal's Azure tenant.
50    /// - `AZURE_CLIENT_ID`: the service principal's client ID.
51    /// - `AZURE_CLIENT_SECRET`: one of the service principal's secrets.
52    ///
53    /// Service Principal with certificate:
54    /// - `AZURE_TENANT_ID`: ID of the service principal's Azure tenant.
55    /// - `AZURE_CLIENT_ID`: the service principal's client ID.
56    /// - `AZURE_CLIENT_CERTIFICATE_PATH`: path to a PEM or PKCS12 certificate file including the private key.
57    /// - `AZURE_CLIENT_CERTIFICATE_PASSWORD`: (optional) password for the certificate file.
58    ///
59    /// Workload Identity (Kubernetes, injected by the Workload Identity mutating webhook):
60    /// - `AZURE_TENANT_ID`: ID of the service principal's Azure tenant.
61    /// - `AZURE_CLIENT_ID`: the service principal's client ID.
62    /// - `AZURE_FEDERATED_TOKEN_FILE`: TokenFilePath is the path of a file containing a Kubernetes service account token.
63    ///
64    /// Managed Identity (User Assigned or System Assigned identities):
65    /// - `AZURE_CLIENT_ID`: (optional) if using a user assigned identity, this will be the client ID of the identity.
66    ///
67    /// Azure CLI:
68    /// - `AZURE_TENANT_ID`: (optional) use a specific tenant via the Azure CLI.
69    ///
70    /// Common across each:
71    /// - `AZURE_AUTHORITY_HOST`: (optional) the host for the identity provider. For example, for Azure public cloud the host defaults to "https://login.microsoftonline.com".
72    ///   See also: https://github.com/Azure/azure-sdk-for-rust/blob/main/sdk/identity/README.md
73    Environmental,
74}
75
76impl KeyValueAzureCosmos {
77    pub fn new(
78        account: String,
79        database: String,
80        container: String,
81        auth_options: KeyValueAzureCosmosAuthOptions,
82        app_id: Option<String>,
83    ) -> Result<Self> {
84        let token = match auth_options {
85            KeyValueAzureCosmosAuthOptions::RuntimeConfigValues(config) => {
86                AuthorizationToken::primary_key(config.key).map_err(log_error)?
87            }
88            KeyValueAzureCosmosAuthOptions::Environmental => {
89                AuthorizationToken::from_token_credential(
90                    azure_identity::create_default_credential()?,
91                )
92            }
93        };
94        let cosmos_client = cosmos_client(account, token)?;
95        let database_client = cosmos_client.database_client(database);
96        let client = database_client.collection_client(container);
97
98        Ok(Self { client, app_id })
99    }
100}
101
102fn cosmos_client(account: impl Into<String>, token: AuthorizationToken) -> Result<CosmosClient> {
103    if cfg!(feature = "connection-pooling") {
104        let client = reqwest::ClientBuilder::new()
105            .build()
106            .context("failed to build reqwest client")?;
107        let transport_options = azure_core::TransportOptions::new(std::sync::Arc::new(client));
108        Ok(CosmosClientBuilder::new(account, token)
109            .transport(transport_options)
110            .build())
111    } else {
112        Ok(CosmosClient::new(account, token))
113    }
114}
115
116#[async_trait]
117impl StoreManager for KeyValueAzureCosmos {
118    async fn get(&self, name: &str) -> Result<Arc<dyn Store>, Error> {
119        Ok(Arc::new(AzureCosmosStore {
120            client: self.client.clone(),
121            store_id: self.app_id.as_ref().map(|i| format!("{i}/{name}")),
122        }))
123    }
124
125    fn is_defined(&self, _store_name: &str) -> bool {
126        true
127    }
128
129    fn summary(&self, _store_name: &str) -> Option<String> {
130        let database = self.client.database_client().database_name();
131        let collection = self.client.collection_name();
132        Some(format!(
133            "Azure CosmosDB database: {database}, collection: {collection}"
134        ))
135    }
136}
137
138#[derive(Clone)]
139struct AzureCosmosStore {
140    client: CollectionClient,
141    /// An optional store id to use as a partition key for all operations.
142    ///
143    /// If the store ID is not set, the store will use `/id` (the row key) as
144    /// the partition key. For example, if `store.set("my_key", "my_value")` is
145    /// called, the partition key will be `my_key` if the store ID is set to
146    /// `None`. If the store ID is set to `Some("myappid/default"), the
147    /// partition key will be `myappid/default`.
148    store_id: Option<String>,
149}
150
151#[async_trait]
152impl Store for AzureCosmosStore {
153    async fn get(&self, key: &str, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
154        let pair = self.get_entity::<Pair>(key).await?;
155        let value = pair.map(|p| p.value);
156
157        // Currently there's no way to stream a single query result using the
158        // `azure_data_cosmos` crate without buffering, so the damage (in terms
159        // of host memory usage) is already done, but we can still enforce the
160        // limit:
161        if std::mem::size_of::<Option<Vec<u8>>>() + value.as_ref().map(|v| v.len()).unwrap_or(0)
162            > max_result_bytes
163        {
164            Err(Error::Other(format!(
165                "query result exceeds limit of {max_result_bytes} bytes"
166            )))
167        } else {
168            Ok(value)
169        }
170    }
171
172    async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
173        let illegal_chars = ['/', '\\', '?', '#'];
174
175        if key.contains(|c| illegal_chars.contains(&c)) {
176            return Err(Error::Other(format!(
177                "Key contains an illegal character. Keys must not include any of: {}",
178                illegal_chars.iter().collect::<String>()
179            )));
180        }
181
182        let pair = Pair {
183            id: key.to_string(),
184            value: value.to_vec(),
185            store_id: self.store_id.clone(),
186        };
187        self.client
188            .create_document(pair)
189            .is_upsert(true)
190            .await
191            .map_err(log_error)?;
192        Ok(())
193    }
194
195    async fn delete(&self, key: &str) -> Result<(), Error> {
196        let document_client = self
197            .client
198            .document_client(key, &self.store_id.clone().unwrap_or(key.to_string()))
199            .map_err(log_error)?;
200        if let Err(e) = document_client.delete_document().await {
201            if e.as_http_error().map(|e| e.status() != 404).unwrap_or(true) {
202                return Err(log_error(e));
203            }
204        }
205        Ok(())
206    }
207
208    async fn exists(&self, key: &str) -> Result<bool, Error> {
209        let mut stream = self
210            .client
211            .query_documents(Query::new(self.get_id_query(key)))
212            .query_cross_partition(true)
213            .max_item_count(1)
214            .into_stream::<Key>();
215
216        match stream.next().await {
217            Some(Ok(res)) => Ok(!res.results.is_empty()),
218            Some(Err(e)) => Err(log_error(e)),
219            None => Ok(false),
220        }
221    }
222
223    async fn get_keys(&self, max_result_bytes: usize) -> Result<Vec<String>, Error> {
224        self.get_keys(max_result_bytes).await
225    }
226
227    async fn get_keys_async(
228        &self,
229        max_result_bytes: usize,
230    ) -> (
231        tokio::sync::mpsc::Receiver<String>,
232        tokio::sync::oneshot::Receiver<Result<(), v3::Error>>,
233    ) {
234        let (keys_tx, keys_rx) = tokio::sync::mpsc::channel(4);
235        let (err_tx, err_rx) = tokio::sync::oneshot::channel();
236
237        let query = self
238            .client
239            .query_documents(Query::new(self.get_keys_query()))
240            .query_cross_partition(true);
241
242        let the_work = async move {
243            let mut stream = query.into_stream::<Key>();
244            while let Some(resp) = stream.next().await {
245                let resp = resp.map_err(log_error_v3)?;
246
247                if resp.results.iter().map(|(k, _)| k.id.len()).sum::<usize>() > max_result_bytes {
248                    return Err(v3::Error::Other(format!(
249                        "query exceeds limit of {max_result_bytes} bytes"
250                    )));
251                }
252
253                for (key, _) in resp.results {
254                    keys_tx.send(key.id).await.map_err(log_error_v3)?;
255                }
256            }
257            Ok(())
258        };
259        tokio::spawn(async move {
260            let res = the_work.await;
261            _ = err_tx.send(res);
262        });
263
264        (keys_rx, err_rx)
265    }
266
267    async fn get_many(
268        &self,
269        keys: Vec<String>,
270        max_result_bytes: usize,
271    ) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
272        let stmt = Query::new(self.get_in_query(keys));
273        let query = self
274            .client
275            .query_documents(stmt)
276            .query_cross_partition(true);
277
278        let mut res = Vec::new();
279        let mut stream = query.into_stream::<Pair>();
280        let mut byte_count = std::mem::size_of::<Vec<(String, Option<Vec<u8>>)>>();
281        while let Some(resp) = stream.next().await {
282            let resp = resp.map_err(log_error)?.results;
283            byte_count += resp
284                .iter()
285                .map(|(pair, _)| {
286                    std::mem::size_of::<(String, Option<Vec<u8>>)>()
287                        + pair.id.len()
288                        + pair.value.len()
289                })
290                .sum::<usize>();
291            if byte_count > max_result_bytes {
292                return Err(Error::Other(format!(
293                    "query result exceeds limit of {max_result_bytes} bytes"
294                )));
295            }
296            res.extend(
297                resp.into_iter()
298                    .map(|(pair, _)| (pair.id, Some(pair.value))),
299            );
300        }
301        Ok(res)
302    }
303
304    async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
305        for (key, value) in key_values {
306            self.set(key.as_ref(), &value).await?
307        }
308        Ok(())
309    }
310
311    async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
312        for key in keys {
313            self.delete(key.as_ref()).await?
314        }
315        Ok(())
316    }
317
318    /// Increments a numerical value.
319    ///
320    /// The initial value for the item must be set through this interface, as this sets the
321    /// number value if it does not exist. If the value was previously set using
322    /// the `set` interface, this will fail due to a type mismatch.
323    // TODO: The function should parse the new value from the return response
324    // rather than sending an additional new request. However, the current SDK
325    // version does not support this.
326    async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
327        let operations = vec![Operation::incr("/value", delta).map_err(log_error)?];
328        match self
329            .client
330            .document_client(&key, &self.store_id.clone().unwrap_or(key.to_string()))
331            .map_err(log_error)?
332            .patch_document(operations)
333            .await
334        {
335            Err(e) => {
336                if e.as_http_error()
337                    .map(|e| e.status() == 404)
338                    .unwrap_or(false)
339                {
340                    let counter = Counter {
341                        id: key.clone(),
342                        value: delta,
343                        store_id: self.store_id.clone(),
344                    };
345                    if let Err(e) = self.client.create_document(counter).is_upsert(false).await {
346                        if e.as_http_error()
347                            .map(|e| e.status())
348                            .unwrap_or(azure_core::StatusCode::Continue)
349                            == 409
350                        {
351                            // Conflict trying to create counter, retry increment
352                            self.increment(key, delta).await?;
353                        } else {
354                            return Err(log_error(e));
355                        }
356                    }
357                    Ok(delta)
358                } else {
359                    Err(log_error(e))
360                }
361            }
362            Ok(_) => self
363                .get_entity::<Counter>(key.as_ref())
364                .await?
365                .map(|c| c.value)
366                .ok_or(Error::Other(
367                    "increment returned an empty value after patching, which indicates a bug"
368                        .to_string(),
369                )),
370        }
371    }
372
373    async fn new_compare_and_swap(
374        &self,
375        bucket_rep: u32,
376        key: &str,
377    ) -> Result<Arc<dyn spin_factor_key_value::Cas>, Error> {
378        Ok(Arc::new(CompareAndSwap {
379            key: key.to_string(),
380            client: self.client.clone(),
381            etag: Mutex::new(None),
382            bucket_rep,
383            store_id: self.store_id.clone(),
384        }))
385    }
386}
387
388struct CompareAndSwap {
389    key: String,
390    client: CollectionClient,
391    bucket_rep: u32,
392    etag: Mutex<Option<String>>,
393    store_id: Option<String>,
394}
395
396impl CompareAndSwap {
397    fn get_query(&self) -> String {
398        let mut query = format!("SELECT * FROM c WHERE c.id='{}'", self.key);
399        self.append_store_id(&mut query, true);
400        query
401    }
402
403    fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
404        append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
405    }
406}
407
408#[async_trait]
409impl Cas for CompareAndSwap {
410    /// `current` will fetch the current value for the key and store the etag for the record. The
411    /// etag will be used to perform and optimistic concurrency update using the `if-match` header.
412    async fn current(&self, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
413        let mut stream = self
414            .client
415            .query_documents(Query::new(self.get_query()))
416            .query_cross_partition(true)
417            .max_item_count(1)
418            .into_stream::<Pair>();
419
420        let current_value: Option<(Vec<u8>, Option<String>)> = match stream.next().await {
421            Some(r) => {
422                let r = r.map_err(log_error)?;
423                match r.results.first() {
424                    Some((item, Some(attr))) => {
425                        Some((item.clone().value, Some(attr.etag().to_string())))
426                    }
427                    Some((item, None)) => Some((item.clone().value, None)),
428                    _ => None,
429                }
430            }
431            None => None,
432        };
433
434        let value = match current_value {
435            Some((value, etag)) => {
436                self.etag.lock().unwrap().clone_from(&etag);
437                Some(value)
438            }
439            None => None,
440        };
441
442        // Currently there's no way to stream a single query result using the
443        // `azure_data_cosmos` crate without buffering, so the damage (in terms
444        // of host memory usage) is already done, but we can still enforce the
445        // limit:
446        if std::mem::size_of::<Option<Vec<u8>>>() + value.as_ref().map(|v| v.len()).unwrap_or(0)
447            > max_result_bytes
448        {
449            Err(Error::Other(format!(
450                "query result exceeds limit of {max_result_bytes} bytes"
451            )))
452        } else {
453            Ok(value)
454        }
455    }
456
457    /// `swap` updates the value for the key using the etag saved in the `current` function for
458    /// optimistic concurrency.
459    async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
460        let pair = Pair {
461            id: self.key.clone(),
462            value,
463            store_id: self.store_id.clone(),
464        };
465
466        let doc_client = self
467            .client
468            .document_client(&self.key, &pair.partition_key())
469            .map_err(log_cas_error)?;
470
471        let etag_value = self.etag.lock().unwrap().clone();
472        match etag_value {
473            Some(etag) => {
474                // attempt to replace the document if the etag matches
475                doc_client
476                    .replace_document(pair)
477                    .if_match_condition(azure_core::request_options::IfMatchCondition::Match(etag))
478                    .await
479                    .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
480                    .map(drop)
481            }
482            None => {
483                // if we have no etag, then we assume the document does not yet exist and must insert; no upserts.
484                self.client
485                    .create_document(pair)
486                    .await
487                    .map_err(|e| SwapError::CasFailed(format!("{e:?}")))
488                    .map(drop)
489            }
490        }
491    }
492
493    async fn bucket_rep(&self) -> u32 {
494        self.bucket_rep
495    }
496
497    async fn key(&self) -> String {
498        self.key.clone()
499    }
500}
501
502impl AzureCosmosStore {
503    async fn get_entity<F>(&self, key: &str) -> Result<Option<F>, Error>
504    where
505        F: CosmosEntity + Send + Sync + serde::de::DeserializeOwned + Clone,
506    {
507        let query = self
508            .client
509            .query_documents(Query::new(self.get_query(key)))
510            .query_cross_partition(true)
511            .max_item_count(1);
512
513        // There can be no duplicated keys, so we create the stream and only take the first result.
514        let mut stream = query.into_stream::<F>();
515        let Some(res) = stream.next().await else {
516            return Ok(None);
517        };
518        Ok(res
519            .map_err(log_error)?
520            .results
521            .first()
522            .map(|(p, _)| p.clone()))
523    }
524
525    async fn get_keys(&self, max_result_bytes: usize) -> Result<Vec<String>, Error> {
526        let query = self
527            .client
528            .query_documents(Query::new(self.get_keys_query()))
529            .query_cross_partition(true);
530        let mut res = Vec::new();
531
532        let mut stream = query.into_stream::<Key>();
533        let mut byte_count = std::mem::size_of::<Vec<String>>();
534        while let Some(resp) = stream.next().await {
535            let resp = resp.map_err(log_error)?.results;
536            byte_count += resp
537                .iter()
538                .map(|(key, _)| std::mem::size_of::<String>() + key.id.len())
539                .sum::<usize>();
540            if byte_count > max_result_bytes {
541                return Err(Error::Other(format!(
542                    "query result exceeds limit of {max_result_bytes} bytes"
543                )));
544            }
545            res.extend(resp.into_iter().map(|(key, _)| key.id));
546        }
547
548        Ok(res)
549    }
550
551    fn get_query(&self, key: &str) -> String {
552        let mut query = format!("SELECT * FROM c WHERE c.id='{key}'");
553        self.append_store_id(&mut query, true);
554        query
555    }
556
557    fn get_id_query(&self, key: &str) -> String {
558        let mut query = format!("SELECT c.id, c.store_id FROM c WHERE c.id='{key}'");
559        self.append_store_id(&mut query, true);
560        query
561    }
562
563    fn get_keys_query(&self) -> String {
564        let mut query = "SELECT c.id, c.store_id FROM c".to_owned();
565        self.append_store_id(&mut query, false);
566        query
567    }
568
569    fn get_in_query(&self, keys: Vec<String>) -> String {
570        let in_clause: String = keys
571            .into_iter()
572            .map(|k| format!("'{k}'"))
573            .collect::<Vec<String>>()
574            .join(", ");
575
576        let mut query = format!("SELECT * FROM c WHERE c.id IN ({in_clause})");
577        self.append_store_id(&mut query, true);
578        query
579    }
580
581    fn append_store_id(&self, query: &mut String, condition_already_exists: bool) {
582        append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists);
583    }
584}
585
586/// Appends an option store id condition to the query.
587fn append_store_id_condition(
588    query: &mut String,
589    store_id: Option<&str>,
590    condition_already_exists: bool,
591) {
592    if let Some(s) = store_id {
593        if condition_already_exists {
594            query.push_str(" AND");
595        } else {
596            query.push_str(" WHERE");
597        }
598        query.push_str(" c.store_id='");
599        query.push_str(s);
600        query.push('\'')
601    }
602}
603
604// Pair structure for key value operations
605#[derive(Serialize, Deserialize, Clone, Debug)]
606pub struct Pair {
607    pub id: String,
608    pub value: Vec<u8>,
609    #[serde(skip_serializing_if = "Option::is_none")]
610    pub store_id: Option<String>,
611}
612
613impl CosmosEntity for Pair {
614    type Entity = String;
615
616    fn partition_key(&self) -> Self::Entity {
617        self.store_id.clone().unwrap_or_else(|| self.id.clone())
618    }
619}
620
621// Counter structure for increment operations
622#[derive(Serialize, Deserialize, Clone, Debug)]
623pub struct Counter {
624    pub id: String,
625    pub value: i64,
626    #[serde(skip_serializing_if = "Option::is_none")]
627    pub store_id: Option<String>,
628}
629
630impl CosmosEntity for Counter {
631    type Entity = String;
632
633    fn partition_key(&self) -> Self::Entity {
634        self.store_id.clone().unwrap_or_else(|| self.id.clone())
635    }
636}
637
638// Key structure for operations with generic value types
639#[derive(Serialize, Deserialize, Clone, Debug)]
640pub struct Key {
641    pub id: String,
642    #[serde(skip_serializing_if = "Option::is_none")]
643    pub store_id: Option<String>,
644}
645
646impl CosmosEntity for Key {
647    type Entity = String;
648
649    fn partition_key(&self) -> Self::Entity {
650        self.store_id.clone().unwrap_or_else(|| self.id.clone())
651    }
652}