348 lines
9.8 KiB
Python
348 lines
9.8 KiB
Python
from __future__ import annotations
|
|
|
|
__all__ = (
|
|
"AsyncCacheInfo",
|
|
"AsyncCacheParameters",
|
|
"AsyncLRUCacheWrapper",
|
|
"cache",
|
|
"lru_cache",
|
|
"reduce",
|
|
)
|
|
|
|
import functools
|
|
import sys
|
|
from collections import OrderedDict
|
|
from collections.abc import (
|
|
AsyncIterable,
|
|
Awaitable,
|
|
Callable,
|
|
Coroutine,
|
|
Hashable,
|
|
Iterable,
|
|
)
|
|
from functools import update_wrapper
|
|
from inspect import iscoroutinefunction
|
|
from typing import (
|
|
Any,
|
|
Generic,
|
|
NamedTuple,
|
|
TypedDict,
|
|
TypeVar,
|
|
cast,
|
|
final,
|
|
overload,
|
|
)
|
|
from weakref import WeakKeyDictionary
|
|
|
|
from ._core._synchronization import Lock
|
|
from .lowlevel import RunVar, checkpoint
|
|
|
|
if sys.version_info >= (3, 11):
|
|
from typing import ParamSpec
|
|
else:
|
|
from typing_extensions import ParamSpec
|
|
|
|
T = TypeVar("T")
|
|
S = TypeVar("S")
|
|
P = ParamSpec("P")
|
|
lru_cache_items: RunVar[
|
|
WeakKeyDictionary[
|
|
AsyncLRUCacheWrapper[Any, Any],
|
|
OrderedDict[Hashable, tuple[_InitialMissingType, Lock] | tuple[Any, None]],
|
|
]
|
|
] = RunVar("lru_cache_items")
|
|
|
|
|
|
class _InitialMissingType:
|
|
pass
|
|
|
|
|
|
initial_missing: _InitialMissingType = _InitialMissingType()
|
|
|
|
|
|
class AsyncCacheInfo(NamedTuple):
|
|
hits: int
|
|
misses: int
|
|
maxsize: int | None
|
|
currsize: int
|
|
|
|
|
|
class AsyncCacheParameters(TypedDict):
|
|
maxsize: int | None
|
|
typed: bool
|
|
always_checkpoint: bool
|
|
|
|
|
|
@final
|
|
class AsyncLRUCacheWrapper(Generic[P, T]):
|
|
def __init__(
|
|
self,
|
|
func: Callable[..., Awaitable[T]],
|
|
maxsize: int | None,
|
|
typed: bool,
|
|
always_checkpoint: bool,
|
|
):
|
|
self.__wrapped__ = func
|
|
self._hits: int = 0
|
|
self._misses: int = 0
|
|
self._maxsize = max(maxsize, 0) if maxsize is not None else None
|
|
self._currsize: int = 0
|
|
self._typed = typed
|
|
self._always_checkpoint = always_checkpoint
|
|
update_wrapper(self, func)
|
|
|
|
def cache_info(self) -> AsyncCacheInfo:
|
|
return AsyncCacheInfo(self._hits, self._misses, self._maxsize, self._currsize)
|
|
|
|
def cache_parameters(self) -> AsyncCacheParameters:
|
|
return {
|
|
"maxsize": self._maxsize,
|
|
"typed": self._typed,
|
|
"always_checkpoint": self._always_checkpoint,
|
|
}
|
|
|
|
def cache_clear(self) -> None:
|
|
if cache := lru_cache_items.get(None):
|
|
cache.pop(self, None)
|
|
self._hits = self._misses = self._currsize = 0
|
|
|
|
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
# Easy case first: if maxsize == 0, no caching is done
|
|
if self._maxsize == 0:
|
|
value = await self.__wrapped__(*args, **kwargs)
|
|
self._misses += 1
|
|
return value
|
|
|
|
# The key is constructed as a flat tuple to avoid memory overhead
|
|
key: tuple[Any, ...] = args
|
|
if kwargs:
|
|
# initial_missing is used as a separator
|
|
key += (initial_missing,) + sum(kwargs.items(), ())
|
|
|
|
if self._typed:
|
|
key += tuple(type(arg) for arg in args)
|
|
if kwargs:
|
|
key += (initial_missing,) + tuple(type(val) for val in kwargs.values())
|
|
|
|
try:
|
|
cache = lru_cache_items.get()
|
|
except LookupError:
|
|
cache = WeakKeyDictionary()
|
|
lru_cache_items.set(cache)
|
|
|
|
try:
|
|
cache_entry = cache[self]
|
|
except KeyError:
|
|
cache_entry = cache[self] = OrderedDict()
|
|
|
|
cached_value: T | _InitialMissingType
|
|
try:
|
|
cached_value, lock = cache_entry[key]
|
|
except KeyError:
|
|
# We're the first task to call this function
|
|
cached_value, lock = (
|
|
initial_missing,
|
|
Lock(fast_acquire=not self._always_checkpoint),
|
|
)
|
|
cache_entry[key] = cached_value, lock
|
|
|
|
if lock is None:
|
|
# The value was already cached
|
|
self._hits += 1
|
|
cache_entry.move_to_end(key)
|
|
if self._always_checkpoint:
|
|
await checkpoint()
|
|
|
|
return cast(T, cached_value)
|
|
|
|
async with lock:
|
|
# Check if another task filled the cache while we acquired the lock
|
|
if (cached_value := cache_entry[key][0]) is initial_missing:
|
|
self._misses += 1
|
|
if self._maxsize is not None and self._currsize >= self._maxsize:
|
|
cache_entry.popitem(last=False)
|
|
else:
|
|
self._currsize += 1
|
|
|
|
value = await self.__wrapped__(*args, **kwargs)
|
|
cache_entry[key] = value, None
|
|
else:
|
|
# Another task filled the cache while we were waiting for the lock
|
|
self._hits += 1
|
|
cache_entry.move_to_end(key)
|
|
value = cast(T, cached_value)
|
|
|
|
return value
|
|
|
|
|
|
class _LRUCacheWrapper(Generic[T]):
|
|
def __init__(self, maxsize: int | None, typed: bool, always_checkpoint: bool):
|
|
self._maxsize = maxsize
|
|
self._typed = typed
|
|
self._always_checkpoint = always_checkpoint
|
|
|
|
@overload
|
|
def __call__( # type: ignore[overload-overlap]
|
|
self, func: Callable[P, Coroutine[Any, Any, T]], /
|
|
) -> AsyncLRUCacheWrapper[P, T]: ...
|
|
|
|
@overload
|
|
def __call__(
|
|
self, func: Callable[..., T], /
|
|
) -> functools._lru_cache_wrapper[T]: ...
|
|
|
|
def __call__(
|
|
self, f: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T], /
|
|
) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
|
|
if iscoroutinefunction(f):
|
|
return AsyncLRUCacheWrapper(
|
|
f, self._maxsize, self._typed, self._always_checkpoint
|
|
)
|
|
|
|
return functools.lru_cache(maxsize=self._maxsize, typed=self._typed)(f) # type: ignore[arg-type]
|
|
|
|
|
|
@overload
|
|
def cache( # type: ignore[overload-overlap]
|
|
func: Callable[P, Coroutine[Any, Any, T]], /
|
|
) -> AsyncLRUCacheWrapper[P, T]: ...
|
|
|
|
|
|
@overload
|
|
def cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
|
|
|
|
|
|
def cache(
|
|
func: Callable[..., T] | Callable[P, Coroutine[Any, Any, T]], /
|
|
) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
|
|
"""
|
|
A convenient shortcut for :func:`lru_cache` with ``maxsize=None``.
|
|
|
|
This is the asynchronous equivalent to :func:`functools.cache`.
|
|
|
|
"""
|
|
return lru_cache(maxsize=None)(func)
|
|
|
|
|
|
@overload
|
|
def lru_cache(
|
|
*, maxsize: int | None = ..., typed: bool = ..., always_checkpoint: bool = ...
|
|
) -> _LRUCacheWrapper[Any]: ...
|
|
|
|
|
|
@overload
|
|
def lru_cache( # type: ignore[overload-overlap]
|
|
func: Callable[P, Coroutine[Any, Any, T]], /
|
|
) -> AsyncLRUCacheWrapper[P, T]: ...
|
|
|
|
|
|
@overload
|
|
def lru_cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
|
|
|
|
|
|
def lru_cache(
|
|
func: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T] | None = None,
|
|
/,
|
|
*,
|
|
maxsize: int | None = 128,
|
|
typed: bool = False,
|
|
always_checkpoint: bool = False,
|
|
) -> (
|
|
AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T] | _LRUCacheWrapper[Any]
|
|
):
|
|
"""
|
|
An asynchronous version of :func:`functools.lru_cache`.
|
|
|
|
If a synchronous function is passed, the standard library
|
|
:func:`functools.lru_cache` is applied instead.
|
|
|
|
:param always_checkpoint: if ``True``, every call to the cached function will be
|
|
guaranteed to yield control to the event loop at least once
|
|
|
|
.. note:: Caches and locks are managed on a per-event loop basis.
|
|
|
|
"""
|
|
if func is None:
|
|
return _LRUCacheWrapper[Any](maxsize, typed, always_checkpoint)
|
|
|
|
if not callable(func):
|
|
raise TypeError("the first argument must be callable")
|
|
|
|
return _LRUCacheWrapper[T](maxsize, typed, always_checkpoint)(func)
|
|
|
|
|
|
@overload
|
|
async def reduce(
|
|
function: Callable[[T, S], Awaitable[T]],
|
|
iterable: Iterable[S] | AsyncIterable[S],
|
|
/,
|
|
initial: T,
|
|
) -> T: ...
|
|
|
|
|
|
@overload
|
|
async def reduce(
|
|
function: Callable[[T, T], Awaitable[T]],
|
|
iterable: Iterable[T] | AsyncIterable[T],
|
|
/,
|
|
) -> T: ...
|
|
|
|
|
|
async def reduce( # type: ignore[misc]
|
|
function: Callable[[T, T], Awaitable[T]] | Callable[[T, S], Awaitable[T]],
|
|
iterable: Iterable[T] | Iterable[S] | AsyncIterable[T] | AsyncIterable[S],
|
|
/,
|
|
initial: T | _InitialMissingType = initial_missing,
|
|
) -> T:
|
|
"""
|
|
Asynchronous version of :func:`functools.reduce`.
|
|
|
|
:param function: a coroutine function that takes two arguments: the accumulated
|
|
value and the next element from the iterable
|
|
:param iterable: an iterable or async iterable
|
|
:param initial: the initial value (if missing, the first element of the iterable is
|
|
used as the initial value)
|
|
|
|
"""
|
|
element: Any
|
|
function_called = False
|
|
if isinstance(iterable, AsyncIterable):
|
|
async_it = iterable.__aiter__()
|
|
if initial is initial_missing:
|
|
try:
|
|
value = cast(T, await async_it.__anext__())
|
|
except StopAsyncIteration:
|
|
raise TypeError(
|
|
"reduce() of empty sequence with no initial value"
|
|
) from None
|
|
else:
|
|
value = cast(T, initial)
|
|
|
|
async for element in async_it:
|
|
value = await function(value, element)
|
|
function_called = True
|
|
elif isinstance(iterable, Iterable):
|
|
it = iter(iterable)
|
|
if initial is initial_missing:
|
|
try:
|
|
value = cast(T, next(it))
|
|
except StopIteration:
|
|
raise TypeError(
|
|
"reduce() of empty sequence with no initial value"
|
|
) from None
|
|
else:
|
|
value = cast(T, initial)
|
|
|
|
for element in it:
|
|
value = await function(value, element)
|
|
function_called = True
|
|
else:
|
|
raise TypeError("reduce() argument 2 must be an iterable or async iterable")
|
|
|
|
# Make sure there is at least one checkpoint, even if an empty iterable and an
|
|
# initial value were given
|
|
if not function_called:
|
|
await checkpoint()
|
|
|
|
return value
|