base.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. # mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
  2. # mypy: no-warn-return-any, allow-any-generics
  3. from __future__ import annotations
  4. import functools
  5. from typing import Optional
  6. from typing import TYPE_CHECKING
  7. from typing import Union
  8. from sqlalchemy import exc
  9. from sqlalchemy import Integer
  10. from sqlalchemy import types as sqltypes
  11. from sqlalchemy.ext.compiler import compiles
  12. from sqlalchemy.schema import Column
  13. from sqlalchemy.schema import DDLElement
  14. from sqlalchemy.sql.elements import quoted_name
  15. from ..util.sqla_compat import _columns_for_constraint # noqa
  16. from ..util.sqla_compat import _find_columns # noqa
  17. from ..util.sqla_compat import _fk_spec # noqa
  18. from ..util.sqla_compat import _is_type_bound # noqa
  19. from ..util.sqla_compat import _table_for_constraint # noqa
  20. if TYPE_CHECKING:
  21. from typing import Any
  22. from sqlalchemy import Computed
  23. from sqlalchemy import Identity
  24. from sqlalchemy.sql.compiler import Compiled
  25. from sqlalchemy.sql.compiler import DDLCompiler
  26. from sqlalchemy.sql.elements import TextClause
  27. from sqlalchemy.sql.functions import Function
  28. from sqlalchemy.sql.schema import FetchedValue
  29. from sqlalchemy.sql.type_api import TypeEngine
  30. from .impl import DefaultImpl
  31. _ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
  32. class AlterTable(DDLElement):
  33. """Represent an ALTER TABLE statement.
  34. Only the string name and optional schema name of the table
  35. is required, not a full Table object.
  36. """
  37. def __init__(
  38. self,
  39. table_name: str,
  40. schema: Optional[Union[quoted_name, str]] = None,
  41. ) -> None:
  42. self.table_name = table_name
  43. self.schema = schema
  44. class RenameTable(AlterTable):
  45. def __init__(
  46. self,
  47. old_table_name: str,
  48. new_table_name: Union[quoted_name, str],
  49. schema: Optional[Union[quoted_name, str]] = None,
  50. ) -> None:
  51. super().__init__(old_table_name, schema=schema)
  52. self.new_table_name = new_table_name
  53. class AlterColumn(AlterTable):
  54. def __init__(
  55. self,
  56. name: str,
  57. column_name: str,
  58. schema: Optional[str] = None,
  59. existing_type: Optional[TypeEngine] = None,
  60. existing_nullable: Optional[bool] = None,
  61. existing_server_default: Optional[_ServerDefault] = None,
  62. existing_comment: Optional[str] = None,
  63. ) -> None:
  64. super().__init__(name, schema=schema)
  65. self.column_name = column_name
  66. self.existing_type = (
  67. sqltypes.to_instance(existing_type)
  68. if existing_type is not None
  69. else None
  70. )
  71. self.existing_nullable = existing_nullable
  72. self.existing_server_default = existing_server_default
  73. self.existing_comment = existing_comment
  74. class ColumnNullable(AlterColumn):
  75. def __init__(
  76. self, name: str, column_name: str, nullable: bool, **kw
  77. ) -> None:
  78. super().__init__(name, column_name, **kw)
  79. self.nullable = nullable
  80. class ColumnType(AlterColumn):
  81. def __init__(
  82. self, name: str, column_name: str, type_: TypeEngine, **kw
  83. ) -> None:
  84. super().__init__(name, column_name, **kw)
  85. self.type_ = sqltypes.to_instance(type_)
  86. class ColumnName(AlterColumn):
  87. def __init__(
  88. self, name: str, column_name: str, newname: str, **kw
  89. ) -> None:
  90. super().__init__(name, column_name, **kw)
  91. self.newname = newname
  92. class ColumnDefault(AlterColumn):
  93. def __init__(
  94. self,
  95. name: str,
  96. column_name: str,
  97. default: Optional[_ServerDefault],
  98. **kw,
  99. ) -> None:
  100. super().__init__(name, column_name, **kw)
  101. self.default = default
  102. class ComputedColumnDefault(AlterColumn):
  103. def __init__(
  104. self, name: str, column_name: str, default: Optional[Computed], **kw
  105. ) -> None:
  106. super().__init__(name, column_name, **kw)
  107. self.default = default
  108. class IdentityColumnDefault(AlterColumn):
  109. def __init__(
  110. self,
  111. name: str,
  112. column_name: str,
  113. default: Optional[Identity],
  114. impl: DefaultImpl,
  115. **kw,
  116. ) -> None:
  117. super().__init__(name, column_name, **kw)
  118. self.default = default
  119. self.impl = impl
  120. class AddColumn(AlterTable):
  121. def __init__(
  122. self,
  123. name: str,
  124. column: Column[Any],
  125. schema: Optional[Union[quoted_name, str]] = None,
  126. if_not_exists: Optional[bool] = None,
  127. ) -> None:
  128. super().__init__(name, schema=schema)
  129. self.column = column
  130. self.if_not_exists = if_not_exists
  131. class DropColumn(AlterTable):
  132. def __init__(
  133. self,
  134. name: str,
  135. column: Column[Any],
  136. schema: Optional[str] = None,
  137. if_exists: Optional[bool] = None,
  138. ) -> None:
  139. super().__init__(name, schema=schema)
  140. self.column = column
  141. self.if_exists = if_exists
  142. class ColumnComment(AlterColumn):
  143. def __init__(
  144. self, name: str, column_name: str, comment: Optional[str], **kw
  145. ) -> None:
  146. super().__init__(name, column_name, **kw)
  147. self.comment = comment
  148. @compiles(RenameTable)
  149. def visit_rename_table(
  150. element: RenameTable, compiler: DDLCompiler, **kw
  151. ) -> str:
  152. return "%s RENAME TO %s" % (
  153. alter_table(compiler, element.table_name, element.schema),
  154. format_table_name(compiler, element.new_table_name, element.schema),
  155. )
  156. @compiles(AddColumn)
  157. def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
  158. return "%s %s" % (
  159. alter_table(compiler, element.table_name, element.schema),
  160. add_column(
  161. compiler, element.column, if_not_exists=element.if_not_exists, **kw
  162. ),
  163. )
  164. @compiles(DropColumn)
  165. def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
  166. return "%s %s" % (
  167. alter_table(compiler, element.table_name, element.schema),
  168. drop_column(
  169. compiler, element.column.name, if_exists=element.if_exists, **kw
  170. ),
  171. )
  172. @compiles(ColumnNullable)
  173. def visit_column_nullable(
  174. element: ColumnNullable, compiler: DDLCompiler, **kw
  175. ) -> str:
  176. return "%s %s %s" % (
  177. alter_table(compiler, element.table_name, element.schema),
  178. alter_column(compiler, element.column_name),
  179. "DROP NOT NULL" if element.nullable else "SET NOT NULL",
  180. )
  181. @compiles(ColumnType)
  182. def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str:
  183. return "%s %s %s" % (
  184. alter_table(compiler, element.table_name, element.schema),
  185. alter_column(compiler, element.column_name),
  186. "TYPE %s" % format_type(compiler, element.type_),
  187. )
  188. @compiles(ColumnName)
  189. def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
  190. return "%s RENAME %s TO %s" % (
  191. alter_table(compiler, element.table_name, element.schema),
  192. format_column_name(compiler, element.column_name),
  193. format_column_name(compiler, element.newname),
  194. )
  195. @compiles(ColumnDefault)
  196. def visit_column_default(
  197. element: ColumnDefault, compiler: DDLCompiler, **kw
  198. ) -> str:
  199. return "%s %s %s" % (
  200. alter_table(compiler, element.table_name, element.schema),
  201. alter_column(compiler, element.column_name),
  202. (
  203. "SET DEFAULT %s" % format_server_default(compiler, element.default)
  204. if element.default is not None
  205. else "DROP DEFAULT"
  206. ),
  207. )
  208. @compiles(ComputedColumnDefault)
  209. def visit_computed_column(
  210. element: ComputedColumnDefault, compiler: DDLCompiler, **kw
  211. ):
  212. raise exc.CompileError(
  213. 'Adding or removing a "computed" construct, e.g. GENERATED '
  214. "ALWAYS AS, to or from an existing column is not supported."
  215. )
  216. @compiles(IdentityColumnDefault)
  217. def visit_identity_column(
  218. element: IdentityColumnDefault, compiler: DDLCompiler, **kw
  219. ):
  220. raise exc.CompileError(
  221. 'Adding, removing or modifying an "identity" construct, '
  222. "e.g. GENERATED AS IDENTITY, to or from an existing "
  223. "column is not supported in this dialect."
  224. )
  225. def quote_dotted(
  226. name: Union[quoted_name, str], quote: functools.partial
  227. ) -> Union[quoted_name, str]:
  228. """quote the elements of a dotted name"""
  229. if isinstance(name, quoted_name):
  230. return quote(name)
  231. result = ".".join([quote(x) for x in name.split(".")])
  232. return result
  233. def format_table_name(
  234. compiler: Compiled,
  235. name: Union[quoted_name, str],
  236. schema: Optional[Union[quoted_name, str]],
  237. ) -> Union[quoted_name, str]:
  238. quote = functools.partial(compiler.preparer.quote)
  239. if schema:
  240. return quote_dotted(schema, quote) + "." + quote(name)
  241. else:
  242. return quote(name)
  243. def format_column_name(
  244. compiler: DDLCompiler, name: Optional[Union[quoted_name, str]]
  245. ) -> Union[quoted_name, str]:
  246. return compiler.preparer.quote(name) # type: ignore[arg-type]
  247. def format_server_default(
  248. compiler: DDLCompiler,
  249. default: Optional[_ServerDefault],
  250. ) -> str:
  251. # this can be updated to use compiler.render_default_string
  252. # for SQLAlchemy 2.0 and above; not in 1.4
  253. default_str = compiler.get_column_default_string(
  254. Column("x", Integer, server_default=default)
  255. )
  256. assert default_str is not None
  257. return default_str
  258. def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str:
  259. return compiler.dialect.type_compiler.process(type_)
  260. def alter_table(
  261. compiler: DDLCompiler,
  262. name: str,
  263. schema: Optional[str],
  264. ) -> str:
  265. return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
  266. def drop_column(
  267. compiler: DDLCompiler, name: str, if_exists: Optional[bool] = None, **kw
  268. ) -> str:
  269. return "DROP COLUMN %s%s" % (
  270. "IF EXISTS " if if_exists else "",
  271. format_column_name(compiler, name),
  272. )
  273. def alter_column(compiler: DDLCompiler, name: str) -> str:
  274. return "ALTER COLUMN %s" % format_column_name(compiler, name)
  275. def add_column(
  276. compiler: DDLCompiler,
  277. column: Column[Any],
  278. if_not_exists: Optional[bool] = None,
  279. **kw,
  280. ) -> str:
  281. text = "ADD COLUMN %s%s" % (
  282. "IF NOT EXISTS " if if_not_exists else "",
  283. compiler.get_column_specification(column, **kw),
  284. )
  285. const = " ".join(
  286. compiler.process(constraint) for constraint in column.constraints
  287. )
  288. if const:
  289. text += " " + const
  290. return text