| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524 |
- # mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
- # mypy: no-warn-return-any, allow-any-generics
- from __future__ import annotations
- import re
- from typing import Any
- from typing import Optional
- from typing import TYPE_CHECKING
- from typing import Union
- from sqlalchemy import schema
- from sqlalchemy import types as sqltypes
- from sqlalchemy.sql import elements
- from sqlalchemy.sql import functions
- from sqlalchemy.sql import operators
- from .base import alter_table
- from .base import AlterColumn
- from .base import ColumnDefault
- from .base import ColumnName
- from .base import ColumnNullable
- from .base import ColumnType
- from .base import format_column_name
- from .base import format_server_default
- from .impl import DefaultImpl
- from .. import util
- from ..util import sqla_compat
- from ..util.sqla_compat import _is_type_bound
- from ..util.sqla_compat import compiles
- if TYPE_CHECKING:
- from typing import Literal
- from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
- from sqlalchemy.sql.ddl import DropConstraint
- from sqlalchemy.sql.elements import ClauseElement
- from sqlalchemy.sql.schema import Constraint
- from sqlalchemy.sql.type_api import TypeEngine
- from .base import _ServerDefault
- class MySQLImpl(DefaultImpl):
- __dialect__ = "mysql"
- transactional_ddl = False
- type_synonyms = DefaultImpl.type_synonyms + (
- {"BOOL", "TINYINT"},
- {"JSON", "LONGTEXT"},
- )
- type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
- def render_ddl_sql_expr(
- self,
- expr: ClauseElement,
- is_server_default: bool = False,
- is_index: bool = False,
- **kw: Any,
- ) -> str:
- # apply Grouping to index expressions;
- # see https://github.com/sqlalchemy/sqlalchemy/blob/
- # 36da2eaf3e23269f2cf28420ae73674beafd0661/
- # lib/sqlalchemy/dialects/mysql/base.py#L2191
- if is_index and (
- isinstance(expr, elements.BinaryExpression)
- or (
- isinstance(expr, elements.UnaryExpression)
- and expr.modifier not in (operators.desc_op, operators.asc_op)
- )
- or isinstance(expr, functions.FunctionElement)
- ):
- expr = elements.Grouping(expr)
- return super().render_ddl_sql_expr(
- expr, is_server_default=is_server_default, is_index=is_index, **kw
- )
- def alter_column(
- self,
- table_name: str,
- column_name: str,
- *,
- nullable: Optional[bool] = None,
- server_default: Optional[
- Union[_ServerDefault, Literal[False]]
- ] = False,
- name: Optional[str] = None,
- type_: Optional[TypeEngine] = None,
- schema: Optional[str] = None,
- existing_type: Optional[TypeEngine] = None,
- existing_server_default: Optional[_ServerDefault] = None,
- existing_nullable: Optional[bool] = None,
- autoincrement: Optional[bool] = None,
- existing_autoincrement: Optional[bool] = None,
- comment: Optional[Union[str, Literal[False]]] = False,
- existing_comment: Optional[str] = None,
- **kw: Any,
- ) -> None:
- if sqla_compat._server_default_is_identity(
- server_default, existing_server_default
- ) or sqla_compat._server_default_is_computed(
- server_default, existing_server_default
- ):
- # modifying computed or identity columns is not supported
- # the default will raise
- super().alter_column(
- table_name,
- column_name,
- nullable=nullable,
- type_=type_,
- schema=schema,
- existing_type=existing_type,
- existing_nullable=existing_nullable,
- server_default=server_default,
- existing_server_default=existing_server_default,
- **kw,
- )
- if name is not None or self._is_mysql_allowed_functional_default(
- type_ if type_ is not None else existing_type, server_default
- ):
- self._exec(
- MySQLChangeColumn(
- table_name,
- column_name,
- schema=schema,
- newname=name if name is not None else column_name,
- nullable=(
- nullable
- if nullable is not None
- else (
- existing_nullable
- if existing_nullable is not None
- else True
- )
- ),
- type_=type_ if type_ is not None else existing_type,
- default=(
- server_default
- if server_default is not False
- else existing_server_default
- ),
- autoincrement=(
- autoincrement
- if autoincrement is not None
- else existing_autoincrement
- ),
- comment=(
- comment if comment is not False else existing_comment
- ),
- )
- )
- elif (
- nullable is not None
- or type_ is not None
- or autoincrement is not None
- or comment is not False
- ):
- self._exec(
- MySQLModifyColumn(
- table_name,
- column_name,
- schema=schema,
- newname=name if name is not None else column_name,
- nullable=(
- nullable
- if nullable is not None
- else (
- existing_nullable
- if existing_nullable is not None
- else True
- )
- ),
- type_=type_ if type_ is not None else existing_type,
- default=(
- server_default
- if server_default is not False
- else existing_server_default
- ),
- autoincrement=(
- autoincrement
- if autoincrement is not None
- else existing_autoincrement
- ),
- comment=(
- comment if comment is not False else existing_comment
- ),
- )
- )
- elif server_default is not False:
- self._exec(
- MySQLAlterDefault(
- table_name, column_name, server_default, schema=schema
- )
- )
- def drop_constraint(
- self,
- const: Constraint,
- **kw: Any,
- ) -> None:
- if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
- return
- super().drop_constraint(const)
- def _is_mysql_allowed_functional_default(
- self,
- type_: Optional[TypeEngine],
- server_default: Optional[Union[_ServerDefault, Literal[False]]],
- ) -> bool:
- return (
- type_ is not None
- and type_._type_affinity is sqltypes.DateTime
- and server_default is not None
- )
- def compare_server_default(
- self,
- inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default,
- ):
- # partially a workaround for SQLAlchemy issue #3023; if the
- # column were created without "NOT NULL", MySQL may have added
- # an implicit default of '0' which we need to skip
- # TODO: this is not really covered anymore ?
- if (
- metadata_column.type._type_affinity is sqltypes.Integer
- and inspector_column.primary_key
- and not inspector_column.autoincrement
- and not rendered_metadata_default
- and rendered_inspector_default == "'0'"
- ):
- return False
- elif (
- rendered_inspector_default
- and inspector_column.type._type_affinity is sqltypes.Integer
- ):
- rendered_inspector_default = (
- re.sub(r"^'|'$", "", rendered_inspector_default)
- if rendered_inspector_default is not None
- else None
- )
- return rendered_inspector_default != rendered_metadata_default
- elif (
- rendered_metadata_default
- and metadata_column.type._type_affinity is sqltypes.String
- ):
- metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default)
- return rendered_inspector_default != f"'{metadata_default}'"
- elif rendered_inspector_default and rendered_metadata_default:
- # adjust for "function()" vs. "FUNCTION" as can occur particularly
- # for the CURRENT_TIMESTAMP function on newer MariaDB versions
- # SQLAlchemy MySQL dialect bundles ON UPDATE into the server
- # default; adjust for this possibly being present.
- onupdate_ins = re.match(
- r"(.*) (on update.*?)(?:\(\))?$",
- rendered_inspector_default.lower(),
- )
- onupdate_met = re.match(
- r"(.*) (on update.*?)(?:\(\))?$",
- rendered_metadata_default.lower(),
- )
- if onupdate_ins:
- if not onupdate_met:
- return True
- elif onupdate_ins.group(2) != onupdate_met.group(2):
- return True
- rendered_inspector_default = onupdate_ins.group(1)
- rendered_metadata_default = onupdate_met.group(1)
- return re.sub(
- r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower()
- ) != re.sub(
- r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower()
- )
- else:
- return rendered_inspector_default != rendered_metadata_default
- def correct_for_autogen_constraints(
- self,
- conn_unique_constraints,
- conn_indexes,
- metadata_unique_constraints,
- metadata_indexes,
- ):
- # TODO: if SQLA 1.0, make use of "duplicates_index"
- # metadata
- removed = set()
- for idx in list(conn_indexes):
- if idx.unique:
- continue
- # MySQL puts implicit indexes on FK columns, even if
- # composite and even if MyISAM, so can't check this too easily.
- # the name of the index may be the column name or it may
- # be the name of the FK constraint.
- for col in idx.columns:
- if idx.name == col.name:
- conn_indexes.remove(idx)
- removed.add(idx.name)
- break
- for fk in col.foreign_keys:
- if fk.name == idx.name:
- conn_indexes.remove(idx)
- removed.add(idx.name)
- break
- if idx.name in removed:
- break
- # then remove indexes from the "metadata_indexes"
- # that we've removed from reflected, otherwise they come out
- # as adds (see #202)
- for idx in list(metadata_indexes):
- if idx.name in removed:
- metadata_indexes.remove(idx)
- def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
- conn_fk_by_sig = {
- self._create_reflected_constraint_sig(fk).unnamed_no_options: fk
- for fk in conn_fks
- }
- metadata_fk_by_sig = {
- self._create_metadata_constraint_sig(fk).unnamed_no_options: fk
- for fk in metadata_fks
- }
- for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
- mdfk = metadata_fk_by_sig[sig]
- cnfk = conn_fk_by_sig[sig]
- # MySQL considers RESTRICT to be the default and doesn't
- # report on it. if the model has explicit RESTRICT and
- # the conn FK has None, set it to RESTRICT
- if (
- mdfk.ondelete is not None
- and mdfk.ondelete.lower() == "restrict"
- and cnfk.ondelete is None
- ):
- cnfk.ondelete = "RESTRICT"
- if (
- mdfk.onupdate is not None
- and mdfk.onupdate.lower() == "restrict"
- and cnfk.onupdate is None
- ):
- cnfk.onupdate = "RESTRICT"
- class MariaDBImpl(MySQLImpl):
- __dialect__ = "mariadb"
- class MySQLAlterDefault(AlterColumn):
- def __init__(
- self,
- name: str,
- column_name: str,
- default: Optional[_ServerDefault],
- schema: Optional[str] = None,
- ) -> None:
- super(AlterColumn, self).__init__(name, schema=schema)
- self.column_name = column_name
- self.default = default
- class MySQLChangeColumn(AlterColumn):
- def __init__(
- self,
- name: str,
- column_name: str,
- schema: Optional[str] = None,
- newname: Optional[str] = None,
- type_: Optional[TypeEngine] = None,
- nullable: Optional[bool] = None,
- default: Optional[Union[_ServerDefault, Literal[False]]] = False,
- autoincrement: Optional[bool] = None,
- comment: Optional[Union[str, Literal[False]]] = False,
- ) -> None:
- super(AlterColumn, self).__init__(name, schema=schema)
- self.column_name = column_name
- self.nullable = nullable
- self.newname = newname
- self.default = default
- self.autoincrement = autoincrement
- self.comment = comment
- if type_ is None:
- raise util.CommandError(
- "All MySQL CHANGE/MODIFY COLUMN operations "
- "require the existing type."
- )
- self.type_ = sqltypes.to_instance(type_)
- class MySQLModifyColumn(MySQLChangeColumn):
- pass
- @compiles(ColumnNullable, "mysql", "mariadb")
- @compiles(ColumnName, "mysql", "mariadb")
- @compiles(ColumnDefault, "mysql", "mariadb")
- @compiles(ColumnType, "mysql", "mariadb")
- def _mysql_doesnt_support_individual(element, compiler, **kw):
- raise NotImplementedError(
- "Individual alter column constructs not supported by MySQL"
- )
- @compiles(MySQLAlterDefault, "mysql", "mariadb")
- def _mysql_alter_default(
- element: MySQLAlterDefault, compiler: MySQLDDLCompiler, **kw
- ) -> str:
- return "%s ALTER COLUMN %s %s" % (
- alter_table(compiler, element.table_name, element.schema),
- format_column_name(compiler, element.column_name),
- (
- "SET DEFAULT %s" % format_server_default(compiler, element.default)
- if element.default is not None
- else "DROP DEFAULT"
- ),
- )
- @compiles(MySQLModifyColumn, "mysql", "mariadb")
- def _mysql_modify_column(
- element: MySQLModifyColumn, compiler: MySQLDDLCompiler, **kw
- ) -> str:
- return "%s MODIFY %s %s" % (
- alter_table(compiler, element.table_name, element.schema),
- format_column_name(compiler, element.column_name),
- _mysql_colspec(
- compiler,
- nullable=element.nullable,
- server_default=element.default,
- type_=element.type_,
- autoincrement=element.autoincrement,
- comment=element.comment,
- ),
- )
- @compiles(MySQLChangeColumn, "mysql", "mariadb")
- def _mysql_change_column(
- element: MySQLChangeColumn, compiler: MySQLDDLCompiler, **kw
- ) -> str:
- return "%s CHANGE %s %s %s" % (
- alter_table(compiler, element.table_name, element.schema),
- format_column_name(compiler, element.column_name),
- format_column_name(compiler, element.newname),
- _mysql_colspec(
- compiler,
- nullable=element.nullable,
- server_default=element.default,
- type_=element.type_,
- autoincrement=element.autoincrement,
- comment=element.comment,
- ),
- )
- def _mysql_colspec(
- compiler: MySQLDDLCompiler,
- nullable: Optional[bool],
- server_default: Optional[Union[_ServerDefault, Literal[False]]],
- type_: TypeEngine,
- autoincrement: Optional[bool],
- comment: Optional[Union[str, Literal[False]]],
- ) -> str:
- spec = "%s %s" % (
- compiler.dialect.type_compiler.process(type_),
- "NULL" if nullable else "NOT NULL",
- )
- if autoincrement:
- spec += " AUTO_INCREMENT"
- if server_default is not False and server_default is not None:
- spec += " DEFAULT %s" % format_server_default(compiler, server_default)
- if comment:
- spec += " COMMENT %s" % compiler.sql_compiler.render_literal_value(
- comment, sqltypes.String()
- )
- return spec
- @compiles(schema.DropConstraint, "mysql", "mariadb")
- def _mysql_drop_constraint(
- element: DropConstraint, compiler: MySQLDDLCompiler, **kw
- ) -> str:
- """Redefine SQLAlchemy's drop constraint to
- raise errors for invalid constraint type."""
- constraint = element.element
- if isinstance(
- constraint,
- (
- schema.ForeignKeyConstraint,
- schema.PrimaryKeyConstraint,
- schema.UniqueConstraint,
- ),
- ):
- assert not kw
- return compiler.visit_drop_constraint(element)
- elif isinstance(constraint, schema.CheckConstraint):
- # note that SQLAlchemy as of 1.2 does not yet support
- # DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
- # here.
- if compiler.dialect.is_mariadb:
- return "ALTER TABLE %s DROP CONSTRAINT %s" % (
- compiler.preparer.format_table(constraint.table),
- compiler.preparer.format_constraint(constraint),
- )
- else:
- return "ALTER TABLE %s DROP CHECK %s" % (
- compiler.preparer.format_table(constraint.table),
- compiler.preparer.format_constraint(constraint),
- )
- else:
- raise NotImplementedError(
- "No generic 'DROP CONSTRAINT' in MySQL - "
- "please specify constraint type"
- )
|