Skip to main content

spin_key_value_aws/
store.rs

1use core::str;
2use std::{
3    collections::HashMap,
4    sync::{Arc, Mutex},
5};
6
7use anyhow::Result;
8use aws_config::{BehaviorVersion, Region, SdkConfig};
9use aws_credential_types::Credentials;
10use aws_sdk_dynamodb::{
11    config::{ProvideCredentials, SharedCredentialsProvider},
12    operation::{
13        batch_get_item::BatchGetItemOutput, batch_write_item::BatchWriteItemOutput,
14        get_item::GetItemOutput,
15    },
16    primitives::Blob,
17    types::{
18        AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, TransactWriteItem, Update,
19        WriteRequest,
20    },
21    Client,
22};
23use spin_core::async_trait;
24use spin_factor_key_value::{
25    log_error, log_error_v3, v3, Cas, Error, Store, StoreManager, SwapError,
26};
27
28pub struct KeyValueAwsDynamo {
29    /// AWS region
30    region: String,
31    /// Whether to use strongly consistent reads
32    consistent_read: bool,
33    /// DynamoDB table, needs to be cloned when getting a store
34    table: Arc<String>,
35    /// DynamoDB client
36    client: async_once_cell::Lazy<
37        Client,
38        std::pin::Pin<Box<dyn std::future::Future<Output = Client> + Send>>,
39    >,
40}
41
42/// AWS Dynamo Key / Value runtime config literal options for authentication
43#[derive(Clone, Debug)]
44pub struct KeyValueAwsDynamoRuntimeConfigOptions {
45    access_key: String,
46    secret_key: String,
47    token: Option<String>,
48}
49
50impl KeyValueAwsDynamoRuntimeConfigOptions {
51    pub fn new(access_key: String, secret_key: String, token: Option<String>) -> Self {
52        Self {
53            access_key,
54            secret_key,
55            token,
56        }
57    }
58}
59
60impl ProvideCredentials for KeyValueAwsDynamoRuntimeConfigOptions {
61    fn provide_credentials<'a>(
62        &'a self,
63    ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
64    where
65        Self: 'a,
66    {
67        aws_credential_types::provider::future::ProvideCredentials::ready(Ok(Credentials::new(
68            self.access_key.clone(),
69            self.secret_key.clone(),
70            self.token.clone(),
71            None, // Optional expiration time
72            "spin_custom_aws_provider",
73        )))
74    }
75}
76
77/// AWS Dynamo Key / Value enumeration for the possible authentication options
78#[derive(Clone, Debug)]
79pub enum KeyValueAwsDynamoAuthOptions {
80    /// Runtime Config values indicates credentials have been specified directly
81    RuntimeConfigValues(KeyValueAwsDynamoRuntimeConfigOptions),
82    /// Environmental indicates that the environment variables of the process should be used to
83    /// create the SDK Config for the Dynamo client. This will use the AWS Rust SDK's
84    /// aws_config::load_defaults to derive credentials based on what environment variables
85    /// have been set.
86    ///
87    /// See https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-authentication.html for options.
88    Environmental,
89}
90
91impl KeyValueAwsDynamo {
92    pub fn new(
93        region: String,
94        consistent_read: bool,
95        table: String,
96        auth_options: KeyValueAwsDynamoAuthOptions,
97    ) -> Result<Self> {
98        let region_clone = region.clone();
99        let client_fut = Box::pin(async move {
100            let sdk_config = match auth_options {
101                KeyValueAwsDynamoAuthOptions::RuntimeConfigValues(config) => SdkConfig::builder()
102                    .credentials_provider(SharedCredentialsProvider::new(config))
103                    .region(Region::new(region_clone))
104                    .behavior_version(BehaviorVersion::latest())
105                    .build(),
106                KeyValueAwsDynamoAuthOptions::Environmental => {
107                    aws_config::load_defaults(BehaviorVersion::latest()).await
108                }
109            };
110            Client::new(&sdk_config)
111        });
112
113        Ok(Self {
114            region,
115            consistent_read,
116            table: Arc::new(table),
117            client: async_once_cell::Lazy::from_future(client_fut),
118        })
119    }
120}
121
122#[async_trait]
123impl StoreManager for KeyValueAwsDynamo {
124    async fn get(&self, _name: &str) -> Result<Arc<dyn Store>, Error> {
125        Ok(Arc::new(AwsDynamoStore {
126            client: self.client.get_unpin().await.clone(),
127            table: self.table.clone(),
128            consistent_read: self.consistent_read,
129        }))
130    }
131
132    fn is_defined(&self, _store_name: &str) -> bool {
133        true
134    }
135
136    fn summary(&self, _store_name: &str) -> Option<String> {
137        Some(format!(
138            "AWS DynamoDB region: {}, table: {}",
139            self.region, self.table
140        ))
141    }
142}
143
144struct AwsDynamoStore {
145    // Client wraps an Arc so should be low cost to clone
146    client: Client,
147    table: Arc<String>,
148    consistent_read: bool,
149}
150
151#[derive(Debug, Clone)]
152enum CasState {
153    // Existing item with version
154    Versioned(String),
155    // Existing item without version
156    Unversioned(Blob),
157    // Item was missing when fetched during `current`, expected to be new
158    Unset,
159    // Potentially new item -- `current` was never called to fetch version
160    Unknown,
161}
162
163struct CompareAndSwap {
164    key: String,
165    client: Client,
166    table: Arc<String>,
167    bucket_rep: u32,
168    state: Mutex<CasState>,
169}
170
171/// Primary key in DynamoDB items used for querying items
172const PK: &str = "PK";
173/// Value key in DynamoDB items storing item value as binary
174const VAL: &str = "VAL";
175/// Version key in DynamoDB items used for atomic operations
176const VER: &str = "VER";
177
178#[async_trait]
179impl Store for AwsDynamoStore {
180    async fn get(&self, key: &str, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
181        let response = self
182            .client
183            .get_item()
184            .consistent_read(self.consistent_read)
185            .table_name(self.table.as_str())
186            .key(
187                PK,
188                aws_sdk_dynamodb::types::AttributeValue::S(key.to_string()),
189            )
190            .projection_expression(VAL)
191            .send()
192            .await
193            .map_err(log_error)?;
194
195        let mut byte_count = std::mem::size_of::<Option<Vec<u8>>>();
196        let item = response.item.and_then(|mut item| {
197            if let Some(AttributeValue::B(val)) = item.remove(VAL) {
198                let val = val.into_inner();
199                byte_count += val.len();
200                Some(val)
201            } else {
202                None
203            }
204        });
205
206        // Currently there's no way to stream a `get_item` result without
207        // buffering using `aws-sdk-dynamodb`, so the damage (in terms of host
208        // memory usage) is already done, but we can still enforce the limit:
209        if byte_count > max_result_bytes {
210            return Err(Error::Other(format!(
211                "query result exceeds limit of {max_result_bytes} bytes"
212            )));
213        }
214
215        Ok(item)
216    }
217
218    async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
219        self.client
220            .put_item()
221            .table_name(self.table.as_str())
222            .item(PK, AttributeValue::S(key.to_string()))
223            .item(VAL, AttributeValue::B(Blob::new(value)))
224            .send()
225            .await
226            .map_err(log_error)?;
227        Ok(())
228    }
229
230    async fn delete(&self, key: &str) -> Result<(), Error> {
231        self.client
232            .delete_item()
233            .table_name(self.table.as_str())
234            .key(PK, AttributeValue::S(key.to_string()))
235            .send()
236            .await
237            .map_err(log_error)?;
238        Ok(())
239    }
240
241    async fn exists(&self, key: &str) -> Result<bool, Error> {
242        let GetItemOutput { item, .. } = self
243            .client
244            .get_item()
245            .consistent_read(self.consistent_read)
246            .table_name(self.table.as_str())
247            .key(
248                PK,
249                aws_sdk_dynamodb::types::AttributeValue::S(key.to_string()),
250            )
251            .projection_expression(PK)
252            .send()
253            .await
254            .map_err(log_error)?;
255
256        Ok(item.map(|item| item.contains_key(PK)).unwrap_or(false))
257    }
258
259    async fn get_keys(&self, max_result_bytes: usize) -> Result<Vec<String>, Error> {
260        let mut primary_keys = Vec::new();
261
262        let mut scan_paginator = self
263            .client
264            .scan()
265            .table_name(self.table.as_str())
266            .projection_expression(PK)
267            .into_paginator()
268            .send();
269
270        let mut byte_count = std::mem::size_of::<Vec<String>>();
271        while let Some(output) = scan_paginator.next().await {
272            let scan_output = output.map_err(log_error)?;
273            if let Some(items) = scan_output.items {
274                for mut item in items {
275                    if let Some(AttributeValue::S(pk)) = item.remove(PK) {
276                        byte_count += std::mem::size_of::<String>() + pk.len();
277                        if byte_count > max_result_bytes {
278                            return Err(Error::Other(format!(
279                                "query result exceeds limit of {max_result_bytes} bytes"
280                            )));
281                        }
282                        primary_keys.push(pk);
283                    }
284                }
285            }
286        }
287
288        Ok(primary_keys)
289    }
290
291    async fn get_keys_async(
292        &self,
293        max_result_bytes: usize,
294    ) -> (
295        tokio::sync::mpsc::Receiver<String>,
296        tokio::sync::oneshot::Receiver<Result<(), v3::Error>>,
297    ) {
298        let (keys_tx, keys_rx) = tokio::sync::mpsc::channel(4);
299        let (err_tx, err_rx) = tokio::sync::oneshot::channel();
300
301        let mut scan_paginator = self
302            .client
303            .scan()
304            .table_name(self.table.as_str())
305            .projection_expression(PK)
306            .into_paginator()
307            .send();
308
309        let the_work = async move {
310            while let Some(output) = scan_paginator.next().await {
311                let scan_output = output.map_err(log_error_v3)?;
312                if let Some(items) = scan_output.items {
313                    for mut item in items {
314                        if let Some(AttributeValue::S(pk)) = item.remove(PK) {
315                            if pk.len() > max_result_bytes {
316                                return Err(v3::Error::Other(format!(
317                                    "key exceeds limit of {max_result_bytes} bytes"
318                                )));
319                            }
320                            keys_tx.send(pk).await.map_err(log_error_v3)?;
321                        }
322                    }
323                }
324            }
325
326            Ok(())
327        };
328        tokio::spawn(async move {
329            let res = the_work.await;
330            _ = err_tx.send(res);
331        });
332
333        (keys_rx, err_rx)
334    }
335
336    async fn get_many(
337        &self,
338        keys: Vec<String>,
339        max_result_bytes: usize,
340    ) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
341        let mut results = Vec::with_capacity(keys.len());
342        let mut keys_and_attributes_builder = KeysAndAttributes::builder()
343            .projection_expression(format!("{PK},{VAL}"))
344            .consistent_read(self.consistent_read);
345        for key in keys {
346            keys_and_attributes_builder = keys_and_attributes_builder.keys(HashMap::from_iter([(
347                PK.to_owned(),
348                AttributeValue::S(key),
349            )]))
350        }
351        let mut request_items = Some(HashMap::from_iter([(
352            self.table.to_string(),
353            keys_and_attributes_builder.build().map_err(log_error)?,
354        )]));
355
356        let mut byte_count = 0;
357        while request_items.is_some() {
358            let BatchGetItemOutput {
359                responses,
360                unprocessed_keys,
361                ..
362            } = self
363                .client
364                .batch_get_item()
365                .set_request_items(request_items)
366                .send()
367                .await
368                .map_err(log_error)?;
369
370            if let Some(items) =
371                responses.and_then(|mut responses| responses.remove(self.table.as_str()))
372            {
373                for mut item in items {
374                    match (item.remove(PK), item.remove(VAL)) {
375                        (Some(AttributeValue::S(pk)), Some(AttributeValue::B(val))) => {
376                            let val = val.into_inner();
377                            byte_count += std::mem::size_of::<(String, Option<Vec<u8>>)>()
378                                + pk.len()
379                                + val.len();
380                            results.push((pk, Some(val)));
381                        }
382                        (Some(AttributeValue::S(pk)), None) => {
383                            byte_count +=
384                                std::mem::size_of::<(String, Option<Vec<u8>>)>() + pk.len();
385                            results.push((pk, None));
386                        }
387                        _ => (),
388                    }
389                    if byte_count > max_result_bytes {
390                        return Err(Error::Other(format!(
391                            "query result exceeds limit of {max_result_bytes} bytes"
392                        )));
393                    }
394                }
395            }
396
397            request_items = unprocessed_keys.filter(|unprocessed| !unprocessed.is_empty());
398        }
399
400        Ok(results)
401    }
402
403    async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
404        let mut data = Vec::with_capacity(key_values.len());
405        for (key, val) in key_values {
406            data.push(
407                WriteRequest::builder()
408                    .put_request(
409                        PutRequest::builder()
410                            .item(PK, AttributeValue::S(key))
411                            .item(VAL, AttributeValue::B(Blob::new(val)))
412                            .build()
413                            .map_err(log_error)?,
414                    )
415                    .build(),
416            )
417        }
418
419        let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)]));
420
421        while request_items.is_some() {
422            let BatchWriteItemOutput {
423                unprocessed_items, ..
424            } = self
425                .client
426                .batch_write_item()
427                .set_request_items(request_items)
428                .send()
429                .await
430                .map_err(log_error)?;
431
432            request_items = unprocessed_items.filter(|unprocessed| !unprocessed.is_empty());
433        }
434
435        Ok(())
436    }
437
438    async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
439        let mut data = Vec::with_capacity(keys.len());
440        for key in keys {
441            data.push(
442                WriteRequest::builder()
443                    .delete_request(
444                        DeleteRequest::builder()
445                            .key(PK, AttributeValue::S(key))
446                            .build()
447                            .map_err(log_error)?,
448                    )
449                    .build(),
450            )
451        }
452
453        let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)]));
454
455        while request_items.is_some() {
456            let BatchWriteItemOutput {
457                unprocessed_items, ..
458            } = self
459                .client
460                .batch_write_item()
461                .set_request_items(request_items)
462                .send()
463                .await
464                .map_err(log_error)?;
465
466            request_items = unprocessed_items.filter(|unprocessed| !unprocessed.is_empty());
467        }
468
469        Ok(())
470    }
471
472    async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
473        let GetItemOutput { item, .. } = self
474            .client
475            .get_item()
476            .consistent_read(true)
477            .table_name(self.table.as_str())
478            .key(PK, AttributeValue::S(key.clone()))
479            .projection_expression(VAL)
480            .send()
481            .await
482            .map_err(log_error)?;
483
484        let old_val = match item {
485            Some(mut current_item) => match current_item.remove(VAL) {
486                // We're expecting i64, so technically we could transmute but seems risky...
487                Some(AttributeValue::B(val)) => Some(
488                    str::from_utf8(&val.into_inner())
489                        .map_err(log_error)?
490                        .parse::<i64>()
491                        .map_err(log_error)?,
492                ),
493                _ => None,
494            },
495            None => None,
496        };
497
498        let new_val = old_val.unwrap_or(0) + delta;
499
500        let mut update = Update::builder()
501            .table_name(self.table.as_str())
502            .key(PK, AttributeValue::S(key))
503            .update_expression("SET #VAL = :new_val")
504            .expression_attribute_names("#VAL", VAL)
505            .expression_attribute_values(
506                ":new_val",
507                AttributeValue::B(Blob::new(new_val.to_string().as_bytes())),
508            );
509
510        if let Some(old_val) = old_val {
511            update = update
512                .condition_expression("#VAL = :old_val")
513                .expression_attribute_values(
514                    ":old_val",
515                    AttributeValue::B(Blob::new(old_val.to_string().as_bytes())),
516                )
517        } else {
518            update = update.condition_expression("attribute_not_exists (#VAL)")
519        }
520
521        self.client
522            .transact_write_items()
523            .transact_items(
524                TransactWriteItem::builder()
525                    .update(update.build().map_err(log_error)?)
526                    .build(),
527            )
528            .send()
529            .await
530            .map_err(log_error)?;
531
532        Ok(new_val)
533    }
534
535    async fn new_compare_and_swap(
536        &self,
537        bucket_rep: u32,
538        key: &str,
539    ) -> Result<Arc<dyn spin_factor_key_value::Cas>, Error> {
540        Ok(Arc::new(CompareAndSwap {
541            key: key.to_string(),
542            client: self.client.clone(),
543            table: self.table.clone(),
544            state: Mutex::new(CasState::Unknown),
545            bucket_rep,
546        }))
547    }
548}
549
550#[async_trait]
551impl Cas for CompareAndSwap {
552    async fn current(&self, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
553        let GetItemOutput { item, .. } = self
554            .client
555            .get_item()
556            .consistent_read(true)
557            .table_name(self.table.as_str())
558            .key(PK, AttributeValue::S(self.key.clone()))
559            .projection_expression(format!("{VAL},{VER}"))
560            .send()
561            .await
562            .map_err(log_error)?;
563
564        let mut byte_count = std::mem::size_of::<Option<Vec<u8>>>();
565        let value = match item {
566            Some(mut current_item) => match (current_item.remove(VAL), current_item.remove(VER)) {
567                (Some(AttributeValue::B(val)), Some(AttributeValue::N(ver))) => {
568                    let val = val.into_inner();
569                    byte_count += val.len();
570
571                    self.state
572                        .lock()
573                        .unwrap()
574                        .clone_from(&CasState::Versioned(ver));
575
576                    Some(val)
577                }
578                (Some(AttributeValue::B(val)), _) => {
579                    self.state
580                        .lock()
581                        .unwrap()
582                        .clone_from(&CasState::Unversioned(val.clone()));
583
584                    let val = val.into_inner();
585                    byte_count += val.len();
586
587                    Some(val)
588                }
589                (_, _) => {
590                    self.state.lock().unwrap().clone_from(&CasState::Unset);
591                    None
592                }
593            },
594            None => {
595                self.state.lock().unwrap().clone_from(&CasState::Unset);
596                None
597            }
598        };
599
600        // Currently there's no way to stream a `get_item` result without
601        // buffering using `aws-sdk-dynamodb`, so the damage (in terms of host
602        // memory usage) is already done, but we can still enforce the limit:
603        if byte_count > max_result_bytes {
604            return Err(Error::Other(format!(
605                "query result exceeds limit of {max_result_bytes} bytes"
606            )));
607        }
608
609        Ok(value)
610    }
611
612    /// `swap` updates the value for the key -- if possible, using the version saved in the `current` function for
613    /// optimistic concurrency or the previous item value
614    async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
615        let mut update = Update::builder()
616            .table_name(self.table.as_str())
617            .key(PK, AttributeValue::S(self.key.clone()))
618            .update_expression("SET #VAL = :val ADD #VER :increment")
619            .expression_attribute_names("#VAL", VAL)
620            .expression_attribute_names("#VER", VER)
621            .expression_attribute_values(":val", AttributeValue::B(Blob::new(value)))
622            .expression_attribute_values(":increment", AttributeValue::N("1".to_owned()));
623
624        let state = self.state.lock().unwrap().clone();
625        match state {
626            CasState::Versioned(version) => {
627                update = update
628                    .condition_expression("#VER = :ver")
629                    .expression_attribute_values(":ver", AttributeValue::N(version));
630            }
631            CasState::Unversioned(old_val) => {
632                update = update
633                    .condition_expression("#VAL = :old_val")
634                    .expression_attribute_values(":old_val", AttributeValue::B(old_val));
635            }
636            CasState::Unset => {
637                update = update.condition_expression("attribute_not_exists (#VAL)");
638            }
639            CasState::Unknown => (),
640        };
641
642        self.client
643            .transact_write_items()
644            .transact_items(
645                TransactWriteItem::builder()
646                    .update(
647                        update
648                            .build()
649                            .map_err(|e| SwapError::Other(format!("{e:?}")))?,
650                    )
651                    .build(),
652            )
653            .send()
654            .await
655            .map_err(|e| SwapError::CasFailed(format!("{e:?}")))?;
656
657        Ok(())
658    }
659
660    async fn bucket_rep(&self) -> u32 {
661        self.bucket_rep
662    }
663
664    async fn key(&self) -> String {
665        self.key.clone()
666    }
667}