asyncio.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # testing/asyncio.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. # functions and wrappers to run tests, fixtures, provisioning and
  9. # setup/teardown in an asyncio event loop, conditionally based on the
  10. # current DB driver being used for a test.
  11. # note that SQLAlchemy's asyncio integration also supports a method
  12. # of running individual asyncio functions inside of separate event loops
  13. # using "async_fallback" mode; however running whole functions in the event
  14. # loop is a more accurate test for how SQLAlchemy's asyncio features
  15. # would run in the real world.
  16. from __future__ import annotations
  17. from functools import wraps
  18. import inspect
  19. from . import config
  20. from ..util.concurrency import _AsyncUtil
  21. # may be set to False if the
  22. # --disable-asyncio flag is passed to the test runner.
  23. ENABLE_ASYNCIO = True
  24. _async_util = _AsyncUtil() # it has lazy init so just always create one
  25. def _shutdown():
  26. """called when the test finishes"""
  27. _async_util.close()
  28. def _run_coroutine_function(fn, *args, **kwargs):
  29. return _async_util.run(fn, *args, **kwargs)
  30. def _assume_async(fn, *args, **kwargs):
  31. """Run a function in an asyncio loop unconditionally.
  32. This function is used for provisioning features like
  33. testing a database connection for server info.
  34. Note that for blocking IO database drivers, this means they block the
  35. event loop.
  36. """
  37. if not ENABLE_ASYNCIO:
  38. return fn(*args, **kwargs)
  39. return _async_util.run_in_greenlet(fn, *args, **kwargs)
  40. def _maybe_async_provisioning(fn, *args, **kwargs):
  41. """Run a function in an asyncio loop if any current drivers might need it.
  42. This function is used for provisioning features that take
  43. place outside of a specific database driver being selected, so if the
  44. current driver that happens to be used for the provisioning operation
  45. is an async driver, it will run in asyncio and not fail.
  46. Note that for blocking IO database drivers, this means they block the
  47. event loop.
  48. """
  49. if not ENABLE_ASYNCIO:
  50. return fn(*args, **kwargs)
  51. if config.any_async:
  52. return _async_util.run_in_greenlet(fn, *args, **kwargs)
  53. else:
  54. return fn(*args, **kwargs)
  55. def _maybe_async(fn, *args, **kwargs):
  56. """Run a function in an asyncio loop if the current selected driver is
  57. async.
  58. This function is used for test setup/teardown and tests themselves
  59. where the current DB driver is known.
  60. """
  61. if not ENABLE_ASYNCIO:
  62. return fn(*args, **kwargs)
  63. is_async = config._current.is_async
  64. if is_async:
  65. return _async_util.run_in_greenlet(fn, *args, **kwargs)
  66. else:
  67. return fn(*args, **kwargs)
  68. def _maybe_async_wrapper(fn):
  69. """Apply the _maybe_async function to an existing function and return
  70. as a wrapped callable, supporting generator functions as well.
  71. This is currently used for pytest fixtures that support generator use.
  72. """
  73. if inspect.isgeneratorfunction(fn):
  74. _stop = object()
  75. def call_next(gen):
  76. try:
  77. return next(gen)
  78. # can't raise StopIteration in an awaitable.
  79. except StopIteration:
  80. return _stop
  81. @wraps(fn)
  82. def wrap_fixture(*args, **kwargs):
  83. gen = fn(*args, **kwargs)
  84. while True:
  85. value = _maybe_async(call_next, gen)
  86. if value is _stop:
  87. break
  88. yield value
  89. else:
  90. @wraps(fn)
  91. def wrap_fixture(*args, **kwargs):
  92. return _maybe_async(fn, *args, **kwargs)
  93. return wrap_fixture