227 lines
8.3 KiB
Python
227 lines
8.3 KiB
Python
import datetime
|
|
import typing as _t
|
|
|
|
from cachelib.base import BaseCache
|
|
from cachelib.serializers import DynamoDbSerializer
|
|
|
|
CREATED_AT_FIELD = "created_at"
|
|
RESPONSE_FIELD = "response"
|
|
|
|
|
|
class DynamoDbCache(BaseCache):
|
|
"""
|
|
Implementation of cachelib.BaseCache that uses an AWS DynamoDb table
|
|
as the backend.
|
|
|
|
Your server process will require dynamodb:GetItem and dynamodb:PutItem
|
|
IAM permissions on the cache table.
|
|
|
|
Limitations: DynamoDB table items are limited to 400 KB in size. Since
|
|
this class stores cached items in a table, the max size of a cache entry
|
|
will be slightly less than 400 KB, since the cache key and expiration
|
|
time fields are also part of the item.
|
|
|
|
:param table_name: The name of the DynamoDB table to use
|
|
:param default_timeout: Set the timeout in seconds after which cache entries
|
|
expire
|
|
:param key_field: The name of the hash_key attribute in the DynamoDb
|
|
table. This must be a string attribute.
|
|
:param expiration_time_field: The name of the table attribute to store the
|
|
expiration time in. This will be an int
|
|
attribute. The timestamp will be stored as
|
|
seconds past the epoch. If you configure
|
|
this as the TTL field, then DynamoDB will
|
|
automatically delete expired entries.
|
|
:param key_prefix: A prefix that should be added to all keys.
|
|
|
|
"""
|
|
|
|
serializer = DynamoDbSerializer()
|
|
|
|
def __init__(
|
|
self,
|
|
table_name: _t.Optional[str] = "python-cache",
|
|
default_timeout: int = 300,
|
|
key_field: _t.Optional[str] = "cache_key",
|
|
expiration_time_field: _t.Optional[str] = "expiration_time",
|
|
key_prefix: _t.Optional[str] = None,
|
|
**kwargs: _t.Any
|
|
):
|
|
super().__init__(default_timeout)
|
|
|
|
try:
|
|
import boto3 # type: ignore
|
|
except ImportError as err:
|
|
raise RuntimeError("no boto3 module found") from err
|
|
|
|
self._table_name = table_name
|
|
self._key_field = key_field
|
|
self._expiration_time_field = expiration_time_field
|
|
self.key_prefix = key_prefix or ""
|
|
self._dynamo = boto3.resource("dynamodb", **kwargs)
|
|
self._attr = boto3.dynamodb.conditions.Attr
|
|
|
|
try:
|
|
self._table = self._dynamo.Table(table_name)
|
|
self._table.load()
|
|
# catch this exception (triggered if the table doesn't exist)
|
|
except Exception:
|
|
table = self._dynamo.create_table(
|
|
AttributeDefinitions=[
|
|
{"AttributeName": key_field, "AttributeType": "S"}
|
|
],
|
|
TableName=table_name,
|
|
KeySchema=[
|
|
{"AttributeName": key_field, "KeyType": "HASH"},
|
|
],
|
|
BillingMode="PAY_PER_REQUEST",
|
|
)
|
|
table.wait_until_exists()
|
|
dynamo = boto3.client("dynamodb", **kwargs)
|
|
dynamo.update_time_to_live(
|
|
TableName=table_name,
|
|
TimeToLiveSpecification={
|
|
"Enabled": True,
|
|
"AttributeName": expiration_time_field,
|
|
},
|
|
)
|
|
self._table = self._dynamo.Table(table_name)
|
|
self._table.load()
|
|
|
|
def _utcnow(self) -> _t.Any:
|
|
"""Return a tz-aware UTC datetime representing the current time"""
|
|
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
|
|
|
|
def _get_item(self, key: str, attributes: _t.Optional[list] = None) -> _t.Any:
|
|
"""
|
|
Get an item from the cache table, optionally limiting the returned
|
|
attributes.
|
|
|
|
:param key: The cache key of the item to fetch
|
|
|
|
:param attributes: An optional list of attributes to fetch. If not
|
|
given, all attributes are fetched. The
|
|
expiration_time field will always be added to the
|
|
list of fetched attributes.
|
|
:return: The table item for key if it exists and is not expired, else
|
|
None
|
|
"""
|
|
kwargs = {}
|
|
if attributes:
|
|
if self._expiration_time_field not in attributes:
|
|
attributes = list(attributes) + [self._expiration_time_field]
|
|
kwargs = dict(ProjectionExpression=",".join(attributes))
|
|
|
|
response = self._table.get_item(Key={self._key_field: key}, **kwargs)
|
|
cache_item = response.get("Item")
|
|
|
|
if cache_item:
|
|
now = int(self._utcnow().timestamp())
|
|
if cache_item.get(self._expiration_time_field, now + 100) > now:
|
|
return cache_item
|
|
|
|
return None
|
|
|
|
def get(self, key: str) -> _t.Any:
|
|
"""
|
|
Get a cache item
|
|
|
|
:param key: The cache key of the item to fetch
|
|
:return: cache value if not expired, else None
|
|
"""
|
|
cache_item = self._get_item(self.key_prefix + key)
|
|
if cache_item:
|
|
response = cache_item[RESPONSE_FIELD]
|
|
value = self.serializer.loads(response)
|
|
return value
|
|
return None
|
|
|
|
def delete(self, key: str) -> bool:
|
|
"""
|
|
Deletes an item from the cache. This is a no-op if the item doesn't
|
|
exist
|
|
|
|
:param key: Key of the item to delete.
|
|
:return: True if the key existed and was deleted
|
|
"""
|
|
try:
|
|
|
|
self._table.delete_item(
|
|
Key={self._key_field: self.key_prefix + key},
|
|
ConditionExpression=self._attr(self._key_field).exists(),
|
|
)
|
|
return True
|
|
except self._dynamo.meta.client.exceptions.ConditionalCheckFailedException:
|
|
return False
|
|
|
|
def _set(
|
|
self,
|
|
key: str,
|
|
value: _t.Any,
|
|
timeout: _t.Optional[int] = None,
|
|
overwrite: _t.Optional[bool] = True,
|
|
) -> _t.Any:
|
|
"""
|
|
Store a cache item, with the option to not overwrite existing items
|
|
|
|
:param key: Cache key to use
|
|
:param value: a serializable object
|
|
:param timeout: The timeout in seconds for the cached item, to override
|
|
the default
|
|
:param overwrite: If true, overwrite any existing cache item with key.
|
|
If false, the new value will only be stored if no
|
|
non-expired cache item exists with key.
|
|
:return: True if the new item was stored.
|
|
"""
|
|
timeout = self._normalize_timeout(timeout)
|
|
now = self._utcnow()
|
|
|
|
kwargs = {}
|
|
if not overwrite:
|
|
# Cause the put to fail if a non-expired item with this key
|
|
# already exists
|
|
|
|
cond = self._attr(self._key_field).not_exists() | self._attr(
|
|
self._expiration_time_field
|
|
).lte(int(now.timestamp()))
|
|
kwargs = dict(ConditionExpression=cond)
|
|
|
|
try:
|
|
dump = self.serializer.dumps(value)
|
|
item = {
|
|
self._key_field: key,
|
|
CREATED_AT_FIELD: now.isoformat(),
|
|
RESPONSE_FIELD: dump,
|
|
}
|
|
if timeout > 0:
|
|
expiration_time = now + datetime.timedelta(seconds=timeout)
|
|
item[self._expiration_time_field] = int(expiration_time.timestamp())
|
|
self._table.put_item(Item=item, **kwargs)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
|
|
return self._set(self.key_prefix + key, value, timeout=timeout, overwrite=True)
|
|
|
|
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
|
|
return self._set(self.key_prefix + key, value, timeout=timeout, overwrite=False)
|
|
|
|
def has(self, key: str) -> bool:
|
|
return (
|
|
self._get_item(self.key_prefix + key, [self._expiration_time_field])
|
|
is not None
|
|
)
|
|
|
|
def clear(self) -> bool:
|
|
paginator = self._dynamo.meta.client.get_paginator("scan")
|
|
pages = paginator.paginate(
|
|
TableName=self._table_name, ProjectionExpression=self._key_field
|
|
)
|
|
|
|
with self._table.batch_writer() as batch:
|
|
for page in pages:
|
|
for item in page["Items"]:
|
|
batch.delete_item(Key=item)
|
|
|
|
return True
|