_autogen_fixtures.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. from __future__ import annotations
  2. from typing import Any
  3. from typing import Dict
  4. from typing import Set
  5. from sqlalchemy import CHAR
  6. from sqlalchemy import CheckConstraint
  7. from sqlalchemy import Column
  8. from sqlalchemy import event
  9. from sqlalchemy import ForeignKey
  10. from sqlalchemy import Index
  11. from sqlalchemy import inspect
  12. from sqlalchemy import Integer
  13. from sqlalchemy import MetaData
  14. from sqlalchemy import Numeric
  15. from sqlalchemy import PrimaryKeyConstraint
  16. from sqlalchemy import String
  17. from sqlalchemy import Table
  18. from sqlalchemy import Text
  19. from sqlalchemy import text
  20. from sqlalchemy import UniqueConstraint
  21. from ... import autogenerate
  22. from ... import util
  23. from ...autogenerate import api
  24. from ...ddl.base import _fk_spec
  25. from ...migration import MigrationContext
  26. from ...operations import ops
  27. from ...testing import config
  28. from ...testing import eq_
  29. from ...testing.env import clear_staging_env
  30. from ...testing.env import staging_env
  31. names_in_this_test: Set[Any] = set()
  32. @event.listens_for(Table, "after_parent_attach")
  33. def new_table(table, parent):
  34. names_in_this_test.add(table.name)
  35. def _default_include_object(obj, name, type_, reflected, compare_to):
  36. if type_ == "table":
  37. return name in names_in_this_test
  38. else:
  39. return True
  40. _default_object_filters: Any = _default_include_object
  41. _default_name_filters: Any = None
  42. class ModelOne:
  43. __requires__ = ("unique_constraint_reflection",)
  44. schema: Any = None
  45. @classmethod
  46. def _get_db_schema(cls):
  47. schema = cls.schema
  48. m = MetaData(schema=schema)
  49. Table(
  50. "user",
  51. m,
  52. Column("id", Integer, primary_key=True),
  53. Column("name", String(50)),
  54. Column("a1", Text),
  55. Column("pw", String(50)),
  56. Index("pw_idx", "pw"),
  57. )
  58. Table(
  59. "address",
  60. m,
  61. Column("id", Integer, primary_key=True),
  62. Column("email_address", String(100), nullable=False),
  63. )
  64. Table(
  65. "order",
  66. m,
  67. Column("order_id", Integer, primary_key=True),
  68. Column(
  69. "amount",
  70. Numeric(8, 2),
  71. nullable=False,
  72. server_default=text("0"),
  73. ),
  74. CheckConstraint("amount >= 0", name="ck_order_amount"),
  75. )
  76. Table(
  77. "extra",
  78. m,
  79. Column("x", CHAR),
  80. Column("uid", Integer, ForeignKey("user.id")),
  81. )
  82. return m
  83. @classmethod
  84. def _get_model_schema(cls):
  85. schema = cls.schema
  86. m = MetaData(schema=schema)
  87. Table(
  88. "user",
  89. m,
  90. Column("id", Integer, primary_key=True),
  91. Column("name", String(50), nullable=False),
  92. Column("a1", Text, server_default="x"),
  93. )
  94. Table(
  95. "address",
  96. m,
  97. Column("id", Integer, primary_key=True),
  98. Column("email_address", String(100), nullable=False),
  99. Column("street", String(50)),
  100. UniqueConstraint("email_address", name="uq_email"),
  101. )
  102. Table(
  103. "order",
  104. m,
  105. Column("order_id", Integer, primary_key=True),
  106. Column(
  107. "amount",
  108. Numeric(10, 2),
  109. nullable=True,
  110. server_default=text("0"),
  111. ),
  112. Column("user_id", Integer, ForeignKey("user.id")),
  113. CheckConstraint("amount > -1", name="ck_order_amount"),
  114. )
  115. Table(
  116. "item",
  117. m,
  118. Column("id", Integer, primary_key=True),
  119. Column("description", String(100)),
  120. Column("order_id", Integer, ForeignKey("order.order_id")),
  121. CheckConstraint("len(description) > 5"),
  122. )
  123. return m
  124. class NamingConvModel:
  125. __requires__ = ("unique_constraint_reflection",)
  126. configure_opts = {"conv_all_constraint_names": True}
  127. naming_convention = {
  128. "ix": "ix_%(column_0_label)s",
  129. "uq": "uq_%(table_name)s_%(constraint_name)s",
  130. "ck": "ck_%(table_name)s_%(constraint_name)s",
  131. "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
  132. "pk": "pk_%(table_name)s",
  133. }
  134. @classmethod
  135. def _get_db_schema(cls):
  136. # database side - assume all constraints have a name that
  137. # we would assume here is a "db generated" name. need to make
  138. # sure these all render with op.f().
  139. m = MetaData()
  140. Table(
  141. "x1",
  142. m,
  143. Column("q", Integer),
  144. Index("db_x1_index_q", "q"),
  145. PrimaryKeyConstraint("q", name="db_x1_primary_q"),
  146. )
  147. Table(
  148. "x2",
  149. m,
  150. Column("q", Integer),
  151. Column("p", ForeignKey("x1.q", name="db_x2_foreign_q")),
  152. CheckConstraint("q > 5", name="db_x2_check_q"),
  153. )
  154. Table(
  155. "x3",
  156. m,
  157. Column("q", Integer),
  158. Column("r", Integer),
  159. Column("s", Integer),
  160. UniqueConstraint("q", name="db_x3_unique_q"),
  161. )
  162. Table(
  163. "x4",
  164. m,
  165. Column("q", Integer),
  166. PrimaryKeyConstraint("q", name="db_x4_primary_q"),
  167. )
  168. Table(
  169. "x5",
  170. m,
  171. Column("q", Integer),
  172. Column("p", ForeignKey("x4.q", name="db_x5_foreign_q")),
  173. Column("r", Integer),
  174. Column("s", Integer),
  175. PrimaryKeyConstraint("q", name="db_x5_primary_q"),
  176. UniqueConstraint("r", name="db_x5_unique_r"),
  177. CheckConstraint("s > 5", name="db_x5_check_s"),
  178. )
  179. # SQLite and it's "no names needed" thing. bleh.
  180. # we can't have a name for these so you'll see "None" for the name.
  181. Table(
  182. "unnamed_sqlite",
  183. m,
  184. Column("q", Integer),
  185. Column("r", Integer),
  186. PrimaryKeyConstraint("q"),
  187. UniqueConstraint("r"),
  188. )
  189. return m
  190. @classmethod
  191. def _get_model_schema(cls):
  192. from sqlalchemy.sql.naming import conv
  193. m = MetaData(naming_convention=cls.naming_convention)
  194. Table(
  195. "x1", m, Column("q", Integer, primary_key=True), Index(None, "q")
  196. )
  197. Table(
  198. "x2",
  199. m,
  200. Column("q", Integer),
  201. Column("p", ForeignKey("x1.q")),
  202. CheckConstraint("q > 5", name="token_x2check1"),
  203. )
  204. Table(
  205. "x3",
  206. m,
  207. Column("q", Integer),
  208. Column("r", Integer),
  209. Column("s", Integer),
  210. UniqueConstraint("r", name="token_x3r"),
  211. UniqueConstraint("s", name=conv("userdef_x3_unique_s")),
  212. )
  213. Table(
  214. "x4",
  215. m,
  216. Column("q", Integer, primary_key=True),
  217. Index("userdef_x4_idx_q", "q"),
  218. )
  219. Table(
  220. "x6",
  221. m,
  222. Column("q", Integer, primary_key=True),
  223. Column("p", ForeignKey("x4.q")),
  224. Column("r", Integer),
  225. Column("s", Integer),
  226. UniqueConstraint("r", name="token_x6r"),
  227. CheckConstraint("s > 5", "token_x6check1"),
  228. CheckConstraint("s < 20", conv("userdef_x6_check_s")),
  229. )
  230. return m
  231. class _ComparesFKs:
  232. def _assert_fk_diff(
  233. self,
  234. diff,
  235. type_,
  236. source_table,
  237. source_columns,
  238. target_table,
  239. target_columns,
  240. name=None,
  241. conditional_name=None,
  242. source_schema=None,
  243. onupdate=None,
  244. ondelete=None,
  245. initially=None,
  246. deferrable=None,
  247. ):
  248. # the public API for ForeignKeyConstraint was not very rich
  249. # in 0.7, 0.8, so here we use the well-known but slightly
  250. # private API to get at its elements
  251. (
  252. fk_source_schema,
  253. fk_source_table,
  254. fk_source_columns,
  255. fk_target_schema,
  256. fk_target_table,
  257. fk_target_columns,
  258. fk_onupdate,
  259. fk_ondelete,
  260. fk_deferrable,
  261. fk_initially,
  262. ) = _fk_spec(diff[1])
  263. eq_(diff[0], type_)
  264. eq_(fk_source_table, source_table)
  265. eq_(fk_source_columns, source_columns)
  266. eq_(fk_target_table, target_table)
  267. eq_(fk_source_schema, source_schema)
  268. eq_(fk_onupdate, onupdate)
  269. eq_(fk_ondelete, ondelete)
  270. eq_(fk_initially, initially)
  271. eq_(fk_deferrable, deferrable)
  272. eq_([elem.column.name for elem in diff[1].elements], target_columns)
  273. if conditional_name is not None:
  274. if conditional_name == "servergenerated":
  275. fks = inspect(self.bind).get_foreign_keys(source_table)
  276. server_fk_name = fks[0]["name"]
  277. eq_(diff[1].name, server_fk_name)
  278. else:
  279. eq_(diff[1].name, conditional_name)
  280. else:
  281. eq_(diff[1].name, name)
  282. class AutogenTest(_ComparesFKs):
  283. def _flatten_diffs(self, diffs):
  284. for d in diffs:
  285. if isinstance(d, list):
  286. yield from self._flatten_diffs(d)
  287. else:
  288. yield d
  289. @classmethod
  290. def _get_bind(cls):
  291. return config.db
  292. configure_opts: Dict[Any, Any] = {}
  293. @classmethod
  294. def setup_class(cls):
  295. staging_env()
  296. cls.bind = cls._get_bind()
  297. cls.m1 = cls._get_db_schema()
  298. cls.m1.create_all(cls.bind)
  299. cls.m2 = cls._get_model_schema()
  300. @classmethod
  301. def teardown_class(cls):
  302. cls.m1.drop_all(cls.bind)
  303. clear_staging_env()
  304. def setUp(self):
  305. self.conn = conn = self.bind.connect()
  306. ctx_opts = {
  307. "compare_type": True,
  308. "compare_server_default": True,
  309. "target_metadata": self.m2,
  310. "upgrade_token": "upgrades",
  311. "downgrade_token": "downgrades",
  312. "alembic_module_prefix": "op.",
  313. "sqlalchemy_module_prefix": "sa.",
  314. "include_object": _default_object_filters,
  315. "include_name": _default_name_filters,
  316. }
  317. if self.configure_opts:
  318. ctx_opts.update(self.configure_opts)
  319. self.context = context = MigrationContext.configure(
  320. connection=conn, opts=ctx_opts
  321. )
  322. self.autogen_context = api.AutogenContext(context, self.m2)
  323. def tearDown(self):
  324. self.conn.close()
  325. def _update_context(
  326. self, object_filters=None, name_filters=None, include_schemas=None
  327. ):
  328. if include_schemas is not None:
  329. self.autogen_context.opts["include_schemas"] = include_schemas
  330. if object_filters is not None:
  331. self.autogen_context._object_filters = [object_filters]
  332. if name_filters is not None:
  333. self.autogen_context._name_filters = [name_filters]
  334. return self.autogen_context
  335. class AutogenFixtureTest(_ComparesFKs):
  336. def _fixture(
  337. self,
  338. m1,
  339. m2,
  340. include_schemas=False,
  341. opts=None,
  342. object_filters=_default_object_filters,
  343. name_filters=_default_name_filters,
  344. return_ops=False,
  345. max_identifier_length=None,
  346. ):
  347. if max_identifier_length:
  348. dialect = self.bind.dialect
  349. existing_length = dialect.max_identifier_length
  350. dialect.max_identifier_length = (
  351. dialect._user_defined_max_identifier_length
  352. ) = max_identifier_length
  353. try:
  354. self._alembic_metadata, model_metadata = m1, m2
  355. for m in util.to_list(self._alembic_metadata):
  356. m.create_all(self.bind)
  357. with self.bind.connect() as conn:
  358. ctx_opts = {
  359. "compare_type": True,
  360. "compare_server_default": True,
  361. "target_metadata": model_metadata,
  362. "upgrade_token": "upgrades",
  363. "downgrade_token": "downgrades",
  364. "alembic_module_prefix": "op.",
  365. "sqlalchemy_module_prefix": "sa.",
  366. "include_object": object_filters,
  367. "include_name": name_filters,
  368. "include_schemas": include_schemas,
  369. }
  370. if opts:
  371. ctx_opts.update(opts)
  372. self.context = context = MigrationContext.configure(
  373. connection=conn, opts=ctx_opts
  374. )
  375. autogen_context = api.AutogenContext(context, model_metadata)
  376. uo = ops.UpgradeOps(ops=[])
  377. autogenerate._produce_net_changes(autogen_context, uo)
  378. if return_ops:
  379. return uo
  380. else:
  381. return uo.as_diffs()
  382. finally:
  383. if max_identifier_length:
  384. dialect = self.bind.dialect
  385. dialect.max_identifier_length = (
  386. dialect._user_defined_max_identifier_length
  387. ) = existing_length
  388. def setUp(self):
  389. staging_env()
  390. self.bind = config.db
  391. def tearDown(self):
  392. if hasattr(self, "_alembic_metadata"):
  393. for m in util.to_list(self._alembic_metadata):
  394. m.drop_all(self.bind)
  395. clear_staging_env()