_py_collections.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. # util/_py_collections.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: allow-untyped-defs, allow-untyped-calls
  8. from __future__ import annotations
  9. from itertools import filterfalse
  10. from typing import AbstractSet
  11. from typing import Any
  12. from typing import Callable
  13. from typing import cast
  14. from typing import Collection
  15. from typing import Dict
  16. from typing import Iterable
  17. from typing import Iterator
  18. from typing import List
  19. from typing import Mapping
  20. from typing import NoReturn
  21. from typing import Optional
  22. from typing import Set
  23. from typing import Tuple
  24. from typing import TYPE_CHECKING
  25. from typing import TypeVar
  26. from typing import Union
  27. from ..util.typing import Self
  28. _T = TypeVar("_T", bound=Any)
  29. _S = TypeVar("_S", bound=Any)
  30. _KT = TypeVar("_KT", bound=Any)
  31. _VT = TypeVar("_VT", bound=Any)
  32. class ReadOnlyContainer:
  33. __slots__ = ()
  34. def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
  35. raise TypeError(
  36. "%s object is immutable and/or readonly" % self.__class__.__name__
  37. )
  38. def _immutable(self, *arg: Any, **kw: Any) -> NoReturn:
  39. raise TypeError("%s object is immutable" % self.__class__.__name__)
  40. def __delitem__(self, key: Any) -> NoReturn:
  41. self._readonly()
  42. def __setitem__(self, key: Any, value: Any) -> NoReturn:
  43. self._readonly()
  44. def __setattr__(self, key: str, value: Any) -> NoReturn:
  45. self._readonly()
  46. class ImmutableDictBase(ReadOnlyContainer, Dict[_KT, _VT]):
  47. if TYPE_CHECKING:
  48. def __new__(cls, *args: Any) -> Self: ...
  49. def __init__(cls, *args: Any): ...
  50. def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
  51. self._immutable()
  52. def clear(self) -> NoReturn:
  53. self._readonly()
  54. def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn:
  55. self._readonly()
  56. def popitem(self) -> NoReturn:
  57. self._readonly()
  58. def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn:
  59. self._readonly()
  60. def update(self, *arg: Any, **kw: Any) -> NoReturn:
  61. self._readonly()
  62. class immutabledict(ImmutableDictBase[_KT, _VT]):
  63. def __new__(cls, *args):
  64. new = ImmutableDictBase.__new__(cls)
  65. dict.__init__(new, *args)
  66. return new
  67. def __init__(
  68. self, *args: Union[Mapping[_KT, _VT], Iterable[Tuple[_KT, _VT]]]
  69. ):
  70. pass
  71. def __reduce__(self):
  72. return immutabledict, (dict(self),)
  73. def union(
  74. self, __d: Optional[Mapping[_KT, _VT]] = None
  75. ) -> immutabledict[_KT, _VT]:
  76. if not __d:
  77. return self
  78. new = ImmutableDictBase.__new__(self.__class__)
  79. dict.__init__(new, self)
  80. dict.update(new, __d)
  81. return new
  82. def _union_w_kw(
  83. self, __d: Optional[Mapping[_KT, _VT]] = None, **kw: _VT
  84. ) -> immutabledict[_KT, _VT]:
  85. # not sure if C version works correctly w/ this yet
  86. if not __d and not kw:
  87. return self
  88. new = ImmutableDictBase.__new__(self.__class__)
  89. dict.__init__(new, self)
  90. if __d:
  91. dict.update(new, __d)
  92. dict.update(new, kw)
  93. return new
  94. def merge_with(
  95. self, *dicts: Optional[Mapping[_KT, _VT]]
  96. ) -> immutabledict[_KT, _VT]:
  97. new = None
  98. for d in dicts:
  99. if d:
  100. if new is None:
  101. new = ImmutableDictBase.__new__(self.__class__)
  102. dict.__init__(new, self)
  103. dict.update(new, d)
  104. if new is None:
  105. return self
  106. return new
  107. def __repr__(self) -> str:
  108. return "immutabledict(%s)" % dict.__repr__(self)
  109. # PEP 584
  110. def __ior__(self, __value: Any) -> NoReturn: # type: ignore
  111. self._readonly()
  112. def __or__( # type: ignore[override]
  113. self, __value: Mapping[_KT, _VT]
  114. ) -> immutabledict[_KT, _VT]:
  115. return immutabledict(
  116. super().__or__(__value), # type: ignore[call-overload]
  117. )
  118. def __ror__( # type: ignore[override]
  119. self, __value: Mapping[_KT, _VT]
  120. ) -> immutabledict[_KT, _VT]:
  121. return immutabledict(
  122. super().__ror__(__value), # type: ignore[call-overload]
  123. )
  124. class OrderedSet(Set[_T]):
  125. __slots__ = ("_list",)
  126. _list: List[_T]
  127. def __init__(self, d: Optional[Iterable[_T]] = None) -> None:
  128. if d is not None:
  129. self._list = unique_list(d)
  130. super().update(self._list)
  131. else:
  132. self._list = []
  133. def copy(self) -> OrderedSet[_T]:
  134. cp = self.__class__()
  135. cp._list = self._list.copy()
  136. set.update(cp, cp._list)
  137. return cp
  138. def add(self, element: _T) -> None:
  139. if element not in self:
  140. self._list.append(element)
  141. super().add(element)
  142. def remove(self, element: _T) -> None:
  143. super().remove(element)
  144. self._list.remove(element)
  145. def pop(self) -> _T:
  146. try:
  147. value = self._list.pop()
  148. except IndexError:
  149. raise KeyError("pop from an empty set") from None
  150. super().remove(value)
  151. return value
  152. def insert(self, pos: int, element: _T) -> None:
  153. if element not in self:
  154. self._list.insert(pos, element)
  155. super().add(element)
  156. def discard(self, element: _T) -> None:
  157. if element in self:
  158. self._list.remove(element)
  159. super().remove(element)
  160. def clear(self) -> None:
  161. super().clear()
  162. self._list = []
  163. def __getitem__(self, key: int) -> _T:
  164. return self._list[key]
  165. def __iter__(self) -> Iterator[_T]:
  166. return iter(self._list)
  167. def __add__(self, other: Iterator[_T]) -> OrderedSet[_T]:
  168. return self.union(other)
  169. def __repr__(self) -> str:
  170. return "%s(%r)" % (self.__class__.__name__, self._list)
  171. __str__ = __repr__
  172. def update(self, *iterables: Iterable[_T]) -> None:
  173. for iterable in iterables:
  174. for e in iterable:
  175. if e not in self:
  176. self._list.append(e)
  177. super().add(e)
  178. def __ior__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
  179. self.update(other)
  180. return self
  181. def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]:
  182. result: OrderedSet[Union[_T, _S]] = self.copy()
  183. result.update(*other)
  184. return result
  185. def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
  186. return self.union(other)
  187. def intersection(self, *other: Iterable[Any]) -> OrderedSet[_T]:
  188. other_set: Set[Any] = set()
  189. other_set.update(*other)
  190. return self.__class__(a for a in self if a in other_set)
  191. def __and__(self, other: AbstractSet[object]) -> OrderedSet[_T]:
  192. return self.intersection(other)
  193. def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]:
  194. collection: Collection[_T]
  195. if isinstance(other, set):
  196. collection = other_set = other
  197. elif isinstance(other, Collection):
  198. collection = other
  199. other_set = set(other)
  200. else:
  201. collection = list(other)
  202. other_set = set(collection)
  203. result = self.__class__(a for a in self if a not in other_set)
  204. result.update(a for a in collection if a not in self)
  205. return result
  206. def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
  207. return cast(OrderedSet[Union[_T, _S]], self).symmetric_difference(
  208. other
  209. )
  210. def difference(self, *other: Iterable[Any]) -> OrderedSet[_T]:
  211. other_set = super().difference(*other)
  212. return self.__class__(a for a in self._list if a in other_set)
  213. def __sub__(self, other: AbstractSet[Optional[_T]]) -> OrderedSet[_T]:
  214. return self.difference(other)
  215. def intersection_update(self, *other: Iterable[Any]) -> None:
  216. super().intersection_update(*other)
  217. self._list = [a for a in self._list if a in self]
  218. def __iand__(self, other: AbstractSet[object]) -> OrderedSet[_T]:
  219. self.intersection_update(other)
  220. return self
  221. def symmetric_difference_update(self, other: Iterable[Any]) -> None:
  222. collection = other if isinstance(other, Collection) else list(other)
  223. super().symmetric_difference_update(collection)
  224. self._list = [a for a in self._list if a in self]
  225. self._list += [a for a in collection if a in self]
  226. def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
  227. self.symmetric_difference_update(other)
  228. return cast(OrderedSet[Union[_T, _S]], self)
  229. def difference_update(self, *other: Iterable[Any]) -> None:
  230. super().difference_update(*other)
  231. self._list = [a for a in self._list if a in self]
  232. def __isub__(self, other: AbstractSet[Optional[_T]]) -> OrderedSet[_T]: # type: ignore # noqa: E501
  233. self.difference_update(other)
  234. return self
  235. class IdentitySet:
  236. """A set that considers only object id() for uniqueness.
  237. This strategy has edge cases for builtin types- it's possible to have
  238. two 'foo' strings in one of these sets, for example. Use sparingly.
  239. """
  240. _members: Dict[int, Any]
  241. def __init__(self, iterable: Optional[Iterable[Any]] = None):
  242. self._members = dict()
  243. if iterable:
  244. self.update(iterable)
  245. def add(self, value: Any) -> None:
  246. self._members[id(value)] = value
  247. def __contains__(self, value: Any) -> bool:
  248. return id(value) in self._members
  249. def remove(self, value: Any) -> None:
  250. del self._members[id(value)]
  251. def discard(self, value: Any) -> None:
  252. try:
  253. self.remove(value)
  254. except KeyError:
  255. pass
  256. def pop(self) -> Any:
  257. try:
  258. pair = self._members.popitem()
  259. return pair[1]
  260. except KeyError:
  261. raise KeyError("pop from an empty set")
  262. def clear(self) -> None:
  263. self._members.clear()
  264. def __eq__(self, other: Any) -> bool:
  265. if isinstance(other, IdentitySet):
  266. return self._members == other._members
  267. else:
  268. return False
  269. def __ne__(self, other: Any) -> bool:
  270. if isinstance(other, IdentitySet):
  271. return self._members != other._members
  272. else:
  273. return True
  274. def issubset(self, iterable: Iterable[Any]) -> bool:
  275. if isinstance(iterable, self.__class__):
  276. other = iterable
  277. else:
  278. other = self.__class__(iterable)
  279. if len(self) > len(other):
  280. return False
  281. for m in filterfalse(
  282. other._members.__contains__, iter(self._members.keys())
  283. ):
  284. return False
  285. return True
  286. def __le__(self, other: Any) -> bool:
  287. if not isinstance(other, IdentitySet):
  288. return NotImplemented
  289. return self.issubset(other)
  290. def __lt__(self, other: Any) -> bool:
  291. if not isinstance(other, IdentitySet):
  292. return NotImplemented
  293. return len(self) < len(other) and self.issubset(other)
  294. def issuperset(self, iterable: Iterable[Any]) -> bool:
  295. if isinstance(iterable, self.__class__):
  296. other = iterable
  297. else:
  298. other = self.__class__(iterable)
  299. if len(self) < len(other):
  300. return False
  301. for m in filterfalse(
  302. self._members.__contains__, iter(other._members.keys())
  303. ):
  304. return False
  305. return True
  306. def __ge__(self, other: Any) -> bool:
  307. if not isinstance(other, IdentitySet):
  308. return NotImplemented
  309. return self.issuperset(other)
  310. def __gt__(self, other: Any) -> bool:
  311. if not isinstance(other, IdentitySet):
  312. return NotImplemented
  313. return len(self) > len(other) and self.issuperset(other)
  314. def union(self, iterable: Iterable[Any]) -> IdentitySet:
  315. result = self.__class__()
  316. members = self._members
  317. result._members.update(members)
  318. result._members.update((id(obj), obj) for obj in iterable)
  319. return result
  320. def __or__(self, other: Any) -> IdentitySet:
  321. if not isinstance(other, IdentitySet):
  322. return NotImplemented
  323. return self.union(other)
  324. def update(self, iterable: Iterable[Any]) -> None:
  325. self._members.update((id(obj), obj) for obj in iterable)
  326. def __ior__(self, other: Any) -> IdentitySet:
  327. if not isinstance(other, IdentitySet):
  328. return NotImplemented
  329. self.update(other)
  330. return self
  331. def difference(self, iterable: Iterable[Any]) -> IdentitySet:
  332. result = self.__new__(self.__class__)
  333. other: Collection[Any]
  334. if isinstance(iterable, self.__class__):
  335. other = iterable._members
  336. else:
  337. other = {id(obj) for obj in iterable}
  338. result._members = {
  339. k: v for k, v in self._members.items() if k not in other
  340. }
  341. return result
  342. def __sub__(self, other: IdentitySet) -> IdentitySet:
  343. if not isinstance(other, IdentitySet):
  344. return NotImplemented
  345. return self.difference(other)
  346. def difference_update(self, iterable: Iterable[Any]) -> None:
  347. self._members = self.difference(iterable)._members
  348. def __isub__(self, other: IdentitySet) -> IdentitySet:
  349. if not isinstance(other, IdentitySet):
  350. return NotImplemented
  351. self.difference_update(other)
  352. return self
  353. def intersection(self, iterable: Iterable[Any]) -> IdentitySet:
  354. result = self.__new__(self.__class__)
  355. other: Collection[Any]
  356. if isinstance(iterable, self.__class__):
  357. other = iterable._members
  358. else:
  359. other = {id(obj) for obj in iterable}
  360. result._members = {
  361. k: v for k, v in self._members.items() if k in other
  362. }
  363. return result
  364. def __and__(self, other: IdentitySet) -> IdentitySet:
  365. if not isinstance(other, IdentitySet):
  366. return NotImplemented
  367. return self.intersection(other)
  368. def intersection_update(self, iterable: Iterable[Any]) -> None:
  369. self._members = self.intersection(iterable)._members
  370. def __iand__(self, other: IdentitySet) -> IdentitySet:
  371. if not isinstance(other, IdentitySet):
  372. return NotImplemented
  373. self.intersection_update(other)
  374. return self
  375. def symmetric_difference(self, iterable: Iterable[Any]) -> IdentitySet:
  376. result = self.__new__(self.__class__)
  377. if isinstance(iterable, self.__class__):
  378. other = iterable._members
  379. else:
  380. other = {id(obj): obj for obj in iterable}
  381. result._members = {
  382. k: v for k, v in self._members.items() if k not in other
  383. }
  384. result._members.update(
  385. (k, v) for k, v in other.items() if k not in self._members
  386. )
  387. return result
  388. def __xor__(self, other: IdentitySet) -> IdentitySet:
  389. if not isinstance(other, IdentitySet):
  390. return NotImplemented
  391. return self.symmetric_difference(other)
  392. def symmetric_difference_update(self, iterable: Iterable[Any]) -> None:
  393. self._members = self.symmetric_difference(iterable)._members
  394. def __ixor__(self, other: IdentitySet) -> IdentitySet:
  395. if not isinstance(other, IdentitySet):
  396. return NotImplemented
  397. self.symmetric_difference(other)
  398. return self
  399. def copy(self) -> IdentitySet:
  400. result = self.__new__(self.__class__)
  401. result._members = self._members.copy()
  402. return result
  403. __copy__ = copy
  404. def __len__(self) -> int:
  405. return len(self._members)
  406. def __iter__(self) -> Iterator[Any]:
  407. return iter(self._members.values())
  408. def __hash__(self) -> NoReturn:
  409. raise TypeError("set objects are unhashable")
  410. def __repr__(self) -> str:
  411. return "%s(%r)" % (type(self).__name__, list(self._members.values()))
  412. def unique_list(
  413. seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None
  414. ) -> List[_T]:
  415. seen: Set[Any] = set()
  416. seen_add = seen.add
  417. if not hashfunc:
  418. return [x for x in seq if x not in seen and not seen_add(x)]
  419. else:
  420. return [
  421. x
  422. for x in seq
  423. if hashfunc(x) not in seen and not seen_add(hashfunc(x))
  424. ]