base.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # ext/asyncio/base.py
  2. # Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. from __future__ import annotations
  8. import abc
  9. import functools
  10. from typing import Any
  11. from typing import AsyncGenerator
  12. from typing import AsyncIterator
  13. from typing import Awaitable
  14. from typing import Callable
  15. from typing import ClassVar
  16. from typing import Dict
  17. from typing import Generator
  18. from typing import Generic
  19. from typing import NoReturn
  20. from typing import Optional
  21. from typing import overload
  22. from typing import Tuple
  23. from typing import TypeVar
  24. import weakref
  25. from . import exc as async_exc
  26. from ... import util
  27. from ...util.typing import Literal
  28. from ...util.typing import Self
  29. _T = TypeVar("_T", bound=Any)
  30. _T_co = TypeVar("_T_co", bound=Any, covariant=True)
  31. _PT = TypeVar("_PT", bound=Any)
  32. class ReversibleProxy(Generic[_PT]):
  33. _proxy_objects: ClassVar[
  34. Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]]
  35. ] = {}
  36. __slots__ = ("__weakref__",)
  37. @overload
  38. def _assign_proxied(self, target: _PT) -> _PT: ...
  39. @overload
  40. def _assign_proxied(self, target: None) -> None: ...
  41. def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
  42. if target is not None:
  43. target_ref: weakref.ref[_PT] = weakref.ref(
  44. target, ReversibleProxy._target_gced
  45. )
  46. proxy_ref = weakref.ref(
  47. self,
  48. functools.partial(ReversibleProxy._target_gced, target_ref),
  49. )
  50. ReversibleProxy._proxy_objects[target_ref] = proxy_ref
  51. return target
  52. @classmethod
  53. def _target_gced(
  54. cls,
  55. ref: weakref.ref[_PT],
  56. proxy_ref: Optional[weakref.ref[Self]] = None, # noqa: U100
  57. ) -> None:
  58. cls._proxy_objects.pop(ref, None)
  59. @classmethod
  60. def _regenerate_proxy_for_target(
  61. cls, target: _PT, **additional_kw: Any
  62. ) -> Self:
  63. raise NotImplementedError()
  64. @overload
  65. @classmethod
  66. def _retrieve_proxy_for_target(
  67. cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any
  68. ) -> Self: ...
  69. @overload
  70. @classmethod
  71. def _retrieve_proxy_for_target(
  72. cls, target: _PT, regenerate: bool = True, **additional_kw: Any
  73. ) -> Optional[Self]: ...
  74. @classmethod
  75. def _retrieve_proxy_for_target(
  76. cls, target: _PT, regenerate: bool = True, **additional_kw: Any
  77. ) -> Optional[Self]:
  78. try:
  79. proxy_ref = cls._proxy_objects[weakref.ref(target)]
  80. except KeyError:
  81. pass
  82. else:
  83. proxy = proxy_ref()
  84. if proxy is not None:
  85. return proxy # type: ignore
  86. if regenerate:
  87. return cls._regenerate_proxy_for_target(target, **additional_kw)
  88. else:
  89. return None
  90. class StartableContext(Awaitable[_T_co], abc.ABC):
  91. __slots__ = ()
  92. @abc.abstractmethod
  93. async def start(self, is_ctxmanager: bool = False) -> _T_co:
  94. raise NotImplementedError()
  95. def __await__(self) -> Generator[Any, Any, _T_co]:
  96. return self.start().__await__()
  97. async def __aenter__(self) -> _T_co:
  98. return await self.start(is_ctxmanager=True)
  99. @abc.abstractmethod
  100. async def __aexit__(
  101. self, type_: Any, value: Any, traceback: Any
  102. ) -> Optional[bool]:
  103. pass
  104. def _raise_for_not_started(self) -> NoReturn:
  105. raise async_exc.AsyncContextNotStarted(
  106. "%s context has not been started and object has not been awaited."
  107. % (self.__class__.__name__)
  108. )
  109. class GeneratorStartableContext(StartableContext[_T_co]):
  110. __slots__ = ("gen",)
  111. gen: AsyncGenerator[_T_co, Any]
  112. def __init__(
  113. self,
  114. func: Callable[..., AsyncIterator[_T_co]],
  115. args: Tuple[Any, ...],
  116. kwds: Dict[str, Any],
  117. ):
  118. self.gen = func(*args, **kwds) # type: ignore
  119. async def start(self, is_ctxmanager: bool = False) -> _T_co:
  120. try:
  121. start_value = await util.anext_(self.gen)
  122. except StopAsyncIteration:
  123. raise RuntimeError("generator didn't yield") from None
  124. # if not a context manager, then interrupt the generator, don't
  125. # let it complete. this step is technically not needed, as the
  126. # generator will close in any case at gc time. not clear if having
  127. # this here is a good idea or not (though it helps for clarity IMO)
  128. if not is_ctxmanager:
  129. await self.gen.aclose()
  130. return start_value
  131. async def __aexit__(
  132. self, typ: Any, value: Any, traceback: Any
  133. ) -> Optional[bool]:
  134. # vendored from contextlib.py
  135. if typ is None:
  136. try:
  137. await util.anext_(self.gen)
  138. except StopAsyncIteration:
  139. return False
  140. else:
  141. raise RuntimeError("generator didn't stop")
  142. else:
  143. if value is None:
  144. # Need to force instantiation so we can reliably
  145. # tell if we get the same exception back
  146. value = typ()
  147. try:
  148. await self.gen.athrow(value)
  149. except StopAsyncIteration as exc:
  150. # Suppress StopIteration *unless* it's the same exception that
  151. # was passed to throw(). This prevents a StopIteration
  152. # raised inside the "with" statement from being suppressed.
  153. return exc is not value
  154. except RuntimeError as exc:
  155. # Don't re-raise the passed in exception. (issue27122)
  156. if exc is value:
  157. return False
  158. # Avoid suppressing if a Stop(Async)Iteration exception
  159. # was passed to athrow() and later wrapped into a RuntimeError
  160. # (see PEP 479 for sync generators; async generators also
  161. # have this behavior). But do this only if the exception
  162. # wrapped
  163. # by the RuntimeError is actully Stop(Async)Iteration (see
  164. # issue29692).
  165. if (
  166. isinstance(value, (StopIteration, StopAsyncIteration))
  167. and exc.__cause__ is value
  168. ):
  169. return False
  170. raise
  171. except BaseException as exc:
  172. # only re-raise if it's *not* the exception that was
  173. # passed to throw(), because __exit__() must not raise
  174. # an exception unless __exit__() itself failed. But throw()
  175. # has to raise the exception to signal propagation, so this
  176. # fixes the impedance mismatch between the throw() protocol
  177. # and the __exit__() protocol.
  178. if exc is not value:
  179. raise
  180. return False
  181. raise RuntimeError("generator didn't stop after athrow()")
  182. def asyncstartablecontext(
  183. func: Callable[..., AsyncIterator[_T_co]],
  184. ) -> Callable[..., GeneratorStartableContext[_T_co]]:
  185. """@asyncstartablecontext decorator.
  186. the decorated function can be called either as ``async with fn()``, **or**
  187. ``await fn()``. This is decidedly different from what
  188. ``@contextlib.asynccontextmanager`` supports, and the usage pattern
  189. is different as well.
  190. Typical usage:
  191. .. sourcecode:: text
  192. @asyncstartablecontext
  193. async def some_async_generator(<arguments>):
  194. <setup>
  195. try:
  196. yield <value>
  197. except GeneratorExit:
  198. # return value was awaited, no context manager is present
  199. # and caller will .close() the resource explicitly
  200. pass
  201. else:
  202. <context manager cleanup>
  203. Above, ``GeneratorExit`` is caught if the function were used as an
  204. ``await``. In this case, it's essential that the cleanup does **not**
  205. occur, so there should not be a ``finally`` block.
  206. If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__``
  207. and we were invoked as a context manager, and cleanup should proceed.
  208. """
  209. @functools.wraps(func)
  210. def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]:
  211. return GeneratorStartableContext(func, args, kwds)
  212. return helper
  213. class ProxyComparable(ReversibleProxy[_PT]):
  214. __slots__ = ()
  215. @util.ro_non_memoized_property
  216. def _proxied(self) -> _PT:
  217. raise NotImplementedError()
  218. def __hash__(self) -> int:
  219. return id(self)
  220. def __eq__(self, other: Any) -> bool:
  221. return (
  222. isinstance(other, self.__class__)
  223. and self._proxied == other._proxied
  224. )
  225. def __ne__(self, other: Any) -> bool:
  226. return (
  227. not isinstance(other, self.__class__)
  228. or self._proxied != other._proxied
  229. )