sqla_compat.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. # mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
  2. # mypy: no-warn-return-any, allow-any-generics
  3. from __future__ import annotations
  4. import contextlib
  5. import re
  6. from typing import Any
  7. from typing import Callable
  8. from typing import Dict
  9. from typing import Iterable
  10. from typing import Iterator
  11. from typing import Optional
  12. from typing import Protocol
  13. from typing import Set
  14. from typing import Type
  15. from typing import TYPE_CHECKING
  16. from typing import TypeVar
  17. from typing import Union
  18. from sqlalchemy import __version__
  19. from sqlalchemy import schema
  20. from sqlalchemy import sql
  21. from sqlalchemy import types as sqltypes
  22. from sqlalchemy.schema import CheckConstraint
  23. from sqlalchemy.schema import Column
  24. from sqlalchemy.schema import ForeignKeyConstraint
  25. from sqlalchemy.sql import visitors
  26. from sqlalchemy.sql.base import DialectKWArgs
  27. from sqlalchemy.sql.elements import BindParameter
  28. from sqlalchemy.sql.elements import ColumnClause
  29. from sqlalchemy.sql.elements import TextClause
  30. from sqlalchemy.sql.elements import UnaryExpression
  31. from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME # type: ignore[attr-defined] # noqa: E501
  32. from sqlalchemy.sql.visitors import traverse
  33. from typing_extensions import TypeGuard
  34. if TYPE_CHECKING:
  35. from sqlalchemy import ClauseElement
  36. from sqlalchemy import Identity
  37. from sqlalchemy import Index
  38. from sqlalchemy import Table
  39. from sqlalchemy.engine import Connection
  40. from sqlalchemy.engine import Dialect
  41. from sqlalchemy.engine import Transaction
  42. from sqlalchemy.sql.base import ColumnCollection
  43. from sqlalchemy.sql.compiler import SQLCompiler
  44. from sqlalchemy.sql.elements import ColumnElement
  45. from sqlalchemy.sql.schema import Constraint
  46. from sqlalchemy.sql.schema import SchemaItem
  47. _CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
  48. class _CompilerProtocol(Protocol):
  49. def __call__(self, element: Any, compiler: Any, **kw: Any) -> str: ...
  50. def _safe_int(value: str) -> Union[int, str]:
  51. try:
  52. return int(value)
  53. except:
  54. return value
  55. _vers = tuple(
  56. [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
  57. )
  58. # https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d
  59. sqla_14_18 = _vers >= (1, 4, 18)
  60. sqla_14_26 = _vers >= (1, 4, 26)
  61. sqla_2 = _vers >= (2,)
  62. sqlalchemy_version = __version__
  63. if TYPE_CHECKING:
  64. def compiles(
  65. element: Type[ClauseElement], *dialects: str
  66. ) -> Callable[[_CompilerProtocol], _CompilerProtocol]: ...
  67. else:
  68. from sqlalchemy.ext.compiler import compiles # noqa: I100,I202
  69. identity_has_dialect_kwargs = issubclass(schema.Identity, DialectKWArgs)
  70. def _get_identity_options_dict(
  71. identity: Union[Identity, schema.Sequence, None],
  72. dialect_kwargs: bool = False,
  73. ) -> Dict[str, Any]:
  74. if identity is None:
  75. return {}
  76. elif identity_has_dialect_kwargs:
  77. assert hasattr(identity, "_as_dict")
  78. as_dict = identity._as_dict()
  79. if dialect_kwargs:
  80. assert isinstance(identity, DialectKWArgs)
  81. as_dict.update(identity.dialect_kwargs)
  82. else:
  83. as_dict = {}
  84. if isinstance(identity, schema.Identity):
  85. # always=None means something different than always=False
  86. as_dict["always"] = identity.always
  87. if identity.on_null is not None:
  88. as_dict["on_null"] = identity.on_null
  89. # attributes common to Identity and Sequence
  90. attrs = (
  91. "start",
  92. "increment",
  93. "minvalue",
  94. "maxvalue",
  95. "nominvalue",
  96. "nomaxvalue",
  97. "cycle",
  98. "cache",
  99. "order",
  100. )
  101. as_dict.update(
  102. {
  103. key: getattr(identity, key, None)
  104. for key in attrs
  105. if getattr(identity, key, None) is not None
  106. }
  107. )
  108. return as_dict
  109. if sqla_2:
  110. from sqlalchemy.sql.base import _NoneName
  111. else:
  112. from sqlalchemy.util import symbol as _NoneName # type: ignore[assignment]
  113. _ConstraintName = Union[None, str, _NoneName]
  114. _ConstraintNameDefined = Union[str, _NoneName]
  115. def constraint_name_defined(
  116. name: _ConstraintName,
  117. ) -> TypeGuard[_ConstraintNameDefined]:
  118. return name is _NONE_NAME or isinstance(name, (str, _NoneName))
  119. def constraint_name_string(name: _ConstraintName) -> TypeGuard[str]:
  120. return isinstance(name, str)
  121. def constraint_name_or_none(name: _ConstraintName) -> Optional[str]:
  122. return name if constraint_name_string(name) else None
  123. AUTOINCREMENT_DEFAULT = "auto"
  124. @contextlib.contextmanager
  125. def _ensure_scope_for_ddl(
  126. connection: Optional[Connection],
  127. ) -> Iterator[None]:
  128. try:
  129. in_transaction = connection.in_transaction # type: ignore[union-attr]
  130. except AttributeError:
  131. # catch for MockConnection, None
  132. in_transaction = None
  133. pass
  134. # yield outside the catch
  135. if in_transaction is None:
  136. yield
  137. else:
  138. if not in_transaction():
  139. assert connection is not None
  140. with connection.begin():
  141. yield
  142. else:
  143. yield
  144. def _safe_begin_connection_transaction(
  145. connection: Connection,
  146. ) -> Transaction:
  147. transaction = connection.get_transaction()
  148. if transaction:
  149. return transaction
  150. else:
  151. return connection.begin()
  152. def _safe_commit_connection_transaction(
  153. connection: Connection,
  154. ) -> None:
  155. transaction = connection.get_transaction()
  156. if transaction:
  157. transaction.commit()
  158. def _safe_rollback_connection_transaction(
  159. connection: Connection,
  160. ) -> None:
  161. transaction = connection.get_transaction()
  162. if transaction:
  163. transaction.rollback()
  164. def _get_connection_in_transaction(connection: Optional[Connection]) -> bool:
  165. try:
  166. in_transaction = connection.in_transaction # type: ignore
  167. except AttributeError:
  168. # catch for MockConnection
  169. return False
  170. else:
  171. return in_transaction()
  172. def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
  173. return idx.expressions # type: ignore
  174. def _copy(schema_item: _CE, **kw) -> _CE:
  175. if hasattr(schema_item, "_copy"):
  176. return schema_item._copy(**kw)
  177. else:
  178. return schema_item.copy(**kw) # type: ignore[union-attr]
  179. def _connectable_has_table(
  180. connectable: Connection, tablename: str, schemaname: Union[str, None]
  181. ) -> bool:
  182. return connectable.dialect.has_table(connectable, tablename, schemaname)
  183. def _exec_on_inspector(inspector, statement, **params):
  184. with inspector._operation_context() as conn:
  185. return conn.execute(statement, params)
  186. def _nullability_might_be_unset(metadata_column):
  187. from sqlalchemy.sql import schema
  188. return metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
  189. def _server_default_is_computed(*server_default) -> bool:
  190. return any(isinstance(sd, schema.Computed) for sd in server_default)
  191. def _server_default_is_identity(*server_default) -> bool:
  192. return any(isinstance(sd, schema.Identity) for sd in server_default)
  193. def _table_for_constraint(constraint: Constraint) -> Table:
  194. if isinstance(constraint, ForeignKeyConstraint):
  195. table = constraint.parent
  196. assert table is not None
  197. return table # type: ignore[return-value]
  198. else:
  199. return constraint.table
  200. def _columns_for_constraint(constraint):
  201. if isinstance(constraint, ForeignKeyConstraint):
  202. return [fk.parent for fk in constraint.elements]
  203. elif isinstance(constraint, CheckConstraint):
  204. return _find_columns(constraint.sqltext)
  205. else:
  206. return list(constraint.columns)
  207. def _resolve_for_variant(type_, dialect):
  208. if _type_has_variants(type_):
  209. base_type, mapping = _get_variant_mapping(type_)
  210. return mapping.get(dialect.name, base_type)
  211. else:
  212. return type_
  213. if hasattr(sqltypes.TypeEngine, "_variant_mapping"): # 2.0
  214. def _type_has_variants(type_):
  215. return bool(type_._variant_mapping)
  216. def _get_variant_mapping(type_):
  217. return type_, type_._variant_mapping
  218. else:
  219. def _type_has_variants(type_):
  220. return type(type_) is sqltypes.Variant
  221. def _get_variant_mapping(type_):
  222. return type_.impl, type_.mapping
  223. def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
  224. if TYPE_CHECKING:
  225. assert constraint.columns is not None
  226. assert constraint.elements is not None
  227. assert isinstance(constraint.parent, Table)
  228. source_columns = [
  229. constraint.columns[key].name for key in constraint.column_keys
  230. ]
  231. source_table = constraint.parent.name
  232. source_schema = constraint.parent.schema
  233. target_schema = constraint.elements[0].column.table.schema
  234. target_table = constraint.elements[0].column.table.name
  235. target_columns = [element.column.name for element in constraint.elements]
  236. ondelete = constraint.ondelete
  237. onupdate = constraint.onupdate
  238. deferrable = constraint.deferrable
  239. initially = constraint.initially
  240. return (
  241. source_schema,
  242. source_table,
  243. source_columns,
  244. target_schema,
  245. target_table,
  246. target_columns,
  247. onupdate,
  248. ondelete,
  249. deferrable,
  250. initially,
  251. )
  252. def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
  253. spec = constraint.elements[0]._get_colspec()
  254. tokens = spec.split(".")
  255. tokens.pop(-1) # colname
  256. tablekey = ".".join(tokens)
  257. assert constraint.parent is not None
  258. return tablekey == constraint.parent.key
  259. def _is_type_bound(constraint: Constraint) -> bool:
  260. # this deals with SQLAlchemy #3260, don't copy CHECK constraints
  261. # that will be generated by the type.
  262. # new feature added for #3260
  263. return constraint._type_bound
  264. def _find_columns(clause):
  265. """locate Column objects within the given expression."""
  266. cols: Set[ColumnElement[Any]] = set()
  267. traverse(clause, {}, {"column": cols.add})
  268. return cols
  269. def _remove_column_from_collection(
  270. collection: ColumnCollection, column: Union[Column[Any], ColumnClause[Any]]
  271. ) -> None:
  272. """remove a column from a ColumnCollection."""
  273. # workaround for older SQLAlchemy, remove the
  274. # same object that's present
  275. assert column.key is not None
  276. to_remove = collection[column.key]
  277. # SQLAlchemy 2.0 will use more ReadOnlyColumnCollection
  278. # (renamed from ImmutableColumnCollection)
  279. if hasattr(collection, "_immutable") or hasattr(collection, "_readonly"):
  280. collection._parent.remove(to_remove)
  281. else:
  282. collection.remove(to_remove)
  283. def _textual_index_column(
  284. table: Table, text_: Union[str, TextClause, ColumnElement[Any]]
  285. ) -> Union[ColumnElement[Any], Column[Any]]:
  286. """a workaround for the Index construct's severe lack of flexibility"""
  287. if isinstance(text_, str):
  288. c = Column(text_, sqltypes.NULLTYPE)
  289. table.append_column(c)
  290. return c
  291. elif isinstance(text_, TextClause):
  292. return _textual_index_element(table, text_)
  293. elif isinstance(text_, _textual_index_element):
  294. return _textual_index_column(table, text_.text)
  295. elif isinstance(text_, sql.ColumnElement):
  296. return _copy_expression(text_, table)
  297. else:
  298. raise ValueError("String or text() construct expected")
  299. def _copy_expression(expression: _CE, target_table: Table) -> _CE:
  300. def replace(col):
  301. if (
  302. isinstance(col, Column)
  303. and col.table is not None
  304. and col.table is not target_table
  305. ):
  306. if col.name in target_table.c:
  307. return target_table.c[col.name]
  308. else:
  309. c = _copy(col)
  310. target_table.append_column(c)
  311. return c
  312. else:
  313. return None
  314. return visitors.replacement_traverse( # type: ignore[call-overload]
  315. expression, {}, replace
  316. )
  317. class _textual_index_element(sql.ColumnElement):
  318. """Wrap around a sqlalchemy text() construct in such a way that
  319. we appear like a column-oriented SQL expression to an Index
  320. construct.
  321. The issue here is that currently the Postgresql dialect, the biggest
  322. recipient of functional indexes, keys all the index expressions to
  323. the corresponding column expressions when rendering CREATE INDEX,
  324. so the Index we create here needs to have a .columns collection that
  325. is the same length as the .expressions collection. Ultimately
  326. SQLAlchemy should support text() expressions in indexes.
  327. See SQLAlchemy issue 3174.
  328. """
  329. __visit_name__ = "_textual_idx_element"
  330. def __init__(self, table: Table, text: TextClause) -> None:
  331. self.table = table
  332. self.text = text
  333. self.key = text.text
  334. self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
  335. table.append_column(self.fake_column)
  336. def get_children(self, **kw):
  337. return [self.fake_column]
  338. @compiles(_textual_index_element)
  339. def _render_textual_index_column(
  340. element: _textual_index_element, compiler: SQLCompiler, **kw
  341. ) -> str:
  342. return compiler.process(element.text, **kw)
  343. class _literal_bindparam(BindParameter):
  344. pass
  345. @compiles(_literal_bindparam)
  346. def _render_literal_bindparam(
  347. element: _literal_bindparam, compiler: SQLCompiler, **kw
  348. ) -> str:
  349. return compiler.render_literal_bindparam(element, **kw)
  350. def _get_constraint_final_name(
  351. constraint: Union[Index, Constraint], dialect: Optional[Dialect]
  352. ) -> Optional[str]:
  353. if constraint.name is None:
  354. return None
  355. assert dialect is not None
  356. # for SQLAlchemy 1.4 we would like to have the option to expand
  357. # the use of "deferred" names for constraints as well as to have
  358. # some flexibility with "None" name and similar; make use of new
  359. # SQLAlchemy API to return what would be the final compiled form of
  360. # the name for this dialect.
  361. return dialect.identifier_preparer.format_constraint(
  362. constraint, _alembic_quote=False
  363. )
  364. def _constraint_is_named(
  365. constraint: Union[Constraint, Index], dialect: Optional[Dialect]
  366. ) -> bool:
  367. if constraint.name is None:
  368. return False
  369. assert dialect is not None
  370. name = dialect.identifier_preparer.format_constraint(
  371. constraint, _alembic_quote=False
  372. )
  373. return name is not None
  374. def is_expression_index(index: Index) -> bool:
  375. for expr in index.expressions:
  376. if is_expression(expr):
  377. return True
  378. return False
  379. def is_expression(expr: Any) -> bool:
  380. while isinstance(expr, UnaryExpression):
  381. expr = expr.element
  382. if not isinstance(expr, ColumnClause) or expr.is_literal:
  383. return True
  384. return False