util.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # ext/mypy/util.py
  2. # Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. from __future__ import annotations
  8. import re
  9. from typing import Any
  10. from typing import Iterable
  11. from typing import Iterator
  12. from typing import List
  13. from typing import Optional
  14. from typing import overload
  15. from typing import Tuple
  16. from typing import Type as TypingType
  17. from typing import TypeVar
  18. from typing import Union
  19. from mypy import version
  20. from mypy.messages import format_type as _mypy_format_type
  21. from mypy.nodes import CallExpr
  22. from mypy.nodes import ClassDef
  23. from mypy.nodes import CLASSDEF_NO_INFO
  24. from mypy.nodes import Context
  25. from mypy.nodes import Expression
  26. from mypy.nodes import FuncDef
  27. from mypy.nodes import IfStmt
  28. from mypy.nodes import JsonDict
  29. from mypy.nodes import MemberExpr
  30. from mypy.nodes import NameExpr
  31. from mypy.nodes import Statement
  32. from mypy.nodes import SymbolTableNode
  33. from mypy.nodes import TypeAlias
  34. from mypy.nodes import TypeInfo
  35. from mypy.options import Options
  36. from mypy.plugin import ClassDefContext
  37. from mypy.plugin import DynamicClassDefContext
  38. from mypy.plugin import SemanticAnalyzerPluginInterface
  39. from mypy.plugins.common import deserialize_and_fixup_type
  40. from mypy.typeops import map_type_from_supertype
  41. from mypy.types import CallableType
  42. from mypy.types import get_proper_type
  43. from mypy.types import Instance
  44. from mypy.types import NoneType
  45. from mypy.types import Type
  46. from mypy.types import TypeVarType
  47. from mypy.types import UnboundType
  48. from mypy.types import UnionType
  49. _vers = tuple(
  50. [int(x) for x in version.__version__.split(".") if re.match(r"^\d+$", x)]
  51. )
  52. mypy_14 = _vers >= (1, 4)
  53. _TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
  54. class SQLAlchemyAttribute:
  55. def __init__(
  56. self,
  57. name: str,
  58. line: int,
  59. column: int,
  60. typ: Optional[Type],
  61. info: TypeInfo,
  62. ) -> None:
  63. self.name = name
  64. self.line = line
  65. self.column = column
  66. self.type = typ
  67. self.info = info
  68. def serialize(self) -> JsonDict:
  69. assert self.type
  70. return {
  71. "name": self.name,
  72. "line": self.line,
  73. "column": self.column,
  74. "type": serialize_type(self.type),
  75. }
  76. def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
  77. """Expands type vars in the context of a subtype when an attribute is
  78. inherited from a generic super type.
  79. """
  80. if not isinstance(self.type, TypeVarType):
  81. return
  82. self.type = map_type_from_supertype(self.type, sub_type, self.info)
  83. @classmethod
  84. def deserialize(
  85. cls,
  86. info: TypeInfo,
  87. data: JsonDict,
  88. api: SemanticAnalyzerPluginInterface,
  89. ) -> SQLAlchemyAttribute:
  90. data = data.copy()
  91. typ = deserialize_and_fixup_type(data.pop("type"), api)
  92. return cls(typ=typ, info=info, **data)
  93. def name_is_dunder(name: str) -> bool:
  94. return bool(re.match(r"^__.+?__$", name))
  95. def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
  96. info.metadata.setdefault("sqlalchemy", {})[key] = data
  97. def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
  98. return info.metadata.get("sqlalchemy", {}).get(key, None)
  99. def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
  100. if info.mro:
  101. for base in info.mro:
  102. metadata = _get_info_metadata(base, key)
  103. if metadata is not None:
  104. return metadata
  105. return None
  106. def establish_as_sqlalchemy(info: TypeInfo) -> None:
  107. info.metadata.setdefault("sqlalchemy", {})
  108. def set_is_base(info: TypeInfo) -> None:
  109. _set_info_metadata(info, "is_base", True)
  110. def get_is_base(info: TypeInfo) -> bool:
  111. is_base = _get_info_metadata(info, "is_base")
  112. return is_base is True
  113. def has_declarative_base(info: TypeInfo) -> bool:
  114. is_base = _get_info_mro_metadata(info, "is_base")
  115. return is_base is True
  116. def set_has_table(info: TypeInfo) -> None:
  117. _set_info_metadata(info, "has_table", True)
  118. def get_has_table(info: TypeInfo) -> bool:
  119. is_base = _get_info_metadata(info, "has_table")
  120. return is_base is True
  121. def get_mapped_attributes(
  122. info: TypeInfo, api: SemanticAnalyzerPluginInterface
  123. ) -> Optional[List[SQLAlchemyAttribute]]:
  124. mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
  125. info, "mapped_attributes"
  126. )
  127. if mapped_attributes is None:
  128. return None
  129. attributes: List[SQLAlchemyAttribute] = []
  130. for data in mapped_attributes:
  131. attr = SQLAlchemyAttribute.deserialize(info, data, api)
  132. attr.expand_typevar_from_subtype(info)
  133. attributes.append(attr)
  134. return attributes
  135. def format_type(typ_: Type, options: Options) -> str:
  136. if mypy_14:
  137. return _mypy_format_type(typ_, options)
  138. else:
  139. return _mypy_format_type(typ_) # type: ignore
  140. def set_mapped_attributes(
  141. info: TypeInfo, attributes: List[SQLAlchemyAttribute]
  142. ) -> None:
  143. _set_info_metadata(
  144. info,
  145. "mapped_attributes",
  146. [attribute.serialize() for attribute in attributes],
  147. )
  148. def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
  149. msg = "[SQLAlchemy Mypy plugin] %s" % msg
  150. return api.fail(msg, ctx)
  151. def add_global(
  152. ctx: Union[ClassDefContext, DynamicClassDefContext],
  153. module: str,
  154. symbol_name: str,
  155. asname: str,
  156. ) -> None:
  157. module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
  158. if asname not in module_globals:
  159. lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
  160. symbol_name
  161. ]
  162. module_globals[asname] = lookup_sym
  163. @overload
  164. def get_callexpr_kwarg(
  165. callexpr: CallExpr, name: str, *, expr_types: None = ...
  166. ) -> Optional[Union[CallExpr, NameExpr]]: ...
  167. @overload
  168. def get_callexpr_kwarg(
  169. callexpr: CallExpr,
  170. name: str,
  171. *,
  172. expr_types: Tuple[TypingType[_TArgType], ...],
  173. ) -> Optional[_TArgType]: ...
  174. def get_callexpr_kwarg(
  175. callexpr: CallExpr,
  176. name: str,
  177. *,
  178. expr_types: Optional[Tuple[TypingType[Any], ...]] = None,
  179. ) -> Optional[Any]:
  180. try:
  181. arg_idx = callexpr.arg_names.index(name)
  182. except ValueError:
  183. return None
  184. kwarg = callexpr.args[arg_idx]
  185. if isinstance(
  186. kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
  187. ):
  188. return kwarg
  189. return None
  190. def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
  191. for stmt in stmts:
  192. if (
  193. isinstance(stmt, IfStmt)
  194. and isinstance(stmt.expr[0], NameExpr)
  195. and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
  196. ):
  197. yield from stmt.body[0].body
  198. else:
  199. yield stmt
  200. def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]:
  201. if isinstance(callee, (MemberExpr, NameExpr)):
  202. if isinstance(callee.node, FuncDef):
  203. if callee.node.type and isinstance(callee.node.type, CallableType):
  204. ret_type = get_proper_type(callee.node.type.ret_type)
  205. if isinstance(ret_type, Instance):
  206. return ret_type
  207. return None
  208. elif isinstance(callee.node, TypeAlias):
  209. target_type = get_proper_type(callee.node.target)
  210. if isinstance(target_type, Instance):
  211. return target_type
  212. elif isinstance(callee.node, TypeInfo):
  213. return callee.node
  214. return None
  215. def unbound_to_instance(
  216. api: SemanticAnalyzerPluginInterface, typ: Type
  217. ) -> Type:
  218. """Take the UnboundType that we seem to get as the ret_type from a FuncDef
  219. and convert it into an Instance/TypeInfo kind of structure that seems
  220. to work as the left-hand type of an AssignmentStatement.
  221. """
  222. if not isinstance(typ, UnboundType):
  223. return typ
  224. # TODO: figure out a more robust way to check this. The node is some
  225. # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
  226. # but I can't figure out how to get them to match up
  227. if typ.name == "Optional":
  228. # convert from "Optional?" to the more familiar
  229. # UnionType[..., NoneType()]
  230. return unbound_to_instance(
  231. api,
  232. UnionType(
  233. [unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
  234. + [NoneType()]
  235. ),
  236. )
  237. node = api.lookup_qualified(typ.name, typ)
  238. if (
  239. node is not None
  240. and isinstance(node, SymbolTableNode)
  241. and isinstance(node.node, TypeInfo)
  242. ):
  243. bound_type = node.node
  244. return Instance(
  245. bound_type,
  246. [
  247. (
  248. unbound_to_instance(api, arg)
  249. if isinstance(arg, UnboundType)
  250. else arg
  251. )
  252. for arg in typ.args
  253. ],
  254. )
  255. else:
  256. return typ
  257. def info_for_cls(
  258. cls: ClassDef, api: SemanticAnalyzerPluginInterface
  259. ) -> Optional[TypeInfo]:
  260. if cls.info is CLASSDEF_NO_INFO:
  261. sym = api.lookup_qualified(cls.name, cls)
  262. if sym is None:
  263. return None
  264. assert sym and isinstance(sym.node, TypeInfo)
  265. return sym.node
  266. return cls.info
  267. def serialize_type(typ: Type) -> Union[str, JsonDict]:
  268. try:
  269. return typ.serialize()
  270. except Exception:
  271. pass
  272. if hasattr(typ, "args"):
  273. typ.args = tuple(
  274. (
  275. a.resolve_string_annotation()
  276. if hasattr(a, "resolve_string_annotation")
  277. else a
  278. )
  279. for a in typ.args
  280. )
  281. elif hasattr(typ, "resolve_string_annotation"):
  282. typ = typ.resolve_string_annotation()
  283. return typ.serialize()