1use core::str;
2use std::{
3 collections::HashMap,
4 sync::{Arc, Mutex},
5};
6
7use anyhow::Result;
8use aws_config::{BehaviorVersion, Region, SdkConfig};
9use aws_credential_types::Credentials;
10use aws_sdk_dynamodb::{
11 config::{ProvideCredentials, SharedCredentialsProvider},
12 operation::{
13 batch_get_item::BatchGetItemOutput, batch_write_item::BatchWriteItemOutput,
14 get_item::GetItemOutput,
15 },
16 primitives::Blob,
17 types::{
18 AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, TransactWriteItem, Update,
19 WriteRequest,
20 },
21 Client,
22};
23use spin_core::async_trait;
24use spin_factor_key_value::{
25 log_error, log_error_v3, v3, Cas, Error, Store, StoreManager, SwapError,
26};
27
28pub struct KeyValueAwsDynamo {
29 region: String,
31 consistent_read: bool,
33 table: Arc<String>,
35 client: async_once_cell::Lazy<
37 Client,
38 std::pin::Pin<Box<dyn std::future::Future<Output = Client> + Send>>,
39 >,
40}
41
42#[derive(Clone, Debug)]
44pub struct KeyValueAwsDynamoRuntimeConfigOptions {
45 access_key: String,
46 secret_key: String,
47 token: Option<String>,
48}
49
50impl KeyValueAwsDynamoRuntimeConfigOptions {
51 pub fn new(access_key: String, secret_key: String, token: Option<String>) -> Self {
52 Self {
53 access_key,
54 secret_key,
55 token,
56 }
57 }
58}
59
60impl ProvideCredentials for KeyValueAwsDynamoRuntimeConfigOptions {
61 fn provide_credentials<'a>(
62 &'a self,
63 ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
64 where
65 Self: 'a,
66 {
67 aws_credential_types::provider::future::ProvideCredentials::ready(Ok(Credentials::new(
68 self.access_key.clone(),
69 self.secret_key.clone(),
70 self.token.clone(),
71 None, "spin_custom_aws_provider",
73 )))
74 }
75}
76
77#[derive(Clone, Debug)]
79pub enum KeyValueAwsDynamoAuthOptions {
80 RuntimeConfigValues(KeyValueAwsDynamoRuntimeConfigOptions),
82 Environmental,
89}
90
91impl KeyValueAwsDynamo {
92 pub fn new(
93 region: String,
94 consistent_read: bool,
95 table: String,
96 auth_options: KeyValueAwsDynamoAuthOptions,
97 ) -> Result<Self> {
98 let region_clone = region.clone();
99 let client_fut = Box::pin(async move {
100 let sdk_config = match auth_options {
101 KeyValueAwsDynamoAuthOptions::RuntimeConfigValues(config) => SdkConfig::builder()
102 .credentials_provider(SharedCredentialsProvider::new(config))
103 .region(Region::new(region_clone))
104 .behavior_version(BehaviorVersion::latest())
105 .build(),
106 KeyValueAwsDynamoAuthOptions::Environmental => {
107 aws_config::load_defaults(BehaviorVersion::latest()).await
108 }
109 };
110 Client::new(&sdk_config)
111 });
112
113 Ok(Self {
114 region,
115 consistent_read,
116 table: Arc::new(table),
117 client: async_once_cell::Lazy::from_future(client_fut),
118 })
119 }
120}
121
122#[async_trait]
123impl StoreManager for KeyValueAwsDynamo {
124 async fn get(&self, _name: &str) -> Result<Arc<dyn Store>, Error> {
125 Ok(Arc::new(AwsDynamoStore {
126 client: self.client.get_unpin().await.clone(),
127 table: self.table.clone(),
128 consistent_read: self.consistent_read,
129 }))
130 }
131
132 fn is_defined(&self, _store_name: &str) -> bool {
133 true
134 }
135
136 fn summary(&self, _store_name: &str) -> Option<String> {
137 Some(format!(
138 "AWS DynamoDB region: {}, table: {}",
139 self.region, self.table
140 ))
141 }
142}
143
144struct AwsDynamoStore {
145 client: Client,
147 table: Arc<String>,
148 consistent_read: bool,
149}
150
151#[derive(Debug, Clone)]
152enum CasState {
153 Versioned(String),
155 Unversioned(Blob),
157 Unset,
159 Unknown,
161}
162
163struct CompareAndSwap {
164 key: String,
165 client: Client,
166 table: Arc<String>,
167 bucket_rep: u32,
168 state: Mutex<CasState>,
169}
170
171const PK: &str = "PK";
173const VAL: &str = "VAL";
175const VER: &str = "VER";
177
178#[async_trait]
179impl Store for AwsDynamoStore {
180 async fn get(&self, key: &str, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
181 let response = self
182 .client
183 .get_item()
184 .consistent_read(self.consistent_read)
185 .table_name(self.table.as_str())
186 .key(
187 PK,
188 aws_sdk_dynamodb::types::AttributeValue::S(key.to_string()),
189 )
190 .projection_expression(VAL)
191 .send()
192 .await
193 .map_err(log_error)?;
194
195 let mut byte_count = std::mem::size_of::<Option<Vec<u8>>>();
196 let item = response.item.and_then(|mut item| {
197 if let Some(AttributeValue::B(val)) = item.remove(VAL) {
198 let val = val.into_inner();
199 byte_count += val.len();
200 Some(val)
201 } else {
202 None
203 }
204 });
205
206 if byte_count > max_result_bytes {
210 return Err(Error::Other(format!(
211 "query result exceeds limit of {max_result_bytes} bytes"
212 )));
213 }
214
215 Ok(item)
216 }
217
218 async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
219 self.client
220 .put_item()
221 .table_name(self.table.as_str())
222 .item(PK, AttributeValue::S(key.to_string()))
223 .item(VAL, AttributeValue::B(Blob::new(value)))
224 .send()
225 .await
226 .map_err(log_error)?;
227 Ok(())
228 }
229
230 async fn delete(&self, key: &str) -> Result<(), Error> {
231 self.client
232 .delete_item()
233 .table_name(self.table.as_str())
234 .key(PK, AttributeValue::S(key.to_string()))
235 .send()
236 .await
237 .map_err(log_error)?;
238 Ok(())
239 }
240
241 async fn exists(&self, key: &str) -> Result<bool, Error> {
242 let GetItemOutput { item, .. } = self
243 .client
244 .get_item()
245 .consistent_read(self.consistent_read)
246 .table_name(self.table.as_str())
247 .key(
248 PK,
249 aws_sdk_dynamodb::types::AttributeValue::S(key.to_string()),
250 )
251 .projection_expression(PK)
252 .send()
253 .await
254 .map_err(log_error)?;
255
256 Ok(item.map(|item| item.contains_key(PK)).unwrap_or(false))
257 }
258
259 async fn get_keys(&self, max_result_bytes: usize) -> Result<Vec<String>, Error> {
260 let mut primary_keys = Vec::new();
261
262 let mut scan_paginator = self
263 .client
264 .scan()
265 .table_name(self.table.as_str())
266 .projection_expression(PK)
267 .into_paginator()
268 .send();
269
270 let mut byte_count = std::mem::size_of::<Vec<String>>();
271 while let Some(output) = scan_paginator.next().await {
272 let scan_output = output.map_err(log_error)?;
273 if let Some(items) = scan_output.items {
274 for mut item in items {
275 if let Some(AttributeValue::S(pk)) = item.remove(PK) {
276 byte_count += std::mem::size_of::<String>() + pk.len();
277 if byte_count > max_result_bytes {
278 return Err(Error::Other(format!(
279 "query result exceeds limit of {max_result_bytes} bytes"
280 )));
281 }
282 primary_keys.push(pk);
283 }
284 }
285 }
286 }
287
288 Ok(primary_keys)
289 }
290
291 async fn get_keys_async(
292 &self,
293 max_result_bytes: usize,
294 ) -> (
295 tokio::sync::mpsc::Receiver<String>,
296 tokio::sync::oneshot::Receiver<Result<(), v3::Error>>,
297 ) {
298 let (keys_tx, keys_rx) = tokio::sync::mpsc::channel(4);
299 let (err_tx, err_rx) = tokio::sync::oneshot::channel();
300
301 let mut scan_paginator = self
302 .client
303 .scan()
304 .table_name(self.table.as_str())
305 .projection_expression(PK)
306 .into_paginator()
307 .send();
308
309 let the_work = async move {
310 while let Some(output) = scan_paginator.next().await {
311 let scan_output = output.map_err(log_error_v3)?;
312 if let Some(items) = scan_output.items {
313 for mut item in items {
314 if let Some(AttributeValue::S(pk)) = item.remove(PK) {
315 if pk.len() > max_result_bytes {
316 return Err(v3::Error::Other(format!(
317 "key exceeds limit of {max_result_bytes} bytes"
318 )));
319 }
320 keys_tx.send(pk).await.map_err(log_error_v3)?;
321 }
322 }
323 }
324 }
325
326 Ok(())
327 };
328 tokio::spawn(async move {
329 let res = the_work.await;
330 _ = err_tx.send(res);
331 });
332
333 (keys_rx, err_rx)
334 }
335
336 async fn get_many(
337 &self,
338 keys: Vec<String>,
339 max_result_bytes: usize,
340 ) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
341 let mut results = Vec::with_capacity(keys.len());
342 let mut keys_and_attributes_builder = KeysAndAttributes::builder()
343 .projection_expression(format!("{PK},{VAL}"))
344 .consistent_read(self.consistent_read);
345 for key in keys {
346 keys_and_attributes_builder = keys_and_attributes_builder.keys(HashMap::from_iter([(
347 PK.to_owned(),
348 AttributeValue::S(key),
349 )]))
350 }
351 let mut request_items = Some(HashMap::from_iter([(
352 self.table.to_string(),
353 keys_and_attributes_builder.build().map_err(log_error)?,
354 )]));
355
356 let mut byte_count = 0;
357 while request_items.is_some() {
358 let BatchGetItemOutput {
359 responses,
360 unprocessed_keys,
361 ..
362 } = self
363 .client
364 .batch_get_item()
365 .set_request_items(request_items)
366 .send()
367 .await
368 .map_err(log_error)?;
369
370 if let Some(items) =
371 responses.and_then(|mut responses| responses.remove(self.table.as_str()))
372 {
373 for mut item in items {
374 match (item.remove(PK), item.remove(VAL)) {
375 (Some(AttributeValue::S(pk)), Some(AttributeValue::B(val))) => {
376 let val = val.into_inner();
377 byte_count += std::mem::size_of::<(String, Option<Vec<u8>>)>()
378 + pk.len()
379 + val.len();
380 results.push((pk, Some(val)));
381 }
382 (Some(AttributeValue::S(pk)), None) => {
383 byte_count +=
384 std::mem::size_of::<(String, Option<Vec<u8>>)>() + pk.len();
385 results.push((pk, None));
386 }
387 _ => (),
388 }
389 if byte_count > max_result_bytes {
390 return Err(Error::Other(format!(
391 "query result exceeds limit of {max_result_bytes} bytes"
392 )));
393 }
394 }
395 }
396
397 request_items = unprocessed_keys.filter(|unprocessed| !unprocessed.is_empty());
398 }
399
400 Ok(results)
401 }
402
403 async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
404 let mut data = Vec::with_capacity(key_values.len());
405 for (key, val) in key_values {
406 data.push(
407 WriteRequest::builder()
408 .put_request(
409 PutRequest::builder()
410 .item(PK, AttributeValue::S(key))
411 .item(VAL, AttributeValue::B(Blob::new(val)))
412 .build()
413 .map_err(log_error)?,
414 )
415 .build(),
416 )
417 }
418
419 let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)]));
420
421 while request_items.is_some() {
422 let BatchWriteItemOutput {
423 unprocessed_items, ..
424 } = self
425 .client
426 .batch_write_item()
427 .set_request_items(request_items)
428 .send()
429 .await
430 .map_err(log_error)?;
431
432 request_items = unprocessed_items.filter(|unprocessed| !unprocessed.is_empty());
433 }
434
435 Ok(())
436 }
437
438 async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
439 let mut data = Vec::with_capacity(keys.len());
440 for key in keys {
441 data.push(
442 WriteRequest::builder()
443 .delete_request(
444 DeleteRequest::builder()
445 .key(PK, AttributeValue::S(key))
446 .build()
447 .map_err(log_error)?,
448 )
449 .build(),
450 )
451 }
452
453 let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)]));
454
455 while request_items.is_some() {
456 let BatchWriteItemOutput {
457 unprocessed_items, ..
458 } = self
459 .client
460 .batch_write_item()
461 .set_request_items(request_items)
462 .send()
463 .await
464 .map_err(log_error)?;
465
466 request_items = unprocessed_items.filter(|unprocessed| !unprocessed.is_empty());
467 }
468
469 Ok(())
470 }
471
472 async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
473 let GetItemOutput { item, .. } = self
474 .client
475 .get_item()
476 .consistent_read(true)
477 .table_name(self.table.as_str())
478 .key(PK, AttributeValue::S(key.clone()))
479 .projection_expression(VAL)
480 .send()
481 .await
482 .map_err(log_error)?;
483
484 let old_val = match item {
485 Some(mut current_item) => match current_item.remove(VAL) {
486 Some(AttributeValue::B(val)) => Some(
488 str::from_utf8(&val.into_inner())
489 .map_err(log_error)?
490 .parse::<i64>()
491 .map_err(log_error)?,
492 ),
493 _ => None,
494 },
495 None => None,
496 };
497
498 let new_val = old_val.unwrap_or(0) + delta;
499
500 let mut update = Update::builder()
501 .table_name(self.table.as_str())
502 .key(PK, AttributeValue::S(key))
503 .update_expression("SET #VAL = :new_val")
504 .expression_attribute_names("#VAL", VAL)
505 .expression_attribute_values(
506 ":new_val",
507 AttributeValue::B(Blob::new(new_val.to_string().as_bytes())),
508 );
509
510 if let Some(old_val) = old_val {
511 update = update
512 .condition_expression("#VAL = :old_val")
513 .expression_attribute_values(
514 ":old_val",
515 AttributeValue::B(Blob::new(old_val.to_string().as_bytes())),
516 )
517 } else {
518 update = update.condition_expression("attribute_not_exists (#VAL)")
519 }
520
521 self.client
522 .transact_write_items()
523 .transact_items(
524 TransactWriteItem::builder()
525 .update(update.build().map_err(log_error)?)
526 .build(),
527 )
528 .send()
529 .await
530 .map_err(log_error)?;
531
532 Ok(new_val)
533 }
534
535 async fn new_compare_and_swap(
536 &self,
537 bucket_rep: u32,
538 key: &str,
539 ) -> Result<Arc<dyn spin_factor_key_value::Cas>, Error> {
540 Ok(Arc::new(CompareAndSwap {
541 key: key.to_string(),
542 client: self.client.clone(),
543 table: self.table.clone(),
544 state: Mutex::new(CasState::Unknown),
545 bucket_rep,
546 }))
547 }
548}
549
550#[async_trait]
551impl Cas for CompareAndSwap {
552 async fn current(&self, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
553 let GetItemOutput { item, .. } = self
554 .client
555 .get_item()
556 .consistent_read(true)
557 .table_name(self.table.as_str())
558 .key(PK, AttributeValue::S(self.key.clone()))
559 .projection_expression(format!("{VAL},{VER}"))
560 .send()
561 .await
562 .map_err(log_error)?;
563
564 let mut byte_count = std::mem::size_of::<Option<Vec<u8>>>();
565 let value = match item {
566 Some(mut current_item) => match (current_item.remove(VAL), current_item.remove(VER)) {
567 (Some(AttributeValue::B(val)), Some(AttributeValue::N(ver))) => {
568 let val = val.into_inner();
569 byte_count += val.len();
570
571 self.state
572 .lock()
573 .unwrap()
574 .clone_from(&CasState::Versioned(ver));
575
576 Some(val)
577 }
578 (Some(AttributeValue::B(val)), _) => {
579 self.state
580 .lock()
581 .unwrap()
582 .clone_from(&CasState::Unversioned(val.clone()));
583
584 let val = val.into_inner();
585 byte_count += val.len();
586
587 Some(val)
588 }
589 (_, _) => {
590 self.state.lock().unwrap().clone_from(&CasState::Unset);
591 None
592 }
593 },
594 None => {
595 self.state.lock().unwrap().clone_from(&CasState::Unset);
596 None
597 }
598 };
599
600 if byte_count > max_result_bytes {
604 return Err(Error::Other(format!(
605 "query result exceeds limit of {max_result_bytes} bytes"
606 )));
607 }
608
609 Ok(value)
610 }
611
612 async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
615 let mut update = Update::builder()
616 .table_name(self.table.as_str())
617 .key(PK, AttributeValue::S(self.key.clone()))
618 .update_expression("SET #VAL = :val ADD #VER :increment")
619 .expression_attribute_names("#VAL", VAL)
620 .expression_attribute_names("#VER", VER)
621 .expression_attribute_values(":val", AttributeValue::B(Blob::new(value)))
622 .expression_attribute_values(":increment", AttributeValue::N("1".to_owned()));
623
624 let state = self.state.lock().unwrap().clone();
625 match state {
626 CasState::Versioned(version) => {
627 update = update
628 .condition_expression("#VER = :ver")
629 .expression_attribute_values(":ver", AttributeValue::N(version));
630 }
631 CasState::Unversioned(old_val) => {
632 update = update
633 .condition_expression("#VAL = :old_val")
634 .expression_attribute_values(":old_val", AttributeValue::B(old_val));
635 }
636 CasState::Unset => {
637 update = update.condition_expression("attribute_not_exists (#VAL)");
638 }
639 CasState::Unknown => (),
640 };
641
642 self.client
643 .transact_write_items()
644 .transact_items(
645 TransactWriteItem::builder()
646 .update(
647 update
648 .build()
649 .map_err(|e| SwapError::Other(format!("{e:?}")))?,
650 )
651 .build(),
652 )
653 .send()
654 .await
655 .map_err(|e| SwapError::CasFailed(format!("{e:?}")))?;
656
657 Ok(())
658 }
659
660 async fn bucket_rep(&self) -> u32 {
661 self.bucket_rep
662 }
663
664 async fn key(&self) -> String {
665 self.key.clone()
666 }
667}