||
- from __future__ import annotations
- from typing import Any
- from typing import Dict
- from typing import Set
- from sqlalchemy import CHAR
- from sqlalchemy import CheckConstraint
- from sqlalchemy import Column
- from sqlalchemy import event
- from sqlalchemy import ForeignKey
- from sqlalchemy import Index
- from sqlalchemy import inspect
- from sqlalchemy import Integer
- from sqlalchemy import MetaData
- from sqlalchemy import Numeric
- from sqlalchemy import PrimaryKeyConstraint
- from sqlalchemy import String
- from sqlalchemy import Table
- from sqlalchemy import Text
- from sqlalchemy import text
- from sqlalchemy import UniqueConstraint
- from ... import autogenerate
- from ... import util
- from ...autogenerate import api
- from ...ddl.base import _fk_spec
- from ...migration import MigrationContext
- from ...operations import ops
- from ...testing import config
- from ...testing import eq_
- from ...testing.env import clear_staging_env
- from ...testing.env import staging_env
- names_in_this_test: Set[Any] = set()
- @event.listens_for(Table, "after_parent_attach")
- def new_table(table, parent):
- names_in_this_test.add(table.name)
- def _default_include_object(obj, name, type_, reflected, compare_to):
- if type_ == "table":
- return name in names_in_this_test
- else:
- return True
- _default_object_filters: Any = _default_include_object
- _default_name_filters: Any = None
- class ModelOne:
- __requires__ = ("unique_constraint_reflection",)
- schema: Any = None
- @classmethod
- def _get_db_schema(cls):
- schema = cls.schema
- m = MetaData(schema=schema)
- Table(
- "user",
- m,
- Column("id", Integer, primary_key=True),
- Column("name", String(50)),
- Column("a1", Text),
- Column("pw", String(50)),
- Index("pw_idx", "pw"),
- )
- Table(
- "address",
- m,
- Column("id", Integer, primary_key=True),
- Column("email_address", String(100), nullable=False),
- )
- Table(
- "order",
- m,
- Column("order_id", Integer, primary_key=True),
- Column(
- "amount",
- Numeric(8, 2),
- nullable=False,
- server_default=text("0"),
- ),
- CheckConstraint("amount >= 0", name="ck_order_amount"),
- )
- Table(
- "extra",
- m,
- Column("x", CHAR),
- Column("uid", Integer, ForeignKey("user.id")),
- )
- return m
- @classmethod
- def _get_model_schema(cls):
- schema = cls.schema
- m = MetaData(schema=schema)
- Table(
- "user",
- m,
- Column("id", Integer, primary_key=True),
- Column("name", String(50), nullable=False),
- Column("a1", Text, server_default="x"),
- )
- Table(
- "address",
- m,
- Column("id", Integer, primary_key=True),
- Column("email_address", String(100), nullable=False),
- Column("street", String(50)),
- UniqueConstraint("email_address", name="uq_email"),
- )
- Table(
- "order",
- m,
- Column("order_id", Integer, primary_key=True),
- Column(
- "amount",
- Numeric(10, 2),
- nullable=True,
- server_default=text("0"),
- ),
- Column("user_id", Integer, ForeignKey("user.id")),
- CheckConstraint("amount > -1", name="ck_order_amount"),
- )
- Table(
- "item",
- m,
- Column("id", Integer, primary_key=True),
- Column("description", String(100)),
- Column("order_id", Integer, ForeignKey("order.order_id")),
- CheckConstraint("len(description) > 5"),
- )
- return m
- class NamingConvModel:
- __requires__ = ("unique_constraint_reflection",)
- configure_opts = {"conv_all_constraint_names": True}
- naming_convention = {
- "ix": "ix_%(column_0_label)s",
- "uq": "uq_%(table_name)s_%(constraint_name)s",
- "ck": "ck_%(table_name)s_%(constraint_name)s",
- "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
- "pk": "pk_%(table_name)s",
- }
- @classmethod
- def _get_db_schema(cls):
- # database side - assume all constraints have a name that
- # we would assume here is a "db generated" name. need to make
- # sure these all render with op.f().
- m = MetaData()
- Table(
- "x1",
- m,
- Column("q", Integer),
- Index("db_x1_index_q", "q"),
- PrimaryKeyConstraint("q", name="db_x1_primary_q"),
- )
- Table(
- "x2",
- m,
- Column("q", Integer),
- Column("p", ForeignKey("x1.q", name="db_x2_foreign_q")),
- CheckConstraint("q > 5", name="db_x2_check_q"),
- )
- Table(
- "x3",
- m,
- Column("q", Integer),
- Column("r", Integer),
- Column("s", Integer),
- UniqueConstraint("q", name="db_x3_unique_q"),
- )
- Table(
- "x4",
- m,
- Column("q", Integer),
- PrimaryKeyConstraint("q", name="db_x4_primary_q"),
- )
- Table(
- "x5",
- m,
- Column("q", Integer),
- Column("p", ForeignKey("x4.q", name="db_x5_foreign_q")),
- Column("r", Integer),
- Column("s", Integer),
- PrimaryKeyConstraint("q", name="db_x5_primary_q"),
- UniqueConstraint("r", name="db_x5_unique_r"),
- CheckConstraint("s > 5", name="db_x5_check_s"),
- )
- # SQLite and it's "no names needed" thing. bleh.
- # we can't have a name for these so you'll see "None" for the name.
- Table(
- "unnamed_sqlite",
- m,
- Column("q", Integer),
- Column("r", Integer),
- PrimaryKeyConstraint("q"),
- UniqueConstraint("r"),
- )
- return m
- @classmethod
- def _get_model_schema(cls):
- from sqlalchemy.sql.naming import conv
- m = MetaData(naming_convention=cls.naming_convention)
- Table(
- "x1", m, Column("q", Integer, primary_key=True), Index(None, "q")
- )
- Table(
- "x2",
- m,
- Column("q", Integer),
- Column("p", ForeignKey("x1.q")),
- CheckConstraint("q > 5", name="token_x2check1"),
- )
- Table(
- "x3",
- m,
- Column("q", Integer),
- Column("r", Integer),
- Column("s", Integer),
- UniqueConstraint("r", name="token_x3r"),
- UniqueConstraint("s", name=conv("userdef_x3_unique_s")),
- )
- Table(
- "x4",
- m,
- Column("q", Integer, primary_key=True),
- Index("userdef_x4_idx_q", "q"),
- )
- Table(
- "x6",
- m,
- Column("q", Integer, primary_key=True),
- Column("p", ForeignKey("x4.q")),
- Column("r", Integer),
- Column("s", Integer),
- UniqueConstraint("r", name="token_x6r"),
- CheckConstraint("s > 5", "token_x6check1"),
- CheckConstraint("s < 20", conv("userdef_x6_check_s")),
- )
- return m
- class _ComparesFKs:
- def _assert_fk_diff(
- self,
- diff,
- type_,
- source_table,
- source_columns,
- target_table,
- target_columns,
- name=None,
- conditional_name=None,
- source_schema=None,
- onupdate=None,
- ondelete=None,
- initially=None,
- deferrable=None,
- ):
- # the public API for ForeignKeyConstraint was not very rich
- # in 0.7, 0.8, so here we use the well-known but slightly
- # private API to get at its elements
- (
- fk_source_schema,
- fk_source_table,
- fk_source_columns,
- fk_target_schema,
- fk_target_table,
- fk_target_columns,
- fk_onupdate,
- fk_ondelete,
- fk_deferrable,
- fk_initially,
- ) = _fk_spec(diff[1])
- eq_(diff[0], type_)
- eq_(fk_source_table, source_table)
- eq_(fk_source_columns, source_columns)
- eq_(fk_target_table, target_table)
- eq_(fk_source_schema, source_schema)
- eq_(fk_onupdate, onupdate)
- eq_(fk_ondelete, ondelete)
- eq_(fk_initially, initially)
- eq_(fk_deferrable, deferrable)
- eq_([elem.column.name for elem in diff[1].elements], target_columns)
- if conditional_name is not None:
- if conditional_name == "servergenerated":
- fks = inspect(self.bind).get_foreign_keys(source_table)
- server_fk_name = fks[0]["name"]
- eq_(diff[1].name, server_fk_name)
- else:
- eq_(diff[1].name, conditional_name)
- else:
- eq_(diff[1].name, name)
- class AutogenTest(_ComparesFKs):
- def _flatten_diffs(self, diffs):
- for d in diffs:
- if isinstance(d, list):
- yield from self._flatten_diffs(d)
- else:
- yield d
- @classmethod
- def _get_bind(cls):
- return config.db
- configure_opts: Dict[Any, Any] = {}
- @classmethod
- def setup_class(cls):
- staging_env()
- cls.bind = cls._get_bind()
- cls.m1 = cls._get_db_schema()
- cls.m1.create_all(cls.bind)
- cls.m2 = cls._get_model_schema()
- @classmethod
- def teardown_class(cls):
- cls.m1.drop_all(cls.bind)
- clear_staging_env()
- def setUp(self):
- self.conn = conn = self.bind.connect()
- ctx_opts = {
- "compare_type": True,
- "compare_server_default": True,
- "target_metadata": self.m2,
- "upgrade_token": "upgrades",
- "downgrade_token": "downgrades",
- "alembic_module_prefix": "op.",
- "sqlalchemy_module_prefix": "sa.",
- "include_object": _default_object_filters,
- "include_name": _default_name_filters,
- }
- if self.configure_opts:
- ctx_opts.update(self.configure_opts)
- self.context = context = MigrationContext.configure(
- connection=conn, opts=ctx_opts
- )
- self.autogen_context = api.AutogenContext(context, self.m2)
- def tearDown(self):
- self.conn.close()
- def _update_context(
- self, object_filters=None, name_filters=None, include_schemas=None
- ):
- if include_schemas is not None:
- self.autogen_context.opts["include_schemas"] = include_schemas
- if object_filters is not None:
- self.autogen_context._object_filters = [object_filters]
- if name_filters is not None:
- self.autogen_context._name_filters = [name_filters]
- return self.autogen_context
- class AutogenFixtureTest(_ComparesFKs):
- def _fixture(
- self,
- m1,
- m2,
- include_schemas=False,
- opts=None,
- object_filters=_default_object_filters,
- name_filters=_default_name_filters,
- return_ops=False,
- max_identifier_length=None,
- ):
- if max_identifier_length:
- dialect = self.bind.dialect
- existing_length = dialect.max_identifier_length
- dialect.max_identifier_length = (
- dialect._user_defined_max_identifier_length
- ) = max_identifier_length
- try:
- self._alembic_metadata, model_metadata = m1, m2
- for m in util.to_list(self._alembic_metadata):
- m.create_all(self.bind)
- with self.bind.connect() as conn:
- ctx_opts = {
- "compare_type": True,
- "compare_server_default": True,
- "target_metadata": model_metadata,
- "upgrade_token": "upgrades",
- "downgrade_token": "downgrades",
- "alembic_module_prefix": "op.",
- "sqlalchemy_module_prefix": "sa.",
- "include_object": object_filters,
- "include_name": name_filters,
- "include_schemas": include_schemas,
- }
- if opts:
- ctx_opts.update(opts)
- self.context = context = MigrationContext.configure(
- connection=conn, opts=ctx_opts
- )
- autogen_context = api.AutogenContext(context, model_metadata)
- uo = ops.UpgradeOps(ops=[])
- autogenerate._produce_net_changes(autogen_context, uo)
- if return_ops:
- return uo
- else:
- return uo.as_diffs()
- finally:
- if max_identifier_length:
- dialect = self.bind.dialect
- dialect.max_identifier_length = (
- dialect._user_defined_max_identifier_length
- ) = existing_length
- def setUp(self):
- staging_env()
- self.bind = config.db
- def tearDown(self):
- if hasattr(self, "_alembic_metadata"):
- for m in util.to_list(self._alembic_metadata):
- m.drop_all(self.bind)
- clear_staging_env()
|