| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541 |
- # util/_py_collections.py
- # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- # mypy: allow-untyped-defs, allow-untyped-calls
- from __future__ import annotations
- from itertools import filterfalse
- from typing import AbstractSet
- from typing import Any
- from typing import Callable
- from typing import cast
- from typing import Collection
- from typing import Dict
- from typing import Iterable
- from typing import Iterator
- from typing import List
- from typing import Mapping
- from typing import NoReturn
- from typing import Optional
- from typing import Set
- from typing import Tuple
- from typing import TYPE_CHECKING
- from typing import TypeVar
- from typing import Union
- from ..util.typing import Self
- _T = TypeVar("_T", bound=Any)
- _S = TypeVar("_S", bound=Any)
- _KT = TypeVar("_KT", bound=Any)
- _VT = TypeVar("_VT", bound=Any)
- class ReadOnlyContainer:
- __slots__ = ()
- def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
- raise TypeError(
- "%s object is immutable and/or readonly" % self.__class__.__name__
- )
- def _immutable(self, *arg: Any, **kw: Any) -> NoReturn:
- raise TypeError("%s object is immutable" % self.__class__.__name__)
- def __delitem__(self, key: Any) -> NoReturn:
- self._readonly()
- def __setitem__(self, key: Any, value: Any) -> NoReturn:
- self._readonly()
- def __setattr__(self, key: str, value: Any) -> NoReturn:
- self._readonly()
- class ImmutableDictBase(ReadOnlyContainer, Dict[_KT, _VT]):
- if TYPE_CHECKING:
- def __new__(cls, *args: Any) -> Self: ...
- def __init__(cls, *args: Any): ...
- def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
- self._immutable()
- def clear(self) -> NoReturn:
- self._readonly()
- def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn:
- self._readonly()
- def popitem(self) -> NoReturn:
- self._readonly()
- def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn:
- self._readonly()
- def update(self, *arg: Any, **kw: Any) -> NoReturn:
- self._readonly()
- class immutabledict(ImmutableDictBase[_KT, _VT]):
- def __new__(cls, *args):
- new = ImmutableDictBase.__new__(cls)
- dict.__init__(new, *args)
- return new
- def __init__(
- self, *args: Union[Mapping[_KT, _VT], Iterable[Tuple[_KT, _VT]]]
- ):
- pass
- def __reduce__(self):
- return immutabledict, (dict(self),)
- def union(
- self, __d: Optional[Mapping[_KT, _VT]] = None
- ) -> immutabledict[_KT, _VT]:
- if not __d:
- return self
- new = ImmutableDictBase.__new__(self.__class__)
- dict.__init__(new, self)
- dict.update(new, __d)
- return new
- def _union_w_kw(
- self, __d: Optional[Mapping[_KT, _VT]] = None, **kw: _VT
- ) -> immutabledict[_KT, _VT]:
- # not sure if C version works correctly w/ this yet
- if not __d and not kw:
- return self
- new = ImmutableDictBase.__new__(self.__class__)
- dict.__init__(new, self)
- if __d:
- dict.update(new, __d)
- dict.update(new, kw)
- return new
- def merge_with(
- self, *dicts: Optional[Mapping[_KT, _VT]]
- ) -> immutabledict[_KT, _VT]:
- new = None
- for d in dicts:
- if d:
- if new is None:
- new = ImmutableDictBase.__new__(self.__class__)
- dict.__init__(new, self)
- dict.update(new, d)
- if new is None:
- return self
- return new
- def __repr__(self) -> str:
- return "immutabledict(%s)" % dict.__repr__(self)
- # PEP 584
- def __ior__(self, __value: Any) -> NoReturn: # type: ignore
- self._readonly()
- def __or__( # type: ignore[override]
- self, __value: Mapping[_KT, _VT]
- ) -> immutabledict[_KT, _VT]:
- return immutabledict(
- super().__or__(__value), # type: ignore[call-overload]
- )
- def __ror__( # type: ignore[override]
- self, __value: Mapping[_KT, _VT]
- ) -> immutabledict[_KT, _VT]:
- return immutabledict(
- super().__ror__(__value), # type: ignore[call-overload]
- )
- class OrderedSet(Set[_T]):
- __slots__ = ("_list",)
- _list: List[_T]
- def __init__(self, d: Optional[Iterable[_T]] = None) -> None:
- if d is not None:
- self._list = unique_list(d)
- super().update(self._list)
- else:
- self._list = []
- def copy(self) -> OrderedSet[_T]:
- cp = self.__class__()
- cp._list = self._list.copy()
- set.update(cp, cp._list)
- return cp
- def add(self, element: _T) -> None:
- if element not in self:
- self._list.append(element)
- super().add(element)
- def remove(self, element: _T) -> None:
- super().remove(element)
- self._list.remove(element)
- def pop(self) -> _T:
- try:
- value = self._list.pop()
- except IndexError:
- raise KeyError("pop from an empty set") from None
- super().remove(value)
- return value
- def insert(self, pos: int, element: _T) -> None:
- if element not in self:
- self._list.insert(pos, element)
- super().add(element)
- def discard(self, element: _T) -> None:
- if element in self:
- self._list.remove(element)
- super().remove(element)
- def clear(self) -> None:
- super().clear()
- self._list = []
- def __getitem__(self, key: int) -> _T:
- return self._list[key]
- def __iter__(self) -> Iterator[_T]:
- return iter(self._list)
- def __add__(self, other: Iterator[_T]) -> OrderedSet[_T]:
- return self.union(other)
- def __repr__(self) -> str:
- return "%s(%r)" % (self.__class__.__name__, self._list)
- __str__ = __repr__
- def update(self, *iterables: Iterable[_T]) -> None:
- for iterable in iterables:
- for e in iterable:
- if e not in self:
- self._list.append(e)
- super().add(e)
- def __ior__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
- self.update(other)
- return self
- def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]:
- result: OrderedSet[Union[_T, _S]] = self.copy()
- result.update(*other)
- return result
- def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
- return self.union(other)
- def intersection(self, *other: Iterable[Any]) -> OrderedSet[_T]:
- other_set: Set[Any] = set()
- other_set.update(*other)
- return self.__class__(a for a in self if a in other_set)
- def __and__(self, other: AbstractSet[object]) -> OrderedSet[_T]:
- return self.intersection(other)
- def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]:
- collection: Collection[_T]
- if isinstance(other, set):
- collection = other_set = other
- elif isinstance(other, Collection):
- collection = other
- other_set = set(other)
- else:
- collection = list(other)
- other_set = set(collection)
- result = self.__class__(a for a in self if a not in other_set)
- result.update(a for a in collection if a not in self)
- return result
- def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
- return cast(OrderedSet[Union[_T, _S]], self).symmetric_difference(
- other
- )
- def difference(self, *other: Iterable[Any]) -> OrderedSet[_T]:
- other_set = super().difference(*other)
- return self.__class__(a for a in self._list if a in other_set)
- def __sub__(self, other: AbstractSet[Optional[_T]]) -> OrderedSet[_T]:
- return self.difference(other)
- def intersection_update(self, *other: Iterable[Any]) -> None:
- super().intersection_update(*other)
- self._list = [a for a in self._list if a in self]
- def __iand__(self, other: AbstractSet[object]) -> OrderedSet[_T]:
- self.intersection_update(other)
- return self
- def symmetric_difference_update(self, other: Iterable[Any]) -> None:
- collection = other if isinstance(other, Collection) else list(other)
- super().symmetric_difference_update(collection)
- self._list = [a for a in self._list if a in self]
- self._list += [a for a in collection if a in self]
- def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
- self.symmetric_difference_update(other)
- return cast(OrderedSet[Union[_T, _S]], self)
- def difference_update(self, *other: Iterable[Any]) -> None:
- super().difference_update(*other)
- self._list = [a for a in self._list if a in self]
- def __isub__(self, other: AbstractSet[Optional[_T]]) -> OrderedSet[_T]: # type: ignore # noqa: E501
- self.difference_update(other)
- return self
- class IdentitySet:
- """A set that considers only object id() for uniqueness.
- This strategy has edge cases for builtin types- it's possible to have
- two 'foo' strings in one of these sets, for example. Use sparingly.
- """
- _members: Dict[int, Any]
- def __init__(self, iterable: Optional[Iterable[Any]] = None):
- self._members = dict()
- if iterable:
- self.update(iterable)
- def add(self, value: Any) -> None:
- self._members[id(value)] = value
- def __contains__(self, value: Any) -> bool:
- return id(value) in self._members
- def remove(self, value: Any) -> None:
- del self._members[id(value)]
- def discard(self, value: Any) -> None:
- try:
- self.remove(value)
- except KeyError:
- pass
- def pop(self) -> Any:
- try:
- pair = self._members.popitem()
- return pair[1]
- except KeyError:
- raise KeyError("pop from an empty set")
- def clear(self) -> None:
- self._members.clear()
- def __eq__(self, other: Any) -> bool:
- if isinstance(other, IdentitySet):
- return self._members == other._members
- else:
- return False
- def __ne__(self, other: Any) -> bool:
- if isinstance(other, IdentitySet):
- return self._members != other._members
- else:
- return True
- def issubset(self, iterable: Iterable[Any]) -> bool:
- if isinstance(iterable, self.__class__):
- other = iterable
- else:
- other = self.__class__(iterable)
- if len(self) > len(other):
- return False
- for m in filterfalse(
- other._members.__contains__, iter(self._members.keys())
- ):
- return False
- return True
- def __le__(self, other: Any) -> bool:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- return self.issubset(other)
- def __lt__(self, other: Any) -> bool:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- return len(self) < len(other) and self.issubset(other)
- def issuperset(self, iterable: Iterable[Any]) -> bool:
- if isinstance(iterable, self.__class__):
- other = iterable
- else:
- other = self.__class__(iterable)
- if len(self) < len(other):
- return False
- for m in filterfalse(
- self._members.__contains__, iter(other._members.keys())
- ):
- return False
- return True
- def __ge__(self, other: Any) -> bool:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- return self.issuperset(other)
- def __gt__(self, other: Any) -> bool:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- return len(self) > len(other) and self.issuperset(other)
- def union(self, iterable: Iterable[Any]) -> IdentitySet:
- result = self.__class__()
- members = self._members
- result._members.update(members)
- result._members.update((id(obj), obj) for obj in iterable)
- return result
- def __or__(self, other: Any) -> IdentitySet:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- return self.union(other)
- def update(self, iterable: Iterable[Any]) -> None:
- self._members.update((id(obj), obj) for obj in iterable)
- def __ior__(self, other: Any) -> IdentitySet:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- self.update(other)
- return self
- def difference(self, iterable: Iterable[Any]) -> IdentitySet:
- result = self.__new__(self.__class__)
- other: Collection[Any]
- if isinstance(iterable, self.__class__):
- other = iterable._members
- else:
- other = {id(obj) for obj in iterable}
- result._members = {
- k: v for k, v in self._members.items() if k not in other
- }
- return result
- def __sub__(self, other: IdentitySet) -> IdentitySet:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- return self.difference(other)
- def difference_update(self, iterable: Iterable[Any]) -> None:
- self._members = self.difference(iterable)._members
- def __isub__(self, other: IdentitySet) -> IdentitySet:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- self.difference_update(other)
- return self
- def intersection(self, iterable: Iterable[Any]) -> IdentitySet:
- result = self.__new__(self.__class__)
- other: Collection[Any]
- if isinstance(iterable, self.__class__):
- other = iterable._members
- else:
- other = {id(obj) for obj in iterable}
- result._members = {
- k: v for k, v in self._members.items() if k in other
- }
- return result
- def __and__(self, other: IdentitySet) -> IdentitySet:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- return self.intersection(other)
- def intersection_update(self, iterable: Iterable[Any]) -> None:
- self._members = self.intersection(iterable)._members
- def __iand__(self, other: IdentitySet) -> IdentitySet:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- self.intersection_update(other)
- return self
- def symmetric_difference(self, iterable: Iterable[Any]) -> IdentitySet:
- result = self.__new__(self.__class__)
- if isinstance(iterable, self.__class__):
- other = iterable._members
- else:
- other = {id(obj): obj for obj in iterable}
- result._members = {
- k: v for k, v in self._members.items() if k not in other
- }
- result._members.update(
- (k, v) for k, v in other.items() if k not in self._members
- )
- return result
- def __xor__(self, other: IdentitySet) -> IdentitySet:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- return self.symmetric_difference(other)
- def symmetric_difference_update(self, iterable: Iterable[Any]) -> None:
- self._members = self.symmetric_difference(iterable)._members
- def __ixor__(self, other: IdentitySet) -> IdentitySet:
- if not isinstance(other, IdentitySet):
- return NotImplemented
- self.symmetric_difference(other)
- return self
- def copy(self) -> IdentitySet:
- result = self.__new__(self.__class__)
- result._members = self._members.copy()
- return result
- __copy__ = copy
- def __len__(self) -> int:
- return len(self._members)
- def __iter__(self) -> Iterator[Any]:
- return iter(self._members.values())
- def __hash__(self) -> NoReturn:
- raise TypeError("set objects are unhashable")
- def __repr__(self) -> str:
- return "%s(%r)" % (type(self).__name__, list(self._members.values()))
- def unique_list(
- seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None
- ) -> List[_T]:
- seen: Set[Any] = set()
- seen_add = seen.add
- if not hashfunc:
- return [x for x in seq if x not in seen and not seen_add(x)]
- else:
- return [
- x
- for x in seq
- if hashfunc(x) not in seen and not seen_add(hashfunc(x))
- ]
|