| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024 |
- # sql/traversals.py
- # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- # mypy: allow-untyped-defs, allow-untyped-calls
- from __future__ import annotations
- from collections import deque
- import collections.abc as collections_abc
- import itertools
- from itertools import zip_longest
- import operator
- import typing
- from typing import Any
- from typing import Callable
- from typing import Deque
- from typing import Dict
- from typing import Iterable
- from typing import Optional
- from typing import Set
- from typing import Tuple
- from typing import Type
- from . import operators
- from .cache_key import HasCacheKey
- from .visitors import _TraverseInternalsType
- from .visitors import anon_map
- from .visitors import ExternallyTraversible
- from .visitors import HasTraversalDispatch
- from .visitors import HasTraverseInternals
- from .. import util
- from ..util import langhelpers
- from ..util.typing import Self
- SKIP_TRAVERSE = util.symbol("skip_traverse")
- COMPARE_FAILED = False
- COMPARE_SUCCEEDED = True
- def compare(obj1: Any, obj2: Any, **kw: Any) -> bool:
- strategy: TraversalComparatorStrategy
- if kw.get("use_proxies", False):
- strategy = ColIdentityComparatorStrategy()
- else:
- strategy = TraversalComparatorStrategy()
- return strategy.compare(obj1, obj2, **kw)
- def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
- for cls in util.walk_subclasses(target_hierarchy):
- if hasattr(cls, "_generate_cache_attrs") and hasattr(
- cls, "_traverse_internals"
- ):
- cls._generate_cache_attrs()
- _copy_internals.generate_dispatch(
- cls,
- cls._traverse_internals,
- "_generated_copy_internals_traversal",
- )
- _get_children.generate_dispatch(
- cls,
- cls._traverse_internals,
- "_generated_get_children_traversal",
- )
- class HasShallowCopy(HasTraverseInternals):
- """attribute-wide operations that are useful for classes that use
- __slots__ and therefore can't operate on their attributes in a dictionary.
- """
- __slots__ = ()
- if typing.TYPE_CHECKING:
- def _generated_shallow_copy_traversal(self, other: Self) -> None: ...
- def _generated_shallow_from_dict_traversal(
- self, d: Dict[str, Any]
- ) -> None: ...
- def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: ...
- @classmethod
- def _generate_shallow_copy(
- cls,
- internal_dispatch: _TraverseInternalsType,
- method_name: str,
- ) -> Callable[[Self, Self], None]:
- code = "\n".join(
- f" other.{attrname} = self.{attrname}"
- for attrname, _ in internal_dispatch
- )
- meth_text = f"def {method_name}(self, other):\n{code}\n"
- return langhelpers._exec_code_in_env(meth_text, {}, method_name)
- @classmethod
- def _generate_shallow_to_dict(
- cls,
- internal_dispatch: _TraverseInternalsType,
- method_name: str,
- ) -> Callable[[Self], Dict[str, Any]]:
- code = ",\n".join(
- f" '{attrname}': self.{attrname}"
- for attrname, _ in internal_dispatch
- )
- meth_text = f"def {method_name}(self):\n return {{{code}}}\n"
- return langhelpers._exec_code_in_env(meth_text, {}, method_name)
- @classmethod
- def _generate_shallow_from_dict(
- cls,
- internal_dispatch: _TraverseInternalsType,
- method_name: str,
- ) -> Callable[[Self, Dict[str, Any]], None]:
- code = "\n".join(
- f" self.{attrname} = d['{attrname}']"
- for attrname, _ in internal_dispatch
- )
- meth_text = f"def {method_name}(self, d):\n{code}\n"
- return langhelpers._exec_code_in_env(meth_text, {}, method_name)
- def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
- cls = self.__class__
- shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
- try:
- shallow_from_dict = cls.__dict__[
- "_generated_shallow_from_dict_traversal"
- ]
- except KeyError:
- shallow_from_dict = self._generate_shallow_from_dict(
- cls._traverse_internals,
- "_generated_shallow_from_dict_traversal",
- )
- cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa: E501
- shallow_from_dict(self, d)
- def _shallow_to_dict(self) -> Dict[str, Any]:
- cls = self.__class__
- shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
- try:
- shallow_to_dict = cls.__dict__[
- "_generated_shallow_to_dict_traversal"
- ]
- except KeyError:
- shallow_to_dict = self._generate_shallow_to_dict(
- cls._traverse_internals, "_generated_shallow_to_dict_traversal"
- )
- cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa: E501
- return shallow_to_dict(self)
- def _shallow_copy_to(self, other: Self) -> None:
- cls = self.__class__
- shallow_copy: Callable[[Self, Self], None]
- try:
- shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
- except KeyError:
- shallow_copy = self._generate_shallow_copy(
- cls._traverse_internals, "_generated_shallow_copy_traversal"
- )
- cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501
- shallow_copy(self, other)
- def _clone(self, **kw: Any) -> Self:
- """Create a shallow copy"""
- c = self.__class__.__new__(self.__class__)
- self._shallow_copy_to(c)
- return c
- class GenerativeOnTraversal(HasShallowCopy):
- """Supplies Generative behavior but making use of traversals to shallow
- copy.
- .. seealso::
- :class:`sqlalchemy.sql.base.Generative`
- """
- __slots__ = ()
- def _generate(self) -> Self:
- cls = self.__class__
- s = cls.__new__(cls)
- self._shallow_copy_to(s)
- return s
- def _clone(element, **kw):
- return element._clone()
- class HasCopyInternals(HasTraverseInternals):
- __slots__ = ()
- def _clone(self, **kw):
- raise NotImplementedError()
- def _copy_internals(
- self, *, omit_attrs: Iterable[str] = (), **kw: Any
- ) -> None:
- """Reassign internal elements to be clones of themselves.
- Called during a copy-and-traverse operation on newly
- shallow-copied elements to create a deep copy.
- The given clone function should be used, which may be applying
- additional transformations to the element (i.e. replacement
- traversal, cloned traversal, annotations).
- """
- try:
- traverse_internals = self._traverse_internals
- except AttributeError:
- # user-defined classes may not have a _traverse_internals
- return
- for attrname, obj, meth in _copy_internals.run_generated_dispatch(
- self, traverse_internals, "_generated_copy_internals_traversal"
- ):
- if attrname in omit_attrs:
- continue
- if obj is not None:
- result = meth(attrname, self, obj, **kw)
- if result is not None:
- setattr(self, attrname, result)
- class _CopyInternalsTraversal(HasTraversalDispatch):
- """Generate a _copy_internals internal traversal dispatch for classes
- with a _traverse_internals collection."""
- def visit_clauseelement(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- return clone(element, **kw)
- def visit_clauseelement_list(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- return [clone(clause, **kw) for clause in element]
- def visit_clauseelement_tuple(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- return tuple([clone(clause, **kw) for clause in element])
- def visit_executable_options(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- return tuple([clone(clause, **kw) for clause in element])
- def visit_clauseelement_unordered_set(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- return {clone(clause, **kw) for clause in element}
- def visit_clauseelement_tuples(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- return [
- tuple(clone(tup_elem, **kw) for tup_elem in elem)
- for elem in element
- ]
- def visit_string_clauseelement_dict(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- return {key: clone(value, **kw) for key, value in element.items()}
- def visit_setup_join_tuple(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- return tuple(
- (
- clone(target, **kw) if target is not None else None,
- clone(onclause, **kw) if onclause is not None else None,
- clone(from_, **kw) if from_ is not None else None,
- flags,
- )
- for (target, onclause, from_, flags) in element
- )
- def visit_memoized_select_entities(self, attrname, parent, element, **kw):
- return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
- def visit_dml_ordered_values(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- # sequence of 2-tuples
- return [
- (
- (
- clone(key, **kw)
- if hasattr(key, "__clause_element__")
- else key
- ),
- clone(value, **kw),
- )
- for key, value in element
- ]
- def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
- return {
- (
- clone(key, **kw) if hasattr(key, "__clause_element__") else key
- ): clone(value, **kw)
- for key, value in element.items()
- }
- def visit_dml_multi_values(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- # sequence of sequences, each sequence contains a list/dict/tuple
- def copy(elem):
- if isinstance(elem, (list, tuple)):
- return [
- (
- clone(value, **kw)
- if hasattr(value, "__clause_element__")
- else value
- )
- for value in elem
- ]
- elif isinstance(elem, dict):
- return {
- (
- clone(key, **kw)
- if hasattr(key, "__clause_element__")
- else key
- ): (
- clone(value, **kw)
- if hasattr(value, "__clause_element__")
- else value
- )
- for key, value in elem.items()
- }
- else:
- # TODO: use abc classes
- assert False
- return [
- [copy(sub_element) for sub_element in sequence]
- for sequence in element
- ]
- def visit_propagate_attrs(
- self, attrname, parent, element, clone=_clone, **kw
- ):
- return element
- _copy_internals = _CopyInternalsTraversal()
- def _flatten_clauseelement(element):
- while hasattr(element, "__clause_element__") and not getattr(
- element, "is_clause_element", False
- ):
- element = element.__clause_element__()
- return element
- class _GetChildrenTraversal(HasTraversalDispatch):
- """Generate a _children_traversal internal traversal dispatch for classes
- with a _traverse_internals collection."""
- def visit_has_cache_key(self, element, **kw):
- # the GetChildren traversal refers explicitly to ClauseElement
- # structures. Within these, a plain HasCacheKey is not a
- # ClauseElement, so don't include these.
- return ()
- def visit_clauseelement(self, element, **kw):
- return (element,)
- def visit_clauseelement_list(self, element, **kw):
- return element
- def visit_clauseelement_tuple(self, element, **kw):
- return element
- def visit_clauseelement_tuples(self, element, **kw):
- return itertools.chain.from_iterable(element)
- def visit_fromclause_canonical_column_collection(self, element, **kw):
- return ()
- def visit_string_clauseelement_dict(self, element, **kw):
- return element.values()
- def visit_fromclause_ordered_set(self, element, **kw):
- return element
- def visit_clauseelement_unordered_set(self, element, **kw):
- return element
- def visit_setup_join_tuple(self, element, **kw):
- for target, onclause, from_, flags in element:
- if from_ is not None:
- yield from_
- if not isinstance(target, str):
- yield _flatten_clauseelement(target)
- if onclause is not None and not isinstance(onclause, str):
- yield _flatten_clauseelement(onclause)
- def visit_memoized_select_entities(self, element, **kw):
- return self.visit_clauseelement_tuple(element, **kw)
- def visit_dml_ordered_values(self, element, **kw):
- for k, v in element:
- if hasattr(k, "__clause_element__"):
- yield k
- yield v
- def visit_dml_values(self, element, **kw):
- expr_values = {k for k in element if hasattr(k, "__clause_element__")}
- str_values = expr_values.symmetric_difference(element)
- for k in sorted(str_values):
- yield element[k]
- for k in expr_values:
- yield k
- yield element[k]
- def visit_dml_multi_values(self, element, **kw):
- return ()
- def visit_propagate_attrs(self, element, **kw):
- return ()
- _get_children = _GetChildrenTraversal()
- @util.preload_module("sqlalchemy.sql.elements")
- def _resolve_name_for_compare(element, name, anon_map, **kw):
- if isinstance(name, util.preloaded.sql_elements._anonymous_label):
- name = name.apply_map(anon_map)
- return name
- class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
- __slots__ = "stack", "cache", "anon_map"
- def __init__(self):
- self.stack: Deque[
- Tuple[
- Optional[ExternallyTraversible],
- Optional[ExternallyTraversible],
- ]
- ] = deque()
- self.cache = set()
- def _memoized_attr_anon_map(self):
- return (anon_map(), anon_map())
- def compare(
- self,
- obj1: ExternallyTraversible,
- obj2: ExternallyTraversible,
- **kw: Any,
- ) -> bool:
- stack = self.stack
- cache = self.cache
- compare_annotations = kw.get("compare_annotations", False)
- stack.append((obj1, obj2))
- while stack:
- left, right = stack.popleft()
- if left is right:
- continue
- elif left is None or right is None:
- # we know they are different so no match
- return False
- elif (left, right) in cache:
- continue
- cache.add((left, right))
- visit_name = left.__visit_name__
- if visit_name != right.__visit_name__:
- return False
- meth = getattr(self, "compare_%s" % visit_name, None)
- if meth:
- attributes_compared = meth(left, right, **kw)
- if attributes_compared is COMPARE_FAILED:
- return False
- elif attributes_compared is SKIP_TRAVERSE:
- continue
- # attributes_compared is returned as a list of attribute
- # names that were "handled" by the comparison method above.
- # remaining attribute names in the _traverse_internals
- # will be compared.
- else:
- attributes_compared = ()
- for (
- (left_attrname, left_visit_sym),
- (right_attrname, right_visit_sym),
- ) in zip_longest(
- left._traverse_internals,
- right._traverse_internals,
- fillvalue=(None, None),
- ):
- if not compare_annotations and (
- (left_attrname == "_annotations")
- or (right_attrname == "_annotations")
- ):
- continue
- if (
- left_attrname != right_attrname
- or left_visit_sym is not right_visit_sym
- ):
- return False
- elif left_attrname in attributes_compared:
- continue
- assert left_visit_sym is not None
- assert left_attrname is not None
- assert right_attrname is not None
- dispatch = self.dispatch(left_visit_sym)
- assert dispatch is not None, (
- f"{self.__class__} has no dispatch for "
- f"'{self._dispatch_lookup[left_visit_sym]}'"
- )
- left_child = operator.attrgetter(left_attrname)(left)
- right_child = operator.attrgetter(right_attrname)(right)
- if left_child is None:
- if right_child is not None:
- return False
- else:
- continue
- elif right_child is None:
- return False
- comparison = dispatch(
- left_attrname, left, left_child, right, right_child, **kw
- )
- if comparison is COMPARE_FAILED:
- return False
- return True
- def compare_inner(self, obj1, obj2, **kw):
- comparator = self.__class__()
- return comparator.compare(obj1, obj2, **kw)
- def visit_has_cache_key(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
- self.anon_map[1], []
- ):
- return COMPARE_FAILED
- def visit_propagate_attrs(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return self.compare_inner(
- left.get("plugin_subject", None), right.get("plugin_subject", None)
- )
- def visit_has_cache_key_list(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for l, r in zip_longest(left, right, fillvalue=None):
- if l is None:
- if r is not None:
- return COMPARE_FAILED
- else:
- continue
- elif r is None:
- return COMPARE_FAILED
- if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
- self.anon_map[1], []
- ):
- return COMPARE_FAILED
- def visit_executable_options(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for l, r in zip_longest(left, right, fillvalue=None):
- if l is None:
- if r is not None:
- return COMPARE_FAILED
- else:
- continue
- elif r is None:
- return COMPARE_FAILED
- if (
- l._gen_cache_key(self.anon_map[0], [])
- if l._is_has_cache_key
- else l
- ) != (
- r._gen_cache_key(self.anon_map[1], [])
- if r._is_has_cache_key
- else r
- ):
- return COMPARE_FAILED
- def visit_clauseelement(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- self.stack.append((left, right))
- def visit_fromclause_canonical_column_collection(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for lcol, rcol in zip_longest(left, right, fillvalue=None):
- self.stack.append((lcol, rcol))
- def visit_fromclause_derived_column_collection(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- pass
- def visit_string_clauseelement_dict(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for lstr, rstr in zip_longest(
- sorted(left), sorted(right), fillvalue=None
- ):
- if lstr != rstr:
- return COMPARE_FAILED
- self.stack.append((left[lstr], right[rstr]))
- def visit_clauseelement_tuples(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for ltup, rtup in zip_longest(left, right, fillvalue=None):
- if ltup is None or rtup is None:
- return COMPARE_FAILED
- for l, r in zip_longest(ltup, rtup, fillvalue=None):
- self.stack.append((l, r))
- def visit_clauseelement_list(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for l, r in zip_longest(left, right, fillvalue=None):
- self.stack.append((l, r))
- def visit_clauseelement_tuple(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for l, r in zip_longest(left, right, fillvalue=None):
- self.stack.append((l, r))
- def _compare_unordered_sequences(self, seq1, seq2, **kw):
- if seq1 is None:
- return seq2 is None
- completed: Set[object] = set()
- for clause in seq1:
- for other_clause in set(seq2).difference(completed):
- if self.compare_inner(clause, other_clause, **kw):
- completed.add(other_clause)
- break
- return len(completed) == len(seq1) == len(seq2)
- def visit_clauseelement_unordered_set(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return self._compare_unordered_sequences(left, right, **kw)
- def visit_fromclause_ordered_set(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for l, r in zip_longest(left, right, fillvalue=None):
- self.stack.append((l, r))
- def visit_string(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return left == right
- def visit_string_list(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return left == right
- def visit_string_multi_dict(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for lk, rk in zip_longest(
- sorted(left.keys()), sorted(right.keys()), fillvalue=(None, None)
- ):
- if lk != rk:
- return COMPARE_FAILED
- lv, rv = left[lk], right[rk]
- lhc = isinstance(left, HasCacheKey)
- rhc = isinstance(right, HasCacheKey)
- if lhc and rhc:
- if lv._gen_cache_key(
- self.anon_map[0], []
- ) != rv._gen_cache_key(self.anon_map[1], []):
- return COMPARE_FAILED
- elif lhc != rhc:
- return COMPARE_FAILED
- elif lv != rv:
- return COMPARE_FAILED
- def visit_multi(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- lhc = isinstance(left, HasCacheKey)
- rhc = isinstance(right, HasCacheKey)
- if lhc and rhc:
- if left._gen_cache_key(
- self.anon_map[0], []
- ) != right._gen_cache_key(self.anon_map[1], []):
- return COMPARE_FAILED
- elif lhc != rhc:
- return COMPARE_FAILED
- else:
- return left == right
- def visit_anon_name(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return _resolve_name_for_compare(
- left_parent, left, self.anon_map[0], **kw
- ) == _resolve_name_for_compare(
- right_parent, right, self.anon_map[1], **kw
- )
- def visit_boolean(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return left == right
- def visit_operator(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return left == right
- def visit_type(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return left._compare_type_affinity(right)
- def visit_plain_dict(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return left == right
- def visit_dialect_options(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return left == right
- def visit_annotations_key(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- if left and right:
- return (
- left_parent._annotations_cache_key
- == right_parent._annotations_cache_key
- )
- else:
- return left == right
- def visit_with_context_options(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
- (fn.__code__, c_key) for fn, c_key in right
- )
- def visit_plain_obj(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return left == right
- def visit_named_ddl_element(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- if left is None:
- if right is not None:
- return COMPARE_FAILED
- return left.name == right.name
- def visit_prefix_sequence(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for (l_clause, l_str), (r_clause, r_str) in zip_longest(
- left, right, fillvalue=(None, None)
- ):
- if l_str != r_str:
- return COMPARE_FAILED
- else:
- self.stack.append((l_clause, r_clause))
- def visit_setup_join_tuple(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- # TODO: look at attrname for "legacy_join" and use different structure
- for (
- (l_target, l_onclause, l_from, l_flags),
- (r_target, r_onclause, r_from, r_flags),
- ) in zip_longest(left, right, fillvalue=(None, None, None, None)):
- if l_flags != r_flags:
- return COMPARE_FAILED
- self.stack.append((l_target, r_target))
- self.stack.append((l_onclause, r_onclause))
- self.stack.append((l_from, r_from))
- def visit_memoized_select_entities(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return self.visit_clauseelement_tuple(
- attrname, left_parent, left, right_parent, right, **kw
- )
- def visit_table_hint_list(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
- right_keys = sorted(
- right, key=lambda elem: (elem[0].fullname, elem[1])
- )
- for (ltable, ldialect), (rtable, rdialect) in zip_longest(
- left_keys, right_keys, fillvalue=(None, None)
- ):
- if ldialect != rdialect:
- return COMPARE_FAILED
- elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
- return COMPARE_FAILED
- else:
- self.stack.append((ltable, rtable))
- def visit_statement_hint_list(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- return left == right
- def visit_unknown_structure(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- raise NotImplementedError()
- def visit_dml_ordered_values(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- # sequence of tuple pairs
- for (lk, lv), (rk, rv) in zip_longest(
- left, right, fillvalue=(None, None)
- ):
- if not self._compare_dml_values_or_ce(lk, rk, **kw):
- return COMPARE_FAILED
- def _compare_dml_values_or_ce(self, lv, rv, **kw):
- lvce = hasattr(lv, "__clause_element__")
- rvce = hasattr(rv, "__clause_element__")
- if lvce != rvce:
- return False
- elif lvce and not self.compare_inner(lv, rv, **kw):
- return False
- elif not lvce and lv != rv:
- return False
- elif not self.compare_inner(lv, rv, **kw):
- return False
- return True
- def visit_dml_values(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- if left is None or right is None or len(left) != len(right):
- return COMPARE_FAILED
- if isinstance(left, collections_abc.Sequence):
- for lv, rv in zip(left, right):
- if not self._compare_dml_values_or_ce(lv, rv, **kw):
- return COMPARE_FAILED
- elif isinstance(right, collections_abc.Sequence):
- return COMPARE_FAILED
- else:
- # dictionaries guaranteed to support insert ordering in
- # py37 so that we can compare the keys in order. without
- # this, we can't compare SQL expression keys because we don't
- # know which key is which
- for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
- if not self._compare_dml_values_or_ce(lk, rk, **kw):
- return COMPARE_FAILED
- if not self._compare_dml_values_or_ce(lv, rv, **kw):
- return COMPARE_FAILED
- def visit_dml_multi_values(
- self, attrname, left_parent, left, right_parent, right, **kw
- ):
- for lseq, rseq in zip_longest(left, right, fillvalue=None):
- if lseq is None or rseq is None:
- return COMPARE_FAILED
- for ld, rd in zip_longest(lseq, rseq, fillvalue=None):
- if (
- self.visit_dml_values(
- attrname, left_parent, ld, right_parent, rd, **kw
- )
- is COMPARE_FAILED
- ):
- return COMPARE_FAILED
- def compare_expression_clauselist(self, left, right, **kw):
- if left.operator is right.operator:
- if operators.is_associative(left.operator):
- if self._compare_unordered_sequences(
- left.clauses, right.clauses, **kw
- ):
- return ["operator", "clauses"]
- else:
- return COMPARE_FAILED
- else:
- return ["operator"]
- else:
- return COMPARE_FAILED
- def compare_clauselist(self, left, right, **kw):
- return self.compare_expression_clauselist(left, right, **kw)
- def compare_binary(self, left, right, **kw):
- if left.operator == right.operator:
- if operators.is_commutative(left.operator):
- if (
- self.compare_inner(left.left, right.left, **kw)
- and self.compare_inner(left.right, right.right, **kw)
- ) or (
- self.compare_inner(left.left, right.right, **kw)
- and self.compare_inner(left.right, right.left, **kw)
- ):
- return ["operator", "negate", "left", "right"]
- else:
- return COMPARE_FAILED
- else:
- return ["operator", "negate"]
- else:
- return COMPARE_FAILED
- def compare_bindparam(self, left, right, **kw):
- compare_keys = kw.pop("compare_keys", True)
- compare_values = kw.pop("compare_values", True)
- if compare_values:
- omit = []
- else:
- # this means, "skip these, we already compared"
- omit = ["callable", "value"]
- if not compare_keys:
- omit.append("key")
- return omit
- class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
- def compare_column_element(
- self, left, right, use_proxies=True, equivalents=(), **kw
- ):
- """Compare ColumnElements using proxies and equivalent collections.
- This is a comparison strategy specific to the ORM.
- """
- to_compare = (right,)
- if equivalents and right in equivalents:
- to_compare = equivalents[right].union(to_compare)
- for oth in to_compare:
- if use_proxies and left.shares_lineage(oth):
- return SKIP_TRAVERSE
- elif hash(left) == hash(right):
- return SKIP_TRAVERSE
- else:
- return COMPARE_FAILED
- def compare_column(self, left, right, **kw):
- return self.compare_column_element(left, right, **kw)
- def compare_label(self, left, right, **kw):
- return self.compare_column_element(left, right, **kw)
- def compare_table(self, left, right, **kw):
- # tables compare on identity, since it's not really feasible to
- # compare them column by column with the above rules
- return SKIP_TRAVERSE if left is right else COMPARE_FAILED
|