_autogen.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
  2. # mypy: no-warn-return-any, allow-any-generics
  3. from __future__ import annotations
  4. from typing import Any
  5. from typing import ClassVar
  6. from typing import Dict
  7. from typing import Generic
  8. from typing import NamedTuple
  9. from typing import Optional
  10. from typing import Sequence
  11. from typing import Tuple
  12. from typing import Type
  13. from typing import TYPE_CHECKING
  14. from typing import TypeVar
  15. from typing import Union
  16. from sqlalchemy.sql.schema import Constraint
  17. from sqlalchemy.sql.schema import ForeignKeyConstraint
  18. from sqlalchemy.sql.schema import Index
  19. from sqlalchemy.sql.schema import UniqueConstraint
  20. from typing_extensions import TypeGuard
  21. from .. import util
  22. from ..util import sqla_compat
  23. if TYPE_CHECKING:
  24. from typing import Literal
  25. from alembic.autogenerate.api import AutogenContext
  26. from alembic.ddl.impl import DefaultImpl
  27. CompareConstraintType = Union[Constraint, Index]
  28. _C = TypeVar("_C", bound=CompareConstraintType)
  29. _clsreg: Dict[str, Type[_constraint_sig]] = {}
  30. class ComparisonResult(NamedTuple):
  31. status: Literal["equal", "different", "skip"]
  32. message: str
  33. @property
  34. def is_equal(self) -> bool:
  35. return self.status == "equal"
  36. @property
  37. def is_different(self) -> bool:
  38. return self.status == "different"
  39. @property
  40. def is_skip(self) -> bool:
  41. return self.status == "skip"
  42. @classmethod
  43. def Equal(cls) -> ComparisonResult:
  44. """the constraints are equal."""
  45. return cls("equal", "The two constraints are equal")
  46. @classmethod
  47. def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
  48. """the constraints are different for the provided reason(s)."""
  49. return cls("different", ", ".join(util.to_list(reason)))
  50. @classmethod
  51. def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
  52. """the constraint cannot be compared for the provided reason(s).
  53. The message is logged, but the constraints will be otherwise
  54. considered equal, meaning that no migration command will be
  55. generated.
  56. """
  57. return cls("skip", ", ".join(util.to_list(reason)))
  58. class _constraint_sig(Generic[_C]):
  59. const: _C
  60. _sig: Tuple[Any, ...]
  61. name: Optional[sqla_compat._ConstraintNameDefined]
  62. impl: DefaultImpl
  63. _is_index: ClassVar[bool] = False
  64. _is_fk: ClassVar[bool] = False
  65. _is_uq: ClassVar[bool] = False
  66. _is_metadata: bool
  67. def __init_subclass__(cls) -> None:
  68. cls._register()
  69. @classmethod
  70. def _register(cls):
  71. raise NotImplementedError()
  72. def __init__(
  73. self, is_metadata: bool, impl: DefaultImpl, const: _C
  74. ) -> None:
  75. raise NotImplementedError()
  76. def compare_to_reflected(
  77. self, other: _constraint_sig[Any]
  78. ) -> ComparisonResult:
  79. assert self.impl is other.impl
  80. assert self._is_metadata
  81. assert not other._is_metadata
  82. return self._compare_to_reflected(other)
  83. def _compare_to_reflected(
  84. self, other: _constraint_sig[_C]
  85. ) -> ComparisonResult:
  86. raise NotImplementedError()
  87. @classmethod
  88. def from_constraint(
  89. cls, is_metadata: bool, impl: DefaultImpl, constraint: _C
  90. ) -> _constraint_sig[_C]:
  91. # these could be cached by constraint/impl, however, if the
  92. # constraint is modified in place, then the sig is wrong. the mysql
  93. # impl currently does this, and if we fixed that we can't be sure
  94. # someone else might do it too, so play it safe.
  95. sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint)
  96. return sig
  97. def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
  98. return sqla_compat._get_constraint_final_name(
  99. self.const, context.dialect
  100. )
  101. @util.memoized_property
  102. def is_named(self):
  103. return sqla_compat._constraint_is_named(self.const, self.impl.dialect)
  104. @util.memoized_property
  105. def unnamed(self) -> Tuple[Any, ...]:
  106. return self._sig
  107. @util.memoized_property
  108. def unnamed_no_options(self) -> Tuple[Any, ...]:
  109. raise NotImplementedError()
  110. @util.memoized_property
  111. def _full_sig(self) -> Tuple[Any, ...]:
  112. return (self.name,) + self.unnamed
  113. def __eq__(self, other) -> bool:
  114. return self._full_sig == other._full_sig
  115. def __ne__(self, other) -> bool:
  116. return self._full_sig != other._full_sig
  117. def __hash__(self) -> int:
  118. return hash(self._full_sig)
  119. class _uq_constraint_sig(_constraint_sig[UniqueConstraint]):
  120. _is_uq = True
  121. @classmethod
  122. def _register(cls) -> None:
  123. _clsreg["unique_constraint"] = cls
  124. is_unique = True
  125. def __init__(
  126. self,
  127. is_metadata: bool,
  128. impl: DefaultImpl,
  129. const: UniqueConstraint,
  130. ) -> None:
  131. self.impl = impl
  132. self.const = const
  133. self.name = sqla_compat.constraint_name_or_none(const.name)
  134. self._sig = tuple(sorted([col.name for col in const.columns]))
  135. self._is_metadata = is_metadata
  136. @property
  137. def column_names(self) -> Tuple[str, ...]:
  138. return tuple([col.name for col in self.const.columns])
  139. def _compare_to_reflected(
  140. self, other: _constraint_sig[_C]
  141. ) -> ComparisonResult:
  142. assert self._is_metadata
  143. metadata_obj = self
  144. conn_obj = other
  145. assert is_uq_sig(conn_obj)
  146. return self.impl.compare_unique_constraint(
  147. metadata_obj.const, conn_obj.const
  148. )
  149. class _ix_constraint_sig(_constraint_sig[Index]):
  150. _is_index = True
  151. name: sqla_compat._ConstraintName
  152. @classmethod
  153. def _register(cls) -> None:
  154. _clsreg["index"] = cls
  155. def __init__(
  156. self, is_metadata: bool, impl: DefaultImpl, const: Index
  157. ) -> None:
  158. self.impl = impl
  159. self.const = const
  160. self.name = const.name
  161. self.is_unique = bool(const.unique)
  162. self._is_metadata = is_metadata
  163. def _compare_to_reflected(
  164. self, other: _constraint_sig[_C]
  165. ) -> ComparisonResult:
  166. assert self._is_metadata
  167. metadata_obj = self
  168. conn_obj = other
  169. assert is_index_sig(conn_obj)
  170. return self.impl.compare_indexes(metadata_obj.const, conn_obj.const)
  171. @util.memoized_property
  172. def has_expressions(self):
  173. return sqla_compat.is_expression_index(self.const)
  174. @util.memoized_property
  175. def column_names(self) -> Tuple[str, ...]:
  176. return tuple([col.name for col in self.const.columns])
  177. @util.memoized_property
  178. def column_names_optional(self) -> Tuple[Optional[str], ...]:
  179. return tuple(
  180. [getattr(col, "name", None) for col in self.const.expressions]
  181. )
  182. @util.memoized_property
  183. def is_named(self):
  184. return True
  185. @util.memoized_property
  186. def unnamed(self):
  187. return (self.is_unique,) + self.column_names_optional
  188. class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
  189. _is_fk = True
  190. @classmethod
  191. def _register(cls) -> None:
  192. _clsreg["foreign_key_constraint"] = cls
  193. def __init__(
  194. self,
  195. is_metadata: bool,
  196. impl: DefaultImpl,
  197. const: ForeignKeyConstraint,
  198. ) -> None:
  199. self._is_metadata = is_metadata
  200. self.impl = impl
  201. self.const = const
  202. self.name = sqla_compat.constraint_name_or_none(const.name)
  203. (
  204. self.source_schema,
  205. self.source_table,
  206. self.source_columns,
  207. self.target_schema,
  208. self.target_table,
  209. self.target_columns,
  210. onupdate,
  211. ondelete,
  212. deferrable,
  213. initially,
  214. ) = sqla_compat._fk_spec(const)
  215. self._sig: Tuple[Any, ...] = (
  216. self.source_schema,
  217. self.source_table,
  218. tuple(self.source_columns),
  219. self.target_schema,
  220. self.target_table,
  221. tuple(self.target_columns),
  222. ) + (
  223. (
  224. (None if onupdate.lower() == "no action" else onupdate.lower())
  225. if onupdate
  226. else None
  227. ),
  228. (
  229. (None if ondelete.lower() == "no action" else ondelete.lower())
  230. if ondelete
  231. else None
  232. ),
  233. # convert initially + deferrable into one three-state value
  234. (
  235. "initially_deferrable"
  236. if initially and initially.lower() == "deferred"
  237. else "deferrable" if deferrable else "not deferrable"
  238. ),
  239. )
  240. @util.memoized_property
  241. def unnamed_no_options(self):
  242. return (
  243. self.source_schema,
  244. self.source_table,
  245. tuple(self.source_columns),
  246. self.target_schema,
  247. self.target_table,
  248. tuple(self.target_columns),
  249. )
  250. def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]:
  251. return sig._is_index
  252. def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
  253. return sig._is_uq
  254. def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
  255. return sig._is_fk