evaluator.py 12 KB


  1. # orm/evaluator.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. """Evaluation functions used **INTERNALLY** by ORM DML use cases.
  9. This module is **private, for internal use by SQLAlchemy**.
  10. .. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to
  11. ``_EvaluatorCompiler``.
  12. """
  13. from __future__ import annotations
  14. from typing import Type
  15. from . import exc as orm_exc
  16. from .base import LoaderCallableStatus
  17. from .base import PassiveFlag
  18. from .. import exc
  19. from .. import inspect
  20. from ..sql import and_
  21. from ..sql import operators
  22. from ..sql.sqltypes import Concatenable
  23. from ..sql.sqltypes import Integer
  24. from ..sql.sqltypes import Numeric
  25. from ..util import warn_deprecated
  26. class UnevaluatableError(exc.InvalidRequestError):
  27. pass
  28. class _NoObject(operators.ColumnOperators):
  29. def operate(self, *arg, **kw):
  30. return None
  31. def reverse_operate(self, *arg, **kw):
  32. return None
  33. class _ExpiredObject(operators.ColumnOperators):
  34. def operate(self, *arg, **kw):
  35. return self
  36. def reverse_operate(self, *arg, **kw):
  37. return self
  38. _NO_OBJECT = _NoObject()
  39. _EXPIRED_OBJECT = _ExpiredObject()
  40. class _EvaluatorCompiler:
  41. def __init__(self, target_cls=None):
  42. self.target_cls = target_cls
  43. def process(self, clause, *clauses):
  44. if clauses:
  45. clause = and_(clause, *clauses)
  46. meth = getattr(self, f"visit_{clause.__visit_name__}", None)
  47. if not meth:
  48. raise UnevaluatableError(
  49. f"Cannot evaluate {type(clause).__name__}"
  50. )
  51. return meth(clause)
  52. def visit_grouping(self, clause):
  53. return self.process(clause.element)
  54. def visit_null(self, clause):
  55. return lambda obj: None
  56. def visit_false(self, clause):
  57. return lambda obj: False
  58. def visit_true(self, clause):
  59. return lambda obj: True
  60. def visit_column(self, clause):
  61. try:
  62. parentmapper = clause._annotations["parentmapper"]
  63. except KeyError as ke:
  64. raise UnevaluatableError(
  65. f"Cannot evaluate column: {clause}"
  66. ) from ke
  67. if self.target_cls and not issubclass(
  68. self.target_cls, parentmapper.class_
  69. ):
  70. raise UnevaluatableError(
  71. "Can't evaluate criteria against "
  72. f"alternate class {parentmapper.class_}"
  73. )
  74. parentmapper._check_configure()
  75. # we'd like to use "proxy_key" annotation to get the "key", however
  76. # in relationship primaryjoin cases proxy_key is sometimes deannotated
  77. # and sometimes apparently not present in the first place (?).
  78. # While I can stop it from being deannotated (though need to see if
  79. # this breaks other things), not sure right now about cases where it's
  80. # not there in the first place. can fix at some later point.
  81. # key = clause._annotations["proxy_key"]
  82. # for now, use the old way
  83. try:
  84. key = parentmapper._columntoproperty[clause].key
  85. except orm_exc.UnmappedColumnError as err:
  86. raise UnevaluatableError(
  87. f"Cannot evaluate expression: {err}"
  88. ) from err
  89. # note this used to fall back to a simple `getattr(obj, key)` evaluator
  90. # if impl was None; as of #8656, we ensure mappers are configured
  91. # so that impl is available
  92. impl = parentmapper.class_manager[key].impl
  93. def get_corresponding_attr(obj):
  94. if obj is None:
  95. return _NO_OBJECT
  96. state = inspect(obj)
  97. dict_ = state.dict
  98. value = impl.get(
  99. state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
  100. )
  101. if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
  102. return _EXPIRED_OBJECT
  103. return value
  104. return get_corresponding_attr
  105. def visit_tuple(self, clause):
  106. return self.visit_clauselist(clause)
  107. def visit_expression_clauselist(self, clause):
  108. return self.visit_clauselist(clause)
  109. def visit_clauselist(self, clause):
  110. evaluators = [self.process(clause) for clause in clause.clauses]
  111. dispatch = (
  112. f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op"
  113. )
  114. meth = getattr(self, dispatch, None)
  115. if meth:
  116. return meth(clause.operator, evaluators, clause)
  117. else:
  118. raise UnevaluatableError(
  119. f"Cannot evaluate clauselist with operator {clause.operator}"
  120. )
  121. def visit_binary(self, clause):
  122. eval_left = self.process(clause.left)
  123. eval_right = self.process(clause.right)
  124. dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
  125. meth = getattr(self, dispatch, None)
  126. if meth:
  127. return meth(clause.operator, eval_left, eval_right, clause)
  128. else:
  129. raise UnevaluatableError(
  130. f"Cannot evaluate {type(clause).__name__} with "
  131. f"operator {clause.operator}"
  132. )
  133. def visit_or_clauselist_op(self, operator, evaluators, clause):
  134. def evaluate(obj):
  135. has_null = False
  136. for sub_evaluate in evaluators:
  137. value = sub_evaluate(obj)
  138. if value is _EXPIRED_OBJECT:
  139. return _EXPIRED_OBJECT
  140. elif value:
  141. return True
  142. has_null = has_null or value is None
  143. if has_null:
  144. return None
  145. return False
  146. return evaluate
  147. def visit_and_clauselist_op(self, operator, evaluators, clause):
  148. def evaluate(obj):
  149. for sub_evaluate in evaluators:
  150. value = sub_evaluate(obj)
  151. if value is _EXPIRED_OBJECT:
  152. return _EXPIRED_OBJECT
  153. if not value:
  154. if value is None or value is _NO_OBJECT:
  155. return None
  156. return False
  157. return True
  158. return evaluate
  159. def visit_comma_op_clauselist_op(self, operator, evaluators, clause):
  160. def evaluate(obj):
  161. values = []
  162. for sub_evaluate in evaluators:
  163. value = sub_evaluate(obj)
  164. if value is _EXPIRED_OBJECT:
  165. return _EXPIRED_OBJECT
  166. elif value is None or value is _NO_OBJECT:
  167. return None
  168. values.append(value)
  169. return tuple(values)
  170. return evaluate
  171. def visit_custom_op_binary_op(
  172. self, operator, eval_left, eval_right, clause
  173. ):
  174. if operator.python_impl:
  175. return self._straight_evaluate(
  176. operator, eval_left, eval_right, clause
  177. )
  178. else:
  179. raise UnevaluatableError(
  180. f"Custom operator {operator.opstring!r} can't be evaluated "
  181. "in Python unless it specifies a callable using "
  182. "`.python_impl`."
  183. )
  184. def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
  185. def evaluate(obj):
  186. left_val = eval_left(obj)
  187. right_val = eval_right(obj)
  188. if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
  189. return _EXPIRED_OBJECT
  190. return left_val == right_val
  191. return evaluate
  192. def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
  193. def evaluate(obj):
  194. left_val = eval_left(obj)
  195. right_val = eval_right(obj)
  196. if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
  197. return _EXPIRED_OBJECT
  198. return left_val != right_val
  199. return evaluate
  200. def _straight_evaluate(self, operator, eval_left, eval_right, clause):
  201. def evaluate(obj):
  202. left_val = eval_left(obj)
  203. right_val = eval_right(obj)
  204. if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
  205. return _EXPIRED_OBJECT
  206. elif left_val is None or right_val is None:
  207. return None
  208. return operator(eval_left(obj), eval_right(obj))
  209. return evaluate
  210. def _straight_evaluate_numeric_only(
  211. self, operator, eval_left, eval_right, clause
  212. ):
  213. if clause.left.type._type_affinity not in (
  214. Numeric,
  215. Integer,
  216. ) or clause.right.type._type_affinity not in (Numeric, Integer):
  217. raise UnevaluatableError(
  218. f'Cannot evaluate math operator "{operator.__name__}" for '
  219. f"datatypes {clause.left.type}, {clause.right.type}"
  220. )
  221. return self._straight_evaluate(operator, eval_left, eval_right, clause)
  222. visit_add_binary_op = _straight_evaluate_numeric_only
  223. visit_mul_binary_op = _straight_evaluate_numeric_only
  224. visit_sub_binary_op = _straight_evaluate_numeric_only
  225. visit_mod_binary_op = _straight_evaluate_numeric_only
  226. visit_truediv_binary_op = _straight_evaluate_numeric_only
  227. visit_lt_binary_op = _straight_evaluate
  228. visit_le_binary_op = _straight_evaluate
  229. visit_ne_binary_op = _straight_evaluate
  230. visit_gt_binary_op = _straight_evaluate
  231. visit_ge_binary_op = _straight_evaluate
  232. visit_eq_binary_op = _straight_evaluate
  233. def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause):
  234. return self._straight_evaluate(
  235. lambda a, b: a in b if a is not _NO_OBJECT else None,
  236. eval_left,
  237. eval_right,
  238. clause,
  239. )
  240. def visit_not_in_op_binary_op(
  241. self, operator, eval_left, eval_right, clause
  242. ):
  243. return self._straight_evaluate(
  244. lambda a, b: a not in b if a is not _NO_OBJECT else None,
  245. eval_left,
  246. eval_right,
  247. clause,
  248. )
  249. def visit_concat_op_binary_op(
  250. self, operator, eval_left, eval_right, clause
  251. ):
  252. if not issubclass(
  253. clause.left.type._type_affinity, Concatenable
  254. ) or not issubclass(clause.right.type._type_affinity, Concatenable):
  255. raise UnevaluatableError(
  256. f"Cannot evaluate concatenate operator "
  257. f'"{operator.__name__}" for '
  258. f"datatypes {clause.left.type}, {clause.right.type}"
  259. )
  260. return self._straight_evaluate(
  261. lambda a, b: a + b, eval_left, eval_right, clause
  262. )
  263. def visit_startswith_op_binary_op(
  264. self, operator, eval_left, eval_right, clause
  265. ):
  266. return self._straight_evaluate(
  267. lambda a, b: a.startswith(b), eval_left, eval_right, clause
  268. )
  269. def visit_endswith_op_binary_op(
  270. self, operator, eval_left, eval_right, clause
  271. ):
  272. return self._straight_evaluate(
  273. lambda a, b: a.endswith(b), eval_left, eval_right, clause
  274. )
  275. def visit_unary(self, clause):
  276. eval_inner = self.process(clause.element)
  277. if clause.operator is operators.inv:
  278. def evaluate(obj):
  279. value = eval_inner(obj)
  280. if value is _EXPIRED_OBJECT:
  281. return _EXPIRED_OBJECT
  282. elif value is None:
  283. return None
  284. return not value
  285. return evaluate
  286. raise UnevaluatableError(
  287. f"Cannot evaluate {type(clause).__name__} "
  288. f"with operator {clause.operator}"
  289. )
  290. def visit_bindparam(self, clause):
  291. if clause.callable:
  292. val = clause.callable()
  293. else:
  294. val = clause.value
  295. return lambda obj: val
  296. def __getattr__(name: str) -> Type[_EvaluatorCompiler]:
  297. if name == "EvaluatorCompiler":
  298. warn_deprecated(
  299. "Direct use of 'EvaluatorCompiler' is not supported, and this "
  300. "name will be removed in a future release. "
  301. "'_EvaluatorCompiler' is for internal use only",
  302. "2.0",
  303. )
  304. return _EvaluatorCompiler
  305. else:
  306. raise AttributeError(f"module {__name__!r} has no attribute {name!r}")