matekasse/venv/lib/python3.11/site-packages/cachelib/file.py

337 lines
12 KiB
Python
Raw Normal View History

2023-07-28 21:30:45 +00:00
import errno
import logging
import os
import platform
import stat
import struct
import tempfile
import typing as _t
from contextlib import contextmanager
from hashlib import md5
from pathlib import Path
from time import sleep
from time import time
from cachelib.base import BaseCache
from cachelib.serializers import FileSystemSerializer
class FileSystemCache(BaseCache):
"""A cache that stores the items on the file system. This cache depends
on being the only user of the `cache_dir`. Make absolutely sure that
nobody but this cache stores files there or otherwise the cache will
randomly delete files therein.
:param cache_dir: the directory where cache files are stored.
:param threshold: the maximum number of items the cache stores before
it starts deleting some. A threshold value of 0
indicates no threshold.
:param default_timeout: the default timeout that is used if no timeout is
specified on :meth:`~BaseCache.set`. A timeout of
0 indicates that the cache never expires.
:param mode: the file mode wanted for the cache files, default 0600
:param hash_method: Default hashlib.md5. The hash method used to
generate the filename for cached results.
"""
#: used for temporary files by the FileSystemCache
_fs_transaction_suffix = ".__wz_cache"
#: keep amount of files in a cache element
_fs_count_file = "__wz_cache_count"
serializer = FileSystemSerializer()
def __init__(
self,
cache_dir: str,
threshold: int = 500,
default_timeout: int = 300,
mode: _t.Optional[int] = None,
hash_method: _t.Any = md5,
):
BaseCache.__init__(self, default_timeout)
self._path = cache_dir
self._threshold = threshold
self._hash_method = hash_method
# Mode set by user takes precedence. If no mode has
# been given, we need to set the correct default based
# on user platform.
self._mode = mode
if self._mode is None:
self._mode = self._get_compatible_platform_mode()
try:
os.makedirs(self._path)
except OSError as ex:
if ex.errno != errno.EEXIST:
raise
# If there are many files and a zero threshold,
# the list_dir can slow initialisation massively
if self._threshold != 0:
self._update_count(value=len(list(self._list_dir())))
def _get_compatible_platform_mode(self) -> int:
mode = 0o600 # nix systems
if platform.system() == "Windows":
mode = stat.S_IWRITE
return mode
@property
def _file_count(self) -> int:
return self.get(self._fs_count_file) or 0
def _update_count(
self, delta: _t.Optional[int] = None, value: _t.Optional[int] = None
) -> None:
# If we have no threshold, don't count files
if self._threshold == 0:
return
if delta:
new_count = self._file_count + delta
else:
new_count = value or 0
self.set(self._fs_count_file, new_count, mgmt_element=True)
def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
timeout = BaseCache._normalize_timeout(self, timeout)
if timeout != 0:
timeout = int(time()) + timeout
return int(timeout)
def _is_mgmt(self, name: str) -> bool:
fshash = self._get_filename(self._fs_count_file).split(os.sep)[-1]
return name == fshash or name.endswith(self._fs_transaction_suffix)
def _list_dir(self) -> _t.Generator[str, None, None]:
"""return a list of (fully qualified) cache filenames"""
return (
os.path.join(self._path, fn)
for fn in os.listdir(self._path)
if not self._is_mgmt(fn)
)
def _over_threshold(self) -> bool:
return self._threshold != 0 and self._file_count > self._threshold
def _remove_expired(self, now: float) -> None:
for fname in self._list_dir():
try:
with self._safe_stream_open(fname, "rb") as f:
expires = struct.unpack("I", f.read(4))[0]
if expires != 0 and expires < now:
os.remove(fname)
self._update_count(delta=-1)
except FileNotFoundError:
pass
except (OSError, EOFError, struct.error):
logging.warning(
"Exception raised while handling cache file '%s'",
fname,
exc_info=True,
)
def _remove_older(self) -> bool:
exp_fname_tuples = []
for fname in self._list_dir():
try:
with self._safe_stream_open(fname, "rb") as f:
timestamp = struct.unpack("I", f.read(4))[0]
exp_fname_tuples.append((timestamp, fname))
except FileNotFoundError:
pass
except (OSError, EOFError, struct.error):
logging.warning(
"Exception raised while handling cache file '%s'",
fname,
exc_info=True,
)
fname_sorted = (
fname
for _, fname in sorted(
exp_fname_tuples, key=lambda item: item[0] # type: ignore
)
)
for fname in fname_sorted:
try:
os.remove(fname)
self._update_count(delta=-1)
except FileNotFoundError:
pass
except OSError:
logging.warning(
"Exception raised while handling cache file '%s'",
fname,
exc_info=True,
)
return False
if not self._over_threshold():
break
return True
def _prune(self) -> None:
if self._over_threshold():
now = time()
self._remove_expired(now)
# if still over threshold
if self._over_threshold():
self._remove_older()
def clear(self) -> bool:
for i, fname in enumerate(self._list_dir()):
try:
os.remove(fname)
except FileNotFoundError:
pass
except OSError:
logging.warning(
"Exception raised while handling cache file '%s'",
fname,
exc_info=True,
)
self._update_count(delta=-i)
return False
self._update_count(value=0)
return True
def _get_filename(self, key: str) -> str:
if isinstance(key, str):
bkey = key.encode("utf-8") # XXX unicode review
bkey_hash = self._hash_method(bkey).hexdigest()
else:
raise TypeError(f"Key must be a string, received type {type(key)}")
return os.path.join(self._path, bkey_hash)
def get(self, key: str) -> _t.Any:
filename = self._get_filename(key)
try:
with self._safe_stream_open(filename, "rb") as f:
pickle_time = struct.unpack("I", f.read(4))[0]
if pickle_time == 0 or pickle_time >= time():
return self.serializer.load(f)
except FileNotFoundError:
pass
except (OSError, EOFError, struct.error):
logging.warning(
"Exception raised while handling cache file '%s'",
filename,
exc_info=True,
)
return None
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool:
filename = self._get_filename(key)
if not os.path.exists(filename):
return self.set(key, value, timeout)
return False
def set(
self,
key: str,
value: _t.Any,
timeout: _t.Optional[int] = None,
mgmt_element: bool = False,
) -> bool:
# Management elements have no timeout
if mgmt_element:
timeout = 0
# Don't prune on management element update, to avoid loop
else:
self._prune()
timeout = self._normalize_timeout(timeout)
filename = self._get_filename(key)
overwrite = os.path.isfile(filename)
try:
fd, tmp = tempfile.mkstemp(
suffix=self._fs_transaction_suffix, dir=self._path
)
with os.fdopen(fd, "wb") as f:
f.write(struct.pack("I", timeout))
self.serializer.dump(value, f)
self._run_safely(os.replace, tmp, filename)
self._run_safely(os.chmod, filename, self._mode)
fsize = Path(filename).stat().st_size
except OSError:
logging.warning(
"Exception raised while handling cache file '%s'",
filename,
exc_info=True,
)
return False
else:
# Management elements should not count towards threshold
if not overwrite and not mgmt_element:
self._update_count(delta=1)
return fsize > 0 # function should fail if file is empty
def delete(self, key: str, mgmt_element: bool = False) -> bool:
try:
os.remove(self._get_filename(key))
except FileNotFoundError: # if file doesn't exist we consider it deleted
return True
except OSError:
logging.warning("Exception raised while handling cache file", exc_info=True)
return False
else:
# Management elements should not count towards threshold
if not mgmt_element:
self._update_count(delta=-1)
return True
def has(self, key: str) -> bool:
filename = self._get_filename(key)
try:
with self._safe_stream_open(filename, "rb") as f:
pickle_time = struct.unpack("I", f.read(4))[0]
if pickle_time == 0 or pickle_time >= time():
return True
else:
return False
except FileNotFoundError: # if there is no file there is no key
return False
except (OSError, EOFError, struct.error):
logging.warning(
"Exception raised while handling cache file '%s'",
filename,
exc_info=True,
)
return False
def _run_safely(self, fn: _t.Callable, *args: _t.Any, **kwargs: _t.Any) -> _t.Any:
"""On Windows os.replace, os.chmod and open can yield
permission errors if executed by two different processes."""
if platform.system() == "Windows":
output = None
wait_step = 0.001
max_sleep_time = 10.0
total_sleep_time = 0.0
while total_sleep_time < max_sleep_time:
try:
output = fn(*args, **kwargs)
except PermissionError:
sleep(wait_step)
total_sleep_time += wait_step
wait_step *= 2
else:
break
else:
output = fn(*args, **kwargs)
return output
@contextmanager
def _safe_stream_open(self, path: str, mode: str) -> _t.Generator:
fs = self._run_safely(open, path, mode)
if fs is None:
raise OSError
try:
yield fs
finally:
fs.close()