assertions.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from __future__ import annotations
  2. import contextlib
  3. import re
  4. import sys
  5. from typing import Any
  6. from typing import Dict
  7. from sqlalchemy import exc as sa_exc
  8. from sqlalchemy.engine import default
  9. from sqlalchemy.engine import URL
  10. from sqlalchemy.testing.assertions import _expect_warnings
  11. from sqlalchemy.testing.assertions import eq_ # noqa
  12. from sqlalchemy.testing.assertions import is_ # noqa
  13. from sqlalchemy.testing.assertions import is_false # noqa
  14. from sqlalchemy.testing.assertions import is_not_ # noqa
  15. from sqlalchemy.testing.assertions import is_true # noqa
  16. from sqlalchemy.testing.assertions import ne_ # noqa
  17. from sqlalchemy.util import decorator
  18. def _assert_proper_exception_context(exception):
  19. """assert that any exception we're catching does not have a __context__
  20. without a __cause__, and that __suppress_context__ is never set.
  21. Python 3 will report nested as exceptions as "during the handling of
  22. error X, error Y occurred". That's not what we want to do. we want
  23. these exceptions in a cause chain.
  24. """
  25. if (
  26. exception.__context__ is not exception.__cause__
  27. and not exception.__suppress_context__
  28. ):
  29. assert False, (
  30. "Exception %r was correctly raised but did not set a cause, "
  31. "within context %r as its cause."
  32. % (exception, exception.__context__)
  33. )
  34. def assert_raises(except_cls, callable_, *args, **kw):
  35. return _assert_raises(except_cls, callable_, args, kw, check_context=True)
  36. def assert_raises_context_ok(except_cls, callable_, *args, **kw):
  37. return _assert_raises(except_cls, callable_, args, kw)
  38. def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
  39. return _assert_raises(
  40. except_cls, callable_, args, kwargs, msg=msg, check_context=True
  41. )
  42. def assert_raises_message_context_ok(
  43. except_cls, msg, callable_, *args, **kwargs
  44. ):
  45. return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
  46. def _assert_raises(
  47. except_cls, callable_, args, kwargs, msg=None, check_context=False
  48. ):
  49. with _expect_raises(except_cls, msg, check_context) as ec:
  50. callable_(*args, **kwargs)
  51. return ec.error
  52. class _ErrorContainer:
  53. error: Any = None
  54. @contextlib.contextmanager
  55. def _expect_raises(
  56. except_cls, msg=None, check_context=False, text_exact=False
  57. ):
  58. ec = _ErrorContainer()
  59. if check_context:
  60. are_we_already_in_a_traceback = sys.exc_info()[0]
  61. try:
  62. yield ec
  63. success = False
  64. except except_cls as err:
  65. ec.error = err
  66. success = True
  67. if msg is not None:
  68. if text_exact:
  69. assert str(err) == msg, f"{msg} != {err}"
  70. else:
  71. assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
  72. if check_context and not are_we_already_in_a_traceback:
  73. _assert_proper_exception_context(err)
  74. print(str(err).encode("utf-8"))
  75. # assert outside the block so it works for AssertionError too !
  76. assert success, "Callable did not raise an exception"
  77. def expect_raises(except_cls, check_context=True):
  78. return _expect_raises(except_cls, check_context=check_context)
  79. def expect_raises_message(
  80. except_cls, msg, check_context=True, text_exact=False
  81. ):
  82. return _expect_raises(
  83. except_cls, msg=msg, check_context=check_context, text_exact=text_exact
  84. )
  85. def eq_ignore_whitespace(a, b, msg=None):
  86. a = re.sub(r"^\s+?|\n", "", a)
  87. a = re.sub(r" {2,}", " ", a)
  88. b = re.sub(r"^\s+?|\n", "", b)
  89. b = re.sub(r" {2,}", " ", b)
  90. assert a == b, msg or "%r != %r" % (a, b)
  91. _dialect_mods: Dict[Any, Any] = {}
  92. def _get_dialect(name):
  93. if name is None or name == "default":
  94. return default.DefaultDialect()
  95. else:
  96. d = URL.create(name).get_dialect()()
  97. if name == "postgresql":
  98. d.implicit_returning = True
  99. elif name == "mssql":
  100. d.legacy_schema_aliasing = False
  101. return d
  102. def expect_warnings(*messages, **kw):
  103. """Context manager which expects one or more warnings.
  104. With no arguments, squelches all SAWarnings emitted via
  105. sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
  106. pass string expressions that will match selected warnings via regex;
  107. all non-matching warnings are sent through.
  108. The expect version **asserts** that the warnings were in fact seen.
  109. Note that the test suite sets SAWarning warnings to raise exceptions.
  110. """
  111. return _expect_warnings(Warning, messages, **kw)
  112. def emits_python_deprecation_warning(*messages):
  113. """Decorator form of expect_warnings().
  114. Note that emits_warning does **not** assert that the warnings
  115. were in fact seen.
  116. """
  117. @decorator
  118. def decorate(fn, *args, **kw):
  119. with _expect_warnings(DeprecationWarning, assert_=False, *messages):
  120. return fn(*args, **kw)
  121. return decorate
  122. def expect_deprecated(*messages, **kw):
  123. return _expect_warnings(DeprecationWarning, messages, **kw)
  124. def expect_sqlalchemy_deprecated(*messages, **kw):
  125. return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
  126. def expect_sqlalchemy_deprecated_20(*messages, **kw):
  127. return _expect_warnings(sa_exc.RemovedIn20Warning, messages, **kw)