mypy.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # testing/fixtures/mypy.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. from __future__ import annotations
  9. import inspect
  10. import os
  11. from pathlib import Path
  12. import re
  13. import shutil
  14. import sys
  15. import tempfile
  16. from .base import TestBase
  17. from .. import config
  18. from ..assertions import eq_
  19. from ... import util
  20. @config.add_to_marker.mypy
  21. class MypyTest(TestBase):
  22. __requires__ = ("no_sqlalchemy2_stubs",)
  23. @config.fixture(scope="function")
  24. def per_func_cachedir(self):
  25. yield from self._cachedir()
  26. @config.fixture(scope="class")
  27. def cachedir(self):
  28. yield from self._cachedir()
  29. def _cachedir(self):
  30. # as of mypy 0.971 i think we need to keep mypy_path empty
  31. mypy_path = ""
  32. with tempfile.TemporaryDirectory() as cachedir:
  33. with open(
  34. Path(cachedir) / "sqla_mypy_config.cfg", "w"
  35. ) as config_file:
  36. config_file.write(
  37. f"""
  38. [mypy]\n
  39. plugins = sqlalchemy.ext.mypy.plugin\n
  40. show_error_codes = True\n
  41. {mypy_path}
  42. disable_error_code = no-untyped-call
  43. [mypy-sqlalchemy.*]
  44. ignore_errors = True
  45. """
  46. )
  47. with open(
  48. Path(cachedir) / "plain_mypy_config.cfg", "w"
  49. ) as config_file:
  50. config_file.write(
  51. f"""
  52. [mypy]\n
  53. show_error_codes = True\n
  54. {mypy_path}
  55. disable_error_code = var-annotated,no-untyped-call
  56. [mypy-sqlalchemy.*]
  57. ignore_errors = True
  58. """
  59. )
  60. yield cachedir
  61. @config.fixture()
  62. def mypy_runner(self, cachedir):
  63. from mypy import api
  64. def run(path, use_plugin=False, use_cachedir=None):
  65. if use_cachedir is None:
  66. use_cachedir = cachedir
  67. args = [
  68. "--strict",
  69. "--raise-exceptions",
  70. "--cache-dir",
  71. use_cachedir,
  72. "--config-file",
  73. os.path.join(
  74. use_cachedir,
  75. (
  76. "sqla_mypy_config.cfg"
  77. if use_plugin
  78. else "plain_mypy_config.cfg"
  79. ),
  80. ),
  81. ]
  82. # mypy as of 0.990 is more aggressively blocking messaging
  83. # for paths that are in sys.path, and as pytest puts currdir,
  84. # test/ etc in sys.path, just copy the source file to the
  85. # tempdir we are working in so that we don't have to try to
  86. # manipulate sys.path and/or guess what mypy is doing
  87. filename = os.path.basename(path)
  88. test_program = os.path.join(use_cachedir, filename)
  89. if path != test_program:
  90. shutil.copyfile(path, test_program)
  91. args.append(test_program)
  92. # I set this locally but for the suite here needs to be
  93. # disabled
  94. os.environ.pop("MYPY_FORCE_COLOR", None)
  95. stdout, stderr, exitcode = api.run(args)
  96. return stdout, stderr, exitcode
  97. return run
  98. @config.fixture
  99. def mypy_typecheck_file(self, mypy_runner):
  100. def run(path, use_plugin=False):
  101. expected_messages = self._collect_messages(path)
  102. stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin)
  103. self._check_output(
  104. path, expected_messages, stdout, stderr, exitcode
  105. )
  106. return run
  107. @staticmethod
  108. def file_combinations(dirname):
  109. if os.path.isabs(dirname):
  110. path = dirname
  111. else:
  112. caller_path = inspect.stack()[1].filename
  113. path = os.path.join(os.path.dirname(caller_path), dirname)
  114. files = list(Path(path).glob("**/*.py"))
  115. for extra_dir in config.options.mypy_extra_test_paths:
  116. if extra_dir and os.path.isdir(extra_dir):
  117. files.extend((Path(extra_dir) / dirname).glob("**/*.py"))
  118. return files
  119. def _collect_messages(self, path):
  120. from sqlalchemy.ext.mypy.util import mypy_14
  121. expected_messages = []
  122. expected_re = re.compile(
  123. r"\s*# EXPECTED(_MYPY)?(_RE)?(_ROW)?(_TYPE)?: (.+)"
  124. )
  125. py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
  126. with open(path) as file_:
  127. current_assert_messages = []
  128. for num, line in enumerate(file_, 1):
  129. m = py_ver_re.match(line)
  130. if m:
  131. major, _, minor = m.group(1).partition(".")
  132. if sys.version_info < (int(major), int(minor)):
  133. config.skip_test(
  134. "Requires python >= %s" % (m.group(1))
  135. )
  136. continue
  137. m = expected_re.match(line)
  138. if m:
  139. is_mypy = bool(m.group(1))
  140. is_re = bool(m.group(2))
  141. is_row = bool(m.group(3))
  142. is_type = bool(m.group(4))
  143. expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(5))
  144. if is_row:
  145. expected_msg = re.sub(
  146. r"Row\[([^\]]+)\]",
  147. lambda m: f"tuple[{m.group(1)}, fallback=s"
  148. f"qlalchemy.engine.row.{m.group(0)}]",
  149. expected_msg,
  150. )
  151. # For some reason it does not use or syntax (|)
  152. expected_msg = re.sub(
  153. r"Optional\[(.*)\]",
  154. lambda m: f"Union[{m.group(1)}, None]",
  155. expected_msg,
  156. )
  157. if is_type:
  158. if not is_re:
  159. # the goal here is that we can cut-and-paste
  160. # from vscode -> pylance into the
  161. # EXPECTED_TYPE: line, then the test suite will
  162. # validate that line against what mypy produces
  163. expected_msg = re.sub(
  164. r"([\[\]])",
  165. lambda m: rf"\{m.group(0)}",
  166. expected_msg,
  167. )
  168. # note making sure preceding text matches
  169. # with a dot, so that an expect for "Select"
  170. # does not match "TypedSelect"
  171. expected_msg = re.sub(
  172. r"([\w_]+)",
  173. lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
  174. expected_msg,
  175. )
  176. expected_msg = re.sub(
  177. "List", "builtins.list", expected_msg
  178. )
  179. expected_msg = re.sub(
  180. r"\b(int|str|float|bool)\b",
  181. lambda m: rf"builtins.{m.group(0)}\*?",
  182. expected_msg,
  183. )
  184. # expected_msg = re.sub(
  185. # r"(Sequence|Tuple|List|Union)",
  186. # lambda m: fr"typing.{m.group(0)}\*?",
  187. # expected_msg,
  188. # )
  189. is_mypy = is_re = True
  190. expected_msg = f'Revealed type is "{expected_msg}"'
  191. if mypy_14 and util.py39:
  192. # use_lowercase_names, py39 and above
  193. # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501
  194. # skip first character which could be capitalized
  195. # "List item x not found" type of message
  196. expected_msg = expected_msg[0] + re.sub(
  197. (
  198. r"\b(List|Tuple|Dict|Set)\b"
  199. if is_type
  200. else r"\b(List|Tuple|Dict|Set|Type)\b"
  201. ),
  202. lambda m: m.group(1).lower(),
  203. expected_msg[1:],
  204. )
  205. if mypy_14 and util.py310:
  206. # use_or_syntax, py310 and above
  207. # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501
  208. expected_msg = re.sub(
  209. r"Optional\[(.*?)\]",
  210. lambda m: f"{m.group(1)} | None",
  211. expected_msg,
  212. )
  213. current_assert_messages.append(
  214. (is_mypy, is_re, expected_msg.strip())
  215. )
  216. elif current_assert_messages:
  217. expected_messages.extend(
  218. (num, is_mypy, is_re, expected_msg)
  219. for (
  220. is_mypy,
  221. is_re,
  222. expected_msg,
  223. ) in current_assert_messages
  224. )
  225. current_assert_messages[:] = []
  226. return expected_messages
  227. def _check_output(
  228. self, path, expected_messages, stdout: str, stderr, exitcode
  229. ):
  230. not_located = []
  231. filename = os.path.basename(path)
  232. if expected_messages:
  233. # mypy 0.990 changed how return codes work, so don't assume a
  234. # 1 or a 0 return code here, could be either depending on if
  235. # errors were generated or not
  236. output = []
  237. raw_lines = stdout.split("\n")
  238. while raw_lines:
  239. e = raw_lines.pop(0)
  240. if re.match(r".+\.py:\d+: error: .*", e):
  241. output.append(("error", e))
  242. elif re.match(
  243. r".+\.py:\d+: note: +(?:Possible overload|def ).*", e
  244. ):
  245. while raw_lines:
  246. ol = raw_lines.pop(0)
  247. if not re.match(r".+\.py:\d+: note: +def .*", ol):
  248. raw_lines.insert(0, ol)
  249. break
  250. elif re.match(
  251. r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
  252. ):
  253. pass
  254. elif re.match(r".+\.py:\d+: note: .*", e):
  255. output.append(("note", e))
  256. for num, is_mypy, is_re, msg in expected_messages:
  257. msg = msg.replace("'", '"')
  258. prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
  259. for idx, (typ, errmsg) in enumerate(output):
  260. if is_re:
  261. if re.match(
  262. rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}",
  263. errmsg,
  264. ):
  265. break
  266. elif (
  267. f"{filename}:{num}: {typ}: {prefix}{msg}"
  268. in errmsg.replace("'", '"')
  269. ):
  270. break
  271. else:
  272. not_located.append(msg)
  273. continue
  274. del output[idx]
  275. if not_located:
  276. missing = "\n".join(not_located)
  277. print("Couldn't locate expected messages:", missing, sep="\n")
  278. if output:
  279. extra = "\n".join(msg for _, msg in output)
  280. print("Remaining messages:", extra, sep="\n")
  281. assert False, "expected messages not found, see stdout"
  282. if output:
  283. print(f"{len(output)} messages from mypy were not consumed:")
  284. print("\n".join(msg for _, msg in output))
  285. assert False, "errors and/or notes remain, see stdout"
  286. else:
  287. if exitcode != 0:
  288. print(stdout, stderr, sep="\n")
  289. eq_(exitcode, 0, msg=stdout)