exclusions.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # testing/exclusions.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. import contextlib
  9. import operator
  10. import re
  11. import sys
  12. from . import config
  13. from .. import util
  14. from ..util import decorator
  15. from ..util.compat import inspect_getfullargspec
  16. def skip_if(predicate, reason=None):
  17. rule = compound()
  18. pred = _as_predicate(predicate, reason)
  19. rule.skips.add(pred)
  20. return rule
  21. def fails_if(predicate, reason=None):
  22. rule = compound()
  23. pred = _as_predicate(predicate, reason)
  24. rule.fails.add(pred)
  25. return rule
  26. class compound:
  27. def __init__(self):
  28. self.fails = set()
  29. self.skips = set()
  30. def __add__(self, other):
  31. return self.add(other)
  32. def as_skips(self):
  33. rule = compound()
  34. rule.skips.update(self.skips)
  35. rule.skips.update(self.fails)
  36. return rule
  37. def add(self, *others):
  38. copy = compound()
  39. copy.fails.update(self.fails)
  40. copy.skips.update(self.skips)
  41. for other in others:
  42. copy.fails.update(other.fails)
  43. copy.skips.update(other.skips)
  44. return copy
  45. def not_(self):
  46. copy = compound()
  47. copy.fails.update(NotPredicate(fail) for fail in self.fails)
  48. copy.skips.update(NotPredicate(skip) for skip in self.skips)
  49. return copy
  50. @property
  51. def enabled(self):
  52. return self.enabled_for_config(config._current)
  53. def enabled_for_config(self, config):
  54. for predicate in self.skips.union(self.fails):
  55. if predicate(config):
  56. return False
  57. else:
  58. return True
  59. def matching_config_reasons(self, config):
  60. return [
  61. predicate._as_string(config)
  62. for predicate in self.skips.union(self.fails)
  63. if predicate(config)
  64. ]
  65. def _extend(self, other):
  66. self.skips.update(other.skips)
  67. self.fails.update(other.fails)
  68. def __call__(self, fn):
  69. if hasattr(fn, "_sa_exclusion_extend"):
  70. fn._sa_exclusion_extend._extend(self)
  71. return fn
  72. @decorator
  73. def decorate(fn, *args, **kw):
  74. return self._do(config._current, fn, *args, **kw)
  75. decorated = decorate(fn)
  76. decorated._sa_exclusion_extend = self
  77. return decorated
  78. @contextlib.contextmanager
  79. def fail_if(self):
  80. all_fails = compound()
  81. all_fails.fails.update(self.skips.union(self.fails))
  82. try:
  83. yield
  84. except Exception as ex:
  85. all_fails._expect_failure(config._current, ex)
  86. else:
  87. all_fails._expect_success(config._current)
  88. def _do(self, cfg, fn, *args, **kw):
  89. for skip in self.skips:
  90. if skip(cfg):
  91. msg = "'%s' : %s" % (
  92. config.get_current_test_name(),
  93. skip._as_string(cfg),
  94. )
  95. config.skip_test(msg)
  96. try:
  97. return_value = fn(*args, **kw)
  98. except Exception as ex:
  99. self._expect_failure(cfg, ex, name=fn.__name__)
  100. else:
  101. self._expect_success(cfg, name=fn.__name__)
  102. return return_value
  103. def _expect_failure(self, config, ex, name="block"):
  104. for fail in self.fails:
  105. if fail(config):
  106. print(
  107. "%s failed as expected (%s): %s "
  108. % (name, fail._as_string(config), ex)
  109. )
  110. break
  111. else:
  112. raise ex.with_traceback(sys.exc_info()[2])
  113. def _expect_success(self, config, name="block"):
  114. if not self.fails:
  115. return
  116. for fail in self.fails:
  117. if fail(config):
  118. raise AssertionError(
  119. "Unexpected success for '%s' (%s)"
  120. % (
  121. name,
  122. " and ".join(
  123. fail._as_string(config) for fail in self.fails
  124. ),
  125. )
  126. )
  127. def only_if(predicate, reason=None):
  128. predicate = _as_predicate(predicate)
  129. return skip_if(NotPredicate(predicate), reason)
  130. def succeeds_if(predicate, reason=None):
  131. predicate = _as_predicate(predicate)
  132. return fails_if(NotPredicate(predicate), reason)
  133. class Predicate:
  134. @classmethod
  135. def as_predicate(cls, predicate, description=None):
  136. if isinstance(predicate, compound):
  137. return cls.as_predicate(predicate.enabled_for_config, description)
  138. elif isinstance(predicate, Predicate):
  139. if description and predicate.description is None:
  140. predicate.description = description
  141. return predicate
  142. elif isinstance(predicate, (list, set)):
  143. return OrPredicate(
  144. [cls.as_predicate(pred) for pred in predicate], description
  145. )
  146. elif isinstance(predicate, tuple):
  147. return SpecPredicate(*predicate)
  148. elif isinstance(predicate, str):
  149. tokens = re.match(
  150. r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
  151. )
  152. if not tokens:
  153. raise ValueError(
  154. "Couldn't locate DB name in predicate: %r" % predicate
  155. )
  156. db = tokens.group(1)
  157. op = tokens.group(2)
  158. spec = (
  159. tuple(int(d) for d in tokens.group(3).split("."))
  160. if tokens.group(3)
  161. else None
  162. )
  163. return SpecPredicate(db, op, spec, description=description)
  164. elif callable(predicate):
  165. return LambdaPredicate(predicate, description)
  166. else:
  167. assert False, "unknown predicate type: %s" % predicate
  168. def _format_description(self, config, negate=False):
  169. bool_ = self(config)
  170. if negate:
  171. bool_ = not negate
  172. return self.description % {
  173. "driver": (
  174. config.db.url.get_driver_name() if config else "<no driver>"
  175. ),
  176. "database": (
  177. config.db.url.get_backend_name() if config else "<no database>"
  178. ),
  179. "doesnt_support": "doesn't support" if bool_ else "does support",
  180. "does_support": "does support" if bool_ else "doesn't support",
  181. }
  182. def _as_string(self, config=None, negate=False):
  183. raise NotImplementedError()
  184. class BooleanPredicate(Predicate):
  185. def __init__(self, value, description=None):
  186. self.value = value
  187. self.description = description or "boolean %s" % value
  188. def __call__(self, config):
  189. return self.value
  190. def _as_string(self, config, negate=False):
  191. return self._format_description(config, negate=negate)
  192. class SpecPredicate(Predicate):
  193. def __init__(self, db, op=None, spec=None, description=None):
  194. self.db = db
  195. self.op = op
  196. self.spec = spec
  197. self.description = description
  198. _ops = {
  199. "<": operator.lt,
  200. ">": operator.gt,
  201. "==": operator.eq,
  202. "!=": operator.ne,
  203. "<=": operator.le,
  204. ">=": operator.ge,
  205. "in": operator.contains,
  206. "between": lambda val, pair: val >= pair[0] and val <= pair[1],
  207. }
  208. def __call__(self, config):
  209. if config is None:
  210. return False
  211. engine = config.db
  212. if "+" in self.db:
  213. dialect, driver = self.db.split("+")
  214. else:
  215. dialect, driver = self.db, None
  216. if dialect and engine.name != dialect:
  217. return False
  218. if driver is not None and engine.driver != driver:
  219. return False
  220. if self.op is not None:
  221. assert driver is None, "DBAPI version specs not supported yet"
  222. version = _server_version(engine)
  223. oper = (
  224. hasattr(self.op, "__call__") and self.op or self._ops[self.op]
  225. )
  226. return oper(version, self.spec)
  227. else:
  228. return True
  229. def _as_string(self, config, negate=False):
  230. if self.description is not None:
  231. return self._format_description(config)
  232. elif self.op is None:
  233. if negate:
  234. return "not %s" % self.db
  235. else:
  236. return "%s" % self.db
  237. else:
  238. if negate:
  239. return "not %s %s %s" % (self.db, self.op, self.spec)
  240. else:
  241. return "%s %s %s" % (self.db, self.op, self.spec)
  242. class LambdaPredicate(Predicate):
  243. def __init__(self, lambda_, description=None, args=None, kw=None):
  244. spec = inspect_getfullargspec(lambda_)
  245. if not spec[0]:
  246. self.lambda_ = lambda db: lambda_()
  247. else:
  248. self.lambda_ = lambda_
  249. self.args = args or ()
  250. self.kw = kw or {}
  251. if description:
  252. self.description = description
  253. elif lambda_.__doc__:
  254. self.description = lambda_.__doc__
  255. else:
  256. self.description = "custom function"
  257. def __call__(self, config):
  258. return self.lambda_(config)
  259. def _as_string(self, config, negate=False):
  260. return self._format_description(config)
  261. class NotPredicate(Predicate):
  262. def __init__(self, predicate, description=None):
  263. self.predicate = predicate
  264. self.description = description
  265. def __call__(self, config):
  266. return not self.predicate(config)
  267. def _as_string(self, config, negate=False):
  268. if self.description:
  269. return self._format_description(config, not negate)
  270. else:
  271. return self.predicate._as_string(config, not negate)
  272. class OrPredicate(Predicate):
  273. def __init__(self, predicates, description=None):
  274. self.predicates = predicates
  275. self.description = description
  276. def __call__(self, config):
  277. for pred in self.predicates:
  278. if pred(config):
  279. return True
  280. return False
  281. def _eval_str(self, config, negate=False):
  282. if negate:
  283. conjunction = " and "
  284. else:
  285. conjunction = " or "
  286. return conjunction.join(
  287. p._as_string(config, negate=negate) for p in self.predicates
  288. )
  289. def _negation_str(self, config):
  290. if self.description is not None:
  291. return "Not " + self._format_description(config)
  292. else:
  293. return self._eval_str(config, negate=True)
  294. def _as_string(self, config, negate=False):
  295. if negate:
  296. return self._negation_str(config)
  297. else:
  298. if self.description is not None:
  299. return self._format_description(config)
  300. else:
  301. return self._eval_str(config)
  302. _as_predicate = Predicate.as_predicate
  303. def _is_excluded(db, op, spec):
  304. return SpecPredicate(db, op, spec)(config._current)
  305. def _server_version(engine):
  306. """Return a server_version_info tuple."""
  307. # force metadata to be retrieved
  308. conn = engine.connect()
  309. version = getattr(engine.dialect, "server_version_info", None)
  310. if version is None:
  311. version = ()
  312. conn.close()
  313. return version
  314. def db_spec(*dbs):
  315. return OrPredicate([Predicate.as_predicate(db) for db in dbs])
  316. def open(): # noqa
  317. return skip_if(BooleanPredicate(False, "mark as execute"))
  318. def closed(reason="marked as skip"):
  319. return skip_if(BooleanPredicate(True, reason))
  320. def fails(reason=None):
  321. return fails_if(BooleanPredicate(True, reason or "expected to fail"))
  322. def future():
  323. return fails_if(BooleanPredicate(True, "Future feature"))
  324. def fails_on(db, reason=None):
  325. return fails_if(db, reason)
  326. def fails_on_everything_except(*dbs):
  327. return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
  328. def skip(db, reason=None):
  329. return skip_if(db, reason)
  330. def only_on(dbs, reason=None):
  331. return only_if(
  332. OrPredicate(
  333. [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
  334. )
  335. )
  336. def exclude(db, op, spec, reason=None):
  337. return skip_if(SpecPredicate(db, op, spec), reason)
  338. def against(config, *queries):
  339. assert queries, "no queries sent!"
  340. return OrPredicate([Predicate.as_predicate(query) for query in queries])(
  341. config
  342. )