identity.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # orm/identity.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. from __future__ import annotations
  8. from typing import Any
  9. from typing import cast
  10. from typing import Dict
  11. from typing import Iterable
  12. from typing import Iterator
  13. from typing import List
  14. from typing import NoReturn
  15. from typing import Optional
  16. from typing import Set
  17. from typing import Tuple
  18. from typing import TYPE_CHECKING
  19. from typing import TypeVar
  20. import weakref
  21. from . import util as orm_util
  22. from .. import exc as sa_exc
  23. if TYPE_CHECKING:
  24. from ._typing import _IdentityKeyType
  25. from .state import InstanceState
  26. _T = TypeVar("_T", bound=Any)
  27. _O = TypeVar("_O", bound=object)
  28. class IdentityMap:
  29. _wr: weakref.ref[IdentityMap]
  30. _dict: Dict[_IdentityKeyType[Any], Any]
  31. _modified: Set[InstanceState[Any]]
  32. def __init__(self) -> None:
  33. self._dict = {}
  34. self._modified = set()
  35. self._wr = weakref.ref(self)
  36. def _kill(self) -> None:
  37. self._add_unpresent = _killed # type: ignore
  38. def all_states(self) -> List[InstanceState[Any]]:
  39. raise NotImplementedError()
  40. def contains_state(self, state: InstanceState[Any]) -> bool:
  41. raise NotImplementedError()
  42. def __contains__(self, key: _IdentityKeyType[Any]) -> bool:
  43. raise NotImplementedError()
  44. def safe_discard(self, state: InstanceState[Any]) -> None:
  45. raise NotImplementedError()
  46. def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
  47. raise NotImplementedError()
  48. def get(
  49. self, key: _IdentityKeyType[_O], default: Optional[_O] = None
  50. ) -> Optional[_O]:
  51. raise NotImplementedError()
  52. def fast_get_state(
  53. self, key: _IdentityKeyType[_O]
  54. ) -> Optional[InstanceState[_O]]:
  55. raise NotImplementedError()
  56. def keys(self) -> Iterable[_IdentityKeyType[Any]]:
  57. return self._dict.keys()
  58. def values(self) -> Iterable[object]:
  59. raise NotImplementedError()
  60. def replace(self, state: InstanceState[_O]) -> Optional[InstanceState[_O]]:
  61. raise NotImplementedError()
  62. def add(self, state: InstanceState[Any]) -> bool:
  63. raise NotImplementedError()
  64. def _fast_discard(self, state: InstanceState[Any]) -> None:
  65. raise NotImplementedError()
  66. def _add_unpresent(
  67. self, state: InstanceState[Any], key: _IdentityKeyType[Any]
  68. ) -> None:
  69. """optional inlined form of add() which can assume item isn't present
  70. in the map"""
  71. self.add(state)
  72. def _manage_incoming_state(self, state: InstanceState[Any]) -> None:
  73. state._instance_dict = self._wr
  74. if state.modified:
  75. self._modified.add(state)
  76. def _manage_removed_state(self, state: InstanceState[Any]) -> None:
  77. del state._instance_dict
  78. if state.modified:
  79. self._modified.discard(state)
  80. def _dirty_states(self) -> Set[InstanceState[Any]]:
  81. return self._modified
  82. def check_modified(self) -> bool:
  83. """return True if any InstanceStates present have been marked
  84. as 'modified'.
  85. """
  86. return bool(self._modified)
  87. def has_key(self, key: _IdentityKeyType[Any]) -> bool:
  88. return key in self
  89. def __len__(self) -> int:
  90. return len(self._dict)
  91. class WeakInstanceDict(IdentityMap):
  92. _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]]
  93. def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
  94. state = cast("InstanceState[_O]", self._dict[key])
  95. o = state.obj()
  96. if o is None:
  97. raise KeyError(key)
  98. return o
  99. def __contains__(self, key: _IdentityKeyType[Any]) -> bool:
  100. try:
  101. if key in self._dict:
  102. state = self._dict[key]
  103. o = state.obj()
  104. else:
  105. return False
  106. except KeyError:
  107. return False
  108. else:
  109. return o is not None
  110. def contains_state(self, state: InstanceState[Any]) -> bool:
  111. if state.key in self._dict:
  112. if TYPE_CHECKING:
  113. assert state.key is not None
  114. try:
  115. return self._dict[state.key] is state
  116. except KeyError:
  117. return False
  118. else:
  119. return False
  120. def replace(
  121. self, state: InstanceState[Any]
  122. ) -> Optional[InstanceState[Any]]:
  123. assert state.key is not None
  124. if state.key in self._dict:
  125. try:
  126. existing = existing_non_none = self._dict[state.key]
  127. except KeyError:
  128. # catch gc removed the key after we just checked for it
  129. existing = None
  130. else:
  131. if existing_non_none is not state:
  132. self._manage_removed_state(existing_non_none)
  133. else:
  134. return None
  135. else:
  136. existing = None
  137. self._dict[state.key] = state
  138. self._manage_incoming_state(state)
  139. return existing
  140. def add(self, state: InstanceState[Any]) -> bool:
  141. key = state.key
  142. assert key is not None
  143. # inline of self.__contains__
  144. if key in self._dict:
  145. try:
  146. existing_state = self._dict[key]
  147. except KeyError:
  148. # catch gc removed the key after we just checked for it
  149. pass
  150. else:
  151. if existing_state is not state:
  152. o = existing_state.obj()
  153. if o is not None:
  154. raise sa_exc.InvalidRequestError(
  155. "Can't attach instance "
  156. "%s; another instance with key %s is already "
  157. "present in this session."
  158. % (orm_util.state_str(state), state.key)
  159. )
  160. else:
  161. return False
  162. self._dict[key] = state
  163. self._manage_incoming_state(state)
  164. return True
  165. def _add_unpresent(
  166. self, state: InstanceState[Any], key: _IdentityKeyType[Any]
  167. ) -> None:
  168. # inlined form of add() called by loading.py
  169. self._dict[key] = state
  170. state._instance_dict = self._wr
  171. def fast_get_state(
  172. self, key: _IdentityKeyType[_O]
  173. ) -> Optional[InstanceState[_O]]:
  174. return self._dict.get(key)
  175. def get(
  176. self, key: _IdentityKeyType[_O], default: Optional[_O] = None
  177. ) -> Optional[_O]:
  178. if key not in self._dict:
  179. return default
  180. try:
  181. state = cast("InstanceState[_O]", self._dict[key])
  182. except KeyError:
  183. # catch gc removed the key after we just checked for it
  184. return default
  185. else:
  186. o = state.obj()
  187. if o is None:
  188. return default
  189. return o
  190. def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]:
  191. values = self.all_states()
  192. result = []
  193. for state in values:
  194. value = state.obj()
  195. key = state.key
  196. assert key is not None
  197. if value is not None:
  198. result.append((key, value))
  199. return result
  200. def values(self) -> List[object]:
  201. values = self.all_states()
  202. result = []
  203. for state in values:
  204. value = state.obj()
  205. if value is not None:
  206. result.append(value)
  207. return result
  208. def __iter__(self) -> Iterator[_IdentityKeyType[Any]]:
  209. return iter(self.keys())
  210. def all_states(self) -> List[InstanceState[Any]]:
  211. return list(self._dict.values())
  212. def _fast_discard(self, state: InstanceState[Any]) -> None:
  213. # used by InstanceState for state being
  214. # GC'ed, inlines _managed_removed_state
  215. key = state.key
  216. assert key is not None
  217. try:
  218. st = self._dict[key]
  219. except KeyError:
  220. # catch gc removed the key after we just checked for it
  221. pass
  222. else:
  223. if st is state:
  224. self._dict.pop(key, None)
  225. def discard(self, state: InstanceState[Any]) -> None:
  226. self.safe_discard(state)
  227. def safe_discard(self, state: InstanceState[Any]) -> None:
  228. key = state.key
  229. if key in self._dict:
  230. assert key is not None
  231. try:
  232. st = self._dict[key]
  233. except KeyError:
  234. # catch gc removed the key after we just checked for it
  235. pass
  236. else:
  237. if st is state:
  238. self._dict.pop(key, None)
  239. self._manage_removed_state(state)
  240. def _killed(state: InstanceState[Any], key: _IdentityKeyType[Any]) -> NoReturn:
  241. # external function to avoid creating cycles when assigned to
  242. # the IdentityMap
  243. raise sa_exc.InvalidRequestError(
  244. "Object %s cannot be converted to 'persistent' state, as this "
  245. "identity map is no longer valid. Has the owning Session "
  246. "been closed?" % orm_util.state_str(state),
  247. code="lkrp",
  248. )