state_changes.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # orm/state_changes.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. """State tracking utilities used by :class:`_orm.Session`."""
  8. from __future__ import annotations
  9. import contextlib
  10. from enum import Enum
  11. from typing import Any
  12. from typing import Callable
  13. from typing import cast
  14. from typing import Iterator
  15. from typing import NoReturn
  16. from typing import Optional
  17. from typing import Tuple
  18. from typing import TypeVar
  19. from typing import Union
  20. from .. import exc as sa_exc
  21. from .. import util
  22. from ..util.typing import Literal
  23. _F = TypeVar("_F", bound=Callable[..., Any])
  24. class _StateChangeState(Enum):
  25. pass
  26. class _StateChangeStates(_StateChangeState):
  27. ANY = 1
  28. NO_CHANGE = 2
  29. CHANGE_IN_PROGRESS = 3
  30. class _StateChange:
  31. """Supplies state assertion decorators.
  32. The current use case is for the :class:`_orm.SessionTransaction` class. The
  33. :class:`_StateChange` class itself is agnostic of the
  34. :class:`_orm.SessionTransaction` class so could in theory be generalized
  35. for other systems as well.
  36. """
  37. _next_state: _StateChangeState = _StateChangeStates.ANY
  38. _state: _StateChangeState = _StateChangeStates.NO_CHANGE
  39. _current_fn: Optional[Callable[..., Any]] = None
  40. def _raise_for_prerequisite_state(
  41. self, operation_name: str, state: _StateChangeState
  42. ) -> NoReturn:
  43. raise sa_exc.IllegalStateChangeError(
  44. f"Can't run operation '{operation_name}()' when Session "
  45. f"is in state {state!r}",
  46. code="isce",
  47. )
  48. @classmethod
  49. def declare_states(
  50. cls,
  51. prerequisite_states: Union[
  52. Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...]
  53. ],
  54. moves_to: _StateChangeState,
  55. ) -> Callable[[_F], _F]:
  56. """Method decorator declaring valid states.
  57. :param prerequisite_states: sequence of acceptable prerequisite
  58. states. Can be the single constant _State.ANY to indicate no
  59. prerequisite state
  60. :param moves_to: the expected state at the end of the method, assuming
  61. no exceptions raised. Can be the constant _State.NO_CHANGE to
  62. indicate state should not change at the end of the method.
  63. """
  64. assert prerequisite_states, "no prequisite states sent"
  65. has_prerequisite_states = (
  66. prerequisite_states is not _StateChangeStates.ANY
  67. )
  68. prerequisite_state_collection = cast(
  69. "Tuple[_StateChangeState, ...]", prerequisite_states
  70. )
  71. expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE
  72. @util.decorator
  73. def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any:
  74. current_state = self._state
  75. if (
  76. has_prerequisite_states
  77. and current_state not in prerequisite_state_collection
  78. ):
  79. self._raise_for_prerequisite_state(fn.__name__, current_state)
  80. next_state = self._next_state
  81. existing_fn = self._current_fn
  82. expect_state = moves_to if expect_state_change else current_state
  83. if (
  84. # destination states are restricted
  85. next_state is not _StateChangeStates.ANY
  86. # method seeks to change state
  87. and expect_state_change
  88. # destination state incorrect
  89. and next_state is not expect_state
  90. ):
  91. if existing_fn and next_state in (
  92. _StateChangeStates.NO_CHANGE,
  93. _StateChangeStates.CHANGE_IN_PROGRESS,
  94. ):
  95. raise sa_exc.IllegalStateChangeError(
  96. f"Method '{fn.__name__}()' can't be called here; "
  97. f"method '{existing_fn.__name__}()' is already "
  98. f"in progress and this would cause an unexpected "
  99. f"state change to {moves_to!r}",
  100. code="isce",
  101. )
  102. else:
  103. raise sa_exc.IllegalStateChangeError(
  104. f"Cant run operation '{fn.__name__}()' here; "
  105. f"will move to state {moves_to!r} where we are "
  106. f"expecting {next_state!r}",
  107. code="isce",
  108. )
  109. self._current_fn = fn
  110. self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS
  111. try:
  112. ret_value = fn(self, *arg, **kw)
  113. except:
  114. raise
  115. else:
  116. if self._state is expect_state:
  117. return ret_value
  118. if self._state is current_state:
  119. raise sa_exc.IllegalStateChangeError(
  120. f"Method '{fn.__name__}()' failed to "
  121. "change state "
  122. f"to {moves_to!r} as expected",
  123. code="isce",
  124. )
  125. elif existing_fn:
  126. raise sa_exc.IllegalStateChangeError(
  127. f"While method '{existing_fn.__name__}()' was "
  128. "running, "
  129. f"method '{fn.__name__}()' caused an "
  130. "unexpected "
  131. f"state change to {self._state!r}",
  132. code="isce",
  133. )
  134. else:
  135. raise sa_exc.IllegalStateChangeError(
  136. f"Method '{fn.__name__}()' caused an unexpected "
  137. f"state change to {self._state!r}",
  138. code="isce",
  139. )
  140. finally:
  141. self._next_state = next_state
  142. self._current_fn = existing_fn
  143. return _go
  144. @contextlib.contextmanager
  145. def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]:
  146. """called within a method that changes states.
  147. method must also use the ``@declare_states()`` decorator.
  148. """
  149. assert self._next_state is _StateChangeStates.CHANGE_IN_PROGRESS, (
  150. "Unexpected call to _expect_state outside of "
  151. "state-changing method"
  152. )
  153. self._next_state = expected
  154. try:
  155. yield
  156. except:
  157. raise
  158. else:
  159. if self._state is not expected:
  160. raise sa_exc.IllegalStateChangeError(
  161. f"Unexpected state change to {self._state!r}", code="isce"
  162. )
  163. finally:
  164. self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS