| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395 |
- # mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
- # mypy: no-warn-return-any, allow-any-generics
- from __future__ import annotations
- from contextlib import contextmanager
- from contextlib import nullcontext
- import logging
- import sys
- from typing import Any
- from typing import Callable
- from typing import cast
- from typing import Collection
- from typing import Dict
- from typing import Iterable
- from typing import Iterator
- from typing import List
- from typing import Optional
- from typing import Set
- from typing import Tuple
- from typing import TYPE_CHECKING
- from typing import Union
- from sqlalchemy import Column
- from sqlalchemy import literal_column
- from sqlalchemy import select
- from sqlalchemy.engine import Engine
- from sqlalchemy.engine import url as sqla_url
- from sqlalchemy.engine.strategies import MockEngineStrategy
- from typing_extensions import ContextManager
- from .. import ddl
- from .. import util
- from ..util import sqla_compat
- from ..util.compat import EncodedIO
- if TYPE_CHECKING:
- from sqlalchemy.engine import Dialect
- from sqlalchemy.engine import URL
- from sqlalchemy.engine.base import Connection
- from sqlalchemy.engine.base import Transaction
- from sqlalchemy.engine.mock import MockConnection
- from sqlalchemy.sql import Executable
- from .environment import EnvironmentContext
- from ..config import Config
- from ..script.base import Script
- from ..script.base import ScriptDirectory
- from ..script.revision import _RevisionOrBase
- from ..script.revision import Revision
- from ..script.revision import RevisionMap
- log = logging.getLogger(__name__)
- class _ProxyTransaction:
- def __init__(self, migration_context: MigrationContext) -> None:
- self.migration_context = migration_context
- @property
- def _proxied_transaction(self) -> Optional[Transaction]:
- return self.migration_context._transaction
- def rollback(self) -> None:
- t = self._proxied_transaction
- assert t is not None
- t.rollback()
- self.migration_context._transaction = None
- def commit(self) -> None:
- t = self._proxied_transaction
- assert t is not None
- t.commit()
- self.migration_context._transaction = None
- def __enter__(self) -> _ProxyTransaction:
- return self
- def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
- if self._proxied_transaction is not None:
- self._proxied_transaction.__exit__(type_, value, traceback)
- self.migration_context._transaction = None
- class MigrationContext:
- """Represent the database state made available to a migration
- script.
- :class:`.MigrationContext` is the front end to an actual
- database connection, or alternatively a string output
- stream given a particular database dialect,
- from an Alembic perspective.
- When inside the ``env.py`` script, the :class:`.MigrationContext`
- is available via the
- :meth:`.EnvironmentContext.get_context` method,
- which is available at ``alembic.context``::
- # from within env.py script
- from alembic import context
- migration_context = context.get_context()
- For usage outside of an ``env.py`` script, such as for
- utility routines that want to check the current version
- in the database, the :meth:`.MigrationContext.configure`
- method to create new :class:`.MigrationContext` objects.
- For example, to get at the current revision in the
- database using :meth:`.MigrationContext.get_current_revision`::
- # in any application, outside of an env.py script
- from alembic.migration import MigrationContext
- from sqlalchemy import create_engine
- engine = create_engine("postgresql://mydatabase")
- conn = engine.connect()
- context = MigrationContext.configure(conn)
- current_rev = context.get_current_revision()
- The above context can also be used to produce
- Alembic migration operations with an :class:`.Operations`
- instance::
- # in any application, outside of the normal Alembic environment
- from alembic.operations import Operations
- op = Operations(context)
- op.alter_column("mytable", "somecolumn", nullable=True)
- """
- def __init__(
- self,
- dialect: Dialect,
- connection: Optional[Connection],
- opts: Dict[str, Any],
- environment_context: Optional[EnvironmentContext] = None,
- ) -> None:
- self.environment_context = environment_context
- self.opts = opts
- self.dialect = dialect
- self.script: Optional[ScriptDirectory] = opts.get("script")
- as_sql: bool = opts.get("as_sql", False)
- transactional_ddl = opts.get("transactional_ddl")
- self._transaction_per_migration = opts.get(
- "transaction_per_migration", False
- )
- self.on_version_apply_callbacks = opts.get("on_version_apply", ())
- self._transaction: Optional[Transaction] = None
- if as_sql:
- self.connection = cast(
- Optional["Connection"], self._stdout_connection(connection)
- )
- assert self.connection is not None
- self._in_external_transaction = False
- else:
- self.connection = connection
- self._in_external_transaction = (
- sqla_compat._get_connection_in_transaction(connection)
- )
- self._migrations_fn: Optional[
- Callable[..., Iterable[RevisionStep]]
- ] = opts.get("fn")
- self.as_sql = as_sql
- self.purge = opts.get("purge", False)
- if "output_encoding" in opts:
- self.output_buffer = EncodedIO(
- opts.get("output_buffer")
- or sys.stdout, # type:ignore[arg-type]
- opts["output_encoding"],
- )
- else:
- self.output_buffer = opts.get(
- "output_buffer", sys.stdout
- ) # type:ignore[assignment] # noqa: E501
- self.transactional_ddl = transactional_ddl
- self._user_compare_type = opts.get("compare_type", True)
- self._user_compare_server_default = opts.get(
- "compare_server_default", False
- )
- self.version_table = version_table = opts.get(
- "version_table", "alembic_version"
- )
- self.version_table_schema = version_table_schema = opts.get(
- "version_table_schema", None
- )
- self._start_from_rev: Optional[str] = opts.get("starting_rev")
- self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
- dialect,
- self.connection,
- self.as_sql,
- transactional_ddl,
- self.output_buffer,
- opts,
- )
- self._version = self.impl.version_table_impl(
- version_table=version_table,
- version_table_schema=version_table_schema,
- version_table_pk=opts.get("version_table_pk", True),
- )
- log.info("Context impl %s.", self.impl.__class__.__name__)
- if self.as_sql:
- log.info("Generating static SQL")
- log.info(
- "Will assume %s DDL.",
- (
- "transactional"
- if self.impl.transactional_ddl
- else "non-transactional"
- ),
- )
- @classmethod
- def configure(
- cls,
- connection: Optional[Connection] = None,
- url: Optional[Union[str, URL]] = None,
- dialect_name: Optional[str] = None,
- dialect: Optional[Dialect] = None,
- environment_context: Optional[EnvironmentContext] = None,
- dialect_opts: Optional[Dict[str, str]] = None,
- opts: Optional[Any] = None,
- ) -> MigrationContext:
- """Create a new :class:`.MigrationContext`.
- This is a factory method usually called
- by :meth:`.EnvironmentContext.configure`.
- :param connection: a :class:`~sqlalchemy.engine.Connection`
- to use for SQL execution in "online" mode. When present,
- is also used to determine the type of dialect in use.
- :param url: a string database url, or a
- :class:`sqlalchemy.engine.url.URL` object.
- The type of dialect to be used will be derived from this if
- ``connection`` is not passed.
- :param dialect_name: string name of a dialect, such as
- "postgresql", "mssql", etc. The type of dialect to be used will be
- derived from this if ``connection`` and ``url`` are not passed.
- :param opts: dictionary of options. Most other options
- accepted by :meth:`.EnvironmentContext.configure` are passed via
- this dictionary.
- """
- if opts is None:
- opts = {}
- if dialect_opts is None:
- dialect_opts = {}
- if connection:
- if isinstance(connection, Engine):
- raise util.CommandError(
- "'connection' argument to configure() is expected "
- "to be a sqlalchemy.engine.Connection instance, "
- "got %r" % connection,
- )
- dialect = connection.dialect
- elif url:
- url_obj = sqla_url.make_url(url)
- dialect = url_obj.get_dialect()(**dialect_opts)
- elif dialect_name:
- url_obj = sqla_url.make_url("%s://" % dialect_name)
- dialect = url_obj.get_dialect()(**dialect_opts)
- elif not dialect:
- raise Exception("Connection, url, or dialect_name is required.")
- assert dialect is not None
- return MigrationContext(dialect, connection, opts, environment_context)
- @contextmanager
- def autocommit_block(self) -> Iterator[None]:
- """Enter an "autocommit" block, for databases that support AUTOCOMMIT
- isolation levels.
- This special directive is intended to support the occasional database
- DDL or system operation that specifically has to be run outside of
- any kind of transaction block. The PostgreSQL database platform
- is the most common target for this style of operation, as many
- of its DDL operations must be run outside of transaction blocks, even
- though the database overall supports transactional DDL.
- The method is used as a context manager within a migration script, by
- calling on :meth:`.Operations.get_context` to retrieve the
- :class:`.MigrationContext`, then invoking
- :meth:`.MigrationContext.autocommit_block` using the ``with:``
- statement::
- def upgrade():
- with op.get_context().autocommit_block():
- op.execute("ALTER TYPE mood ADD VALUE 'soso'")
- Above, a PostgreSQL "ALTER TYPE..ADD VALUE" directive is emitted,
- which must be run outside of a transaction block at the database level.
- The :meth:`.MigrationContext.autocommit_block` method makes use of the
- SQLAlchemy ``AUTOCOMMIT`` isolation level setting, which against the
- psycogp2 DBAPI corresponds to the ``connection.autocommit`` setting,
- to ensure that the database driver is not inside of a DBAPI level
- transaction block.
- .. warning::
- As is necessary, **the database transaction preceding the block is
- unconditionally committed**. This means that the run of migrations
- preceding the operation will be committed, before the overall
- migration operation is complete.
- It is recommended that when an application includes migrations with
- "autocommit" blocks, that
- :paramref:`.EnvironmentContext.transaction_per_migration` be used
- so that the calling environment is tuned to expect short per-file
- migrations whether or not one of them has an autocommit block.
- """
- _in_connection_transaction = self._in_connection_transaction()
- if self.impl.transactional_ddl and self.as_sql:
- self.impl.emit_commit()
- elif _in_connection_transaction:
- assert self._transaction is not None
- self._transaction.commit()
- self._transaction = None
- if not self.as_sql:
- assert self.connection is not None
- current_level = self.connection.get_isolation_level()
- base_connection = self.connection
- # in 1.3 and 1.4 non-future mode, the connection gets switched
- # out. we can use the base connection with the new mode
- # except that it will not know it's in "autocommit" and will
- # emit deprecation warnings when an autocommit action takes
- # place.
- self.connection = self.impl.connection = (
- base_connection.execution_options(isolation_level="AUTOCOMMIT")
- )
- # sqlalchemy future mode will "autobegin" in any case, so take
- # control of that "transaction" here
- fake_trans: Optional[Transaction] = self.connection.begin()
- else:
- fake_trans = None
- try:
- yield
- finally:
- if not self.as_sql:
- assert self.connection is not None
- if fake_trans is not None:
- fake_trans.commit()
- self.connection.execution_options(
- isolation_level=current_level
- )
- self.connection = self.impl.connection = base_connection
- if self.impl.transactional_ddl and self.as_sql:
- self.impl.emit_begin()
- elif _in_connection_transaction:
- assert self.connection is not None
- self._transaction = self.connection.begin()
- def begin_transaction(
- self, _per_migration: bool = False
- ) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]:
- """Begin a logical transaction for migration operations.
- This method is used within an ``env.py`` script to demarcate where
- the outer "transaction" for a series of migrations begins. Example::
- def run_migrations_online():
- connectable = create_engine(...)
- with connectable.connect() as connection:
- context.configure(
- connection=connection, target_metadata=target_metadata
- )
- with context.begin_transaction():
- context.run_migrations()
- Above, :meth:`.MigrationContext.begin_transaction` is used to demarcate
- where the outer logical transaction occurs around the
- :meth:`.MigrationContext.run_migrations` operation.
- A "Logical" transaction means that the operation may or may not
- correspond to a real database transaction. If the target database
- supports transactional DDL (or
- :paramref:`.EnvironmentContext.configure.transactional_ddl` is true),
- the :paramref:`.EnvironmentContext.configure.transaction_per_migration`
- flag is not set, and the migration is against a real database
- connection (as opposed to using "offline" ``--sql`` mode), a real
- transaction will be started. If ``--sql`` mode is in effect, the
- operation would instead correspond to a string such as "BEGIN" being
- emitted to the string output.
- The returned object is a Python context manager that should only be
- used in the context of a ``with:`` statement as indicated above.
- The object has no other guaranteed API features present.
- .. seealso::
- :meth:`.MigrationContext.autocommit_block`
- """
- if self._in_external_transaction:
- return nullcontext()
- if self.impl.transactional_ddl:
- transaction_now = _per_migration == self._transaction_per_migration
- else:
- transaction_now = _per_migration is True
- if not transaction_now:
- return nullcontext()
- elif not self.impl.transactional_ddl:
- assert _per_migration
- if self.as_sql:
- return nullcontext()
- else:
- # track our own notion of a "transaction block", which must be
- # committed when complete. Don't rely upon whether or not the
- # SQLAlchemy connection reports as "in transaction"; this
- # because SQLAlchemy future connection features autobegin
- # behavior, so it may already be in a transaction from our
- # emitting of queries like "has_version_table", etc. While we
- # could track these operations as well, that leaves open the
- # possibility of new operations or other things happening in
- # the user environment that still may be triggering
- # "autobegin".
- in_transaction = self._transaction is not None
- if in_transaction:
- return nullcontext()
- else:
- assert self.connection is not None
- self._transaction = (
- sqla_compat._safe_begin_connection_transaction(
- self.connection
- )
- )
- return _ProxyTransaction(self)
- elif self.as_sql:
- @contextmanager
- def begin_commit():
- self.impl.emit_begin()
- yield
- self.impl.emit_commit()
- return begin_commit()
- else:
- assert self.connection is not None
- self._transaction = sqla_compat._safe_begin_connection_transaction(
- self.connection
- )
- return _ProxyTransaction(self)
- def get_current_revision(self) -> Optional[str]:
- """Return the current revision, usually that which is present
- in the ``alembic_version`` table in the database.
- This method intends to be used only for a migration stream that
- does not contain unmerged branches in the target database;
- if there are multiple branches present, an exception is raised.
- The :meth:`.MigrationContext.get_current_heads` should be preferred
- over this method going forward in order to be compatible with
- branch migration support.
- If this :class:`.MigrationContext` was configured in "offline"
- mode, that is with ``as_sql=True``, the ``starting_rev``
- parameter is returned instead, if any.
- """
- heads = self.get_current_heads()
- if len(heads) == 0:
- return None
- elif len(heads) > 1:
- raise util.CommandError(
- "Version table '%s' has more than one head present; "
- "please use get_current_heads()" % self.version_table
- )
- else:
- return heads[0]
- def get_current_heads(self) -> Tuple[str, ...]:
- """Return a tuple of the current 'head versions' that are represented
- in the target database.
- For a migration stream without branches, this will be a single
- value, synonymous with that of
- :meth:`.MigrationContext.get_current_revision`. However when multiple
- unmerged branches exist within the target database, the returned tuple
- will contain a value for each head.
- If this :class:`.MigrationContext` was configured in "offline"
- mode, that is with ``as_sql=True``, the ``starting_rev``
- parameter is returned in a one-length tuple.
- If no version table is present, or if there are no revisions
- present, an empty tuple is returned.
- """
- if self.as_sql:
- start_from_rev: Any = self._start_from_rev
- if start_from_rev == "base":
- start_from_rev = None
- elif start_from_rev is not None and self.script:
- start_from_rev = [
- self.script.get_revision(sfr).revision
- for sfr in util.to_list(start_from_rev)
- if sfr not in (None, "base")
- ]
- return util.to_tuple(start_from_rev, default=())
- else:
- if self._start_from_rev:
- raise util.CommandError(
- "Can't specify current_rev to context "
- "when using a database connection"
- )
- if not self._has_version_table():
- return ()
- assert self.connection is not None
- return tuple(
- row[0]
- for row in self.connection.execute(
- select(self._version.c.version_num)
- )
- )
- def _ensure_version_table(self, purge: bool = False) -> None:
- with sqla_compat._ensure_scope_for_ddl(self.connection):
- assert self.connection is not None
- self._version.create(self.connection, checkfirst=True)
- if purge:
- assert self.connection is not None
- self.connection.execute(self._version.delete())
- def _has_version_table(self) -> bool:
- assert self.connection is not None
- return sqla_compat._connectable_has_table(
- self.connection, self.version_table, self.version_table_schema
- )
- def stamp(self, script_directory: ScriptDirectory, revision: str) -> None:
- """Stamp the version table with a specific revision.
- This method calculates those branches to which the given revision
- can apply, and updates those branches as though they were migrated
- towards that revision (either up or down). If no current branches
- include the revision, it is added as a new branch head.
- """
- heads = self.get_current_heads()
- if not self.as_sql and not heads:
- self._ensure_version_table()
- head_maintainer = HeadMaintainer(self, heads)
- for step in script_directory._stamp_revs(revision, heads):
- head_maintainer.update_to_step(step)
- def run_migrations(self, **kw: Any) -> None:
- r"""Run the migration scripts established for this
- :class:`.MigrationContext`, if any.
- The commands in :mod:`alembic.command` will set up a function
- that is ultimately passed to the :class:`.MigrationContext`
- as the ``fn`` argument. This function represents the "work"
- that will be done when :meth:`.MigrationContext.run_migrations`
- is called, typically from within the ``env.py`` script of the
- migration environment. The "work function" then provides an iterable
- of version callables and other version information which
- in the case of the ``upgrade`` or ``downgrade`` commands are the
- list of version scripts to invoke. Other commands yield nothing,
- in the case that a command wants to run some other operation
- against the database such as the ``current`` or ``stamp`` commands.
- :param \**kw: keyword arguments here will be passed to each
- migration callable, that is the ``upgrade()`` or ``downgrade()``
- method within revision scripts.
- """
- self.impl.start_migrations()
- heads: Tuple[str, ...]
- if self.purge:
- if self.as_sql:
- raise util.CommandError("Can't use --purge with --sql mode")
- self._ensure_version_table(purge=True)
- heads = ()
- else:
- heads = self.get_current_heads()
- dont_mutate = self.opts.get("dont_mutate", False)
- if not self.as_sql and not heads and not dont_mutate:
- self._ensure_version_table()
- head_maintainer = HeadMaintainer(self, heads)
- assert self._migrations_fn is not None
- for step in self._migrations_fn(heads, self):
- with self.begin_transaction(_per_migration=True):
- if self.as_sql and not head_maintainer.heads:
- # for offline mode, include a CREATE TABLE from
- # the base
- assert self.connection is not None
- self._version.create(self.connection)
- log.info("Running %s", step)
- if self.as_sql:
- self.impl.static_output(
- "-- Running %s" % (step.short_log,)
- )
- step.migration_fn(**kw)
- # previously, we wouldn't stamp per migration
- # if we were in a transaction, however given the more
- # complex model that involves any number of inserts
- # and row-targeted updates and deletes, it's simpler for now
- # just to run the operations on every version
- head_maintainer.update_to_step(step)
- for callback in self.on_version_apply_callbacks:
- callback(
- ctx=self,
- step=step.info,
- heads=set(head_maintainer.heads),
- run_args=kw,
- )
- if self.as_sql and not head_maintainer.heads:
- assert self.connection is not None
- self._version.drop(self.connection)
- def _in_connection_transaction(self) -> bool:
- try:
- meth = self.connection.in_transaction # type:ignore[union-attr]
- except AttributeError:
- return False
- else:
- return meth()
- def execute(
- self,
- sql: Union[Executable, str],
- execution_options: Optional[Dict[str, Any]] = None,
- ) -> None:
- """Execute a SQL construct or string statement.
- The underlying execution mechanics are used, that is
- if this is "offline mode" the SQL is written to the
- output buffer, otherwise the SQL is emitted on
- the current SQLAlchemy connection.
- """
- self.impl._exec(sql, execution_options)
- def _stdout_connection(
- self, connection: Optional[Connection]
- ) -> MockConnection:
- def dump(construct, *multiparams, **params):
- self.impl._exec(construct)
- return MockEngineStrategy.MockConnection(self.dialect, dump)
- @property
- def bind(self) -> Optional[Connection]:
- """Return the current "bind".
- In online mode, this is an instance of
- :class:`sqlalchemy.engine.Connection`, and is suitable
- for ad-hoc execution of any kind of usage described
- in SQLAlchemy Core documentation as well as
- for usage with the :meth:`sqlalchemy.schema.Table.create`
- and :meth:`sqlalchemy.schema.MetaData.create_all` methods
- of :class:`~sqlalchemy.schema.Table`,
- :class:`~sqlalchemy.schema.MetaData`.
- Note that when "standard output" mode is enabled,
- this bind will be a "mock" connection handler that cannot
- return results and is only appropriate for a very limited
- subset of commands.
- """
- return self.connection
- @property
- def config(self) -> Optional[Config]:
- """Return the :class:`.Config` used by the current environment,
- if any."""
- if self.environment_context:
- return self.environment_context.config
- else:
- return None
- def _compare_type(
- self, inspector_column: Column[Any], metadata_column: Column
- ) -> bool:
- if self._user_compare_type is False:
- return False
- if callable(self._user_compare_type):
- user_value = self._user_compare_type(
- self,
- inspector_column,
- metadata_column,
- inspector_column.type,
- metadata_column.type,
- )
- if user_value is not None:
- return user_value
- return self.impl.compare_type(inspector_column, metadata_column)
- def _compare_server_default(
- self,
- inspector_column: Column[Any],
- metadata_column: Column[Any],
- rendered_metadata_default: Optional[str],
- rendered_column_default: Optional[str],
- ) -> bool:
- if self._user_compare_server_default is False:
- return False
- if callable(self._user_compare_server_default):
- user_value = self._user_compare_server_default(
- self,
- inspector_column,
- metadata_column,
- rendered_column_default,
- metadata_column.server_default,
- rendered_metadata_default,
- )
- if user_value is not None:
- return user_value
- return self.impl.compare_server_default(
- inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_column_default,
- )
- class HeadMaintainer:
- def __init__(self, context: MigrationContext, heads: Any) -> None:
- self.context = context
- self.heads = set(heads)
- def _insert_version(self, version: str) -> None:
- assert version not in self.heads
- self.heads.add(version)
- self.context.impl._exec(
- self.context._version.insert().values(
- version_num=literal_column("'%s'" % version)
- )
- )
- def _delete_version(self, version: str) -> None:
- self.heads.remove(version)
- ret = self.context.impl._exec(
- self.context._version.delete().where(
- self.context._version.c.version_num
- == literal_column("'%s'" % version)
- )
- )
- if (
- not self.context.as_sql
- and self.context.dialect.supports_sane_rowcount
- and ret is not None
- and ret.rowcount != 1
- ):
- raise util.CommandError(
- "Online migration expected to match one "
- "row when deleting '%s' in '%s'; "
- "%d found"
- % (version, self.context.version_table, ret.rowcount)
- )
- def _update_version(self, from_: str, to_: str) -> None:
- assert to_ not in self.heads
- self.heads.remove(from_)
- self.heads.add(to_)
- ret = self.context.impl._exec(
- self.context._version.update()
- .values(version_num=literal_column("'%s'" % to_))
- .where(
- self.context._version.c.version_num
- == literal_column("'%s'" % from_)
- )
- )
- if (
- not self.context.as_sql
- and self.context.dialect.supports_sane_rowcount
- and ret is not None
- and ret.rowcount != 1
- ):
- raise util.CommandError(
- "Online migration expected to match one "
- "row when updating '%s' to '%s' in '%s'; "
- "%d found"
- % (from_, to_, self.context.version_table, ret.rowcount)
- )
- def update_to_step(self, step: Union[RevisionStep, StampStep]) -> None:
- if step.should_delete_branch(self.heads):
- vers = step.delete_version_num
- log.debug("branch delete %s", vers)
- self._delete_version(vers)
- elif step.should_create_branch(self.heads):
- vers = step.insert_version_num
- log.debug("new branch insert %s", vers)
- self._insert_version(vers)
- elif step.should_merge_branches(self.heads):
- # delete revs, update from rev, update to rev
- (
- delete_revs,
- update_from_rev,
- update_to_rev,
- ) = step.merge_branch_idents(self.heads)
- log.debug(
- "merge, delete %s, update %s to %s",
- delete_revs,
- update_from_rev,
- update_to_rev,
- )
- for delrev in delete_revs:
- self._delete_version(delrev)
- self._update_version(update_from_rev, update_to_rev)
- elif step.should_unmerge_branches(self.heads):
- (
- update_from_rev,
- update_to_rev,
- insert_revs,
- ) = step.unmerge_branch_idents(self.heads)
- log.debug(
- "unmerge, insert %s, update %s to %s",
- insert_revs,
- update_from_rev,
- update_to_rev,
- )
- for insrev in insert_revs:
- self._insert_version(insrev)
- self._update_version(update_from_rev, update_to_rev)
- else:
- from_, to_ = step.update_version_num(self.heads)
- log.debug("update %s to %s", from_, to_)
- self._update_version(from_, to_)
- class MigrationInfo:
- """Exposes information about a migration step to a callback listener.
- The :class:`.MigrationInfo` object is available exclusively for the
- benefit of the :paramref:`.EnvironmentContext.on_version_apply`
- callback hook.
- """
- is_upgrade: bool
- """True/False: indicates whether this operation ascends or descends the
- version tree."""
- is_stamp: bool
- """True/False: indicates whether this operation is a stamp (i.e. whether
- it results in any actual database operations)."""
- up_revision_id: Optional[str]
- """Version string corresponding to :attr:`.Revision.revision`.
- In the case of a stamp operation, it is advised to use the
- :attr:`.MigrationInfo.up_revision_ids` tuple as a stamp operation can
- make a single movement from one or more branches down to a single
- branchpoint, in which case there will be multiple "up" revisions.
- .. seealso::
- :attr:`.MigrationInfo.up_revision_ids`
- """
- up_revision_ids: Tuple[str, ...]
- """Tuple of version strings corresponding to :attr:`.Revision.revision`.
- In the majority of cases, this tuple will be a single value, synonymous
- with the scalar value of :attr:`.MigrationInfo.up_revision_id`.
- It can be multiple revision identifiers only in the case of an
- ``alembic stamp`` operation which is moving downwards from multiple
- branches down to their common branch point.
- """
- down_revision_ids: Tuple[str, ...]
- """Tuple of strings representing the base revisions of this migration step.
- If empty, this represents a root revision; otherwise, the first item
- corresponds to :attr:`.Revision.down_revision`, and the rest are inferred
- from dependencies.
- """
- revision_map: RevisionMap
- """The revision map inside of which this operation occurs."""
- def __init__(
- self,
- revision_map: RevisionMap,
- is_upgrade: bool,
- is_stamp: bool,
- up_revisions: Union[str, Tuple[str, ...]],
- down_revisions: Union[str, Tuple[str, ...]],
- ) -> None:
- self.revision_map = revision_map
- self.is_upgrade = is_upgrade
- self.is_stamp = is_stamp
- self.up_revision_ids = util.to_tuple(up_revisions, default=())
- if self.up_revision_ids:
- self.up_revision_id = self.up_revision_ids[0]
- else:
- # this should never be the case with
- # "upgrade", "downgrade", or "stamp" as we are always
- # measuring movement in terms of at least one upgrade version
- self.up_revision_id = None
- self.down_revision_ids = util.to_tuple(down_revisions, default=())
- @property
- def is_migration(self) -> bool:
- """True/False: indicates whether this operation is a migration.
- At present this is true if and only the migration is not a stamp.
- If other operation types are added in the future, both this attribute
- and :attr:`~.MigrationInfo.is_stamp` will be false.
- """
- return not self.is_stamp
- @property
- def source_revision_ids(self) -> Tuple[str, ...]:
- """Active revisions before this migration step is applied."""
- return (
- self.down_revision_ids if self.is_upgrade else self.up_revision_ids
- )
- @property
- def destination_revision_ids(self) -> Tuple[str, ...]:
- """Active revisions after this migration step is applied."""
- return (
- self.up_revision_ids if self.is_upgrade else self.down_revision_ids
- )
- @property
- def up_revision(self) -> Optional[Revision]:
- """Get :attr:`~.MigrationInfo.up_revision_id` as
- a :class:`.Revision`.
- """
- return self.revision_map.get_revision(self.up_revision_id)
- @property
- def up_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
- """Get :attr:`~.MigrationInfo.up_revision_ids` as a
- :class:`.Revision`."""
- return self.revision_map.get_revisions(self.up_revision_ids)
- @property
- def down_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
- """Get :attr:`~.MigrationInfo.down_revision_ids` as a tuple of
- :class:`Revisions <.Revision>`."""
- return self.revision_map.get_revisions(self.down_revision_ids)
- @property
- def source_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
- """Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of
- :class:`Revisions <.Revision>`."""
- return self.revision_map.get_revisions(self.source_revision_ids)
- @property
- def destination_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
- """Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of
- :class:`Revisions <.Revision>`."""
- return self.revision_map.get_revisions(self.destination_revision_ids)
- class MigrationStep:
- from_revisions_no_deps: Tuple[str, ...]
- to_revisions_no_deps: Tuple[str, ...]
- is_upgrade: bool
- migration_fn: Any
- if TYPE_CHECKING:
- @property
- def doc(self) -> Optional[str]: ...
- @property
- def name(self) -> str:
- return self.migration_fn.__name__
- @classmethod
- def upgrade_from_script(
- cls, revision_map: RevisionMap, script: Script
- ) -> RevisionStep:
- return RevisionStep(revision_map, script, True)
- @classmethod
- def downgrade_from_script(
- cls, revision_map: RevisionMap, script: Script
- ) -> RevisionStep:
- return RevisionStep(revision_map, script, False)
- @property
- def is_downgrade(self) -> bool:
- return not self.is_upgrade
- @property
- def short_log(self) -> str:
- return "%s %s -> %s" % (
- self.name,
- util.format_as_comma(self.from_revisions_no_deps),
- util.format_as_comma(self.to_revisions_no_deps),
- )
- def __str__(self):
- if self.doc:
- return "%s %s -> %s, %s" % (
- self.name,
- util.format_as_comma(self.from_revisions_no_deps),
- util.format_as_comma(self.to_revisions_no_deps),
- self.doc,
- )
- else:
- return self.short_log
- class RevisionStep(MigrationStep):
- def __init__(
- self, revision_map: RevisionMap, revision: Script, is_upgrade: bool
- ) -> None:
- self.revision_map = revision_map
- self.revision = revision
- self.is_upgrade = is_upgrade
- if is_upgrade:
- self.migration_fn = revision.module.upgrade
- else:
- self.migration_fn = revision.module.downgrade
- def __repr__(self):
- return "RevisionStep(%r, is_upgrade=%r)" % (
- self.revision.revision,
- self.is_upgrade,
- )
- def __eq__(self, other: object) -> bool:
- return (
- isinstance(other, RevisionStep)
- and other.revision == self.revision
- and self.is_upgrade == other.is_upgrade
- )
- @property
- def doc(self) -> Optional[str]:
- return self.revision.doc
- @property
- def from_revisions(self) -> Tuple[str, ...]:
- if self.is_upgrade:
- return self.revision._normalized_down_revisions
- else:
- return (self.revision.revision,)
- @property
- def from_revisions_no_deps( # type:ignore[override]
- self,
- ) -> Tuple[str, ...]:
- if self.is_upgrade:
- return self.revision._versioned_down_revisions
- else:
- return (self.revision.revision,)
- @property
- def to_revisions(self) -> Tuple[str, ...]:
- if self.is_upgrade:
- return (self.revision.revision,)
- else:
- return self.revision._normalized_down_revisions
- @property
- def to_revisions_no_deps( # type:ignore[override]
- self,
- ) -> Tuple[str, ...]:
- if self.is_upgrade:
- return (self.revision.revision,)
- else:
- return self.revision._versioned_down_revisions
- @property
- def _has_scalar_down_revision(self) -> bool:
- return len(self.revision._normalized_down_revisions) == 1
- def should_delete_branch(self, heads: Set[str]) -> bool:
- """A delete is when we are a. in a downgrade and b.
- we are going to the "base" or we are going to a version that
- is implied as a dependency on another version that is remaining.
- """
- if not self.is_downgrade:
- return False
- if self.revision.revision not in heads:
- return False
- downrevs = self.revision._normalized_down_revisions
- if not downrevs:
- # is a base
- return True
- else:
- # determine what the ultimate "to_revisions" for an
- # unmerge would be. If there are none, then we're a delete.
- to_revisions = self._unmerge_to_revisions(heads)
- return not to_revisions
- def merge_branch_idents(
- self, heads: Set[str]
- ) -> Tuple[List[str], str, str]:
- other_heads = set(heads).difference(self.from_revisions)
- if other_heads:
- ancestors = {
- r.revision
- for r in self.revision_map._get_ancestor_nodes(
- self.revision_map.get_revisions(other_heads), check=False
- )
- }
- from_revisions = list(
- set(self.from_revisions).difference(ancestors)
- )
- else:
- from_revisions = list(self.from_revisions)
- return (
- # delete revs, update from rev, update to rev
- list(from_revisions[0:-1]),
- from_revisions[-1],
- self.to_revisions[0],
- )
- def _unmerge_to_revisions(self, heads: Set[str]) -> Tuple[str, ...]:
- other_heads = set(heads).difference([self.revision.revision])
- if other_heads:
- ancestors = {
- r.revision
- for r in self.revision_map._get_ancestor_nodes(
- self.revision_map.get_revisions(other_heads), check=False
- )
- }
- return tuple(set(self.to_revisions).difference(ancestors))
- else:
- # for each revision we plan to return, compute its ancestors
- # (excluding self), and remove those from the final output since
- # they are already accounted for.
- ancestors = {
- r.revision
- for to_revision in self.to_revisions
- for r in self.revision_map._get_ancestor_nodes(
- self.revision_map.get_revisions(to_revision), check=False
- )
- if r.revision != to_revision
- }
- return tuple(set(self.to_revisions).difference(ancestors))
- def unmerge_branch_idents(
- self, heads: Set[str]
- ) -> Tuple[str, str, Tuple[str, ...]]:
- to_revisions = self._unmerge_to_revisions(heads)
- return (
- # update from rev, update to rev, insert revs
- self.from_revisions[0],
- to_revisions[-1],
- to_revisions[0:-1],
- )
- def should_create_branch(self, heads: Set[str]) -> bool:
- if not self.is_upgrade:
- return False
- downrevs = self.revision._normalized_down_revisions
- if not downrevs:
- # is a base
- return True
- else:
- # none of our downrevs are present, so...
- # we have to insert our version. This is true whether
- # or not there is only one downrev, or multiple (in the latter
- # case, we're a merge point.)
- if not heads.intersection(downrevs):
- return True
- else:
- return False
- def should_merge_branches(self, heads: Set[str]) -> bool:
- if not self.is_upgrade:
- return False
- downrevs = self.revision._normalized_down_revisions
- if len(downrevs) > 1 and len(heads.intersection(downrevs)) > 1:
- return True
- return False
- def should_unmerge_branches(self, heads: Set[str]) -> bool:
- if not self.is_downgrade:
- return False
- downrevs = self.revision._normalized_down_revisions
- if self.revision.revision in heads and len(downrevs) > 1:
- return True
- return False
- def update_version_num(self, heads: Set[str]) -> Tuple[str, str]:
- if not self._has_scalar_down_revision:
- downrev = heads.intersection(
- self.revision._normalized_down_revisions
- )
- assert (
- len(downrev) == 1
- ), "Can't do an UPDATE because downrevision is ambiguous"
- down_revision = list(downrev)[0]
- else:
- down_revision = self.revision._normalized_down_revisions[0]
- if self.is_upgrade:
- return down_revision, self.revision.revision
- else:
- return self.revision.revision, down_revision
- @property
- def delete_version_num(self) -> str:
- return self.revision.revision
- @property
- def insert_version_num(self) -> str:
- return self.revision.revision
- @property
- def info(self) -> MigrationInfo:
- return MigrationInfo(
- revision_map=self.revision_map,
- up_revisions=self.revision.revision,
- down_revisions=self.revision._normalized_down_revisions,
- is_upgrade=self.is_upgrade,
- is_stamp=False,
- )
- class StampStep(MigrationStep):
- def __init__(
- self,
- from_: Optional[Union[str, Collection[str]]],
- to_: Optional[Union[str, Collection[str]]],
- is_upgrade: bool,
- branch_move: bool,
- revision_map: Optional[RevisionMap] = None,
- ) -> None:
- self.from_: Tuple[str, ...] = util.to_tuple(from_, default=())
- self.to_: Tuple[str, ...] = util.to_tuple(to_, default=())
- self.is_upgrade = is_upgrade
- self.branch_move = branch_move
- self.migration_fn = self.stamp_revision
- self.revision_map = revision_map
- doc: Optional[str] = None
- def stamp_revision(self, **kw: Any) -> None:
- return None
- def __eq__(self, other):
- return (
- isinstance(other, StampStep)
- and other.from_revisions == self.from_revisions
- and other.to_revisions == self.to_revisions
- and other.branch_move == self.branch_move
- and self.is_upgrade == other.is_upgrade
- )
- @property
- def from_revisions(self):
- return self.from_
- @property
- def to_revisions(self) -> Tuple[str, ...]:
- return self.to_
- @property
- def from_revisions_no_deps( # type:ignore[override]
- self,
- ) -> Tuple[str, ...]:
- return self.from_
- @property
- def to_revisions_no_deps( # type:ignore[override]
- self,
- ) -> Tuple[str, ...]:
- return self.to_
- @property
- def delete_version_num(self) -> str:
- assert len(self.from_) == 1
- return self.from_[0]
- @property
- def insert_version_num(self) -> str:
- assert len(self.to_) == 1
- return self.to_[0]
- def update_version_num(self, heads: Set[str]) -> Tuple[str, str]:
- assert len(self.from_) == 1
- assert len(self.to_) == 1
- return self.from_[0], self.to_[0]
- def merge_branch_idents(
- self, heads: Union[Set[str], List[str]]
- ) -> Union[Tuple[List[Any], str, str], Tuple[List[str], str, str]]:
- return (
- # delete revs, update from rev, update to rev
- list(self.from_[0:-1]),
- self.from_[-1],
- self.to_[0],
- )
- def unmerge_branch_idents(
- self, heads: Set[str]
- ) -> Tuple[str, str, List[str]]:
- return (
- # update from rev, update to rev, insert revs
- self.from_[0],
- self.to_[-1],
- list(self.to_[0:-1]),
- )
- def should_delete_branch(self, heads: Set[str]) -> bool:
- # TODO: we probably need to look for self.to_ inside of heads,
- # in a similar manner as should_create_branch, however we have
- # no tests for this yet (stamp downgrades w/ branches)
- return self.is_downgrade and self.branch_move
- def should_create_branch(self, heads: Set[str]) -> Union[Set[str], bool]:
- return (
- self.is_upgrade
- and (self.branch_move or set(self.from_).difference(heads))
- and set(self.to_).difference(heads)
- )
- def should_merge_branches(self, heads: Set[str]) -> bool:
- return len(self.from_) > 1
- def should_unmerge_branches(self, heads: Set[str]) -> bool:
- return len(self.to_) > 1
- @property
- def info(self) -> MigrationInfo:
- up, down = (
- (self.to_, self.from_)
- if self.is_upgrade
- else (self.from_, self.to_)
- )
- assert self.revision_map is not None
- return MigrationInfo(
- revision_map=self.revision_map,
- up_revisions=up,
- down_revisions=down,
- is_upgrade=self.is_upgrade,
- is_stamp=True,
- )
|