apply.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # ext/mypy/apply.py
  2. # Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. from __future__ import annotations
  8. from typing import List
  9. from typing import Optional
  10. from typing import Union
  11. from mypy.nodes import ARG_NAMED_OPT
  12. from mypy.nodes import Argument
  13. from mypy.nodes import AssignmentStmt
  14. from mypy.nodes import CallExpr
  15. from mypy.nodes import ClassDef
  16. from mypy.nodes import MDEF
  17. from mypy.nodes import MemberExpr
  18. from mypy.nodes import NameExpr
  19. from mypy.nodes import RefExpr
  20. from mypy.nodes import StrExpr
  21. from mypy.nodes import SymbolTableNode
  22. from mypy.nodes import TempNode
  23. from mypy.nodes import TypeInfo
  24. from mypy.nodes import Var
  25. from mypy.plugin import SemanticAnalyzerPluginInterface
  26. from mypy.plugins.common import add_method_to_class
  27. from mypy.types import AnyType
  28. from mypy.types import get_proper_type
  29. from mypy.types import Instance
  30. from mypy.types import NoneTyp
  31. from mypy.types import ProperType
  32. from mypy.types import TypeOfAny
  33. from mypy.types import UnboundType
  34. from mypy.types import UnionType
  35. from . import infer
  36. from . import util
  37. from .names import expr_to_mapped_constructor
  38. from .names import NAMED_TYPE_SQLA_MAPPED
  39. def apply_mypy_mapped_attr(
  40. cls: ClassDef,
  41. api: SemanticAnalyzerPluginInterface,
  42. item: Union[NameExpr, StrExpr],
  43. attributes: List[util.SQLAlchemyAttribute],
  44. ) -> None:
  45. if isinstance(item, NameExpr):
  46. name = item.name
  47. elif isinstance(item, StrExpr):
  48. name = item.value
  49. else:
  50. return None
  51. for stmt in cls.defs.body:
  52. if (
  53. isinstance(stmt, AssignmentStmt)
  54. and isinstance(stmt.lvalues[0], NameExpr)
  55. and stmt.lvalues[0].name == name
  56. ):
  57. break
  58. else:
  59. util.fail(api, f"Can't find mapped attribute {name}", cls)
  60. return None
  61. if stmt.type is None:
  62. util.fail(
  63. api,
  64. "Statement linked from _mypy_mapped_attrs has no "
  65. "typing information",
  66. stmt,
  67. )
  68. return None
  69. left_hand_explicit_type = get_proper_type(stmt.type)
  70. assert isinstance(
  71. left_hand_explicit_type, (Instance, UnionType, UnboundType)
  72. )
  73. attributes.append(
  74. util.SQLAlchemyAttribute(
  75. name=name,
  76. line=item.line,
  77. column=item.column,
  78. typ=left_hand_explicit_type,
  79. info=cls.info,
  80. )
  81. )
  82. apply_type_to_mapped_statement(
  83. api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
  84. )
  85. def re_apply_declarative_assignments(
  86. cls: ClassDef,
  87. api: SemanticAnalyzerPluginInterface,
  88. attributes: List[util.SQLAlchemyAttribute],
  89. ) -> None:
  90. """For multiple class passes, re-apply our left-hand side types as mypy
  91. seems to reset them in place.
  92. """
  93. mapped_attr_lookup = {attr.name: attr for attr in attributes}
  94. update_cls_metadata = False
  95. for stmt in cls.defs.body:
  96. # for a re-apply, all of our statements are AssignmentStmt;
  97. # @declared_attr calls will have been converted and this
  98. # currently seems to be preserved by mypy (but who knows if this
  99. # will change).
  100. if (
  101. isinstance(stmt, AssignmentStmt)
  102. and isinstance(stmt.lvalues[0], NameExpr)
  103. and stmt.lvalues[0].name in mapped_attr_lookup
  104. and isinstance(stmt.lvalues[0].node, Var)
  105. ):
  106. left_node = stmt.lvalues[0].node
  107. python_type_for_type = mapped_attr_lookup[
  108. stmt.lvalues[0].name
  109. ].type
  110. left_node_proper_type = get_proper_type(left_node.type)
  111. # if we have scanned an UnboundType and now there's a more
  112. # specific type than UnboundType, call the re-scan so we
  113. # can get that set up correctly
  114. if (
  115. isinstance(python_type_for_type, UnboundType)
  116. and not isinstance(left_node_proper_type, UnboundType)
  117. and (
  118. isinstance(stmt.rvalue, CallExpr)
  119. and isinstance(stmt.rvalue.callee, MemberExpr)
  120. and isinstance(stmt.rvalue.callee.expr, NameExpr)
  121. and stmt.rvalue.callee.expr.node is not None
  122. and stmt.rvalue.callee.expr.node.fullname
  123. == NAMED_TYPE_SQLA_MAPPED
  124. and stmt.rvalue.callee.name == "_empty_constructor"
  125. and isinstance(stmt.rvalue.args[0], CallExpr)
  126. and isinstance(stmt.rvalue.args[0].callee, RefExpr)
  127. )
  128. ):
  129. new_python_type_for_type = (
  130. infer.infer_type_from_right_hand_nameexpr(
  131. api,
  132. stmt,
  133. left_node,
  134. left_node_proper_type,
  135. stmt.rvalue.args[0].callee,
  136. )
  137. )
  138. if new_python_type_for_type is not None and not isinstance(
  139. new_python_type_for_type, UnboundType
  140. ):
  141. python_type_for_type = new_python_type_for_type
  142. # update the SQLAlchemyAttribute with the better
  143. # information
  144. mapped_attr_lookup[stmt.lvalues[0].name].type = (
  145. python_type_for_type
  146. )
  147. update_cls_metadata = True
  148. if (
  149. not isinstance(left_node.type, Instance)
  150. or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED
  151. ):
  152. assert python_type_for_type is not None
  153. left_node.type = api.named_type(
  154. NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
  155. )
  156. if update_cls_metadata:
  157. util.set_mapped_attributes(cls.info, attributes)
  158. def apply_type_to_mapped_statement(
  159. api: SemanticAnalyzerPluginInterface,
  160. stmt: AssignmentStmt,
  161. lvalue: NameExpr,
  162. left_hand_explicit_type: Optional[ProperType],
  163. python_type_for_type: Optional[ProperType],
  164. ) -> None:
  165. """Apply the Mapped[<type>] annotation and right hand object to a
  166. declarative assignment statement.
  167. This converts a Python declarative class statement such as::
  168. class User(Base):
  169. # ...
  170. attrname = Column(Integer)
  171. To one that describes the final Python behavior to Mypy::
  172. ... format: off
  173. class User(Base):
  174. # ...
  175. attrname : Mapped[Optional[int]] = <meaningless temp node>
  176. ... format: on
  177. """
  178. left_node = lvalue.node
  179. assert isinstance(left_node, Var)
  180. # to be completely honest I have no idea what the difference between
  181. # left_node.type and stmt.type is, what it means if these are different
  182. # vs. the same, why in order to get tests to pass I have to assign
  183. # to stmt.type for the second case and not the first. this is complete
  184. # trying every combination until it works stuff.
  185. if left_hand_explicit_type is not None:
  186. lvalue.is_inferred_def = False
  187. left_node.type = api.named_type(
  188. NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
  189. )
  190. else:
  191. lvalue.is_inferred_def = False
  192. left_node.type = api.named_type(
  193. NAMED_TYPE_SQLA_MAPPED,
  194. (
  195. [AnyType(TypeOfAny.special_form)]
  196. if python_type_for_type is None
  197. else [python_type_for_type]
  198. ),
  199. )
  200. # so to have it skip the right side totally, we can do this:
  201. # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
  202. # however, if we instead manufacture a new node that uses the old
  203. # one, then we can still get type checking for the call itself,
  204. # e.g. the Column, relationship() call, etc.
  205. # rewrite the node as:
  206. # <attr> : Mapped[<typ>] =
  207. # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
  208. # the original right-hand side is maintained so it gets type checked
  209. # internally
  210. stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue)
  211. if stmt.type is not None and python_type_for_type is not None:
  212. stmt.type = python_type_for_type
  213. def add_additional_orm_attributes(
  214. cls: ClassDef,
  215. api: SemanticAnalyzerPluginInterface,
  216. attributes: List[util.SQLAlchemyAttribute],
  217. ) -> None:
  218. """Apply __init__, __table__ and other attributes to the mapped class."""
  219. info = util.info_for_cls(cls, api)
  220. if info is None:
  221. return
  222. is_base = util.get_is_base(info)
  223. if "__init__" not in info.names and not is_base:
  224. mapped_attr_names = {attr.name: attr.type for attr in attributes}
  225. for base in info.mro[1:-1]:
  226. if "sqlalchemy" not in info.metadata:
  227. continue
  228. base_cls_attributes = util.get_mapped_attributes(base, api)
  229. if base_cls_attributes is None:
  230. continue
  231. for attr in base_cls_attributes:
  232. mapped_attr_names.setdefault(attr.name, attr.type)
  233. arguments = []
  234. for name, typ in mapped_attr_names.items():
  235. if typ is None:
  236. typ = AnyType(TypeOfAny.special_form)
  237. arguments.append(
  238. Argument(
  239. variable=Var(name, typ),
  240. type_annotation=typ,
  241. initializer=TempNode(typ),
  242. kind=ARG_NAMED_OPT,
  243. )
  244. )
  245. add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
  246. if "__table__" not in info.names and util.get_has_table(info):
  247. _apply_placeholder_attr_to_class(
  248. api, cls, "sqlalchemy.sql.schema.Table", "__table__"
  249. )
  250. if not is_base:
  251. _apply_placeholder_attr_to_class(
  252. api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
  253. )
  254. def _apply_placeholder_attr_to_class(
  255. api: SemanticAnalyzerPluginInterface,
  256. cls: ClassDef,
  257. qualified_name: str,
  258. attrname: str,
  259. ) -> None:
  260. sym = api.lookup_fully_qualified_or_none(qualified_name)
  261. if sym:
  262. assert isinstance(sym.node, TypeInfo)
  263. type_: ProperType = Instance(sym.node, [])
  264. else:
  265. type_ = AnyType(TypeOfAny.special_form)
  266. var = Var(attrname)
  267. var._fullname = cls.fullname + "." + attrname
  268. var.info = cls.info
  269. var.type = type_
  270. cls.info.names[attrname] = SymbolTableNode(MDEF, var)