collections.pyx 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. # cyextension/collections.pyx
  2. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. cimport cython
  8. from cpython.long cimport PyLong_FromLongLong
  9. from cpython.set cimport PySet_Add
  10. from collections.abc import Collection
  11. from itertools import filterfalse
  12. cdef bint add_not_present(set seen, object item, hashfunc):
  13. hash_value = hashfunc(item)
  14. if hash_value not in seen:
  15. PySet_Add(seen, hash_value)
  16. return True
  17. else:
  18. return False
  19. cdef list cunique_list(seq, hashfunc=None):
  20. cdef set seen = set()
  21. if not hashfunc:
  22. return [x for x in seq if x not in seen and not PySet_Add(seen, x)]
  23. else:
  24. return [x for x in seq if add_not_present(seen, x, hashfunc)]
  25. def unique_list(seq, hashfunc=None):
  26. return cunique_list(seq, hashfunc)
  27. cdef class OrderedSet(set):
  28. cdef list _list
  29. @classmethod
  30. def __class_getitem__(cls, key):
  31. return cls
  32. def __init__(self, d=None):
  33. set.__init__(self)
  34. if d is not None:
  35. self._list = cunique_list(d)
  36. set.update(self, self._list)
  37. else:
  38. self._list = []
  39. cpdef OrderedSet copy(self):
  40. cdef OrderedSet cp = OrderedSet.__new__(OrderedSet)
  41. cp._list = list(self._list)
  42. set.update(cp, cp._list)
  43. return cp
  44. @cython.final
  45. cdef OrderedSet _from_list(self, list new_list):
  46. cdef OrderedSet new = OrderedSet.__new__(OrderedSet)
  47. new._list = new_list
  48. set.update(new, new_list)
  49. return new
  50. def add(self, element):
  51. if element not in self:
  52. self._list.append(element)
  53. PySet_Add(self, element)
  54. def remove(self, element):
  55. # set.remove will raise if element is not in self
  56. set.remove(self, element)
  57. self._list.remove(element)
  58. def pop(self):
  59. try:
  60. value = self._list.pop()
  61. except IndexError:
  62. raise KeyError("pop from an empty set") from None
  63. set.remove(self, value)
  64. return value
  65. def insert(self, Py_ssize_t pos, element):
  66. if element not in self:
  67. self._list.insert(pos, element)
  68. PySet_Add(self, element)
  69. def discard(self, element):
  70. if element in self:
  71. set.remove(self, element)
  72. self._list.remove(element)
  73. def clear(self):
  74. set.clear(self)
  75. self._list = []
  76. def __getitem__(self, key):
  77. return self._list[key]
  78. def __iter__(self):
  79. return iter(self._list)
  80. def __add__(self, other):
  81. return self.union(other)
  82. def __repr__(self):
  83. return "%s(%r)" % (self.__class__.__name__, self._list)
  84. __str__ = __repr__
  85. def update(self, *iterables):
  86. for iterable in iterables:
  87. for e in iterable:
  88. if e not in self:
  89. self._list.append(e)
  90. set.add(self, e)
  91. def __ior__(self, iterable):
  92. self.update(iterable)
  93. return self
  94. def union(self, *other):
  95. result = self.copy()
  96. result.update(*other)
  97. return result
  98. def __or__(self, other):
  99. return self.union(other)
  100. def intersection(self, *other):
  101. cdef set other_set = set.intersection(self, *other)
  102. return self._from_list([a for a in self._list if a in other_set])
  103. def __and__(self, other):
  104. return self.intersection(other)
  105. def symmetric_difference(self, other):
  106. cdef set other_set
  107. if isinstance(other, set):
  108. other_set = <set> other
  109. collection = other_set
  110. elif isinstance(other, Collection):
  111. collection = other
  112. other_set = set(other)
  113. else:
  114. collection = list(other)
  115. other_set = set(collection)
  116. result = self._from_list([a for a in self._list if a not in other_set])
  117. result.update(a for a in collection if a not in self)
  118. return result
  119. def __xor__(self, other):
  120. return self.symmetric_difference(other)
  121. def difference(self, *other):
  122. cdef set other_set = set.difference(self, *other)
  123. return self._from_list([a for a in self._list if a in other_set])
  124. def __sub__(self, other):
  125. return self.difference(other)
  126. def intersection_update(self, *other):
  127. set.intersection_update(self, *other)
  128. self._list = [a for a in self._list if a in self]
  129. def __iand__(self, other):
  130. self.intersection_update(other)
  131. return self
  132. cpdef symmetric_difference_update(self, other):
  133. collection = other if isinstance(other, Collection) else list(other)
  134. set.symmetric_difference_update(self, collection)
  135. self._list = [a for a in self._list if a in self]
  136. self._list += [a for a in collection if a in self]
  137. def __ixor__(self, other):
  138. self.symmetric_difference_update(other)
  139. return self
  140. def difference_update(self, *other):
  141. set.difference_update(self, *other)
  142. self._list = [a for a in self._list if a in self]
  143. def __isub__(self, other):
  144. self.difference_update(other)
  145. return self
  146. cdef object cy_id(object item):
  147. return PyLong_FromLongLong(<long long> (<void *>item))
  148. # NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped
  149. # instead of the __rmeth__, so they need to check that also self is of the
  150. # correct type. This is fixed in cython 3.x. See:
  151. # https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods
  152. cdef class IdentitySet:
  153. """A set that considers only object id() for uniqueness.
  154. This strategy has edge cases for builtin types- it's possible to have
  155. two 'foo' strings in one of these sets, for example. Use sparingly.
  156. """
  157. cdef dict _members
  158. def __init__(self, iterable=None):
  159. self._members = {}
  160. if iterable:
  161. self.update(iterable)
  162. def add(self, value):
  163. self._members[cy_id(value)] = value
  164. def __contains__(self, value):
  165. return cy_id(value) in self._members
  166. cpdef remove(self, value):
  167. del self._members[cy_id(value)]
  168. def discard(self, value):
  169. try:
  170. self.remove(value)
  171. except KeyError:
  172. pass
  173. def pop(self):
  174. cdef tuple pair
  175. try:
  176. pair = self._members.popitem()
  177. return pair[1]
  178. except KeyError:
  179. raise KeyError("pop from an empty set")
  180. def clear(self):
  181. self._members.clear()
  182. def __eq__(self, other):
  183. cdef IdentitySet other_
  184. if isinstance(other, IdentitySet):
  185. other_ = other
  186. return self._members == other_._members
  187. else:
  188. return False
  189. def __ne__(self, other):
  190. cdef IdentitySet other_
  191. if isinstance(other, IdentitySet):
  192. other_ = other
  193. return self._members != other_._members
  194. else:
  195. return True
  196. cpdef issubset(self, iterable):
  197. cdef IdentitySet other
  198. if isinstance(iterable, self.__class__):
  199. other = iterable
  200. else:
  201. other = self.__class__(iterable)
  202. if len(self) > len(other):
  203. return False
  204. for m in filterfalse(other._members.__contains__, self._members):
  205. return False
  206. return True
  207. def __le__(self, other):
  208. if not isinstance(other, IdentitySet):
  209. return NotImplemented
  210. return self.issubset(other)
  211. def __lt__(self, other):
  212. if not isinstance(other, IdentitySet):
  213. return NotImplemented
  214. return len(self) < len(other) and self.issubset(other)
  215. cpdef issuperset(self, iterable):
  216. cdef IdentitySet other
  217. if isinstance(iterable, self.__class__):
  218. other = iterable
  219. else:
  220. other = self.__class__(iterable)
  221. if len(self) < len(other):
  222. return False
  223. for m in filterfalse(self._members.__contains__, other._members):
  224. return False
  225. return True
  226. def __ge__(self, other):
  227. if not isinstance(other, IdentitySet):
  228. return NotImplemented
  229. return self.issuperset(other)
  230. def __gt__(self, other):
  231. if not isinstance(other, IdentitySet):
  232. return NotImplemented
  233. return len(self) > len(other) and self.issuperset(other)
  234. cpdef IdentitySet union(self, iterable):
  235. cdef IdentitySet result = self.__class__()
  236. result._members.update(self._members)
  237. result.update(iterable)
  238. return result
  239. def __or__(self, other):
  240. if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet):
  241. return NotImplemented
  242. return self.union(other)
  243. cpdef update(self, iterable):
  244. for obj in iterable:
  245. self._members[cy_id(obj)] = obj
  246. def __ior__(self, other):
  247. if not isinstance(other, IdentitySet):
  248. return NotImplemented
  249. self.update(other)
  250. return self
  251. cpdef IdentitySet difference(self, iterable):
  252. cdef IdentitySet result = self.__new__(self.__class__)
  253. if isinstance(iterable, self.__class__):
  254. other = (<IdentitySet>iterable)._members
  255. else:
  256. other = {cy_id(obj) for obj in iterable}
  257. result._members = {k:v for k, v in self._members.items() if k not in other}
  258. return result
  259. def __sub__(self, other):
  260. if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet):
  261. return NotImplemented
  262. return self.difference(other)
  263. cpdef difference_update(self, iterable):
  264. cdef IdentitySet other = self.difference(iterable)
  265. self._members = other._members
  266. def __isub__(self, other):
  267. if not isinstance(other, IdentitySet):
  268. return NotImplemented
  269. self.difference_update(other)
  270. return self
  271. cpdef IdentitySet intersection(self, iterable):
  272. cdef IdentitySet result = self.__new__(self.__class__)
  273. if isinstance(iterable, self.__class__):
  274. other = (<IdentitySet>iterable)._members
  275. else:
  276. other = {cy_id(obj) for obj in iterable}
  277. result._members = {k: v for k, v in self._members.items() if k in other}
  278. return result
  279. def __and__(self, other):
  280. if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet):
  281. return NotImplemented
  282. return self.intersection(other)
  283. cpdef intersection_update(self, iterable):
  284. cdef IdentitySet other = self.intersection(iterable)
  285. self._members = other._members
  286. def __iand__(self, other):
  287. if not isinstance(other, IdentitySet):
  288. return NotImplemented
  289. self.intersection_update(other)
  290. return self
  291. cpdef IdentitySet symmetric_difference(self, iterable):
  292. cdef IdentitySet result = self.__new__(self.__class__)
  293. cdef dict other
  294. if isinstance(iterable, self.__class__):
  295. other = (<IdentitySet>iterable)._members
  296. else:
  297. other = {cy_id(obj): obj for obj in iterable}
  298. result._members = {k: v for k, v in self._members.items() if k not in other}
  299. result._members.update(
  300. [(k, v) for k, v in other.items() if k not in self._members]
  301. )
  302. return result
  303. def __xor__(self, other):
  304. if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet):
  305. return NotImplemented
  306. return self.symmetric_difference(other)
  307. cpdef symmetric_difference_update(self, iterable):
  308. cdef IdentitySet other = self.symmetric_difference(iterable)
  309. self._members = other._members
  310. def __ixor__(self, other):
  311. if not isinstance(other, IdentitySet):
  312. return NotImplemented
  313. self.symmetric_difference(other)
  314. return self
  315. cpdef IdentitySet copy(self):
  316. cdef IdentitySet cp = self.__new__(self.__class__)
  317. cp._members = self._members.copy()
  318. return cp
  319. def __copy__(self):
  320. return self.copy()
  321. def __len__(self):
  322. return len(self._members)
  323. def __iter__(self):
  324. return iter(self._members.values())
  325. def __hash__(self):
  326. raise TypeError("set objects are unhashable")
  327. def __repr__(self):
  328. return "%s(%r)" % (type(self).__name__, list(self._members.values()))