| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332 |
- # testing/fixtures/mypy.py
- # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- # mypy: ignore-errors
- from __future__ import annotations
- import inspect
- import os
- from pathlib import Path
- import re
- import shutil
- import sys
- import tempfile
- from .base import TestBase
- from .. import config
- from ..assertions import eq_
- from ... import util
- @config.add_to_marker.mypy
- class MypyTest(TestBase):
- __requires__ = ("no_sqlalchemy2_stubs",)
- @config.fixture(scope="function")
- def per_func_cachedir(self):
- yield from self._cachedir()
- @config.fixture(scope="class")
- def cachedir(self):
- yield from self._cachedir()
- def _cachedir(self):
- # as of mypy 0.971 i think we need to keep mypy_path empty
- mypy_path = ""
- with tempfile.TemporaryDirectory() as cachedir:
- with open(
- Path(cachedir) / "sqla_mypy_config.cfg", "w"
- ) as config_file:
- config_file.write(
- f"""
- [mypy]\n
- plugins = sqlalchemy.ext.mypy.plugin\n
- show_error_codes = True\n
- {mypy_path}
- disable_error_code = no-untyped-call
- [mypy-sqlalchemy.*]
- ignore_errors = True
- """
- )
- with open(
- Path(cachedir) / "plain_mypy_config.cfg", "w"
- ) as config_file:
- config_file.write(
- f"""
- [mypy]\n
- show_error_codes = True\n
- {mypy_path}
- disable_error_code = var-annotated,no-untyped-call
- [mypy-sqlalchemy.*]
- ignore_errors = True
- """
- )
- yield cachedir
- @config.fixture()
- def mypy_runner(self, cachedir):
- from mypy import api
- def run(path, use_plugin=False, use_cachedir=None):
- if use_cachedir is None:
- use_cachedir = cachedir
- args = [
- "--strict",
- "--raise-exceptions",
- "--cache-dir",
- use_cachedir,
- "--config-file",
- os.path.join(
- use_cachedir,
- (
- "sqla_mypy_config.cfg"
- if use_plugin
- else "plain_mypy_config.cfg"
- ),
- ),
- ]
- # mypy as of 0.990 is more aggressively blocking messaging
- # for paths that are in sys.path, and as pytest puts currdir,
- # test/ etc in sys.path, just copy the source file to the
- # tempdir we are working in so that we don't have to try to
- # manipulate sys.path and/or guess what mypy is doing
- filename = os.path.basename(path)
- test_program = os.path.join(use_cachedir, filename)
- if path != test_program:
- shutil.copyfile(path, test_program)
- args.append(test_program)
- # I set this locally but for the suite here needs to be
- # disabled
- os.environ.pop("MYPY_FORCE_COLOR", None)
- stdout, stderr, exitcode = api.run(args)
- return stdout, stderr, exitcode
- return run
- @config.fixture
- def mypy_typecheck_file(self, mypy_runner):
- def run(path, use_plugin=False):
- expected_messages = self._collect_messages(path)
- stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin)
- self._check_output(
- path, expected_messages, stdout, stderr, exitcode
- )
- return run
- @staticmethod
- def file_combinations(dirname):
- if os.path.isabs(dirname):
- path = dirname
- else:
- caller_path = inspect.stack()[1].filename
- path = os.path.join(os.path.dirname(caller_path), dirname)
- files = list(Path(path).glob("**/*.py"))
- for extra_dir in config.options.mypy_extra_test_paths:
- if extra_dir and os.path.isdir(extra_dir):
- files.extend((Path(extra_dir) / dirname).glob("**/*.py"))
- return files
- def _collect_messages(self, path):
- from sqlalchemy.ext.mypy.util import mypy_14
- expected_messages = []
- expected_re = re.compile(
- r"\s*# EXPECTED(_MYPY)?(_RE)?(_ROW)?(_TYPE)?: (.+)"
- )
- py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
- with open(path) as file_:
- current_assert_messages = []
- for num, line in enumerate(file_, 1):
- m = py_ver_re.match(line)
- if m:
- major, _, minor = m.group(1).partition(".")
- if sys.version_info < (int(major), int(minor)):
- config.skip_test(
- "Requires python >= %s" % (m.group(1))
- )
- continue
- m = expected_re.match(line)
- if m:
- is_mypy = bool(m.group(1))
- is_re = bool(m.group(2))
- is_row = bool(m.group(3))
- is_type = bool(m.group(4))
- expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(5))
- if is_row:
- expected_msg = re.sub(
- r"Row\[([^\]]+)\]",
- lambda m: f"tuple[{m.group(1)}, fallback=s"
- f"qlalchemy.engine.row.{m.group(0)}]",
- expected_msg,
- )
- # For some reason it does not use or syntax (|)
- expected_msg = re.sub(
- r"Optional\[(.*)\]",
- lambda m: f"Union[{m.group(1)}, None]",
- expected_msg,
- )
- if is_type:
- if not is_re:
- # the goal here is that we can cut-and-paste
- # from vscode -> pylance into the
- # EXPECTED_TYPE: line, then the test suite will
- # validate that line against what mypy produces
- expected_msg = re.sub(
- r"([\[\]])",
- lambda m: rf"\{m.group(0)}",
- expected_msg,
- )
- # note making sure preceding text matches
- # with a dot, so that an expect for "Select"
- # does not match "TypedSelect"
- expected_msg = re.sub(
- r"([\w_]+)",
- lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
- expected_msg,
- )
- expected_msg = re.sub(
- "List", "builtins.list", expected_msg
- )
- expected_msg = re.sub(
- r"\b(int|str|float|bool)\b",
- lambda m: rf"builtins.{m.group(0)}\*?",
- expected_msg,
- )
- # expected_msg = re.sub(
- # r"(Sequence|Tuple|List|Union)",
- # lambda m: fr"typing.{m.group(0)}\*?",
- # expected_msg,
- # )
- is_mypy = is_re = True
- expected_msg = f'Revealed type is "{expected_msg}"'
- if mypy_14 and util.py39:
- # use_lowercase_names, py39 and above
- # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501
- # skip first character which could be capitalized
- # "List item x not found" type of message
- expected_msg = expected_msg[0] + re.sub(
- (
- r"\b(List|Tuple|Dict|Set)\b"
- if is_type
- else r"\b(List|Tuple|Dict|Set|Type)\b"
- ),
- lambda m: m.group(1).lower(),
- expected_msg[1:],
- )
- if mypy_14 and util.py310:
- # use_or_syntax, py310 and above
- # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501
- expected_msg = re.sub(
- r"Optional\[(.*?)\]",
- lambda m: f"{m.group(1)} | None",
- expected_msg,
- )
- current_assert_messages.append(
- (is_mypy, is_re, expected_msg.strip())
- )
- elif current_assert_messages:
- expected_messages.extend(
- (num, is_mypy, is_re, expected_msg)
- for (
- is_mypy,
- is_re,
- expected_msg,
- ) in current_assert_messages
- )
- current_assert_messages[:] = []
- return expected_messages
- def _check_output(
- self, path, expected_messages, stdout: str, stderr, exitcode
- ):
- not_located = []
- filename = os.path.basename(path)
- if expected_messages:
- # mypy 0.990 changed how return codes work, so don't assume a
- # 1 or a 0 return code here, could be either depending on if
- # errors were generated or not
- output = []
- raw_lines = stdout.split("\n")
- while raw_lines:
- e = raw_lines.pop(0)
- if re.match(r".+\.py:\d+: error: .*", e):
- output.append(("error", e))
- elif re.match(
- r".+\.py:\d+: note: +(?:Possible overload|def ).*", e
- ):
- while raw_lines:
- ol = raw_lines.pop(0)
- if not re.match(r".+\.py:\d+: note: +def .*", ol):
- raw_lines.insert(0, ol)
- break
- elif re.match(
- r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
- ):
- pass
- elif re.match(r".+\.py:\d+: note: .*", e):
- output.append(("note", e))
- for num, is_mypy, is_re, msg in expected_messages:
- msg = msg.replace("'", '"')
- prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
- for idx, (typ, errmsg) in enumerate(output):
- if is_re:
- if re.match(
- rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}",
- errmsg,
- ):
- break
- elif (
- f"{filename}:{num}: {typ}: {prefix}{msg}"
- in errmsg.replace("'", '"')
- ):
- break
- else:
- not_located.append(msg)
- continue
- del output[idx]
- if not_located:
- missing = "\n".join(not_located)
- print("Couldn't locate expected messages:", missing, sep="\n")
- if output:
- extra = "\n".join(msg for _, msg in output)
- print("Remaining messages:", extra, sep="\n")
- assert False, "expected messages not found, see stdout"
- if output:
- print(f"{len(output)} messages from mypy were not consumed:")
- print("\n".join(msg for _, msg in output))
- assert False, "errors and/or notes remain, see stdout"
- else:
- if exitcode != 0:
- print(stdout, stderr, sep="\n")
- eq_(exitcode, 0, msg=stdout)
|