plugin.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # ext/mypy/plugin.py
  2. # Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. """
  8. Mypy plugin for SQLAlchemy ORM.
  9. """
  10. from __future__ import annotations
  11. from typing import Callable
  12. from typing import List
  13. from typing import Optional
  14. from typing import Tuple
  15. from typing import Type as TypingType
  16. from typing import Union
  17. from mypy import nodes
  18. from mypy.mro import calculate_mro
  19. from mypy.mro import MroError
  20. from mypy.nodes import Block
  21. from mypy.nodes import ClassDef
  22. from mypy.nodes import GDEF
  23. from mypy.nodes import MypyFile
  24. from mypy.nodes import NameExpr
  25. from mypy.nodes import SymbolTable
  26. from mypy.nodes import SymbolTableNode
  27. from mypy.nodes import TypeInfo
  28. from mypy.plugin import AttributeContext
  29. from mypy.plugin import ClassDefContext
  30. from mypy.plugin import DynamicClassDefContext
  31. from mypy.plugin import Plugin
  32. from mypy.plugin import SemanticAnalyzerPluginInterface
  33. from mypy.types import get_proper_type
  34. from mypy.types import Instance
  35. from mypy.types import Type
  36. from . import decl_class
  37. from . import names
  38. from . import util
  39. try:
  40. __import__("sqlalchemy-stubs")
  41. except ImportError:
  42. pass
  43. else:
  44. raise ImportError(
  45. "The SQLAlchemy mypy plugin in SQLAlchemy "
  46. "2.0 does not work with sqlalchemy-stubs or "
  47. "sqlalchemy2-stubs installed, as well as with any other third party "
  48. "SQLAlchemy stubs. Please uninstall all SQLAlchemy stubs "
  49. "packages."
  50. )
  51. class SQLAlchemyPlugin(Plugin):
  52. def get_dynamic_class_hook(
  53. self, fullname: str
  54. ) -> Optional[Callable[[DynamicClassDefContext], None]]:
  55. if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
  56. return _dynamic_class_hook
  57. return None
  58. def get_customize_class_mro_hook(
  59. self, fullname: str
  60. ) -> Optional[Callable[[ClassDefContext], None]]:
  61. return _fill_in_decorators
  62. def get_class_decorator_hook(
  63. self, fullname: str
  64. ) -> Optional[Callable[[ClassDefContext], None]]:
  65. sym = self.lookup_fully_qualified(fullname)
  66. if sym is not None and sym.node is not None:
  67. type_id = names.type_id_for_named_node(sym.node)
  68. if type_id is names.MAPPED_DECORATOR:
  69. return _cls_decorator_hook
  70. elif type_id in (
  71. names.AS_DECLARATIVE,
  72. names.AS_DECLARATIVE_BASE,
  73. ):
  74. return _base_cls_decorator_hook
  75. elif type_id is names.DECLARATIVE_MIXIN:
  76. return _declarative_mixin_hook
  77. return None
  78. def get_metaclass_hook(
  79. self, fullname: str
  80. ) -> Optional[Callable[[ClassDefContext], None]]:
  81. if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
  82. # Set any classes that explicitly have metaclass=DeclarativeMeta
  83. # as declarative so the check in `get_base_class_hook()` works
  84. return _metaclass_cls_hook
  85. return None
  86. def get_base_class_hook(
  87. self, fullname: str
  88. ) -> Optional[Callable[[ClassDefContext], None]]:
  89. sym = self.lookup_fully_qualified(fullname)
  90. if (
  91. sym
  92. and isinstance(sym.node, TypeInfo)
  93. and util.has_declarative_base(sym.node)
  94. ):
  95. return _base_cls_hook
  96. return None
  97. def get_attribute_hook(
  98. self, fullname: str
  99. ) -> Optional[Callable[[AttributeContext], Type]]:
  100. if fullname.startswith(
  101. "sqlalchemy.orm.attributes.QueryableAttribute."
  102. ):
  103. return _queryable_getattr_hook
  104. return None
  105. def get_additional_deps(
  106. self, file: MypyFile
  107. ) -> List[Tuple[int, str, int]]:
  108. return [
  109. #
  110. (10, "sqlalchemy.orm", -1),
  111. (10, "sqlalchemy.orm.attributes", -1),
  112. (10, "sqlalchemy.orm.decl_api", -1),
  113. ]
  114. def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
  115. return SQLAlchemyPlugin
  116. def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
  117. """Generate a declarative Base class when the declarative_base() function
  118. is encountered."""
  119. _add_globals(ctx)
  120. cls = ClassDef(ctx.name, Block([]))
  121. cls.fullname = ctx.api.qualified_name(ctx.name)
  122. info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
  123. cls.info = info
  124. _set_declarative_metaclass(ctx.api, cls)
  125. cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
  126. if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
  127. util.set_is_base(cls_arg.node)
  128. decl_class.scan_declarative_assignments_and_apply_types(
  129. cls_arg.node.defn, ctx.api, is_mixin_scan=True
  130. )
  131. info.bases = [Instance(cls_arg.node, [])]
  132. else:
  133. obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
  134. info.bases = [obj]
  135. try:
  136. calculate_mro(info)
  137. except MroError:
  138. util.fail(
  139. ctx.api, "Not able to calculate MRO for declarative base", ctx.call
  140. )
  141. obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
  142. info.bases = [obj]
  143. info.fallback_to_any = True
  144. ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
  145. util.set_is_base(info)
  146. def _fill_in_decorators(ctx: ClassDefContext) -> None:
  147. for decorator in ctx.cls.decorators:
  148. # set the ".fullname" attribute of a class decorator
  149. # that is a MemberExpr. This causes the logic in
  150. # semanal.py->apply_class_plugin_hooks to invoke the
  151. # get_class_decorator_hook for our "registry.map_class()"
  152. # and "registry.as_declarative_base()" methods.
  153. # this seems like a bug in mypy that these decorators are otherwise
  154. # skipped.
  155. if (
  156. isinstance(decorator, nodes.CallExpr)
  157. and isinstance(decorator.callee, nodes.MemberExpr)
  158. and decorator.callee.name == "as_declarative_base"
  159. ):
  160. target = decorator.callee
  161. elif (
  162. isinstance(decorator, nodes.MemberExpr)
  163. and decorator.name == "mapped"
  164. ):
  165. target = decorator
  166. else:
  167. continue
  168. if isinstance(target.expr, NameExpr):
  169. sym = ctx.api.lookup_qualified(
  170. target.expr.name, target, suppress_errors=True
  171. )
  172. else:
  173. continue
  174. if sym and sym.node:
  175. sym_type = get_proper_type(sym.type)
  176. if isinstance(sym_type, Instance):
  177. target.fullname = f"{sym_type.type.fullname}.{target.name}"
  178. else:
  179. # if the registry is in the same file as where the
  180. # decorator is used, it might not have semantic
  181. # symbols applied and we can't get a fully qualified
  182. # name or an inferred type, so we are actually going to
  183. # flag an error in this case that they need to annotate
  184. # it. The "registry" is declared just
  185. # once (or few times), so they have to just not use
  186. # type inference for its assignment in this one case.
  187. util.fail(
  188. ctx.api,
  189. "Class decorator called %s(), but we can't "
  190. "tell if it's from an ORM registry. Please "
  191. "annotate the registry assignment, e.g. "
  192. "my_registry: registry = registry()" % target.name,
  193. sym.node,
  194. )
  195. def _cls_decorator_hook(ctx: ClassDefContext) -> None:
  196. _add_globals(ctx)
  197. assert isinstance(ctx.reason, nodes.MemberExpr)
  198. expr = ctx.reason.expr
  199. assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
  200. node_type = get_proper_type(expr.node.type)
  201. assert (
  202. isinstance(node_type, Instance)
  203. and names.type_id_for_named_node(node_type.type) is names.REGISTRY
  204. )
  205. decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
  206. def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
  207. _add_globals(ctx)
  208. cls = ctx.cls
  209. _set_declarative_metaclass(ctx.api, cls)
  210. util.set_is_base(ctx.cls.info)
  211. decl_class.scan_declarative_assignments_and_apply_types(
  212. cls, ctx.api, is_mixin_scan=True
  213. )
  214. def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
  215. _add_globals(ctx)
  216. util.set_is_base(ctx.cls.info)
  217. decl_class.scan_declarative_assignments_and_apply_types(
  218. ctx.cls, ctx.api, is_mixin_scan=True
  219. )
  220. def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
  221. util.set_is_base(ctx.cls.info)
  222. def _base_cls_hook(ctx: ClassDefContext) -> None:
  223. _add_globals(ctx)
  224. decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
  225. def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
  226. # how do I....tell it it has no attribute of a certain name?
  227. # can't find any Type that seems to match that
  228. return ctx.default_attr_type
  229. def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
  230. """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
  231. for all class defs
  232. """
  233. util.add_global(ctx, "sqlalchemy.orm", "Mapped", "__sa_Mapped")
  234. def _set_declarative_metaclass(
  235. api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
  236. ) -> None:
  237. info = target_cls.info
  238. sym = api.lookup_fully_qualified_or_none(
  239. "sqlalchemy.orm.decl_api.DeclarativeMeta"
  240. )
  241. assert sym is not None and isinstance(sym.node, TypeInfo)
  242. info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])