traversals.py 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024
  1. # sql/traversals.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: allow-untyped-defs, allow-untyped-calls
  8. from __future__ import annotations
  9. from collections import deque
  10. import collections.abc as collections_abc
  11. import itertools
  12. from itertools import zip_longest
  13. import operator
  14. import typing
  15. from typing import Any
  16. from typing import Callable
  17. from typing import Deque
  18. from typing import Dict
  19. from typing import Iterable
  20. from typing import Optional
  21. from typing import Set
  22. from typing import Tuple
  23. from typing import Type
  24. from . import operators
  25. from .cache_key import HasCacheKey
  26. from .visitors import _TraverseInternalsType
  27. from .visitors import anon_map
  28. from .visitors import ExternallyTraversible
  29. from .visitors import HasTraversalDispatch
  30. from .visitors import HasTraverseInternals
  31. from .. import util
  32. from ..util import langhelpers
  33. from ..util.typing import Self
  34. SKIP_TRAVERSE = util.symbol("skip_traverse")
  35. COMPARE_FAILED = False
  36. COMPARE_SUCCEEDED = True
  37. def compare(obj1: Any, obj2: Any, **kw: Any) -> bool:
  38. strategy: TraversalComparatorStrategy
  39. if kw.get("use_proxies", False):
  40. strategy = ColIdentityComparatorStrategy()
  41. else:
  42. strategy = TraversalComparatorStrategy()
  43. return strategy.compare(obj1, obj2, **kw)
  44. def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
  45. for cls in util.walk_subclasses(target_hierarchy):
  46. if hasattr(cls, "_generate_cache_attrs") and hasattr(
  47. cls, "_traverse_internals"
  48. ):
  49. cls._generate_cache_attrs()
  50. _copy_internals.generate_dispatch(
  51. cls,
  52. cls._traverse_internals,
  53. "_generated_copy_internals_traversal",
  54. )
  55. _get_children.generate_dispatch(
  56. cls,
  57. cls._traverse_internals,
  58. "_generated_get_children_traversal",
  59. )
  60. class HasShallowCopy(HasTraverseInternals):
  61. """attribute-wide operations that are useful for classes that use
  62. __slots__ and therefore can't operate on their attributes in a dictionary.
  63. """
  64. __slots__ = ()
  65. if typing.TYPE_CHECKING:
  66. def _generated_shallow_copy_traversal(self, other: Self) -> None: ...
  67. def _generated_shallow_from_dict_traversal(
  68. self, d: Dict[str, Any]
  69. ) -> None: ...
  70. def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: ...
  71. @classmethod
  72. def _generate_shallow_copy(
  73. cls,
  74. internal_dispatch: _TraverseInternalsType,
  75. method_name: str,
  76. ) -> Callable[[Self, Self], None]:
  77. code = "\n".join(
  78. f" other.{attrname} = self.{attrname}"
  79. for attrname, _ in internal_dispatch
  80. )
  81. meth_text = f"def {method_name}(self, other):\n{code}\n"
  82. return langhelpers._exec_code_in_env(meth_text, {}, method_name)
  83. @classmethod
  84. def _generate_shallow_to_dict(
  85. cls,
  86. internal_dispatch: _TraverseInternalsType,
  87. method_name: str,
  88. ) -> Callable[[Self], Dict[str, Any]]:
  89. code = ",\n".join(
  90. f" '{attrname}': self.{attrname}"
  91. for attrname, _ in internal_dispatch
  92. )
  93. meth_text = f"def {method_name}(self):\n return {{{code}}}\n"
  94. return langhelpers._exec_code_in_env(meth_text, {}, method_name)
  95. @classmethod
  96. def _generate_shallow_from_dict(
  97. cls,
  98. internal_dispatch: _TraverseInternalsType,
  99. method_name: str,
  100. ) -> Callable[[Self, Dict[str, Any]], None]:
  101. code = "\n".join(
  102. f" self.{attrname} = d['{attrname}']"
  103. for attrname, _ in internal_dispatch
  104. )
  105. meth_text = f"def {method_name}(self, d):\n{code}\n"
  106. return langhelpers._exec_code_in_env(meth_text, {}, method_name)
  107. def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
  108. cls = self.__class__
  109. shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
  110. try:
  111. shallow_from_dict = cls.__dict__[
  112. "_generated_shallow_from_dict_traversal"
  113. ]
  114. except KeyError:
  115. shallow_from_dict = self._generate_shallow_from_dict(
  116. cls._traverse_internals,
  117. "_generated_shallow_from_dict_traversal",
  118. )
  119. cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa: E501
  120. shallow_from_dict(self, d)
  121. def _shallow_to_dict(self) -> Dict[str, Any]:
  122. cls = self.__class__
  123. shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
  124. try:
  125. shallow_to_dict = cls.__dict__[
  126. "_generated_shallow_to_dict_traversal"
  127. ]
  128. except KeyError:
  129. shallow_to_dict = self._generate_shallow_to_dict(
  130. cls._traverse_internals, "_generated_shallow_to_dict_traversal"
  131. )
  132. cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa: E501
  133. return shallow_to_dict(self)
  134. def _shallow_copy_to(self, other: Self) -> None:
  135. cls = self.__class__
  136. shallow_copy: Callable[[Self, Self], None]
  137. try:
  138. shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
  139. except KeyError:
  140. shallow_copy = self._generate_shallow_copy(
  141. cls._traverse_internals, "_generated_shallow_copy_traversal"
  142. )
  143. cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501
  144. shallow_copy(self, other)
  145. def _clone(self, **kw: Any) -> Self:
  146. """Create a shallow copy"""
  147. c = self.__class__.__new__(self.__class__)
  148. self._shallow_copy_to(c)
  149. return c
  150. class GenerativeOnTraversal(HasShallowCopy):
  151. """Supplies Generative behavior but making use of traversals to shallow
  152. copy.
  153. .. seealso::
  154. :class:`sqlalchemy.sql.base.Generative`
  155. """
  156. __slots__ = ()
  157. def _generate(self) -> Self:
  158. cls = self.__class__
  159. s = cls.__new__(cls)
  160. self._shallow_copy_to(s)
  161. return s
  162. def _clone(element, **kw):
  163. return element._clone()
  164. class HasCopyInternals(HasTraverseInternals):
  165. __slots__ = ()
  166. def _clone(self, **kw):
  167. raise NotImplementedError()
  168. def _copy_internals(
  169. self, *, omit_attrs: Iterable[str] = (), **kw: Any
  170. ) -> None:
  171. """Reassign internal elements to be clones of themselves.
  172. Called during a copy-and-traverse operation on newly
  173. shallow-copied elements to create a deep copy.
  174. The given clone function should be used, which may be applying
  175. additional transformations to the element (i.e. replacement
  176. traversal, cloned traversal, annotations).
  177. """
  178. try:
  179. traverse_internals = self._traverse_internals
  180. except AttributeError:
  181. # user-defined classes may not have a _traverse_internals
  182. return
  183. for attrname, obj, meth in _copy_internals.run_generated_dispatch(
  184. self, traverse_internals, "_generated_copy_internals_traversal"
  185. ):
  186. if attrname in omit_attrs:
  187. continue
  188. if obj is not None:
  189. result = meth(attrname, self, obj, **kw)
  190. if result is not None:
  191. setattr(self, attrname, result)
  192. class _CopyInternalsTraversal(HasTraversalDispatch):
  193. """Generate a _copy_internals internal traversal dispatch for classes
  194. with a _traverse_internals collection."""
  195. def visit_clauseelement(
  196. self, attrname, parent, element, clone=_clone, **kw
  197. ):
  198. return clone(element, **kw)
  199. def visit_clauseelement_list(
  200. self, attrname, parent, element, clone=_clone, **kw
  201. ):
  202. return [clone(clause, **kw) for clause in element]
  203. def visit_clauseelement_tuple(
  204. self, attrname, parent, element, clone=_clone, **kw
  205. ):
  206. return tuple([clone(clause, **kw) for clause in element])
  207. def visit_executable_options(
  208. self, attrname, parent, element, clone=_clone, **kw
  209. ):
  210. return tuple([clone(clause, **kw) for clause in element])
  211. def visit_clauseelement_unordered_set(
  212. self, attrname, parent, element, clone=_clone, **kw
  213. ):
  214. return {clone(clause, **kw) for clause in element}
  215. def visit_clauseelement_tuples(
  216. self, attrname, parent, element, clone=_clone, **kw
  217. ):
  218. return [
  219. tuple(clone(tup_elem, **kw) for tup_elem in elem)
  220. for elem in element
  221. ]
  222. def visit_string_clauseelement_dict(
  223. self, attrname, parent, element, clone=_clone, **kw
  224. ):
  225. return {key: clone(value, **kw) for key, value in element.items()}
  226. def visit_setup_join_tuple(
  227. self, attrname, parent, element, clone=_clone, **kw
  228. ):
  229. return tuple(
  230. (
  231. clone(target, **kw) if target is not None else None,
  232. clone(onclause, **kw) if onclause is not None else None,
  233. clone(from_, **kw) if from_ is not None else None,
  234. flags,
  235. )
  236. for (target, onclause, from_, flags) in element
  237. )
  238. def visit_memoized_select_entities(self, attrname, parent, element, **kw):
  239. return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
  240. def visit_dml_ordered_values(
  241. self, attrname, parent, element, clone=_clone, **kw
  242. ):
  243. # sequence of 2-tuples
  244. return [
  245. (
  246. (
  247. clone(key, **kw)
  248. if hasattr(key, "__clause_element__")
  249. else key
  250. ),
  251. clone(value, **kw),
  252. )
  253. for key, value in element
  254. ]
  255. def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
  256. return {
  257. (
  258. clone(key, **kw) if hasattr(key, "__clause_element__") else key
  259. ): clone(value, **kw)
  260. for key, value in element.items()
  261. }
  262. def visit_dml_multi_values(
  263. self, attrname, parent, element, clone=_clone, **kw
  264. ):
  265. # sequence of sequences, each sequence contains a list/dict/tuple
  266. def copy(elem):
  267. if isinstance(elem, (list, tuple)):
  268. return [
  269. (
  270. clone(value, **kw)
  271. if hasattr(value, "__clause_element__")
  272. else value
  273. )
  274. for value in elem
  275. ]
  276. elif isinstance(elem, dict):
  277. return {
  278. (
  279. clone(key, **kw)
  280. if hasattr(key, "__clause_element__")
  281. else key
  282. ): (
  283. clone(value, **kw)
  284. if hasattr(value, "__clause_element__")
  285. else value
  286. )
  287. for key, value in elem.items()
  288. }
  289. else:
  290. # TODO: use abc classes
  291. assert False
  292. return [
  293. [copy(sub_element) for sub_element in sequence]
  294. for sequence in element
  295. ]
  296. def visit_propagate_attrs(
  297. self, attrname, parent, element, clone=_clone, **kw
  298. ):
  299. return element
  300. _copy_internals = _CopyInternalsTraversal()
  301. def _flatten_clauseelement(element):
  302. while hasattr(element, "__clause_element__") and not getattr(
  303. element, "is_clause_element", False
  304. ):
  305. element = element.__clause_element__()
  306. return element
  307. class _GetChildrenTraversal(HasTraversalDispatch):
  308. """Generate a _children_traversal internal traversal dispatch for classes
  309. with a _traverse_internals collection."""
  310. def visit_has_cache_key(self, element, **kw):
  311. # the GetChildren traversal refers explicitly to ClauseElement
  312. # structures. Within these, a plain HasCacheKey is not a
  313. # ClauseElement, so don't include these.
  314. return ()
  315. def visit_clauseelement(self, element, **kw):
  316. return (element,)
  317. def visit_clauseelement_list(self, element, **kw):
  318. return element
  319. def visit_clauseelement_tuple(self, element, **kw):
  320. return element
  321. def visit_clauseelement_tuples(self, element, **kw):
  322. return itertools.chain.from_iterable(element)
  323. def visit_fromclause_canonical_column_collection(self, element, **kw):
  324. return ()
  325. def visit_string_clauseelement_dict(self, element, **kw):
  326. return element.values()
  327. def visit_fromclause_ordered_set(self, element, **kw):
  328. return element
  329. def visit_clauseelement_unordered_set(self, element, **kw):
  330. return element
  331. def visit_setup_join_tuple(self, element, **kw):
  332. for target, onclause, from_, flags in element:
  333. if from_ is not None:
  334. yield from_
  335. if not isinstance(target, str):
  336. yield _flatten_clauseelement(target)
  337. if onclause is not None and not isinstance(onclause, str):
  338. yield _flatten_clauseelement(onclause)
  339. def visit_memoized_select_entities(self, element, **kw):
  340. return self.visit_clauseelement_tuple(element, **kw)
  341. def visit_dml_ordered_values(self, element, **kw):
  342. for k, v in element:
  343. if hasattr(k, "__clause_element__"):
  344. yield k
  345. yield v
  346. def visit_dml_values(self, element, **kw):
  347. expr_values = {k for k in element if hasattr(k, "__clause_element__")}
  348. str_values = expr_values.symmetric_difference(element)
  349. for k in sorted(str_values):
  350. yield element[k]
  351. for k in expr_values:
  352. yield k
  353. yield element[k]
  354. def visit_dml_multi_values(self, element, **kw):
  355. return ()
  356. def visit_propagate_attrs(self, element, **kw):
  357. return ()
  358. _get_children = _GetChildrenTraversal()
  359. @util.preload_module("sqlalchemy.sql.elements")
  360. def _resolve_name_for_compare(element, name, anon_map, **kw):
  361. if isinstance(name, util.preloaded.sql_elements._anonymous_label):
  362. name = name.apply_map(anon_map)
  363. return name
  364. class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
  365. __slots__ = "stack", "cache", "anon_map"
  366. def __init__(self):
  367. self.stack: Deque[
  368. Tuple[
  369. Optional[ExternallyTraversible],
  370. Optional[ExternallyTraversible],
  371. ]
  372. ] = deque()
  373. self.cache = set()
  374. def _memoized_attr_anon_map(self):
  375. return (anon_map(), anon_map())
  376. def compare(
  377. self,
  378. obj1: ExternallyTraversible,
  379. obj2: ExternallyTraversible,
  380. **kw: Any,
  381. ) -> bool:
  382. stack = self.stack
  383. cache = self.cache
  384. compare_annotations = kw.get("compare_annotations", False)
  385. stack.append((obj1, obj2))
  386. while stack:
  387. left, right = stack.popleft()
  388. if left is right:
  389. continue
  390. elif left is None or right is None:
  391. # we know they are different so no match
  392. return False
  393. elif (left, right) in cache:
  394. continue
  395. cache.add((left, right))
  396. visit_name = left.__visit_name__
  397. if visit_name != right.__visit_name__:
  398. return False
  399. meth = getattr(self, "compare_%s" % visit_name, None)
  400. if meth:
  401. attributes_compared = meth(left, right, **kw)
  402. if attributes_compared is COMPARE_FAILED:
  403. return False
  404. elif attributes_compared is SKIP_TRAVERSE:
  405. continue
  406. # attributes_compared is returned as a list of attribute
  407. # names that were "handled" by the comparison method above.
  408. # remaining attribute names in the _traverse_internals
  409. # will be compared.
  410. else:
  411. attributes_compared = ()
  412. for (
  413. (left_attrname, left_visit_sym),
  414. (right_attrname, right_visit_sym),
  415. ) in zip_longest(
  416. left._traverse_internals,
  417. right._traverse_internals,
  418. fillvalue=(None, None),
  419. ):
  420. if not compare_annotations and (
  421. (left_attrname == "_annotations")
  422. or (right_attrname == "_annotations")
  423. ):
  424. continue
  425. if (
  426. left_attrname != right_attrname
  427. or left_visit_sym is not right_visit_sym
  428. ):
  429. return False
  430. elif left_attrname in attributes_compared:
  431. continue
  432. assert left_visit_sym is not None
  433. assert left_attrname is not None
  434. assert right_attrname is not None
  435. dispatch = self.dispatch(left_visit_sym)
  436. assert dispatch is not None, (
  437. f"{self.__class__} has no dispatch for "
  438. f"'{self._dispatch_lookup[left_visit_sym]}'"
  439. )
  440. left_child = operator.attrgetter(left_attrname)(left)
  441. right_child = operator.attrgetter(right_attrname)(right)
  442. if left_child is None:
  443. if right_child is not None:
  444. return False
  445. else:
  446. continue
  447. elif right_child is None:
  448. return False
  449. comparison = dispatch(
  450. left_attrname, left, left_child, right, right_child, **kw
  451. )
  452. if comparison is COMPARE_FAILED:
  453. return False
  454. return True
  455. def compare_inner(self, obj1, obj2, **kw):
  456. comparator = self.__class__()
  457. return comparator.compare(obj1, obj2, **kw)
  458. def visit_has_cache_key(
  459. self, attrname, left_parent, left, right_parent, right, **kw
  460. ):
  461. if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
  462. self.anon_map[1], []
  463. ):
  464. return COMPARE_FAILED
  465. def visit_propagate_attrs(
  466. self, attrname, left_parent, left, right_parent, right, **kw
  467. ):
  468. return self.compare_inner(
  469. left.get("plugin_subject", None), right.get("plugin_subject", None)
  470. )
  471. def visit_has_cache_key_list(
  472. self, attrname, left_parent, left, right_parent, right, **kw
  473. ):
  474. for l, r in zip_longest(left, right, fillvalue=None):
  475. if l is None:
  476. if r is not None:
  477. return COMPARE_FAILED
  478. else:
  479. continue
  480. elif r is None:
  481. return COMPARE_FAILED
  482. if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
  483. self.anon_map[1], []
  484. ):
  485. return COMPARE_FAILED
  486. def visit_executable_options(
  487. self, attrname, left_parent, left, right_parent, right, **kw
  488. ):
  489. for l, r in zip_longest(left, right, fillvalue=None):
  490. if l is None:
  491. if r is not None:
  492. return COMPARE_FAILED
  493. else:
  494. continue
  495. elif r is None:
  496. return COMPARE_FAILED
  497. if (
  498. l._gen_cache_key(self.anon_map[0], [])
  499. if l._is_has_cache_key
  500. else l
  501. ) != (
  502. r._gen_cache_key(self.anon_map[1], [])
  503. if r._is_has_cache_key
  504. else r
  505. ):
  506. return COMPARE_FAILED
  507. def visit_clauseelement(
  508. self, attrname, left_parent, left, right_parent, right, **kw
  509. ):
  510. self.stack.append((left, right))
  511. def visit_fromclause_canonical_column_collection(
  512. self, attrname, left_parent, left, right_parent, right, **kw
  513. ):
  514. for lcol, rcol in zip_longest(left, right, fillvalue=None):
  515. self.stack.append((lcol, rcol))
  516. def visit_fromclause_derived_column_collection(
  517. self, attrname, left_parent, left, right_parent, right, **kw
  518. ):
  519. pass
  520. def visit_string_clauseelement_dict(
  521. self, attrname, left_parent, left, right_parent, right, **kw
  522. ):
  523. for lstr, rstr in zip_longest(
  524. sorted(left), sorted(right), fillvalue=None
  525. ):
  526. if lstr != rstr:
  527. return COMPARE_FAILED
  528. self.stack.append((left[lstr], right[rstr]))
  529. def visit_clauseelement_tuples(
  530. self, attrname, left_parent, left, right_parent, right, **kw
  531. ):
  532. for ltup, rtup in zip_longest(left, right, fillvalue=None):
  533. if ltup is None or rtup is None:
  534. return COMPARE_FAILED
  535. for l, r in zip_longest(ltup, rtup, fillvalue=None):
  536. self.stack.append((l, r))
  537. def visit_clauseelement_list(
  538. self, attrname, left_parent, left, right_parent, right, **kw
  539. ):
  540. for l, r in zip_longest(left, right, fillvalue=None):
  541. self.stack.append((l, r))
  542. def visit_clauseelement_tuple(
  543. self, attrname, left_parent, left, right_parent, right, **kw
  544. ):
  545. for l, r in zip_longest(left, right, fillvalue=None):
  546. self.stack.append((l, r))
  547. def _compare_unordered_sequences(self, seq1, seq2, **kw):
  548. if seq1 is None:
  549. return seq2 is None
  550. completed: Set[object] = set()
  551. for clause in seq1:
  552. for other_clause in set(seq2).difference(completed):
  553. if self.compare_inner(clause, other_clause, **kw):
  554. completed.add(other_clause)
  555. break
  556. return len(completed) == len(seq1) == len(seq2)
  557. def visit_clauseelement_unordered_set(
  558. self, attrname, left_parent, left, right_parent, right, **kw
  559. ):
  560. return self._compare_unordered_sequences(left, right, **kw)
  561. def visit_fromclause_ordered_set(
  562. self, attrname, left_parent, left, right_parent, right, **kw
  563. ):
  564. for l, r in zip_longest(left, right, fillvalue=None):
  565. self.stack.append((l, r))
  566. def visit_string(
  567. self, attrname, left_parent, left, right_parent, right, **kw
  568. ):
  569. return left == right
  570. def visit_string_list(
  571. self, attrname, left_parent, left, right_parent, right, **kw
  572. ):
  573. return left == right
  574. def visit_string_multi_dict(
  575. self, attrname, left_parent, left, right_parent, right, **kw
  576. ):
  577. for lk, rk in zip_longest(
  578. sorted(left.keys()), sorted(right.keys()), fillvalue=(None, None)
  579. ):
  580. if lk != rk:
  581. return COMPARE_FAILED
  582. lv, rv = left[lk], right[rk]
  583. lhc = isinstance(left, HasCacheKey)
  584. rhc = isinstance(right, HasCacheKey)
  585. if lhc and rhc:
  586. if lv._gen_cache_key(
  587. self.anon_map[0], []
  588. ) != rv._gen_cache_key(self.anon_map[1], []):
  589. return COMPARE_FAILED
  590. elif lhc != rhc:
  591. return COMPARE_FAILED
  592. elif lv != rv:
  593. return COMPARE_FAILED
  594. def visit_multi(
  595. self, attrname, left_parent, left, right_parent, right, **kw
  596. ):
  597. lhc = isinstance(left, HasCacheKey)
  598. rhc = isinstance(right, HasCacheKey)
  599. if lhc and rhc:
  600. if left._gen_cache_key(
  601. self.anon_map[0], []
  602. ) != right._gen_cache_key(self.anon_map[1], []):
  603. return COMPARE_FAILED
  604. elif lhc != rhc:
  605. return COMPARE_FAILED
  606. else:
  607. return left == right
  608. def visit_anon_name(
  609. self, attrname, left_parent, left, right_parent, right, **kw
  610. ):
  611. return _resolve_name_for_compare(
  612. left_parent, left, self.anon_map[0], **kw
  613. ) == _resolve_name_for_compare(
  614. right_parent, right, self.anon_map[1], **kw
  615. )
  616. def visit_boolean(
  617. self, attrname, left_parent, left, right_parent, right, **kw
  618. ):
  619. return left == right
  620. def visit_operator(
  621. self, attrname, left_parent, left, right_parent, right, **kw
  622. ):
  623. return left == right
  624. def visit_type(
  625. self, attrname, left_parent, left, right_parent, right, **kw
  626. ):
  627. return left._compare_type_affinity(right)
  628. def visit_plain_dict(
  629. self, attrname, left_parent, left, right_parent, right, **kw
  630. ):
  631. return left == right
  632. def visit_dialect_options(
  633. self, attrname, left_parent, left, right_parent, right, **kw
  634. ):
  635. return left == right
  636. def visit_annotations_key(
  637. self, attrname, left_parent, left, right_parent, right, **kw
  638. ):
  639. if left and right:
  640. return (
  641. left_parent._annotations_cache_key
  642. == right_parent._annotations_cache_key
  643. )
  644. else:
  645. return left == right
  646. def visit_with_context_options(
  647. self, attrname, left_parent, left, right_parent, right, **kw
  648. ):
  649. return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
  650. (fn.__code__, c_key) for fn, c_key in right
  651. )
  652. def visit_plain_obj(
  653. self, attrname, left_parent, left, right_parent, right, **kw
  654. ):
  655. return left == right
  656. def visit_named_ddl_element(
  657. self, attrname, left_parent, left, right_parent, right, **kw
  658. ):
  659. if left is None:
  660. if right is not None:
  661. return COMPARE_FAILED
  662. return left.name == right.name
  663. def visit_prefix_sequence(
  664. self, attrname, left_parent, left, right_parent, right, **kw
  665. ):
  666. for (l_clause, l_str), (r_clause, r_str) in zip_longest(
  667. left, right, fillvalue=(None, None)
  668. ):
  669. if l_str != r_str:
  670. return COMPARE_FAILED
  671. else:
  672. self.stack.append((l_clause, r_clause))
  673. def visit_setup_join_tuple(
  674. self, attrname, left_parent, left, right_parent, right, **kw
  675. ):
  676. # TODO: look at attrname for "legacy_join" and use different structure
  677. for (
  678. (l_target, l_onclause, l_from, l_flags),
  679. (r_target, r_onclause, r_from, r_flags),
  680. ) in zip_longest(left, right, fillvalue=(None, None, None, None)):
  681. if l_flags != r_flags:
  682. return COMPARE_FAILED
  683. self.stack.append((l_target, r_target))
  684. self.stack.append((l_onclause, r_onclause))
  685. self.stack.append((l_from, r_from))
  686. def visit_memoized_select_entities(
  687. self, attrname, left_parent, left, right_parent, right, **kw
  688. ):
  689. return self.visit_clauseelement_tuple(
  690. attrname, left_parent, left, right_parent, right, **kw
  691. )
  692. def visit_table_hint_list(
  693. self, attrname, left_parent, left, right_parent, right, **kw
  694. ):
  695. left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
  696. right_keys = sorted(
  697. right, key=lambda elem: (elem[0].fullname, elem[1])
  698. )
  699. for (ltable, ldialect), (rtable, rdialect) in zip_longest(
  700. left_keys, right_keys, fillvalue=(None, None)
  701. ):
  702. if ldialect != rdialect:
  703. return COMPARE_FAILED
  704. elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
  705. return COMPARE_FAILED
  706. else:
  707. self.stack.append((ltable, rtable))
  708. def visit_statement_hint_list(
  709. self, attrname, left_parent, left, right_parent, right, **kw
  710. ):
  711. return left == right
  712. def visit_unknown_structure(
  713. self, attrname, left_parent, left, right_parent, right, **kw
  714. ):
  715. raise NotImplementedError()
  716. def visit_dml_ordered_values(
  717. self, attrname, left_parent, left, right_parent, right, **kw
  718. ):
  719. # sequence of tuple pairs
  720. for (lk, lv), (rk, rv) in zip_longest(
  721. left, right, fillvalue=(None, None)
  722. ):
  723. if not self._compare_dml_values_or_ce(lk, rk, **kw):
  724. return COMPARE_FAILED
  725. def _compare_dml_values_or_ce(self, lv, rv, **kw):
  726. lvce = hasattr(lv, "__clause_element__")
  727. rvce = hasattr(rv, "__clause_element__")
  728. if lvce != rvce:
  729. return False
  730. elif lvce and not self.compare_inner(lv, rv, **kw):
  731. return False
  732. elif not lvce and lv != rv:
  733. return False
  734. elif not self.compare_inner(lv, rv, **kw):
  735. return False
  736. return True
  737. def visit_dml_values(
  738. self, attrname, left_parent, left, right_parent, right, **kw
  739. ):
  740. if left is None or right is None or len(left) != len(right):
  741. return COMPARE_FAILED
  742. if isinstance(left, collections_abc.Sequence):
  743. for lv, rv in zip(left, right):
  744. if not self._compare_dml_values_or_ce(lv, rv, **kw):
  745. return COMPARE_FAILED
  746. elif isinstance(right, collections_abc.Sequence):
  747. return COMPARE_FAILED
  748. else:
  749. # dictionaries guaranteed to support insert ordering in
  750. # py37 so that we can compare the keys in order. without
  751. # this, we can't compare SQL expression keys because we don't
  752. # know which key is which
  753. for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
  754. if not self._compare_dml_values_or_ce(lk, rk, **kw):
  755. return COMPARE_FAILED
  756. if not self._compare_dml_values_or_ce(lv, rv, **kw):
  757. return COMPARE_FAILED
  758. def visit_dml_multi_values(
  759. self, attrname, left_parent, left, right_parent, right, **kw
  760. ):
  761. for lseq, rseq in zip_longest(left, right, fillvalue=None):
  762. if lseq is None or rseq is None:
  763. return COMPARE_FAILED
  764. for ld, rd in zip_longest(lseq, rseq, fillvalue=None):
  765. if (
  766. self.visit_dml_values(
  767. attrname, left_parent, ld, right_parent, rd, **kw
  768. )
  769. is COMPARE_FAILED
  770. ):
  771. return COMPARE_FAILED
  772. def compare_expression_clauselist(self, left, right, **kw):
  773. if left.operator is right.operator:
  774. if operators.is_associative(left.operator):
  775. if self._compare_unordered_sequences(
  776. left.clauses, right.clauses, **kw
  777. ):
  778. return ["operator", "clauses"]
  779. else:
  780. return COMPARE_FAILED
  781. else:
  782. return ["operator"]
  783. else:
  784. return COMPARE_FAILED
  785. def compare_clauselist(self, left, right, **kw):
  786. return self.compare_expression_clauselist(left, right, **kw)
  787. def compare_binary(self, left, right, **kw):
  788. if left.operator == right.operator:
  789. if operators.is_commutative(left.operator):
  790. if (
  791. self.compare_inner(left.left, right.left, **kw)
  792. and self.compare_inner(left.right, right.right, **kw)
  793. ) or (
  794. self.compare_inner(left.left, right.right, **kw)
  795. and self.compare_inner(left.right, right.left, **kw)
  796. ):
  797. return ["operator", "negate", "left", "right"]
  798. else:
  799. return COMPARE_FAILED
  800. else:
  801. return ["operator", "negate"]
  802. else:
  803. return COMPARE_FAILED
  804. def compare_bindparam(self, left, right, **kw):
  805. compare_keys = kw.pop("compare_keys", True)
  806. compare_values = kw.pop("compare_values", True)
  807. if compare_values:
  808. omit = []
  809. else:
  810. # this means, "skip these, we already compared"
  811. omit = ["callable", "value"]
  812. if not compare_keys:
  813. omit.append("key")
  814. return omit
  815. class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
  816. def compare_column_element(
  817. self, left, right, use_proxies=True, equivalents=(), **kw
  818. ):
  819. """Compare ColumnElements using proxies and equivalent collections.
  820. This is a comparison strategy specific to the ORM.
  821. """
  822. to_compare = (right,)
  823. if equivalents and right in equivalents:
  824. to_compare = equivalents[right].union(to_compare)
  825. for oth in to_compare:
  826. if use_proxies and left.shares_lineage(oth):
  827. return SKIP_TRAVERSE
  828. elif hash(left) == hash(right):
  829. return SKIP_TRAVERSE
  830. else:
  831. return COMPARE_FAILED
  832. def compare_column(self, left, right, **kw):
  833. return self.compare_column_element(left, right, **kw)
  834. def compare_label(self, left, right, **kw):
  835. return self.compare_column_element(left, right, **kw)
  836. def compare_table(self, left, right, **kw):
  837. # tables compare on identity, since it's not really feasible to
  838. # compare them column by column with the above rules
  839. return SKIP_TRAVERSE if left is right else COMPARE_FAILED