194 lines
7.2 KiB
Python
194 lines
7.2 KiB
Python
import inspect
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from functools import cached_property, partial
|
|
from typing import Any, Callable, List, Optional, Union
|
|
|
|
from fastapi._compat import ModelField
|
|
from fastapi.security.base import SecurityBase
|
|
from fastapi.types import DependencyCacheKey
|
|
from typing_extensions import Literal
|
|
|
|
if sys.version_info >= (3, 13): # pragma: no cover
|
|
from inspect import iscoroutinefunction
|
|
else: # pragma: no cover
|
|
from asyncio import iscoroutinefunction
|
|
|
|
|
|
def _unwrapped_call(call: Optional[Callable[..., Any]]) -> Any:
|
|
if call is None:
|
|
return call # pragma: no cover
|
|
unwrapped = inspect.unwrap(_impartial(call))
|
|
return unwrapped
|
|
|
|
|
|
def _impartial(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
while isinstance(func, partial):
|
|
func = func.func
|
|
return func
|
|
|
|
|
|
@dataclass
|
|
class Dependant:
|
|
path_params: List[ModelField] = field(default_factory=list)
|
|
query_params: List[ModelField] = field(default_factory=list)
|
|
header_params: List[ModelField] = field(default_factory=list)
|
|
cookie_params: List[ModelField] = field(default_factory=list)
|
|
body_params: List[ModelField] = field(default_factory=list)
|
|
dependencies: List["Dependant"] = field(default_factory=list)
|
|
name: Optional[str] = None
|
|
call: Optional[Callable[..., Any]] = None
|
|
request_param_name: Optional[str] = None
|
|
websocket_param_name: Optional[str] = None
|
|
http_connection_param_name: Optional[str] = None
|
|
response_param_name: Optional[str] = None
|
|
background_tasks_param_name: Optional[str] = None
|
|
security_scopes_param_name: Optional[str] = None
|
|
own_oauth_scopes: Optional[List[str]] = None
|
|
parent_oauth_scopes: Optional[List[str]] = None
|
|
use_cache: bool = True
|
|
path: Optional[str] = None
|
|
scope: Union[Literal["function", "request"], None] = None
|
|
|
|
@cached_property
|
|
def oauth_scopes(self) -> List[str]:
|
|
scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
|
|
# This doesn't use a set to preserve order, just in case
|
|
for scope in self.own_oauth_scopes or []:
|
|
if scope not in scopes:
|
|
scopes.append(scope)
|
|
return scopes
|
|
|
|
@cached_property
|
|
def cache_key(self) -> DependencyCacheKey:
|
|
scopes_for_cache = (
|
|
tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
|
|
)
|
|
return (
|
|
self.call,
|
|
scopes_for_cache,
|
|
self.computed_scope or "",
|
|
)
|
|
|
|
@cached_property
|
|
def _uses_scopes(self) -> bool:
|
|
if self.own_oauth_scopes:
|
|
return True
|
|
if self.security_scopes_param_name is not None:
|
|
return True
|
|
if self._is_security_scheme:
|
|
return True
|
|
for sub_dep in self.dependencies:
|
|
if sub_dep._uses_scopes:
|
|
return True
|
|
return False
|
|
|
|
@cached_property
|
|
def _is_security_scheme(self) -> bool:
|
|
if self.call is None:
|
|
return False # pragma: no cover
|
|
unwrapped = _unwrapped_call(self.call)
|
|
return isinstance(unwrapped, SecurityBase)
|
|
|
|
# Mainly to get the type of SecurityBase, but it's the same self.call
|
|
@cached_property
|
|
def _security_scheme(self) -> SecurityBase:
|
|
unwrapped = _unwrapped_call(self.call)
|
|
assert isinstance(unwrapped, SecurityBase)
|
|
return unwrapped
|
|
|
|
@cached_property
|
|
def _security_dependencies(self) -> List["Dependant"]:
|
|
security_deps = [dep for dep in self.dependencies if dep._is_security_scheme]
|
|
return security_deps
|
|
|
|
@cached_property
|
|
def is_gen_callable(self) -> bool:
|
|
if self.call is None:
|
|
return False # pragma: no cover
|
|
if inspect.isgeneratorfunction(
|
|
_impartial(self.call)
|
|
) or inspect.isgeneratorfunction(_unwrapped_call(self.call)):
|
|
return True
|
|
if inspect.isclass(_unwrapped_call(self.call)):
|
|
return False
|
|
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
|
|
if dunder_call is None:
|
|
return False # pragma: no cover
|
|
if inspect.isgeneratorfunction(
|
|
_impartial(dunder_call)
|
|
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_call)):
|
|
return True
|
|
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
|
|
if dunder_unwrapped_call is None:
|
|
return False # pragma: no cover
|
|
if inspect.isgeneratorfunction(
|
|
_impartial(dunder_unwrapped_call)
|
|
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_unwrapped_call)):
|
|
return True
|
|
return False
|
|
|
|
@cached_property
|
|
def is_async_gen_callable(self) -> bool:
|
|
if self.call is None:
|
|
return False # pragma: no cover
|
|
if inspect.isasyncgenfunction(
|
|
_impartial(self.call)
|
|
) or inspect.isasyncgenfunction(_unwrapped_call(self.call)):
|
|
return True
|
|
if inspect.isclass(_unwrapped_call(self.call)):
|
|
return False
|
|
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
|
|
if dunder_call is None:
|
|
return False # pragma: no cover
|
|
if inspect.isasyncgenfunction(
|
|
_impartial(dunder_call)
|
|
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_call)):
|
|
return True
|
|
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
|
|
if dunder_unwrapped_call is None:
|
|
return False # pragma: no cover
|
|
if inspect.isasyncgenfunction(
|
|
_impartial(dunder_unwrapped_call)
|
|
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_unwrapped_call)):
|
|
return True
|
|
return False
|
|
|
|
@cached_property
|
|
def is_coroutine_callable(self) -> bool:
|
|
if self.call is None:
|
|
return False # pragma: no cover
|
|
if inspect.isroutine(_impartial(self.call)) and iscoroutinefunction(
|
|
_impartial(self.call)
|
|
):
|
|
return True
|
|
if inspect.isroutine(_unwrapped_call(self.call)) and iscoroutinefunction(
|
|
_unwrapped_call(self.call)
|
|
):
|
|
return True
|
|
if inspect.isclass(_unwrapped_call(self.call)):
|
|
return False
|
|
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
|
|
if dunder_call is None:
|
|
return False # pragma: no cover
|
|
if iscoroutinefunction(_impartial(dunder_call)) or iscoroutinefunction(
|
|
_unwrapped_call(dunder_call)
|
|
):
|
|
return True
|
|
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
|
|
if dunder_unwrapped_call is None:
|
|
return False # pragma: no cover
|
|
if iscoroutinefunction(
|
|
_impartial(dunder_unwrapped_call)
|
|
) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)):
|
|
return True
|
|
return False
|
|
|
|
@cached_property
|
|
def computed_scope(self) -> Union[str, None]:
|
|
if self.scope:
|
|
return self.scope
|
|
if self.is_gen_callable or self.is_async_gen_callable:
|
|
return "request"
|
|
return None
|