| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495 |
- # mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
- # mypy: no-warn-return-any, allow-any-generics
- from __future__ import annotations
- import contextlib
- import re
- from typing import Any
- from typing import Callable
- from typing import Dict
- from typing import Iterable
- from typing import Iterator
- from typing import Optional
- from typing import Protocol
- from typing import Set
- from typing import Type
- from typing import TYPE_CHECKING
- from typing import TypeVar
- from typing import Union
- from sqlalchemy import __version__
- from sqlalchemy import schema
- from sqlalchemy import sql
- from sqlalchemy import types as sqltypes
- from sqlalchemy.schema import CheckConstraint
- from sqlalchemy.schema import Column
- from sqlalchemy.schema import ForeignKeyConstraint
- from sqlalchemy.sql import visitors
- from sqlalchemy.sql.base import DialectKWArgs
- from sqlalchemy.sql.elements import BindParameter
- from sqlalchemy.sql.elements import ColumnClause
- from sqlalchemy.sql.elements import TextClause
- from sqlalchemy.sql.elements import UnaryExpression
- from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME # type: ignore[attr-defined] # noqa: E501
- from sqlalchemy.sql.visitors import traverse
- from typing_extensions import TypeGuard
- if TYPE_CHECKING:
- from sqlalchemy import ClauseElement
- from sqlalchemy import Identity
- from sqlalchemy import Index
- from sqlalchemy import Table
- from sqlalchemy.engine import Connection
- from sqlalchemy.engine import Dialect
- from sqlalchemy.engine import Transaction
- from sqlalchemy.sql.base import ColumnCollection
- from sqlalchemy.sql.compiler import SQLCompiler
- from sqlalchemy.sql.elements import ColumnElement
- from sqlalchemy.sql.schema import Constraint
- from sqlalchemy.sql.schema import SchemaItem
- _CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
- class _CompilerProtocol(Protocol):
- def __call__(self, element: Any, compiler: Any, **kw: Any) -> str: ...
- def _safe_int(value: str) -> Union[int, str]:
- try:
- return int(value)
- except:
- return value
- _vers = tuple(
- [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
- )
- # https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d
- sqla_14_18 = _vers >= (1, 4, 18)
- sqla_14_26 = _vers >= (1, 4, 26)
- sqla_2 = _vers >= (2,)
- sqlalchemy_version = __version__
- if TYPE_CHECKING:
- def compiles(
- element: Type[ClauseElement], *dialects: str
- ) -> Callable[[_CompilerProtocol], _CompilerProtocol]: ...
- else:
- from sqlalchemy.ext.compiler import compiles # noqa: I100,I202
- identity_has_dialect_kwargs = issubclass(schema.Identity, DialectKWArgs)
- def _get_identity_options_dict(
- identity: Union[Identity, schema.Sequence, None],
- dialect_kwargs: bool = False,
- ) -> Dict[str, Any]:
- if identity is None:
- return {}
- elif identity_has_dialect_kwargs:
- assert hasattr(identity, "_as_dict")
- as_dict = identity._as_dict()
- if dialect_kwargs:
- assert isinstance(identity, DialectKWArgs)
- as_dict.update(identity.dialect_kwargs)
- else:
- as_dict = {}
- if isinstance(identity, schema.Identity):
- # always=None means something different than always=False
- as_dict["always"] = identity.always
- if identity.on_null is not None:
- as_dict["on_null"] = identity.on_null
- # attributes common to Identity and Sequence
- attrs = (
- "start",
- "increment",
- "minvalue",
- "maxvalue",
- "nominvalue",
- "nomaxvalue",
- "cycle",
- "cache",
- "order",
- )
- as_dict.update(
- {
- key: getattr(identity, key, None)
- for key in attrs
- if getattr(identity, key, None) is not None
- }
- )
- return as_dict
- if sqla_2:
- from sqlalchemy.sql.base import _NoneName
- else:
- from sqlalchemy.util import symbol as _NoneName # type: ignore[assignment]
- _ConstraintName = Union[None, str, _NoneName]
- _ConstraintNameDefined = Union[str, _NoneName]
- def constraint_name_defined(
- name: _ConstraintName,
- ) -> TypeGuard[_ConstraintNameDefined]:
- return name is _NONE_NAME or isinstance(name, (str, _NoneName))
- def constraint_name_string(name: _ConstraintName) -> TypeGuard[str]:
- return isinstance(name, str)
- def constraint_name_or_none(name: _ConstraintName) -> Optional[str]:
- return name if constraint_name_string(name) else None
- AUTOINCREMENT_DEFAULT = "auto"
- @contextlib.contextmanager
- def _ensure_scope_for_ddl(
- connection: Optional[Connection],
- ) -> Iterator[None]:
- try:
- in_transaction = connection.in_transaction # type: ignore[union-attr]
- except AttributeError:
- # catch for MockConnection, None
- in_transaction = None
- pass
- # yield outside the catch
- if in_transaction is None:
- yield
- else:
- if not in_transaction():
- assert connection is not None
- with connection.begin():
- yield
- else:
- yield
- def _safe_begin_connection_transaction(
- connection: Connection,
- ) -> Transaction:
- transaction = connection.get_transaction()
- if transaction:
- return transaction
- else:
- return connection.begin()
- def _safe_commit_connection_transaction(
- connection: Connection,
- ) -> None:
- transaction = connection.get_transaction()
- if transaction:
- transaction.commit()
- def _safe_rollback_connection_transaction(
- connection: Connection,
- ) -> None:
- transaction = connection.get_transaction()
- if transaction:
- transaction.rollback()
- def _get_connection_in_transaction(connection: Optional[Connection]) -> bool:
- try:
- in_transaction = connection.in_transaction # type: ignore
- except AttributeError:
- # catch for MockConnection
- return False
- else:
- return in_transaction()
- def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
- return idx.expressions # type: ignore
- def _copy(schema_item: _CE, **kw) -> _CE:
- if hasattr(schema_item, "_copy"):
- return schema_item._copy(**kw)
- else:
- return schema_item.copy(**kw) # type: ignore[union-attr]
- def _connectable_has_table(
- connectable: Connection, tablename: str, schemaname: Union[str, None]
- ) -> bool:
- return connectable.dialect.has_table(connectable, tablename, schemaname)
- def _exec_on_inspector(inspector, statement, **params):
- with inspector._operation_context() as conn:
- return conn.execute(statement, params)
- def _nullability_might_be_unset(metadata_column):
- from sqlalchemy.sql import schema
- return metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
- def _server_default_is_computed(*server_default) -> bool:
- return any(isinstance(sd, schema.Computed) for sd in server_default)
- def _server_default_is_identity(*server_default) -> bool:
- return any(isinstance(sd, schema.Identity) for sd in server_default)
- def _table_for_constraint(constraint: Constraint) -> Table:
- if isinstance(constraint, ForeignKeyConstraint):
- table = constraint.parent
- assert table is not None
- return table # type: ignore[return-value]
- else:
- return constraint.table
- def _columns_for_constraint(constraint):
- if isinstance(constraint, ForeignKeyConstraint):
- return [fk.parent for fk in constraint.elements]
- elif isinstance(constraint, CheckConstraint):
- return _find_columns(constraint.sqltext)
- else:
- return list(constraint.columns)
- def _resolve_for_variant(type_, dialect):
- if _type_has_variants(type_):
- base_type, mapping = _get_variant_mapping(type_)
- return mapping.get(dialect.name, base_type)
- else:
- return type_
- if hasattr(sqltypes.TypeEngine, "_variant_mapping"): # 2.0
- def _type_has_variants(type_):
- return bool(type_._variant_mapping)
- def _get_variant_mapping(type_):
- return type_, type_._variant_mapping
- else:
- def _type_has_variants(type_):
- return type(type_) is sqltypes.Variant
- def _get_variant_mapping(type_):
- return type_.impl, type_.mapping
- def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
- if TYPE_CHECKING:
- assert constraint.columns is not None
- assert constraint.elements is not None
- assert isinstance(constraint.parent, Table)
- source_columns = [
- constraint.columns[key].name for key in constraint.column_keys
- ]
- source_table = constraint.parent.name
- source_schema = constraint.parent.schema
- target_schema = constraint.elements[0].column.table.schema
- target_table = constraint.elements[0].column.table.name
- target_columns = [element.column.name for element in constraint.elements]
- ondelete = constraint.ondelete
- onupdate = constraint.onupdate
- deferrable = constraint.deferrable
- initially = constraint.initially
- return (
- source_schema,
- source_table,
- source_columns,
- target_schema,
- target_table,
- target_columns,
- onupdate,
- ondelete,
- deferrable,
- initially,
- )
- def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
- spec = constraint.elements[0]._get_colspec()
- tokens = spec.split(".")
- tokens.pop(-1) # colname
- tablekey = ".".join(tokens)
- assert constraint.parent is not None
- return tablekey == constraint.parent.key
- def _is_type_bound(constraint: Constraint) -> bool:
- # this deals with SQLAlchemy #3260, don't copy CHECK constraints
- # that will be generated by the type.
- # new feature added for #3260
- return constraint._type_bound
- def _find_columns(clause):
- """locate Column objects within the given expression."""
- cols: Set[ColumnElement[Any]] = set()
- traverse(clause, {}, {"column": cols.add})
- return cols
- def _remove_column_from_collection(
- collection: ColumnCollection, column: Union[Column[Any], ColumnClause[Any]]
- ) -> None:
- """remove a column from a ColumnCollection."""
- # workaround for older SQLAlchemy, remove the
- # same object that's present
- assert column.key is not None
- to_remove = collection[column.key]
- # SQLAlchemy 2.0 will use more ReadOnlyColumnCollection
- # (renamed from ImmutableColumnCollection)
- if hasattr(collection, "_immutable") or hasattr(collection, "_readonly"):
- collection._parent.remove(to_remove)
- else:
- collection.remove(to_remove)
- def _textual_index_column(
- table: Table, text_: Union[str, TextClause, ColumnElement[Any]]
- ) -> Union[ColumnElement[Any], Column[Any]]:
- """a workaround for the Index construct's severe lack of flexibility"""
- if isinstance(text_, str):
- c = Column(text_, sqltypes.NULLTYPE)
- table.append_column(c)
- return c
- elif isinstance(text_, TextClause):
- return _textual_index_element(table, text_)
- elif isinstance(text_, _textual_index_element):
- return _textual_index_column(table, text_.text)
- elif isinstance(text_, sql.ColumnElement):
- return _copy_expression(text_, table)
- else:
- raise ValueError("String or text() construct expected")
- def _copy_expression(expression: _CE, target_table: Table) -> _CE:
- def replace(col):
- if (
- isinstance(col, Column)
- and col.table is not None
- and col.table is not target_table
- ):
- if col.name in target_table.c:
- return target_table.c[col.name]
- else:
- c = _copy(col)
- target_table.append_column(c)
- return c
- else:
- return None
- return visitors.replacement_traverse( # type: ignore[call-overload]
- expression, {}, replace
- )
- class _textual_index_element(sql.ColumnElement):
- """Wrap around a sqlalchemy text() construct in such a way that
- we appear like a column-oriented SQL expression to an Index
- construct.
- The issue here is that currently the Postgresql dialect, the biggest
- recipient of functional indexes, keys all the index expressions to
- the corresponding column expressions when rendering CREATE INDEX,
- so the Index we create here needs to have a .columns collection that
- is the same length as the .expressions collection. Ultimately
- SQLAlchemy should support text() expressions in indexes.
- See SQLAlchemy issue 3174.
- """
- __visit_name__ = "_textual_idx_element"
- def __init__(self, table: Table, text: TextClause) -> None:
- self.table = table
- self.text = text
- self.key = text.text
- self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
- table.append_column(self.fake_column)
- def get_children(self, **kw):
- return [self.fake_column]
- @compiles(_textual_index_element)
- def _render_textual_index_column(
- element: _textual_index_element, compiler: SQLCompiler, **kw
- ) -> str:
- return compiler.process(element.text, **kw)
- class _literal_bindparam(BindParameter):
- pass
- @compiles(_literal_bindparam)
- def _render_literal_bindparam(
- element: _literal_bindparam, compiler: SQLCompiler, **kw
- ) -> str:
- return compiler.render_literal_bindparam(element, **kw)
- def _get_constraint_final_name(
- constraint: Union[Index, Constraint], dialect: Optional[Dialect]
- ) -> Optional[str]:
- if constraint.name is None:
- return None
- assert dialect is not None
- # for SQLAlchemy 1.4 we would like to have the option to expand
- # the use of "deferred" names for constraints as well as to have
- # some flexibility with "None" name and similar; make use of new
- # SQLAlchemy API to return what would be the final compiled form of
- # the name for this dialect.
- return dialect.identifier_preparer.format_constraint(
- constraint, _alembic_quote=False
- )
- def _constraint_is_named(
- constraint: Union[Constraint, Index], dialect: Optional[Dialect]
- ) -> bool:
- if constraint.name is None:
- return False
- assert dialect is not None
- name = dialect.identifier_preparer.format_constraint(
- constraint, _alembic_quote=False
- )
- return name is not None
- def is_expression_index(index: Index) -> bool:
- for expr in index.expressions:
- if is_expression(expr):
- return True
- return False
- def is_expression(expr: Any) -> bool:
- while isinstance(expr, UnaryExpression):
- expr = expr.element
- if not isinstance(expr, ColumnClause) or expr.is_literal:
- return True
- return False
|