spin_key_value_aws/
store.rs

1use core::str;
2use std::{
3    collections::HashMap,
4    sync::{Arc, Mutex},
5};
6
7use anyhow::Result;
8use aws_config::{BehaviorVersion, Region, SdkConfig};
9use aws_credential_types::Credentials;
10use aws_sdk_dynamodb::{
11    config::{ProvideCredentials, SharedCredentialsProvider},
12    operation::{
13        batch_get_item::BatchGetItemOutput, batch_write_item::BatchWriteItemOutput,
14        get_item::GetItemOutput,
15    },
16    primitives::Blob,
17    types::{
18        AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, TransactWriteItem, Update,
19        WriteRequest,
20    },
21    Client,
22};
23use spin_core::async_trait;
24use spin_factor_key_value::{log_error, Cas, Error, Store, StoreManager, SwapError};
25
26pub struct KeyValueAwsDynamo {
27    /// AWS region
28    region: String,
29    /// Whether to use strongly consistent reads
30    consistent_read: bool,
31    /// DynamoDB table, needs to be cloned when getting a store
32    table: Arc<String>,
33    /// DynamoDB client
34    client: async_once_cell::Lazy<
35        Client,
36        std::pin::Pin<Box<dyn std::future::Future<Output = Client> + Send>>,
37    >,
38}
39
40/// AWS Dynamo Key / Value runtime config literal options for authentication
41#[derive(Clone, Debug)]
42pub struct KeyValueAwsDynamoRuntimeConfigOptions {
43    access_key: String,
44    secret_key: String,
45    token: Option<String>,
46}
47
48impl KeyValueAwsDynamoRuntimeConfigOptions {
49    pub fn new(access_key: String, secret_key: String, token: Option<String>) -> Self {
50        Self {
51            access_key,
52            secret_key,
53            token,
54        }
55    }
56}
57
58impl ProvideCredentials for KeyValueAwsDynamoRuntimeConfigOptions {
59    fn provide_credentials<'a>(
60        &'a self,
61    ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
62    where
63        Self: 'a,
64    {
65        aws_credential_types::provider::future::ProvideCredentials::ready(Ok(Credentials::new(
66            self.access_key.clone(),
67            self.secret_key.clone(),
68            self.token.clone(),
69            None, // Optional expiration time
70            "spin_custom_aws_provider",
71        )))
72    }
73}
74
75/// AWS Dynamo Key / Value enumeration for the possible authentication options
76#[derive(Clone, Debug)]
77pub enum KeyValueAwsDynamoAuthOptions {
78    /// Runtime Config values indicates credentials have been specified directly
79    RuntimeConfigValues(KeyValueAwsDynamoRuntimeConfigOptions),
80    /// Environmental indicates that the environment variables of the process should be used to
81    /// create the SDK Config for the Dynamo client. This will use the AWS Rust SDK's
82    /// aws_config::load_defaults to derive credentials based on what environment variables
83    /// have been set.
84    ///
85    /// See https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-authentication.html for options.
86    Environmental,
87}
88
89impl KeyValueAwsDynamo {
90    pub fn new(
91        region: String,
92        consistent_read: bool,
93        table: String,
94        auth_options: KeyValueAwsDynamoAuthOptions,
95    ) -> Result<Self> {
96        let region_clone = region.clone();
97        let client_fut = Box::pin(async move {
98            let sdk_config = match auth_options {
99                KeyValueAwsDynamoAuthOptions::RuntimeConfigValues(config) => SdkConfig::builder()
100                    .credentials_provider(SharedCredentialsProvider::new(config))
101                    .region(Region::new(region_clone))
102                    .behavior_version(BehaviorVersion::latest())
103                    .build(),
104                KeyValueAwsDynamoAuthOptions::Environmental => {
105                    aws_config::load_defaults(BehaviorVersion::latest()).await
106                }
107            };
108            Client::new(&sdk_config)
109        });
110
111        Ok(Self {
112            region,
113            consistent_read,
114            table: Arc::new(table),
115            client: async_once_cell::Lazy::from_future(client_fut),
116        })
117    }
118}
119
120#[async_trait]
121impl StoreManager for KeyValueAwsDynamo {
122    async fn get(&self, _name: &str) -> Result<Arc<dyn Store>, Error> {
123        Ok(Arc::new(AwsDynamoStore {
124            client: self.client.get_unpin().await.clone(),
125            table: self.table.clone(),
126            consistent_read: self.consistent_read,
127        }))
128    }
129
130    fn is_defined(&self, _store_name: &str) -> bool {
131        true
132    }
133
134    fn summary(&self, _store_name: &str) -> Option<String> {
135        Some(format!(
136            "AWS DynamoDB region: {}, table: {}",
137            self.region, self.table
138        ))
139    }
140}
141
142struct AwsDynamoStore {
143    // Client wraps an Arc so should be low cost to clone
144    client: Client,
145    table: Arc<String>,
146    consistent_read: bool,
147}
148
149#[derive(Debug, Clone)]
150enum CasState {
151    // Existing item with version
152    Versioned(String),
153    // Existing item without version
154    Unversioned(Blob),
155    // Item was missing when fetched during `current`, expected to be new
156    Unset,
157    // Potentially new item -- `current` was never called to fetch version
158    Unknown,
159}
160
161struct CompareAndSwap {
162    key: String,
163    client: Client,
164    table: Arc<String>,
165    bucket_rep: u32,
166    state: Mutex<CasState>,
167}
168
169/// Primary key in DynamoDB items used for querying items
170const PK: &str = "PK";
171/// Value key in DynamoDB items storing item value as binary
172const VAL: &str = "VAL";
173/// Version key in DynamoDB items used for atomic operations
174const VER: &str = "VER";
175
176#[async_trait]
177impl Store for AwsDynamoStore {
178    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
179        let response = self
180            .client
181            .get_item()
182            .consistent_read(self.consistent_read)
183            .table_name(self.table.as_str())
184            .key(
185                PK,
186                aws_sdk_dynamodb::types::AttributeValue::S(key.to_string()),
187            )
188            .projection_expression(VAL)
189            .send()
190            .await
191            .map_err(log_error)?;
192
193        let item = response.item.and_then(|mut item| {
194            if let Some(AttributeValue::B(val)) = item.remove(VAL) {
195                Some(val.into_inner())
196            } else {
197                None
198            }
199        });
200
201        Ok(item)
202    }
203
204    async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
205        self.client
206            .put_item()
207            .table_name(self.table.as_str())
208            .item(PK, AttributeValue::S(key.to_string()))
209            .item(VAL, AttributeValue::B(Blob::new(value)))
210            .send()
211            .await
212            .map_err(log_error)?;
213        Ok(())
214    }
215
216    async fn delete(&self, key: &str) -> Result<(), Error> {
217        self.client
218            .delete_item()
219            .table_name(self.table.as_str())
220            .key(PK, AttributeValue::S(key.to_string()))
221            .send()
222            .await
223            .map_err(log_error)?;
224        Ok(())
225    }
226
227    async fn exists(&self, key: &str) -> Result<bool, Error> {
228        let GetItemOutput { item, .. } = self
229            .client
230            .get_item()
231            .consistent_read(self.consistent_read)
232            .table_name(self.table.as_str())
233            .key(
234                PK,
235                aws_sdk_dynamodb::types::AttributeValue::S(key.to_string()),
236            )
237            .projection_expression(PK)
238            .send()
239            .await
240            .map_err(log_error)?;
241
242        Ok(item.map(|item| item.contains_key(PK)).unwrap_or(false))
243    }
244
245    async fn get_keys(&self) -> Result<Vec<String>, Error> {
246        let mut primary_keys = Vec::new();
247
248        let mut scan_paginator = self
249            .client
250            .scan()
251            .table_name(self.table.as_str())
252            .projection_expression(PK)
253            .into_paginator()
254            .send();
255
256        while let Some(output) = scan_paginator.next().await {
257            let scan_output = output.map_err(log_error)?;
258            if let Some(items) = scan_output.items {
259                for mut item in items {
260                    if let Some(AttributeValue::S(pk)) = item.remove(PK) {
261                        primary_keys.push(pk);
262                    }
263                }
264            }
265        }
266
267        Ok(primary_keys)
268    }
269
270    async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
271        let mut results = Vec::with_capacity(keys.len());
272        let mut keys_and_attributes_builder = KeysAndAttributes::builder()
273            .projection_expression(format!("{PK},{VAL}"))
274            .consistent_read(self.consistent_read);
275        for key in keys {
276            keys_and_attributes_builder = keys_and_attributes_builder.keys(HashMap::from_iter([(
277                PK.to_owned(),
278                AttributeValue::S(key),
279            )]))
280        }
281        let mut request_items = Some(HashMap::from_iter([(
282            self.table.to_string(),
283            keys_and_attributes_builder.build().map_err(log_error)?,
284        )]));
285
286        while request_items.is_some() {
287            let BatchGetItemOutput {
288                responses,
289                unprocessed_keys,
290                ..
291            } = self
292                .client
293                .batch_get_item()
294                .set_request_items(request_items)
295                .send()
296                .await
297                .map_err(log_error)?;
298
299            if let Some(items) =
300                responses.and_then(|mut responses| responses.remove(self.table.as_str()))
301            {
302                for mut item in items {
303                    match (item.remove(PK), item.remove(VAL)) {
304                        (Some(AttributeValue::S(pk)), Some(AttributeValue::B(val))) => {
305                            results.push((pk, Some(val.into_inner())));
306                        }
307                        (Some(AttributeValue::S(pk)), None) => {
308                            results.push((pk, None));
309                        }
310                        _ => (),
311                    }
312                }
313            }
314
315            request_items = unprocessed_keys.filter(|unprocessed| !unprocessed.is_empty());
316        }
317
318        Ok(results)
319    }
320
321    async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
322        let mut data = Vec::with_capacity(key_values.len());
323        for (key, val) in key_values {
324            data.push(
325                WriteRequest::builder()
326                    .put_request(
327                        PutRequest::builder()
328                            .item(PK, AttributeValue::S(key))
329                            .item(VAL, AttributeValue::B(Blob::new(val)))
330                            .build()
331                            .map_err(log_error)?,
332                    )
333                    .build(),
334            )
335        }
336
337        let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)]));
338
339        while request_items.is_some() {
340            let BatchWriteItemOutput {
341                unprocessed_items, ..
342            } = self
343                .client
344                .batch_write_item()
345                .set_request_items(request_items)
346                .send()
347                .await
348                .map_err(log_error)?;
349
350            request_items = unprocessed_items.filter(|unprocessed| !unprocessed.is_empty());
351        }
352
353        Ok(())
354    }
355
356    async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
357        let mut data = Vec::with_capacity(keys.len());
358        for key in keys {
359            data.push(
360                WriteRequest::builder()
361                    .delete_request(
362                        DeleteRequest::builder()
363                            .key(PK, AttributeValue::S(key))
364                            .build()
365                            .map_err(log_error)?,
366                    )
367                    .build(),
368            )
369        }
370
371        let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)]));
372
373        while request_items.is_some() {
374            let BatchWriteItemOutput {
375                unprocessed_items, ..
376            } = self
377                .client
378                .batch_write_item()
379                .set_request_items(request_items)
380                .send()
381                .await
382                .map_err(log_error)?;
383
384            request_items = unprocessed_items.filter(|unprocessed| !unprocessed.is_empty());
385        }
386
387        Ok(())
388    }
389
390    async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
391        let GetItemOutput { item, .. } = self
392            .client
393            .get_item()
394            .consistent_read(true)
395            .table_name(self.table.as_str())
396            .key(PK, AttributeValue::S(key.clone()))
397            .projection_expression(VAL)
398            .send()
399            .await
400            .map_err(log_error)?;
401
402        let old_val = match item {
403            Some(mut current_item) => match current_item.remove(VAL) {
404                // We're expecting i64, so technically we could transmute but seems risky...
405                Some(AttributeValue::B(val)) => Some(
406                    str::from_utf8(&val.into_inner())
407                        .map_err(log_error)?
408                        .parse::<i64>()
409                        .map_err(log_error)?,
410                ),
411                _ => None,
412            },
413            None => None,
414        };
415
416        let new_val = old_val.unwrap_or(0) + delta;
417
418        let mut update = Update::builder()
419            .table_name(self.table.as_str())
420            .key(PK, AttributeValue::S(key))
421            .update_expression("SET #VAL = :new_val")
422            .expression_attribute_names("#VAL", VAL)
423            .expression_attribute_values(
424                ":new_val",
425                AttributeValue::B(Blob::new(new_val.to_string().as_bytes())),
426            );
427
428        if let Some(old_val) = old_val {
429            update = update
430                .condition_expression("#VAL = :old_val")
431                .expression_attribute_values(
432                    ":old_val",
433                    AttributeValue::B(Blob::new(old_val.to_string().as_bytes())),
434                )
435        } else {
436            update = update.condition_expression("attribute_not_exists (#VAL)")
437        }
438
439        self.client
440            .transact_write_items()
441            .transact_items(
442                TransactWriteItem::builder()
443                    .update(update.build().map_err(log_error)?)
444                    .build(),
445            )
446            .send()
447            .await
448            .map_err(log_error)?;
449
450        Ok(new_val)
451    }
452
453    async fn new_compare_and_swap(
454        &self,
455        bucket_rep: u32,
456        key: &str,
457    ) -> Result<Arc<dyn spin_factor_key_value::Cas>, Error> {
458        Ok(Arc::new(CompareAndSwap {
459            key: key.to_string(),
460            client: self.client.clone(),
461            table: self.table.clone(),
462            state: Mutex::new(CasState::Unknown),
463            bucket_rep,
464        }))
465    }
466}
467
468#[async_trait]
469impl Cas for CompareAndSwap {
470    async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
471        let GetItemOutput { item, .. } = self
472            .client
473            .get_item()
474            .consistent_read(true)
475            .table_name(self.table.as_str())
476            .key(PK, AttributeValue::S(self.key.clone()))
477            .projection_expression(format!("{VAL},{VER}"))
478            .send()
479            .await
480            .map_err(log_error)?;
481
482        match item {
483            Some(mut current_item) => match (current_item.remove(VAL), current_item.remove(VER)) {
484                (Some(AttributeValue::B(val)), Some(AttributeValue::N(ver))) => {
485                    self.state
486                        .lock()
487                        .unwrap()
488                        .clone_from(&CasState::Versioned(ver));
489
490                    Ok(Some(val.into_inner()))
491                }
492                (Some(AttributeValue::B(val)), _) => {
493                    self.state
494                        .lock()
495                        .unwrap()
496                        .clone_from(&CasState::Unversioned(val.clone()));
497
498                    Ok(Some(val.into_inner()))
499                }
500                (_, _) => {
501                    self.state.lock().unwrap().clone_from(&CasState::Unset);
502                    Ok(None)
503                }
504            },
505            None => {
506                self.state.lock().unwrap().clone_from(&CasState::Unset);
507                Ok(None)
508            }
509        }
510    }
511
512    /// `swap` updates the value for the key -- if possible, using the version saved in the `current` function for
513    /// optimistic concurrency or the previous item value
514    async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
515        let mut update = Update::builder()
516            .table_name(self.table.as_str())
517            .key(PK, AttributeValue::S(self.key.clone()))
518            .update_expression("SET #VAL = :val ADD #VER :increment")
519            .expression_attribute_names("#VAL", VAL)
520            .expression_attribute_names("#VER", VER)
521            .expression_attribute_values(":val", AttributeValue::B(Blob::new(value)))
522            .expression_attribute_values(":increment", AttributeValue::N("1".to_owned()));
523
524        let state = self.state.lock().unwrap().clone();
525        match state {
526            CasState::Versioned(version) => {
527                update = update
528                    .condition_expression("#VER = :ver")
529                    .expression_attribute_values(":ver", AttributeValue::N(version));
530            }
531            CasState::Unversioned(old_val) => {
532                update = update
533                    .condition_expression("#VAL = :old_val")
534                    .expression_attribute_values(":old_val", AttributeValue::B(old_val));
535            }
536            CasState::Unset => {
537                update = update.condition_expression("attribute_not_exists (#VAL)");
538            }
539            CasState::Unknown => (),
540        };
541
542        self.client
543            .transact_write_items()
544            .transact_items(
545                TransactWriteItem::builder()
546                    .update(
547                        update
548                            .build()
549                            .map_err(|e| SwapError::Other(format!("{e:?}")))?,
550                    )
551                    .build(),
552            )
553            .send()
554            .await
555            .map_err(|e| SwapError::CasFailed(format!("{e:?}")))?;
556
557        Ok(())
558    }
559
560    async fn bucket_rep(&self) -> u32 {
561        self.bucket_rep
562    }
563
564    async fn key(&self) -> String {
565        self.key.clone()
566    }
567}