from __future__ import annotations __all__ = ( "AsyncCacheInfo", "AsyncCacheParameters", "AsyncLRUCacheWrapper", "cache", "lru_cache", "reduce", ) import functools import sys from collections import OrderedDict from collections.abc import ( AsyncIterable, Awaitable, Callable, Coroutine, Hashable, Iterable, ) from functools import update_wrapper from inspect import iscoroutinefunction from typing import ( Any, Generic, NamedTuple, TypedDict, TypeVar, cast, final, overload, ) from weakref import WeakKeyDictionary from ._core._synchronization import Lock from .lowlevel import RunVar, checkpoint if sys.version_info >= (3, 11): from typing import ParamSpec else: from typing_extensions import ParamSpec T = TypeVar("T") S = TypeVar("S") P = ParamSpec("P") lru_cache_items: RunVar[ WeakKeyDictionary[ AsyncLRUCacheWrapper[Any, Any], OrderedDict[Hashable, tuple[_InitialMissingType, Lock] | tuple[Any, None]], ] ] = RunVar("lru_cache_items") class _InitialMissingType: pass initial_missing: _InitialMissingType = _InitialMissingType() class AsyncCacheInfo(NamedTuple): hits: int misses: int maxsize: int | None currsize: int class AsyncCacheParameters(TypedDict): maxsize: int | None typed: bool always_checkpoint: bool @final class AsyncLRUCacheWrapper(Generic[P, T]): def __init__( self, func: Callable[..., Awaitable[T]], maxsize: int | None, typed: bool, always_checkpoint: bool, ): self.__wrapped__ = func self._hits: int = 0 self._misses: int = 0 self._maxsize = max(maxsize, 0) if maxsize is not None else None self._currsize: int = 0 self._typed = typed self._always_checkpoint = always_checkpoint update_wrapper(self, func) def cache_info(self) -> AsyncCacheInfo: return AsyncCacheInfo(self._hits, self._misses, self._maxsize, self._currsize) def cache_parameters(self) -> AsyncCacheParameters: return { "maxsize": self._maxsize, "typed": self._typed, "always_checkpoint": self._always_checkpoint, } def cache_clear(self) -> None: if cache := lru_cache_items.get(None): cache.pop(self, None) self._hits = self._misses = self._currsize = 0 async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # Easy case first: if maxsize == 0, no caching is done if self._maxsize == 0: value = await self.__wrapped__(*args, **kwargs) self._misses += 1 return value # The key is constructed as a flat tuple to avoid memory overhead key: tuple[Any, ...] = args if kwargs: # initial_missing is used as a separator key += (initial_missing,) + sum(kwargs.items(), ()) if self._typed: key += tuple(type(arg) for arg in args) if kwargs: key += (initial_missing,) + tuple(type(val) for val in kwargs.values()) try: cache = lru_cache_items.get() except LookupError: cache = WeakKeyDictionary() lru_cache_items.set(cache) try: cache_entry = cache[self] except KeyError: cache_entry = cache[self] = OrderedDict() cached_value: T | _InitialMissingType try: cached_value, lock = cache_entry[key] except KeyError: # We're the first task to call this function cached_value, lock = ( initial_missing, Lock(fast_acquire=not self._always_checkpoint), ) cache_entry[key] = cached_value, lock if lock is None: # The value was already cached self._hits += 1 cache_entry.move_to_end(key) if self._always_checkpoint: await checkpoint() return cast(T, cached_value) async with lock: # Check if another task filled the cache while we acquired the lock if (cached_value := cache_entry[key][0]) is initial_missing: self._misses += 1 if self._maxsize is not None and self._currsize >= self._maxsize: cache_entry.popitem(last=False) else: self._currsize += 1 value = await self.__wrapped__(*args, **kwargs) cache_entry[key] = value, None else: # Another task filled the cache while we were waiting for the lock self._hits += 1 cache_entry.move_to_end(key) value = cast(T, cached_value) return value class _LRUCacheWrapper(Generic[T]): def __init__(self, maxsize: int | None, typed: bool, always_checkpoint: bool): self._maxsize = maxsize self._typed = typed self._always_checkpoint = always_checkpoint @overload def __call__( # type: ignore[overload-overlap] self, func: Callable[P, Coroutine[Any, Any, T]], / ) -> AsyncLRUCacheWrapper[P, T]: ... @overload def __call__( self, func: Callable[..., T], / ) -> functools._lru_cache_wrapper[T]: ... def __call__( self, f: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T], / ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]: if iscoroutinefunction(f): return AsyncLRUCacheWrapper( f, self._maxsize, self._typed, self._always_checkpoint ) return functools.lru_cache(maxsize=self._maxsize, typed=self._typed)(f) # type: ignore[arg-type] @overload def cache( # type: ignore[overload-overlap] func: Callable[P, Coroutine[Any, Any, T]], / ) -> AsyncLRUCacheWrapper[P, T]: ... @overload def cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ... def cache( func: Callable[..., T] | Callable[P, Coroutine[Any, Any, T]], / ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]: """ A convenient shortcut for :func:`lru_cache` with ``maxsize=None``. This is the asynchronous equivalent to :func:`functools.cache`. """ return lru_cache(maxsize=None)(func) @overload def lru_cache( *, maxsize: int | None = ..., typed: bool = ..., always_checkpoint: bool = ... ) -> _LRUCacheWrapper[Any]: ... @overload def lru_cache( # type: ignore[overload-overlap] func: Callable[P, Coroutine[Any, Any, T]], / ) -> AsyncLRUCacheWrapper[P, T]: ... @overload def lru_cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ... def lru_cache( func: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T] | None = None, /, *, maxsize: int | None = 128, typed: bool = False, always_checkpoint: bool = False, ) -> ( AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T] | _LRUCacheWrapper[Any] ): """ An asynchronous version of :func:`functools.lru_cache`. If a synchronous function is passed, the standard library :func:`functools.lru_cache` is applied instead. :param always_checkpoint: if ``True``, every call to the cached function will be guaranteed to yield control to the event loop at least once .. note:: Caches and locks are managed on a per-event loop basis. """ if func is None: return _LRUCacheWrapper[Any](maxsize, typed, always_checkpoint) if not callable(func): raise TypeError("the first argument must be callable") return _LRUCacheWrapper[T](maxsize, typed, always_checkpoint)(func) @overload async def reduce( function: Callable[[T, S], Awaitable[T]], iterable: Iterable[S] | AsyncIterable[S], /, initial: T, ) -> T: ... @overload async def reduce( function: Callable[[T, T], Awaitable[T]], iterable: Iterable[T] | AsyncIterable[T], /, ) -> T: ... async def reduce( # type: ignore[misc] function: Callable[[T, T], Awaitable[T]] | Callable[[T, S], Awaitable[T]], iterable: Iterable[T] | Iterable[S] | AsyncIterable[T] | AsyncIterable[S], /, initial: T | _InitialMissingType = initial_missing, ) -> T: """ Asynchronous version of :func:`functools.reduce`. :param function: a coroutine function that takes two arguments: the accumulated value and the next element from the iterable :param iterable: an iterable or async iterable :param initial: the initial value (if missing, the first element of the iterable is used as the initial value) """ element: Any function_called = False if isinstance(iterable, AsyncIterable): async_it = iterable.__aiter__() if initial is initial_missing: try: value = cast(T, await async_it.__anext__()) except StopAsyncIteration: raise TypeError( "reduce() of empty sequence with no initial value" ) from None else: value = cast(T, initial) async for element in async_it: value = await function(value, element) function_called = True elif isinstance(iterable, Iterable): it = iter(iterable) if initial is initial_missing: try: value = cast(T, next(it)) except StopIteration: raise TypeError( "reduce() of empty sequence with no initial value" ) from None else: value = cast(T, initial) for element in it: value = await function(value, element) function_called = True else: raise TypeError("reduce() argument 2 must be an iterable or async iterable") # Make sure there is at least one checkpoint, even if an empty iterable and an # initial value were given if not function_called: await checkpoint() return value