_collections.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. # util/_collections.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: allow-untyped-defs, allow-untyped-calls
  8. """Collection classes and helpers."""
  9. from __future__ import annotations
  10. import operator
  11. import threading
  12. import types
  13. import typing
  14. from typing import Any
  15. from typing import Callable
  16. from typing import cast
  17. from typing import Container
  18. from typing import Dict
  19. from typing import FrozenSet
  20. from typing import Generic
  21. from typing import Iterable
  22. from typing import Iterator
  23. from typing import List
  24. from typing import Mapping
  25. from typing import NoReturn
  26. from typing import Optional
  27. from typing import overload
  28. from typing import Sequence
  29. from typing import Set
  30. from typing import Tuple
  31. from typing import TypeVar
  32. from typing import Union
  33. from typing import ValuesView
  34. import weakref
  35. from ._has_cy import HAS_CYEXTENSION
  36. from .typing import is_non_string_iterable
  37. from .typing import Literal
  38. from .typing import Protocol
  39. if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
  40. from ._py_collections import immutabledict as immutabledict
  41. from ._py_collections import IdentitySet as IdentitySet
  42. from ._py_collections import ReadOnlyContainer as ReadOnlyContainer
  43. from ._py_collections import ImmutableDictBase as ImmutableDictBase
  44. from ._py_collections import OrderedSet as OrderedSet
  45. from ._py_collections import unique_list as unique_list
  46. else:
  47. from sqlalchemy.cyextension.immutabledict import (
  48. ReadOnlyContainer as ReadOnlyContainer,
  49. )
  50. from sqlalchemy.cyextension.immutabledict import (
  51. ImmutableDictBase as ImmutableDictBase,
  52. )
  53. from sqlalchemy.cyextension.immutabledict import (
  54. immutabledict as immutabledict,
  55. )
  56. from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet
  57. from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet
  58. from sqlalchemy.cyextension.collections import ( # noqa
  59. unique_list as unique_list,
  60. )
  61. _T = TypeVar("_T", bound=Any)
  62. _KT = TypeVar("_KT", bound=Any)
  63. _VT = TypeVar("_VT", bound=Any)
  64. _T_co = TypeVar("_T_co", covariant=True)
  65. EMPTY_SET: FrozenSet[Any] = frozenset()
  66. NONE_SET: FrozenSet[Any] = frozenset([None])
  67. def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
  68. """merge two lists, maintaining ordering as much as possible.
  69. this is to reconcile vars(cls) with cls.__annotations__.
  70. Example::
  71. >>> a = ["__tablename__", "id", "x", "created_at"]
  72. >>> b = ["id", "name", "data", "y", "created_at"]
  73. >>> merge_lists_w_ordering(a, b)
  74. ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at']
  75. This is not necessarily the ordering that things had on the class,
  76. in this case the class is::
  77. class User(Base):
  78. __tablename__ = "users"
  79. id: Mapped[int] = mapped_column(primary_key=True)
  80. name: Mapped[str]
  81. data: Mapped[Optional[str]]
  82. x = Column(Integer)
  83. y: Mapped[int]
  84. created_at: Mapped[datetime.datetime] = mapped_column()
  85. But things are *mostly* ordered.
  86. The algorithm could also be done by creating a partial ordering for
  87. all items in both lists and then using topological_sort(), but that
  88. is too much overhead.
  89. Background on how I came up with this is at:
  90. https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae
  91. """
  92. overlap = set(a).intersection(b)
  93. result = []
  94. current, other = iter(a), iter(b)
  95. while True:
  96. for element in current:
  97. if element in overlap:
  98. overlap.discard(element)
  99. other, current = current, other
  100. break
  101. result.append(element)
  102. else:
  103. result.extend(other)
  104. break
  105. return result
  106. def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]:
  107. if not d:
  108. return EMPTY_DICT
  109. elif isinstance(d, immutabledict):
  110. return d
  111. else:
  112. return immutabledict(d)
  113. EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
  114. class FacadeDict(ImmutableDictBase[_KT, _VT]):
  115. """A dictionary that is not publicly mutable."""
  116. def __new__(cls, *args: Any) -> FacadeDict[Any, Any]:
  117. new = ImmutableDictBase.__new__(cls)
  118. return new
  119. def copy(self) -> NoReturn:
  120. raise NotImplementedError(
  121. "an immutabledict shouldn't need to be copied. use dict(d) "
  122. "if you need a mutable dictionary."
  123. )
  124. def __reduce__(self) -> Any:
  125. return FacadeDict, (dict(self),)
  126. def _insert_item(self, key: _KT, value: _VT) -> None:
  127. """insert an item into the dictionary directly."""
  128. dict.__setitem__(self, key, value)
  129. def __repr__(self) -> str:
  130. return "FacadeDict(%s)" % dict.__repr__(self)
  131. _DT = TypeVar("_DT", bound=Any)
  132. _F = TypeVar("_F", bound=Any)
  133. class Properties(Generic[_T]):
  134. """Provide a __getattr__/__setattr__ interface over a dict."""
  135. __slots__ = ("_data",)
  136. _data: Dict[str, _T]
  137. def __init__(self, data: Dict[str, _T]):
  138. object.__setattr__(self, "_data", data)
  139. def __len__(self) -> int:
  140. return len(self._data)
  141. def __iter__(self) -> Iterator[_T]:
  142. return iter(list(self._data.values()))
  143. def __dir__(self) -> List[str]:
  144. return dir(super()) + [str(k) for k in self._data.keys()]
  145. def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]:
  146. return list(self) + list(other)
  147. def __setitem__(self, key: str, obj: _T) -> None:
  148. self._data[key] = obj
  149. def __getitem__(self, key: str) -> _T:
  150. return self._data[key]
  151. def __delitem__(self, key: str) -> None:
  152. del self._data[key]
  153. def __setattr__(self, key: str, obj: _T) -> None:
  154. self._data[key] = obj
  155. def __getstate__(self) -> Dict[str, Any]:
  156. return {"_data": self._data}
  157. def __setstate__(self, state: Dict[str, Any]) -> None:
  158. object.__setattr__(self, "_data", state["_data"])
  159. def __getattr__(self, key: str) -> _T:
  160. try:
  161. return self._data[key]
  162. except KeyError:
  163. raise AttributeError(key)
  164. def __contains__(self, key: str) -> bool:
  165. return key in self._data
  166. def as_readonly(self) -> ReadOnlyProperties[_T]:
  167. """Return an immutable proxy for this :class:`.Properties`."""
  168. return ReadOnlyProperties(self._data)
  169. def update(self, value: Dict[str, _T]) -> None:
  170. self._data.update(value)
  171. @overload
  172. def get(self, key: str) -> Optional[_T]: ...
  173. @overload
  174. def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ...
  175. def get(
  176. self, key: str, default: Optional[Union[_DT, _T]] = None
  177. ) -> Optional[Union[_T, _DT]]:
  178. if key in self:
  179. return self[key]
  180. else:
  181. return default
  182. def keys(self) -> List[str]:
  183. return list(self._data)
  184. def values(self) -> List[_T]:
  185. return list(self._data.values())
  186. def items(self) -> List[Tuple[str, _T]]:
  187. return list(self._data.items())
  188. def has_key(self, key: str) -> bool:
  189. return key in self._data
  190. def clear(self) -> None:
  191. self._data.clear()
  192. class OrderedProperties(Properties[_T]):
  193. """Provide a __getattr__/__setattr__ interface with an OrderedDict
  194. as backing store."""
  195. __slots__ = ()
  196. def __init__(self):
  197. Properties.__init__(self, OrderedDict())
  198. class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]):
  199. """Provide immutable dict/object attribute to an underlying dictionary."""
  200. __slots__ = ()
  201. def _ordered_dictionary_sort(d, key=None):
  202. """Sort an OrderedDict in-place."""
  203. items = [(k, d[k]) for k in sorted(d, key=key)]
  204. d.clear()
  205. d.update(items)
  206. OrderedDict = dict
  207. sort_dictionary = _ordered_dictionary_sort
  208. class WeakSequence(Sequence[_T]):
  209. def __init__(self, __elements: Sequence[_T] = ()):
  210. # adapted from weakref.WeakKeyDictionary, prevent reference
  211. # cycles in the collection itself
  212. def _remove(item, selfref=weakref.ref(self)):
  213. self = selfref()
  214. if self is not None:
  215. self._storage.remove(item)
  216. self._remove = _remove
  217. self._storage = [
  218. weakref.ref(element, _remove) for element in __elements
  219. ]
  220. def append(self, item):
  221. self._storage.append(weakref.ref(item, self._remove))
  222. def __len__(self):
  223. return len(self._storage)
  224. def __iter__(self):
  225. return (
  226. obj for obj in (ref() for ref in self._storage) if obj is not None
  227. )
  228. def __getitem__(self, index):
  229. try:
  230. obj = self._storage[index]
  231. except KeyError:
  232. raise IndexError("Index %s out of range" % index)
  233. else:
  234. return obj()
  235. class OrderedIdentitySet(IdentitySet):
  236. def __init__(self, iterable: Optional[Iterable[Any]] = None):
  237. IdentitySet.__init__(self)
  238. self._members = OrderedDict()
  239. if iterable:
  240. for o in iterable:
  241. self.add(o)
  242. class PopulateDict(Dict[_KT, _VT]):
  243. """A dict which populates missing values via a creation function.
  244. Note the creation function takes a key, unlike
  245. collections.defaultdict.
  246. """
  247. def __init__(self, creator: Callable[[_KT], _VT]):
  248. self.creator = creator
  249. def __missing__(self, key: Any) -> Any:
  250. self[key] = val = self.creator(key)
  251. return val
  252. class WeakPopulateDict(Dict[_KT, _VT]):
  253. """Like PopulateDict, but assumes a self + a method and does not create
  254. a reference cycle.
  255. """
  256. def __init__(self, creator_method: types.MethodType):
  257. self.creator = creator_method.__func__
  258. weakself = creator_method.__self__
  259. self.weakself = weakref.ref(weakself)
  260. def __missing__(self, key: Any) -> Any:
  261. self[key] = val = self.creator(self.weakself(), key)
  262. return val
  263. # Define collections that are capable of storing
  264. # ColumnElement objects as hashable keys/elements.
  265. # At this point, these are mostly historical, things
  266. # used to be more complicated.
  267. column_set = set
  268. column_dict = dict
  269. ordered_column_set = OrderedSet
  270. class UniqueAppender(Generic[_T]):
  271. """Appends items to a collection ensuring uniqueness.
  272. Additional appends() of the same object are ignored. Membership is
  273. determined by identity (``is a``) not equality (``==``).
  274. """
  275. __slots__ = "data", "_data_appender", "_unique"
  276. data: Union[Iterable[_T], Set[_T], List[_T]]
  277. _data_appender: Callable[[_T], None]
  278. _unique: Dict[int, Literal[True]]
  279. def __init__(
  280. self,
  281. data: Union[Iterable[_T], Set[_T], List[_T]],
  282. via: Optional[str] = None,
  283. ):
  284. self.data = data
  285. self._unique = {}
  286. if via:
  287. self._data_appender = getattr(data, via)
  288. elif hasattr(data, "append"):
  289. self._data_appender = cast("List[_T]", data).append
  290. elif hasattr(data, "add"):
  291. self._data_appender = cast("Set[_T]", data).add
  292. def append(self, item: _T) -> None:
  293. id_ = id(item)
  294. if id_ not in self._unique:
  295. self._data_appender(item)
  296. self._unique[id_] = True
  297. def __iter__(self) -> Iterator[_T]:
  298. return iter(self.data)
  299. def coerce_generator_arg(arg: Any) -> List[Any]:
  300. if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
  301. return list(arg[0])
  302. else:
  303. return cast("List[Any]", arg)
  304. def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
  305. if x is None:
  306. return default # type: ignore
  307. if not is_non_string_iterable(x):
  308. return [x]
  309. elif isinstance(x, list):
  310. return x
  311. else:
  312. return list(x)
  313. def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool:
  314. r"""return True if any items of set\_ are present in iterable.
  315. Goes through special effort to ensure __hash__ is not called
  316. on items in iterable that don't support it.
  317. """
  318. return any(i in set_ for i in iterable if i.__hash__)
  319. def to_set(x):
  320. if x is None:
  321. return set()
  322. if not isinstance(x, set):
  323. return set(to_list(x))
  324. else:
  325. return x
  326. def to_column_set(x: Any) -> Set[Any]:
  327. if x is None:
  328. return column_set()
  329. if not isinstance(x, column_set):
  330. return column_set(to_list(x))
  331. else:
  332. return x
  333. def update_copy(
  334. d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any
  335. ) -> Dict[Any, Any]:
  336. """Copy the given dict and update with the given values."""
  337. d = d.copy()
  338. if _new:
  339. d.update(_new)
  340. d.update(**kw)
  341. return d
  342. def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]:
  343. """Given an iterator of which further sub-elements may also be
  344. iterators, flatten the sub-elements into a single iterator.
  345. """
  346. elem: _T
  347. for elem in x:
  348. if not isinstance(elem, str) and hasattr(elem, "__iter__"):
  349. yield from flatten_iterator(elem)
  350. else:
  351. yield elem
  352. class LRUCache(typing.MutableMapping[_KT, _VT]):
  353. """Dictionary with 'squishy' removal of least
  354. recently used items.
  355. Note that either get() or [] should be used here, but
  356. generally its not safe to do an "in" check first as the dictionary
  357. can change subsequent to that call.
  358. """
  359. __slots__ = (
  360. "capacity",
  361. "threshold",
  362. "size_alert",
  363. "_data",
  364. "_counter",
  365. "_mutex",
  366. )
  367. capacity: int
  368. threshold: float
  369. size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]]
  370. def __init__(
  371. self,
  372. capacity: int = 100,
  373. threshold: float = 0.5,
  374. size_alert: Optional[Callable[..., None]] = None,
  375. ):
  376. self.capacity = capacity
  377. self.threshold = threshold
  378. self.size_alert = size_alert
  379. self._counter = 0
  380. self._mutex = threading.Lock()
  381. self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {}
  382. def _inc_counter(self):
  383. self._counter += 1
  384. return self._counter
  385. @overload
  386. def get(self, key: _KT) -> Optional[_VT]: ...
  387. @overload
  388. def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ...
  389. def get(
  390. self, key: _KT, default: Optional[Union[_VT, _T]] = None
  391. ) -> Optional[Union[_VT, _T]]:
  392. item = self._data.get(key)
  393. if item is not None:
  394. item[2][0] = self._inc_counter()
  395. return item[1]
  396. else:
  397. return default
  398. def __getitem__(self, key: _KT) -> _VT:
  399. item = self._data[key]
  400. item[2][0] = self._inc_counter()
  401. return item[1]
  402. def __iter__(self) -> Iterator[_KT]:
  403. return iter(self._data)
  404. def __len__(self) -> int:
  405. return len(self._data)
  406. def values(self) -> ValuesView[_VT]:
  407. return typing.ValuesView({k: i[1] for k, i in self._data.items()})
  408. def __setitem__(self, key: _KT, value: _VT) -> None:
  409. self._data[key] = (key, value, [self._inc_counter()])
  410. self._manage_size()
  411. def __delitem__(self, __v: _KT) -> None:
  412. del self._data[__v]
  413. @property
  414. def size_threshold(self) -> float:
  415. return self.capacity + self.capacity * self.threshold
  416. def _manage_size(self) -> None:
  417. if not self._mutex.acquire(False):
  418. return
  419. try:
  420. size_alert = bool(self.size_alert)
  421. while len(self) > self.capacity + self.capacity * self.threshold:
  422. if size_alert:
  423. size_alert = False
  424. self.size_alert(self) # type: ignore
  425. by_counter = sorted(
  426. self._data.values(),
  427. key=operator.itemgetter(2),
  428. reverse=True,
  429. )
  430. for item in by_counter[self.capacity :]:
  431. try:
  432. del self._data[item[0]]
  433. except KeyError:
  434. # deleted elsewhere; skip
  435. continue
  436. finally:
  437. self._mutex.release()
  438. class _CreateFuncType(Protocol[_T_co]):
  439. def __call__(self) -> _T_co: ...
  440. class _ScopeFuncType(Protocol):
  441. def __call__(self) -> Any: ...
  442. class ScopedRegistry(Generic[_T]):
  443. """A Registry that can store one or multiple instances of a single
  444. class on the basis of a "scope" function.
  445. The object implements ``__call__`` as the "getter", so by
  446. calling ``myregistry()`` the contained object is returned
  447. for the current scope.
  448. :param createfunc:
  449. a callable that returns a new object to be placed in the registry
  450. :param scopefunc:
  451. a callable that will return a key to store/retrieve an object.
  452. """
  453. __slots__ = "createfunc", "scopefunc", "registry"
  454. createfunc: _CreateFuncType[_T]
  455. scopefunc: _ScopeFuncType
  456. registry: Any
  457. def __init__(
  458. self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any]
  459. ):
  460. """Construct a new :class:`.ScopedRegistry`.
  461. :param createfunc: A creation function that will generate
  462. a new value for the current scope, if none is present.
  463. :param scopefunc: A function that returns a hashable
  464. token representing the current scope (such as, current
  465. thread identifier).
  466. """
  467. self.createfunc = createfunc
  468. self.scopefunc = scopefunc
  469. self.registry = {}
  470. def __call__(self) -> _T:
  471. key = self.scopefunc()
  472. try:
  473. return self.registry[key] # type: ignore[no-any-return]
  474. except KeyError:
  475. return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501
  476. def has(self) -> bool:
  477. """Return True if an object is present in the current scope."""
  478. return self.scopefunc() in self.registry
  479. def set(self, obj: _T) -> None:
  480. """Set the value for the current scope."""
  481. self.registry[self.scopefunc()] = obj
  482. def clear(self) -> None:
  483. """Clear the current scope, if any."""
  484. try:
  485. del self.registry[self.scopefunc()]
  486. except KeyError:
  487. pass
  488. class ThreadLocalRegistry(ScopedRegistry[_T]):
  489. """A :class:`.ScopedRegistry` that uses a ``threading.local()``
  490. variable for storage.
  491. """
  492. def __init__(self, createfunc: Callable[[], _T]):
  493. self.createfunc = createfunc
  494. self.registry = threading.local()
  495. def __call__(self) -> _T:
  496. try:
  497. return self.registry.value # type: ignore[no-any-return]
  498. except AttributeError:
  499. val = self.registry.value = self.createfunc()
  500. return val
  501. def has(self) -> bool:
  502. return hasattr(self.registry, "value")
  503. def set(self, obj: _T) -> None:
  504. self.registry.value = obj
  505. def clear(self) -> None:
  506. try:
  507. del self.registry.value
  508. except AttributeError:
  509. pass
  510. def has_dupes(sequence, target):
  511. """Given a sequence and search object, return True if there's more
  512. than one, False if zero or one of them.
  513. """
  514. # compare to .index version below, this version introduces less function
  515. # overhead and is usually the same speed. At 15000 items (way bigger than
  516. # a relationship-bound collection in memory usually is) it begins to
  517. # fall behind the other version only by microseconds.
  518. c = 0
  519. for item in sequence:
  520. if item is target:
  521. c += 1
  522. if c > 1:
  523. return True
  524. return False
  525. # .index version. the two __contains__ calls as well
  526. # as .index() and isinstance() slow this down.
  527. # def has_dupes(sequence, target):
  528. # if target not in sequence:
  529. # return False
  530. # elif not isinstance(sequence, collections_abc.Sequence):
  531. # return False
  532. #
  533. # idx = sequence.index(target)
  534. # return target in sequence[idx + 1:]