_concurrency_py3k.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # util/_concurrency_py3k.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: allow-untyped-defs, allow-untyped-calls
  8. from __future__ import annotations
  9. import asyncio
  10. from contextvars import Context
  11. import sys
  12. import typing
  13. from typing import Any
  14. from typing import Awaitable
  15. from typing import Callable
  16. from typing import Coroutine
  17. from typing import Optional
  18. from typing import TYPE_CHECKING
  19. from typing import TypeVar
  20. from typing import Union
  21. from .langhelpers import memoized_property
  22. from .. import exc
  23. from ..util import py311
  24. from ..util.typing import Literal
  25. from ..util.typing import Protocol
  26. from ..util.typing import Self
  27. from ..util.typing import TypeGuard
  28. _T = TypeVar("_T")
  29. if typing.TYPE_CHECKING:
  30. class greenlet(Protocol):
  31. dead: bool
  32. gr_context: Optional[Context]
  33. def __init__(self, fn: Callable[..., Any], driver: greenlet): ...
  34. def throw(self, *arg: Any) -> Any:
  35. return None
  36. def switch(self, value: Any) -> Any:
  37. return None
  38. def getcurrent() -> greenlet: ...
  39. else:
  40. from greenlet import getcurrent
  41. from greenlet import greenlet
  42. # If greenlet.gr_context is present in current version of greenlet,
  43. # it will be set with the current context on creation.
  44. # Refs: https://github.com/python-greenlet/greenlet/pull/198
  45. _has_gr_context = hasattr(getcurrent(), "gr_context")
  46. def is_exit_exception(e: BaseException) -> bool:
  47. # note asyncio.CancelledError is already BaseException
  48. # so was an exit exception in any case
  49. return not isinstance(e, Exception) or isinstance(
  50. e, (asyncio.TimeoutError, asyncio.CancelledError)
  51. )
  52. # implementation based on snaury gist at
  53. # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
  54. # Issue for context: https://github.com/python-greenlet/greenlet/issues/173
  55. class _AsyncIoGreenlet(greenlet):
  56. dead: bool
  57. __sqlalchemy_greenlet_provider__ = True
  58. def __init__(self, fn: Callable[..., Any], driver: greenlet):
  59. greenlet.__init__(self, fn, driver)
  60. if _has_gr_context:
  61. self.gr_context = driver.gr_context
  62. _T_co = TypeVar("_T_co", covariant=True)
  63. if TYPE_CHECKING:
  64. def iscoroutine(
  65. awaitable: Awaitable[_T_co],
  66. ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: ...
  67. else:
  68. iscoroutine = asyncio.iscoroutine
  69. def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
  70. # https://docs.python.org/3/reference/datamodel.html#coroutine.close
  71. if iscoroutine(awaitable):
  72. awaitable.close()
  73. def in_greenlet() -> bool:
  74. current = getcurrent()
  75. return getattr(current, "__sqlalchemy_greenlet_provider__", False)
  76. def await_only(awaitable: Awaitable[_T]) -> _T:
  77. """Awaits an async function in a sync method.
  78. The sync method must be inside a :func:`greenlet_spawn` context.
  79. :func:`await_only` calls cannot be nested.
  80. :param awaitable: The coroutine to call.
  81. """
  82. # this is called in the context greenlet while running fn
  83. current = getcurrent()
  84. if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
  85. _safe_cancel_awaitable(awaitable)
  86. raise exc.MissingGreenlet(
  87. "greenlet_spawn has not been called; can't call await_only() "
  88. "here. Was IO attempted in an unexpected place?"
  89. )
  90. # returns the control to the driver greenlet passing it
  91. # a coroutine to run. Once the awaitable is done, the driver greenlet
  92. # switches back to this greenlet with the result of awaitable that is
  93. # then returned to the caller (or raised as error)
  94. return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501
  95. def await_fallback(awaitable: Awaitable[_T]) -> _T:
  96. """Awaits an async function in a sync method.
  97. The sync method must be inside a :func:`greenlet_spawn` context.
  98. :func:`await_fallback` calls cannot be nested.
  99. :param awaitable: The coroutine to call.
  100. .. deprecated:: 2.0.24 The ``await_fallback()`` function will be removed
  101. in SQLAlchemy 2.1. Use :func:`_util.await_only` instead, running the
  102. function / program / etc. within a top-level greenlet that is set up
  103. using :func:`_util.greenlet_spawn`.
  104. """
  105. # this is called in the context greenlet while running fn
  106. current = getcurrent()
  107. if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
  108. loop = get_event_loop()
  109. if loop.is_running():
  110. _safe_cancel_awaitable(awaitable)
  111. raise exc.MissingGreenlet(
  112. "greenlet_spawn has not been called and asyncio event "
  113. "loop is already running; can't call await_fallback() here. "
  114. "Was IO attempted in an unexpected place?"
  115. )
  116. return loop.run_until_complete(awaitable)
  117. return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501
  118. async def greenlet_spawn(
  119. fn: Callable[..., _T],
  120. *args: Any,
  121. _require_await: bool = False,
  122. **kwargs: Any,
  123. ) -> _T:
  124. """Runs a sync function ``fn`` in a new greenlet.
  125. The sync function can then use :func:`await_only` to wait for async
  126. functions.
  127. :param fn: The sync callable to call.
  128. :param \\*args: Positional arguments to pass to the ``fn`` callable.
  129. :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
  130. """
  131. result: Any
  132. context = _AsyncIoGreenlet(fn, getcurrent())
  133. # runs the function synchronously in gl greenlet. If the execution
  134. # is interrupted by await_only, context is not dead and result is a
  135. # coroutine to wait. If the context is dead the function has
  136. # returned, and its result can be returned.
  137. switch_occurred = False
  138. result = context.switch(*args, **kwargs)
  139. while not context.dead:
  140. switch_occurred = True
  141. try:
  142. # wait for a coroutine from await_only and then return its
  143. # result back to it.
  144. value = await result
  145. except BaseException:
  146. # this allows an exception to be raised within
  147. # the moderated greenlet so that it can continue
  148. # its expected flow.
  149. result = context.throw(*sys.exc_info())
  150. else:
  151. result = context.switch(value)
  152. if _require_await and not switch_occurred:
  153. raise exc.AwaitRequired(
  154. "The current operation required an async execution but none was "
  155. "detected. This will usually happen when using a non compatible "
  156. "DBAPI driver. Please ensure that an async DBAPI is used."
  157. )
  158. return result # type: ignore[no-any-return]
  159. class AsyncAdaptedLock:
  160. @memoized_property
  161. def mutex(self) -> asyncio.Lock:
  162. # there should not be a race here for coroutines creating the
  163. # new lock as we are not using await, so therefore no concurrency
  164. return asyncio.Lock()
  165. def __enter__(self) -> bool:
  166. # await is used to acquire the lock only after the first calling
  167. # coroutine has created the mutex.
  168. return await_fallback(self.mutex.acquire())
  169. def __exit__(self, *arg: Any, **kw: Any) -> None:
  170. self.mutex.release()
  171. def get_event_loop() -> asyncio.AbstractEventLoop:
  172. """vendor asyncio.get_event_loop() for python 3.7 and above.
  173. Python 3.10 deprecates get_event_loop() as a standalone.
  174. """
  175. try:
  176. return asyncio.get_running_loop()
  177. except RuntimeError:
  178. # avoid "During handling of the above exception, another exception..."
  179. pass
  180. return asyncio.get_event_loop_policy().get_event_loop()
  181. if not TYPE_CHECKING and py311:
  182. _Runner = asyncio.Runner
  183. else:
  184. class _Runner:
  185. """Runner implementation for test only"""
  186. _loop: Union[None, asyncio.AbstractEventLoop, Literal[False]]
  187. def __init__(self) -> None:
  188. self._loop = None
  189. def __enter__(self) -> Self:
  190. self._lazy_init()
  191. return self
  192. def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
  193. self.close()
  194. def close(self) -> None:
  195. if self._loop:
  196. try:
  197. self._loop.run_until_complete(
  198. self._loop.shutdown_asyncgens()
  199. )
  200. finally:
  201. self._loop.close()
  202. self._loop = False
  203. def get_loop(self) -> asyncio.AbstractEventLoop:
  204. """Return embedded event loop."""
  205. self._lazy_init()
  206. assert self._loop
  207. return self._loop
  208. def run(self, coro: Coroutine[Any, Any, _T]) -> _T:
  209. self._lazy_init()
  210. assert self._loop
  211. return self._loop.run_until_complete(coro)
  212. def _lazy_init(self) -> None:
  213. if self._loop is False:
  214. raise RuntimeError("Runner is closed")
  215. if self._loop is None:
  216. self._loop = asyncio.new_event_loop()