| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717 |
- # util/_collections.py
- # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- # mypy: allow-untyped-defs, allow-untyped-calls
- """Collection classes and helpers."""
- from __future__ import annotations
- import operator
- import threading
- import types
- import typing
- from typing import Any
- from typing import Callable
- from typing import cast
- from typing import Container
- from typing import Dict
- from typing import FrozenSet
- from typing import Generic
- from typing import Iterable
- from typing import Iterator
- from typing import List
- from typing import Mapping
- from typing import NoReturn
- from typing import Optional
- from typing import overload
- from typing import Sequence
- from typing import Set
- from typing import Tuple
- from typing import TypeVar
- from typing import Union
- from typing import ValuesView
- import weakref
- from ._has_cy import HAS_CYEXTENSION
- from .typing import is_non_string_iterable
- from .typing import Literal
- from .typing import Protocol
- if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
- from ._py_collections import immutabledict as immutabledict
- from ._py_collections import IdentitySet as IdentitySet
- from ._py_collections import ReadOnlyContainer as ReadOnlyContainer
- from ._py_collections import ImmutableDictBase as ImmutableDictBase
- from ._py_collections import OrderedSet as OrderedSet
- from ._py_collections import unique_list as unique_list
- else:
- from sqlalchemy.cyextension.immutabledict import (
- ReadOnlyContainer as ReadOnlyContainer,
- )
- from sqlalchemy.cyextension.immutabledict import (
- ImmutableDictBase as ImmutableDictBase,
- )
- from sqlalchemy.cyextension.immutabledict import (
- immutabledict as immutabledict,
- )
- from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet
- from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet
- from sqlalchemy.cyextension.collections import ( # noqa
- unique_list as unique_list,
- )
- _T = TypeVar("_T", bound=Any)
- _KT = TypeVar("_KT", bound=Any)
- _VT = TypeVar("_VT", bound=Any)
- _T_co = TypeVar("_T_co", covariant=True)
- EMPTY_SET: FrozenSet[Any] = frozenset()
- NONE_SET: FrozenSet[Any] = frozenset([None])
- def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
- """merge two lists, maintaining ordering as much as possible.
- this is to reconcile vars(cls) with cls.__annotations__.
- Example::
- >>> a = ["__tablename__", "id", "x", "created_at"]
- >>> b = ["id", "name", "data", "y", "created_at"]
- >>> merge_lists_w_ordering(a, b)
- ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at']
- This is not necessarily the ordering that things had on the class,
- in this case the class is::
- class User(Base):
- __tablename__ = "users"
- id: Mapped[int] = mapped_column(primary_key=True)
- name: Mapped[str]
- data: Mapped[Optional[str]]
- x = Column(Integer)
- y: Mapped[int]
- created_at: Mapped[datetime.datetime] = mapped_column()
- But things are *mostly* ordered.
- The algorithm could also be done by creating a partial ordering for
- all items in both lists and then using topological_sort(), but that
- is too much overhead.
- Background on how I came up with this is at:
- https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae
- """
- overlap = set(a).intersection(b)
- result = []
- current, other = iter(a), iter(b)
- while True:
- for element in current:
- if element in overlap:
- overlap.discard(element)
- other, current = current, other
- break
- result.append(element)
- else:
- result.extend(other)
- break
- return result
- def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]:
- if not d:
- return EMPTY_DICT
- elif isinstance(d, immutabledict):
- return d
- else:
- return immutabledict(d)
- EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
- class FacadeDict(ImmutableDictBase[_KT, _VT]):
- """A dictionary that is not publicly mutable."""
- def __new__(cls, *args: Any) -> FacadeDict[Any, Any]:
- new = ImmutableDictBase.__new__(cls)
- return new
- def copy(self) -> NoReturn:
- raise NotImplementedError(
- "an immutabledict shouldn't need to be copied. use dict(d) "
- "if you need a mutable dictionary."
- )
- def __reduce__(self) -> Any:
- return FacadeDict, (dict(self),)
- def _insert_item(self, key: _KT, value: _VT) -> None:
- """insert an item into the dictionary directly."""
- dict.__setitem__(self, key, value)
- def __repr__(self) -> str:
- return "FacadeDict(%s)" % dict.__repr__(self)
- _DT = TypeVar("_DT", bound=Any)
- _F = TypeVar("_F", bound=Any)
- class Properties(Generic[_T]):
- """Provide a __getattr__/__setattr__ interface over a dict."""
- __slots__ = ("_data",)
- _data: Dict[str, _T]
- def __init__(self, data: Dict[str, _T]):
- object.__setattr__(self, "_data", data)
- def __len__(self) -> int:
- return len(self._data)
- def __iter__(self) -> Iterator[_T]:
- return iter(list(self._data.values()))
- def __dir__(self) -> List[str]:
- return dir(super()) + [str(k) for k in self._data.keys()]
- def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]:
- return list(self) + list(other)
- def __setitem__(self, key: str, obj: _T) -> None:
- self._data[key] = obj
- def __getitem__(self, key: str) -> _T:
- return self._data[key]
- def __delitem__(self, key: str) -> None:
- del self._data[key]
- def __setattr__(self, key: str, obj: _T) -> None:
- self._data[key] = obj
- def __getstate__(self) -> Dict[str, Any]:
- return {"_data": self._data}
- def __setstate__(self, state: Dict[str, Any]) -> None:
- object.__setattr__(self, "_data", state["_data"])
- def __getattr__(self, key: str) -> _T:
- try:
- return self._data[key]
- except KeyError:
- raise AttributeError(key)
- def __contains__(self, key: str) -> bool:
- return key in self._data
- def as_readonly(self) -> ReadOnlyProperties[_T]:
- """Return an immutable proxy for this :class:`.Properties`."""
- return ReadOnlyProperties(self._data)
- def update(self, value: Dict[str, _T]) -> None:
- self._data.update(value)
- @overload
- def get(self, key: str) -> Optional[_T]: ...
- @overload
- def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ...
- def get(
- self, key: str, default: Optional[Union[_DT, _T]] = None
- ) -> Optional[Union[_T, _DT]]:
- if key in self:
- return self[key]
- else:
- return default
- def keys(self) -> List[str]:
- return list(self._data)
- def values(self) -> List[_T]:
- return list(self._data.values())
- def items(self) -> List[Tuple[str, _T]]:
- return list(self._data.items())
- def has_key(self, key: str) -> bool:
- return key in self._data
- def clear(self) -> None:
- self._data.clear()
- class OrderedProperties(Properties[_T]):
- """Provide a __getattr__/__setattr__ interface with an OrderedDict
- as backing store."""
- __slots__ = ()
- def __init__(self):
- Properties.__init__(self, OrderedDict())
- class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]):
- """Provide immutable dict/object attribute to an underlying dictionary."""
- __slots__ = ()
- def _ordered_dictionary_sort(d, key=None):
- """Sort an OrderedDict in-place."""
- items = [(k, d[k]) for k in sorted(d, key=key)]
- d.clear()
- d.update(items)
- OrderedDict = dict
- sort_dictionary = _ordered_dictionary_sort
- class WeakSequence(Sequence[_T]):
- def __init__(self, __elements: Sequence[_T] = ()):
- # adapted from weakref.WeakKeyDictionary, prevent reference
- # cycles in the collection itself
- def _remove(item, selfref=weakref.ref(self)):
- self = selfref()
- if self is not None:
- self._storage.remove(item)
- self._remove = _remove
- self._storage = [
- weakref.ref(element, _remove) for element in __elements
- ]
- def append(self, item):
- self._storage.append(weakref.ref(item, self._remove))
- def __len__(self):
- return len(self._storage)
- def __iter__(self):
- return (
- obj for obj in (ref() for ref in self._storage) if obj is not None
- )
- def __getitem__(self, index):
- try:
- obj = self._storage[index]
- except KeyError:
- raise IndexError("Index %s out of range" % index)
- else:
- return obj()
- class OrderedIdentitySet(IdentitySet):
- def __init__(self, iterable: Optional[Iterable[Any]] = None):
- IdentitySet.__init__(self)
- self._members = OrderedDict()
- if iterable:
- for o in iterable:
- self.add(o)
- class PopulateDict(Dict[_KT, _VT]):
- """A dict which populates missing values via a creation function.
- Note the creation function takes a key, unlike
- collections.defaultdict.
- """
- def __init__(self, creator: Callable[[_KT], _VT]):
- self.creator = creator
- def __missing__(self, key: Any) -> Any:
- self[key] = val = self.creator(key)
- return val
- class WeakPopulateDict(Dict[_KT, _VT]):
- """Like PopulateDict, but assumes a self + a method and does not create
- a reference cycle.
- """
- def __init__(self, creator_method: types.MethodType):
- self.creator = creator_method.__func__
- weakself = creator_method.__self__
- self.weakself = weakref.ref(weakself)
- def __missing__(self, key: Any) -> Any:
- self[key] = val = self.creator(self.weakself(), key)
- return val
- # Define collections that are capable of storing
- # ColumnElement objects as hashable keys/elements.
- # At this point, these are mostly historical, things
- # used to be more complicated.
- column_set = set
- column_dict = dict
- ordered_column_set = OrderedSet
- class UniqueAppender(Generic[_T]):
- """Appends items to a collection ensuring uniqueness.
- Additional appends() of the same object are ignored. Membership is
- determined by identity (``is a``) not equality (``==``).
- """
- __slots__ = "data", "_data_appender", "_unique"
- data: Union[Iterable[_T], Set[_T], List[_T]]
- _data_appender: Callable[[_T], None]
- _unique: Dict[int, Literal[True]]
- def __init__(
- self,
- data: Union[Iterable[_T], Set[_T], List[_T]],
- via: Optional[str] = None,
- ):
- self.data = data
- self._unique = {}
- if via:
- self._data_appender = getattr(data, via)
- elif hasattr(data, "append"):
- self._data_appender = cast("List[_T]", data).append
- elif hasattr(data, "add"):
- self._data_appender = cast("Set[_T]", data).add
- def append(self, item: _T) -> None:
- id_ = id(item)
- if id_ not in self._unique:
- self._data_appender(item)
- self._unique[id_] = True
- def __iter__(self) -> Iterator[_T]:
- return iter(self.data)
- def coerce_generator_arg(arg: Any) -> List[Any]:
- if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
- return list(arg[0])
- else:
- return cast("List[Any]", arg)
- def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
- if x is None:
- return default # type: ignore
- if not is_non_string_iterable(x):
- return [x]
- elif isinstance(x, list):
- return x
- else:
- return list(x)
- def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool:
- r"""return True if any items of set\_ are present in iterable.
- Goes through special effort to ensure __hash__ is not called
- on items in iterable that don't support it.
- """
- return any(i in set_ for i in iterable if i.__hash__)
- def to_set(x):
- if x is None:
- return set()
- if not isinstance(x, set):
- return set(to_list(x))
- else:
- return x
- def to_column_set(x: Any) -> Set[Any]:
- if x is None:
- return column_set()
- if not isinstance(x, column_set):
- return column_set(to_list(x))
- else:
- return x
- def update_copy(
- d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any
- ) -> Dict[Any, Any]:
- """Copy the given dict and update with the given values."""
- d = d.copy()
- if _new:
- d.update(_new)
- d.update(**kw)
- return d
- def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]:
- """Given an iterator of which further sub-elements may also be
- iterators, flatten the sub-elements into a single iterator.
- """
- elem: _T
- for elem in x:
- if not isinstance(elem, str) and hasattr(elem, "__iter__"):
- yield from flatten_iterator(elem)
- else:
- yield elem
- class LRUCache(typing.MutableMapping[_KT, _VT]):
- """Dictionary with 'squishy' removal of least
- recently used items.
- Note that either get() or [] should be used here, but
- generally its not safe to do an "in" check first as the dictionary
- can change subsequent to that call.
- """
- __slots__ = (
- "capacity",
- "threshold",
- "size_alert",
- "_data",
- "_counter",
- "_mutex",
- )
- capacity: int
- threshold: float
- size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]]
- def __init__(
- self,
- capacity: int = 100,
- threshold: float = 0.5,
- size_alert: Optional[Callable[..., None]] = None,
- ):
- self.capacity = capacity
- self.threshold = threshold
- self.size_alert = size_alert
- self._counter = 0
- self._mutex = threading.Lock()
- self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {}
- def _inc_counter(self):
- self._counter += 1
- return self._counter
- @overload
- def get(self, key: _KT) -> Optional[_VT]: ...
- @overload
- def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ...
- def get(
- self, key: _KT, default: Optional[Union[_VT, _T]] = None
- ) -> Optional[Union[_VT, _T]]:
- item = self._data.get(key)
- if item is not None:
- item[2][0] = self._inc_counter()
- return item[1]
- else:
- return default
- def __getitem__(self, key: _KT) -> _VT:
- item = self._data[key]
- item[2][0] = self._inc_counter()
- return item[1]
- def __iter__(self) -> Iterator[_KT]:
- return iter(self._data)
- def __len__(self) -> int:
- return len(self._data)
- def values(self) -> ValuesView[_VT]:
- return typing.ValuesView({k: i[1] for k, i in self._data.items()})
- def __setitem__(self, key: _KT, value: _VT) -> None:
- self._data[key] = (key, value, [self._inc_counter()])
- self._manage_size()
- def __delitem__(self, __v: _KT) -> None:
- del self._data[__v]
- @property
- def size_threshold(self) -> float:
- return self.capacity + self.capacity * self.threshold
- def _manage_size(self) -> None:
- if not self._mutex.acquire(False):
- return
- try:
- size_alert = bool(self.size_alert)
- while len(self) > self.capacity + self.capacity * self.threshold:
- if size_alert:
- size_alert = False
- self.size_alert(self) # type: ignore
- by_counter = sorted(
- self._data.values(),
- key=operator.itemgetter(2),
- reverse=True,
- )
- for item in by_counter[self.capacity :]:
- try:
- del self._data[item[0]]
- except KeyError:
- # deleted elsewhere; skip
- continue
- finally:
- self._mutex.release()
- class _CreateFuncType(Protocol[_T_co]):
- def __call__(self) -> _T_co: ...
- class _ScopeFuncType(Protocol):
- def __call__(self) -> Any: ...
- class ScopedRegistry(Generic[_T]):
- """A Registry that can store one or multiple instances of a single
- class on the basis of a "scope" function.
- The object implements ``__call__`` as the "getter", so by
- calling ``myregistry()`` the contained object is returned
- for the current scope.
- :param createfunc:
- a callable that returns a new object to be placed in the registry
- :param scopefunc:
- a callable that will return a key to store/retrieve an object.
- """
- __slots__ = "createfunc", "scopefunc", "registry"
- createfunc: _CreateFuncType[_T]
- scopefunc: _ScopeFuncType
- registry: Any
- def __init__(
- self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any]
- ):
- """Construct a new :class:`.ScopedRegistry`.
- :param createfunc: A creation function that will generate
- a new value for the current scope, if none is present.
- :param scopefunc: A function that returns a hashable
- token representing the current scope (such as, current
- thread identifier).
- """
- self.createfunc = createfunc
- self.scopefunc = scopefunc
- self.registry = {}
- def __call__(self) -> _T:
- key = self.scopefunc()
- try:
- return self.registry[key] # type: ignore[no-any-return]
- except KeyError:
- return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501
- def has(self) -> bool:
- """Return True if an object is present in the current scope."""
- return self.scopefunc() in self.registry
- def set(self, obj: _T) -> None:
- """Set the value for the current scope."""
- self.registry[self.scopefunc()] = obj
- def clear(self) -> None:
- """Clear the current scope, if any."""
- try:
- del self.registry[self.scopefunc()]
- except KeyError:
- pass
- class ThreadLocalRegistry(ScopedRegistry[_T]):
- """A :class:`.ScopedRegistry` that uses a ``threading.local()``
- variable for storage.
- """
- def __init__(self, createfunc: Callable[[], _T]):
- self.createfunc = createfunc
- self.registry = threading.local()
- def __call__(self) -> _T:
- try:
- return self.registry.value # type: ignore[no-any-return]
- except AttributeError:
- val = self.registry.value = self.createfunc()
- return val
- def has(self) -> bool:
- return hasattr(self.registry, "value")
- def set(self, obj: _T) -> None:
- self.registry.value = obj
- def clear(self) -> None:
- try:
- del self.registry.value
- except AttributeError:
- pass
- def has_dupes(sequence, target):
- """Given a sequence and search object, return True if there's more
- than one, False if zero or one of them.
- """
- # compare to .index version below, this version introduces less function
- # overhead and is usually the same speed. At 15000 items (way bigger than
- # a relationship-bound collection in memory usually is) it begins to
- # fall behind the other version only by microseconds.
- c = 0
- for item in sequence:
- if item is target:
- c += 1
- if c > 1:
- return True
- return False
- # .index version. the two __contains__ calls as well
- # as .index() and isinstance() slow this down.
- # def has_dupes(sequence, target):
- # if target not in sequence:
- # return False
- # elif not isinstance(sequence, collections_abc.Sequence):
- # return False
- #
- # idx = sequence.index(target)
- # return target in sequence[idx + 1:]
|