1use core::str;
2use std::{
3 collections::HashMap,
4 sync::{Arc, Mutex},
5};
6
7use anyhow::Result;
8use aws_config::{BehaviorVersion, Region, SdkConfig};
9use aws_credential_types::Credentials;
10use aws_sdk_dynamodb::{
11 config::{ProvideCredentials, SharedCredentialsProvider},
12 operation::{
13 batch_get_item::BatchGetItemOutput, batch_write_item::BatchWriteItemOutput,
14 get_item::GetItemOutput,
15 },
16 primitives::Blob,
17 types::{
18 AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, TransactWriteItem, Update,
19 WriteRequest,
20 },
21 Client,
22};
23use spin_core::async_trait;
24use spin_factor_key_value::{log_error, Cas, Error, Store, StoreManager, SwapError};
25
26pub struct KeyValueAwsDynamo {
27 region: String,
29 consistent_read: bool,
31 table: Arc<String>,
33 client: async_once_cell::Lazy<
35 Client,
36 std::pin::Pin<Box<dyn std::future::Future<Output = Client> + Send>>,
37 >,
38}
39
40#[derive(Clone, Debug)]
42pub struct KeyValueAwsDynamoRuntimeConfigOptions {
43 access_key: String,
44 secret_key: String,
45 token: Option<String>,
46}
47
48impl KeyValueAwsDynamoRuntimeConfigOptions {
49 pub fn new(access_key: String, secret_key: String, token: Option<String>) -> Self {
50 Self {
51 access_key,
52 secret_key,
53 token,
54 }
55 }
56}
57
58impl ProvideCredentials for KeyValueAwsDynamoRuntimeConfigOptions {
59 fn provide_credentials<'a>(
60 &'a self,
61 ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
62 where
63 Self: 'a,
64 {
65 aws_credential_types::provider::future::ProvideCredentials::ready(Ok(Credentials::new(
66 self.access_key.clone(),
67 self.secret_key.clone(),
68 self.token.clone(),
69 None, "spin_custom_aws_provider",
71 )))
72 }
73}
74
75#[derive(Clone, Debug)]
77pub enum KeyValueAwsDynamoAuthOptions {
78 RuntimeConfigValues(KeyValueAwsDynamoRuntimeConfigOptions),
80 Environmental,
87}
88
89impl KeyValueAwsDynamo {
90 pub fn new(
91 region: String,
92 consistent_read: bool,
93 table: String,
94 auth_options: KeyValueAwsDynamoAuthOptions,
95 ) -> Result<Self> {
96 let region_clone = region.clone();
97 let client_fut = Box::pin(async move {
98 let sdk_config = match auth_options {
99 KeyValueAwsDynamoAuthOptions::RuntimeConfigValues(config) => SdkConfig::builder()
100 .credentials_provider(SharedCredentialsProvider::new(config))
101 .region(Region::new(region_clone))
102 .behavior_version(BehaviorVersion::latest())
103 .build(),
104 KeyValueAwsDynamoAuthOptions::Environmental => {
105 aws_config::load_defaults(BehaviorVersion::latest()).await
106 }
107 };
108 Client::new(&sdk_config)
109 });
110
111 Ok(Self {
112 region,
113 consistent_read,
114 table: Arc::new(table),
115 client: async_once_cell::Lazy::from_future(client_fut),
116 })
117 }
118}
119
120#[async_trait]
121impl StoreManager for KeyValueAwsDynamo {
122 async fn get(&self, _name: &str) -> Result<Arc<dyn Store>, Error> {
123 Ok(Arc::new(AwsDynamoStore {
124 client: self.client.get_unpin().await.clone(),
125 table: self.table.clone(),
126 consistent_read: self.consistent_read,
127 }))
128 }
129
130 fn is_defined(&self, _store_name: &str) -> bool {
131 true
132 }
133
134 fn summary(&self, _store_name: &str) -> Option<String> {
135 Some(format!(
136 "AWS DynamoDB region: {}, table: {}",
137 self.region, self.table
138 ))
139 }
140}
141
142struct AwsDynamoStore {
143 client: Client,
145 table: Arc<String>,
146 consistent_read: bool,
147}
148
149#[derive(Debug, Clone)]
150enum CasState {
151 Versioned(String),
153 Unversioned(Blob),
155 Unset,
157 Unknown,
159}
160
161struct CompareAndSwap {
162 key: String,
163 client: Client,
164 table: Arc<String>,
165 bucket_rep: u32,
166 state: Mutex<CasState>,
167}
168
169const PK: &str = "PK";
171const VAL: &str = "VAL";
173const VER: &str = "VER";
175
176#[async_trait]
177impl Store for AwsDynamoStore {
178 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
179 let response = self
180 .client
181 .get_item()
182 .consistent_read(self.consistent_read)
183 .table_name(self.table.as_str())
184 .key(
185 PK,
186 aws_sdk_dynamodb::types::AttributeValue::S(key.to_string()),
187 )
188 .projection_expression(VAL)
189 .send()
190 .await
191 .map_err(log_error)?;
192
193 let item = response.item.and_then(|mut item| {
194 if let Some(AttributeValue::B(val)) = item.remove(VAL) {
195 Some(val.into_inner())
196 } else {
197 None
198 }
199 });
200
201 Ok(item)
202 }
203
204 async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
205 self.client
206 .put_item()
207 .table_name(self.table.as_str())
208 .item(PK, AttributeValue::S(key.to_string()))
209 .item(VAL, AttributeValue::B(Blob::new(value)))
210 .send()
211 .await
212 .map_err(log_error)?;
213 Ok(())
214 }
215
216 async fn delete(&self, key: &str) -> Result<(), Error> {
217 self.client
218 .delete_item()
219 .table_name(self.table.as_str())
220 .key(PK, AttributeValue::S(key.to_string()))
221 .send()
222 .await
223 .map_err(log_error)?;
224 Ok(())
225 }
226
227 async fn exists(&self, key: &str) -> Result<bool, Error> {
228 let GetItemOutput { item, .. } = self
229 .client
230 .get_item()
231 .consistent_read(self.consistent_read)
232 .table_name(self.table.as_str())
233 .key(
234 PK,
235 aws_sdk_dynamodb::types::AttributeValue::S(key.to_string()),
236 )
237 .projection_expression(PK)
238 .send()
239 .await
240 .map_err(log_error)?;
241
242 Ok(item.map(|item| item.contains_key(PK)).unwrap_or(false))
243 }
244
245 async fn get_keys(&self) -> Result<Vec<String>, Error> {
246 let mut primary_keys = Vec::new();
247
248 let mut scan_paginator = self
249 .client
250 .scan()
251 .table_name(self.table.as_str())
252 .projection_expression(PK)
253 .into_paginator()
254 .send();
255
256 while let Some(output) = scan_paginator.next().await {
257 let scan_output = output.map_err(log_error)?;
258 if let Some(items) = scan_output.items {
259 for mut item in items {
260 if let Some(AttributeValue::S(pk)) = item.remove(PK) {
261 primary_keys.push(pk);
262 }
263 }
264 }
265 }
266
267 Ok(primary_keys)
268 }
269
270 async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
271 let mut results = Vec::with_capacity(keys.len());
272 let mut keys_and_attributes_builder = KeysAndAttributes::builder()
273 .projection_expression(format!("{PK},{VAL}"))
274 .consistent_read(self.consistent_read);
275 for key in keys {
276 keys_and_attributes_builder = keys_and_attributes_builder.keys(HashMap::from_iter([(
277 PK.to_owned(),
278 AttributeValue::S(key),
279 )]))
280 }
281 let mut request_items = Some(HashMap::from_iter([(
282 self.table.to_string(),
283 keys_and_attributes_builder.build().map_err(log_error)?,
284 )]));
285
286 while request_items.is_some() {
287 let BatchGetItemOutput {
288 responses,
289 unprocessed_keys,
290 ..
291 } = self
292 .client
293 .batch_get_item()
294 .set_request_items(request_items)
295 .send()
296 .await
297 .map_err(log_error)?;
298
299 if let Some(items) =
300 responses.and_then(|mut responses| responses.remove(self.table.as_str()))
301 {
302 for mut item in items {
303 match (item.remove(PK), item.remove(VAL)) {
304 (Some(AttributeValue::S(pk)), Some(AttributeValue::B(val))) => {
305 results.push((pk, Some(val.into_inner())));
306 }
307 (Some(AttributeValue::S(pk)), None) => {
308 results.push((pk, None));
309 }
310 _ => (),
311 }
312 }
313 }
314
315 request_items = unprocessed_keys.filter(|unprocessed| !unprocessed.is_empty());
316 }
317
318 Ok(results)
319 }
320
321 async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
322 let mut data = Vec::with_capacity(key_values.len());
323 for (key, val) in key_values {
324 data.push(
325 WriteRequest::builder()
326 .put_request(
327 PutRequest::builder()
328 .item(PK, AttributeValue::S(key))
329 .item(VAL, AttributeValue::B(Blob::new(val)))
330 .build()
331 .map_err(log_error)?,
332 )
333 .build(),
334 )
335 }
336
337 let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)]));
338
339 while request_items.is_some() {
340 let BatchWriteItemOutput {
341 unprocessed_items, ..
342 } = self
343 .client
344 .batch_write_item()
345 .set_request_items(request_items)
346 .send()
347 .await
348 .map_err(log_error)?;
349
350 request_items = unprocessed_items.filter(|unprocessed| !unprocessed.is_empty());
351 }
352
353 Ok(())
354 }
355
356 async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
357 let mut data = Vec::with_capacity(keys.len());
358 for key in keys {
359 data.push(
360 WriteRequest::builder()
361 .delete_request(
362 DeleteRequest::builder()
363 .key(PK, AttributeValue::S(key))
364 .build()
365 .map_err(log_error)?,
366 )
367 .build(),
368 )
369 }
370
371 let mut request_items = Some(HashMap::from_iter([(self.table.to_string(), data)]));
372
373 while request_items.is_some() {
374 let BatchWriteItemOutput {
375 unprocessed_items, ..
376 } = self
377 .client
378 .batch_write_item()
379 .set_request_items(request_items)
380 .send()
381 .await
382 .map_err(log_error)?;
383
384 request_items = unprocessed_items.filter(|unprocessed| !unprocessed.is_empty());
385 }
386
387 Ok(())
388 }
389
390 async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
391 let GetItemOutput { item, .. } = self
392 .client
393 .get_item()
394 .consistent_read(true)
395 .table_name(self.table.as_str())
396 .key(PK, AttributeValue::S(key.clone()))
397 .projection_expression(VAL)
398 .send()
399 .await
400 .map_err(log_error)?;
401
402 let old_val = match item {
403 Some(mut current_item) => match current_item.remove(VAL) {
404 Some(AttributeValue::B(val)) => Some(
406 str::from_utf8(&val.into_inner())
407 .map_err(log_error)?
408 .parse::<i64>()
409 .map_err(log_error)?,
410 ),
411 _ => None,
412 },
413 None => None,
414 };
415
416 let new_val = old_val.unwrap_or(0) + delta;
417
418 let mut update = Update::builder()
419 .table_name(self.table.as_str())
420 .key(PK, AttributeValue::S(key))
421 .update_expression("SET #VAL = :new_val")
422 .expression_attribute_names("#VAL", VAL)
423 .expression_attribute_values(
424 ":new_val",
425 AttributeValue::B(Blob::new(new_val.to_string().as_bytes())),
426 );
427
428 if let Some(old_val) = old_val {
429 update = update
430 .condition_expression("#VAL = :old_val")
431 .expression_attribute_values(
432 ":old_val",
433 AttributeValue::B(Blob::new(old_val.to_string().as_bytes())),
434 )
435 } else {
436 update = update.condition_expression("attribute_not_exists (#VAL)")
437 }
438
439 self.client
440 .transact_write_items()
441 .transact_items(
442 TransactWriteItem::builder()
443 .update(update.build().map_err(log_error)?)
444 .build(),
445 )
446 .send()
447 .await
448 .map_err(log_error)?;
449
450 Ok(new_val)
451 }
452
453 async fn new_compare_and_swap(
454 &self,
455 bucket_rep: u32,
456 key: &str,
457 ) -> Result<Arc<dyn spin_factor_key_value::Cas>, Error> {
458 Ok(Arc::new(CompareAndSwap {
459 key: key.to_string(),
460 client: self.client.clone(),
461 table: self.table.clone(),
462 state: Mutex::new(CasState::Unknown),
463 bucket_rep,
464 }))
465 }
466}
467
468#[async_trait]
469impl Cas for CompareAndSwap {
470 async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
471 let GetItemOutput { item, .. } = self
472 .client
473 .get_item()
474 .consistent_read(true)
475 .table_name(self.table.as_str())
476 .key(PK, AttributeValue::S(self.key.clone()))
477 .projection_expression(format!("{VAL},{VER}"))
478 .send()
479 .await
480 .map_err(log_error)?;
481
482 match item {
483 Some(mut current_item) => match (current_item.remove(VAL), current_item.remove(VER)) {
484 (Some(AttributeValue::B(val)), Some(AttributeValue::N(ver))) => {
485 self.state
486 .lock()
487 .unwrap()
488 .clone_from(&CasState::Versioned(ver));
489
490 Ok(Some(val.into_inner()))
491 }
492 (Some(AttributeValue::B(val)), _) => {
493 self.state
494 .lock()
495 .unwrap()
496 .clone_from(&CasState::Unversioned(val.clone()));
497
498 Ok(Some(val.into_inner()))
499 }
500 (_, _) => {
501 self.state.lock().unwrap().clone_from(&CasState::Unset);
502 Ok(None)
503 }
504 },
505 None => {
506 self.state.lock().unwrap().clone_from(&CasState::Unset);
507 Ok(None)
508 }
509 }
510 }
511
512 async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
515 let mut update = Update::builder()
516 .table_name(self.table.as_str())
517 .key(PK, AttributeValue::S(self.key.clone()))
518 .update_expression("SET #VAL = :val ADD #VER :increment")
519 .expression_attribute_names("#VAL", VAL)
520 .expression_attribute_names("#VER", VER)
521 .expression_attribute_values(":val", AttributeValue::B(Blob::new(value)))
522 .expression_attribute_values(":increment", AttributeValue::N("1".to_owned()));
523
524 let state = self.state.lock().unwrap().clone();
525 match state {
526 CasState::Versioned(version) => {
527 update = update
528 .condition_expression("#VER = :ver")
529 .expression_attribute_values(":ver", AttributeValue::N(version));
530 }
531 CasState::Unversioned(old_val) => {
532 update = update
533 .condition_expression("#VAL = :old_val")
534 .expression_attribute_values(":old_val", AttributeValue::B(old_val));
535 }
536 CasState::Unset => {
537 update = update.condition_expression("attribute_not_exists (#VAL)");
538 }
539 CasState::Unknown => (),
540 };
541
542 self.client
543 .transact_write_items()
544 .transact_items(
545 TransactWriteItem::builder()
546 .update(
547 update
548 .build()
549 .map_err(|e| SwapError::Other(format!("{e:?}")))?,
550 )
551 .build(),
552 )
553 .send()
554 .await
555 .map_err(|e| SwapError::CasFailed(format!("{e:?}")))?;
556
557 Ok(())
558 }
559
560 async fn bucket_rep(&self) -> u32 {
561 self.bucket_rep
562 }
563
564 async fn key(&self) -> String {
565 self.key.clone()
566 }
567}