| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379 |
- # orm/evaluator.py
- # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- # mypy: ignore-errors
- """Evaluation functions used **INTERNALLY** by ORM DML use cases.
- This module is **private, for internal use by SQLAlchemy**.
- .. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to
- ``_EvaluatorCompiler``.
- """
- from __future__ import annotations
- from typing import Type
- from . import exc as orm_exc
- from .base import LoaderCallableStatus
- from .base import PassiveFlag
- from .. import exc
- from .. import inspect
- from ..sql import and_
- from ..sql import operators
- from ..sql.sqltypes import Concatenable
- from ..sql.sqltypes import Integer
- from ..sql.sqltypes import Numeric
- from ..util import warn_deprecated
- class UnevaluatableError(exc.InvalidRequestError):
- pass
- class _NoObject(operators.ColumnOperators):
- def operate(self, *arg, **kw):
- return None
- def reverse_operate(self, *arg, **kw):
- return None
- class _ExpiredObject(operators.ColumnOperators):
- def operate(self, *arg, **kw):
- return self
- def reverse_operate(self, *arg, **kw):
- return self
- _NO_OBJECT = _NoObject()
- _EXPIRED_OBJECT = _ExpiredObject()
- class _EvaluatorCompiler:
- def __init__(self, target_cls=None):
- self.target_cls = target_cls
- def process(self, clause, *clauses):
- if clauses:
- clause = and_(clause, *clauses)
- meth = getattr(self, f"visit_{clause.__visit_name__}", None)
- if not meth:
- raise UnevaluatableError(
- f"Cannot evaluate {type(clause).__name__}"
- )
- return meth(clause)
- def visit_grouping(self, clause):
- return self.process(clause.element)
- def visit_null(self, clause):
- return lambda obj: None
- def visit_false(self, clause):
- return lambda obj: False
- def visit_true(self, clause):
- return lambda obj: True
- def visit_column(self, clause):
- try:
- parentmapper = clause._annotations["parentmapper"]
- except KeyError as ke:
- raise UnevaluatableError(
- f"Cannot evaluate column: {clause}"
- ) from ke
- if self.target_cls and not issubclass(
- self.target_cls, parentmapper.class_
- ):
- raise UnevaluatableError(
- "Can't evaluate criteria against "
- f"alternate class {parentmapper.class_}"
- )
- parentmapper._check_configure()
- # we'd like to use "proxy_key" annotation to get the "key", however
- # in relationship primaryjoin cases proxy_key is sometimes deannotated
- # and sometimes apparently not present in the first place (?).
- # While I can stop it from being deannotated (though need to see if
- # this breaks other things), not sure right now about cases where it's
- # not there in the first place. can fix at some later point.
- # key = clause._annotations["proxy_key"]
- # for now, use the old way
- try:
- key = parentmapper._columntoproperty[clause].key
- except orm_exc.UnmappedColumnError as err:
- raise UnevaluatableError(
- f"Cannot evaluate expression: {err}"
- ) from err
- # note this used to fall back to a simple `getattr(obj, key)` evaluator
- # if impl was None; as of #8656, we ensure mappers are configured
- # so that impl is available
- impl = parentmapper.class_manager[key].impl
- def get_corresponding_attr(obj):
- if obj is None:
- return _NO_OBJECT
- state = inspect(obj)
- dict_ = state.dict
- value = impl.get(
- state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
- )
- if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
- return _EXPIRED_OBJECT
- return value
- return get_corresponding_attr
- def visit_tuple(self, clause):
- return self.visit_clauselist(clause)
- def visit_expression_clauselist(self, clause):
- return self.visit_clauselist(clause)
- def visit_clauselist(self, clause):
- evaluators = [self.process(clause) for clause in clause.clauses]
- dispatch = (
- f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op"
- )
- meth = getattr(self, dispatch, None)
- if meth:
- return meth(clause.operator, evaluators, clause)
- else:
- raise UnevaluatableError(
- f"Cannot evaluate clauselist with operator {clause.operator}"
- )
- def visit_binary(self, clause):
- eval_left = self.process(clause.left)
- eval_right = self.process(clause.right)
- dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
- meth = getattr(self, dispatch, None)
- if meth:
- return meth(clause.operator, eval_left, eval_right, clause)
- else:
- raise UnevaluatableError(
- f"Cannot evaluate {type(clause).__name__} with "
- f"operator {clause.operator}"
- )
- def visit_or_clauselist_op(self, operator, evaluators, clause):
- def evaluate(obj):
- has_null = False
- for sub_evaluate in evaluators:
- value = sub_evaluate(obj)
- if value is _EXPIRED_OBJECT:
- return _EXPIRED_OBJECT
- elif value:
- return True
- has_null = has_null or value is None
- if has_null:
- return None
- return False
- return evaluate
- def visit_and_clauselist_op(self, operator, evaluators, clause):
- def evaluate(obj):
- for sub_evaluate in evaluators:
- value = sub_evaluate(obj)
- if value is _EXPIRED_OBJECT:
- return _EXPIRED_OBJECT
- if not value:
- if value is None or value is _NO_OBJECT:
- return None
- return False
- return True
- return evaluate
- def visit_comma_op_clauselist_op(self, operator, evaluators, clause):
- def evaluate(obj):
- values = []
- for sub_evaluate in evaluators:
- value = sub_evaluate(obj)
- if value is _EXPIRED_OBJECT:
- return _EXPIRED_OBJECT
- elif value is None or value is _NO_OBJECT:
- return None
- values.append(value)
- return tuple(values)
- return evaluate
- def visit_custom_op_binary_op(
- self, operator, eval_left, eval_right, clause
- ):
- if operator.python_impl:
- return self._straight_evaluate(
- operator, eval_left, eval_right, clause
- )
- else:
- raise UnevaluatableError(
- f"Custom operator {operator.opstring!r} can't be evaluated "
- "in Python unless it specifies a callable using "
- "`.python_impl`."
- )
- def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
- def evaluate(obj):
- left_val = eval_left(obj)
- right_val = eval_right(obj)
- if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
- return _EXPIRED_OBJECT
- return left_val == right_val
- return evaluate
- def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
- def evaluate(obj):
- left_val = eval_left(obj)
- right_val = eval_right(obj)
- if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
- return _EXPIRED_OBJECT
- return left_val != right_val
- return evaluate
- def _straight_evaluate(self, operator, eval_left, eval_right, clause):
- def evaluate(obj):
- left_val = eval_left(obj)
- right_val = eval_right(obj)
- if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
- return _EXPIRED_OBJECT
- elif left_val is None or right_val is None:
- return None
- return operator(eval_left(obj), eval_right(obj))
- return evaluate
- def _straight_evaluate_numeric_only(
- self, operator, eval_left, eval_right, clause
- ):
- if clause.left.type._type_affinity not in (
- Numeric,
- Integer,
- ) or clause.right.type._type_affinity not in (Numeric, Integer):
- raise UnevaluatableError(
- f'Cannot evaluate math operator "{operator.__name__}" for '
- f"datatypes {clause.left.type}, {clause.right.type}"
- )
- return self._straight_evaluate(operator, eval_left, eval_right, clause)
- visit_add_binary_op = _straight_evaluate_numeric_only
- visit_mul_binary_op = _straight_evaluate_numeric_only
- visit_sub_binary_op = _straight_evaluate_numeric_only
- visit_mod_binary_op = _straight_evaluate_numeric_only
- visit_truediv_binary_op = _straight_evaluate_numeric_only
- visit_lt_binary_op = _straight_evaluate
- visit_le_binary_op = _straight_evaluate
- visit_ne_binary_op = _straight_evaluate
- visit_gt_binary_op = _straight_evaluate
- visit_ge_binary_op = _straight_evaluate
- visit_eq_binary_op = _straight_evaluate
- def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause):
- return self._straight_evaluate(
- lambda a, b: a in b if a is not _NO_OBJECT else None,
- eval_left,
- eval_right,
- clause,
- )
- def visit_not_in_op_binary_op(
- self, operator, eval_left, eval_right, clause
- ):
- return self._straight_evaluate(
- lambda a, b: a not in b if a is not _NO_OBJECT else None,
- eval_left,
- eval_right,
- clause,
- )
- def visit_concat_op_binary_op(
- self, operator, eval_left, eval_right, clause
- ):
- if not issubclass(
- clause.left.type._type_affinity, Concatenable
- ) or not issubclass(clause.right.type._type_affinity, Concatenable):
- raise UnevaluatableError(
- f"Cannot evaluate concatenate operator "
- f'"{operator.__name__}" for '
- f"datatypes {clause.left.type}, {clause.right.type}"
- )
- return self._straight_evaluate(
- lambda a, b: a + b, eval_left, eval_right, clause
- )
- def visit_startswith_op_binary_op(
- self, operator, eval_left, eval_right, clause
- ):
- return self._straight_evaluate(
- lambda a, b: a.startswith(b), eval_left, eval_right, clause
- )
- def visit_endswith_op_binary_op(
- self, operator, eval_left, eval_right, clause
- ):
- return self._straight_evaluate(
- lambda a, b: a.endswith(b), eval_left, eval_right, clause
- )
- def visit_unary(self, clause):
- eval_inner = self.process(clause.element)
- if clause.operator is operators.inv:
- def evaluate(obj):
- value = eval_inner(obj)
- if value is _EXPIRED_OBJECT:
- return _EXPIRED_OBJECT
- elif value is None:
- return None
- return not value
- return evaluate
- raise UnevaluatableError(
- f"Cannot evaluate {type(clause).__name__} "
- f"with operator {clause.operator}"
- )
- def visit_bindparam(self, clause):
- if clause.callable:
- val = clause.callable()
- else:
- val = clause.value
- return lambda obj: val
- def __getattr__(name: str) -> Type[_EvaluatorCompiler]:
- if name == "EvaluatorCompiler":
- warn_deprecated(
- "Direct use of 'EvaluatorCompiler' is not supported, and this "
- "name will be removed in a future release. "
- "'_EvaluatorCompiler' is for internal use only",
- "2.0",
- )
- return _EvaluatorCompiler
- else:
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|