oracle.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
  2. # mypy: no-warn-return-any, allow-any-generics
  3. from __future__ import annotations
  4. import re
  5. from typing import Any
  6. from typing import Optional
  7. from typing import TYPE_CHECKING
  8. from sqlalchemy.sql import sqltypes
  9. from .base import AddColumn
  10. from .base import alter_table
  11. from .base import ColumnComment
  12. from .base import ColumnDefault
  13. from .base import ColumnName
  14. from .base import ColumnNullable
  15. from .base import ColumnType
  16. from .base import format_column_name
  17. from .base import format_server_default
  18. from .base import format_table_name
  19. from .base import format_type
  20. from .base import IdentityColumnDefault
  21. from .base import RenameTable
  22. from .impl import DefaultImpl
  23. from ..util.sqla_compat import compiles
  24. if TYPE_CHECKING:
  25. from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
  26. from sqlalchemy.engine.cursor import CursorResult
  27. from sqlalchemy.sql.schema import Column
  28. class OracleImpl(DefaultImpl):
  29. __dialect__ = "oracle"
  30. transactional_ddl = False
  31. batch_separator = "/"
  32. command_terminator = ""
  33. type_synonyms = DefaultImpl.type_synonyms + (
  34. {"VARCHAR", "VARCHAR2"},
  35. {"BIGINT", "INTEGER", "SMALLINT", "DECIMAL", "NUMERIC", "NUMBER"},
  36. {"DOUBLE", "FLOAT", "DOUBLE_PRECISION"},
  37. )
  38. identity_attrs_ignore = ()
  39. def __init__(self, *arg, **kw) -> None:
  40. super().__init__(*arg, **kw)
  41. self.batch_separator = self.context_opts.get(
  42. "oracle_batch_separator", self.batch_separator
  43. )
  44. def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
  45. result = super()._exec(construct, *args, **kw)
  46. if self.as_sql and self.batch_separator:
  47. self.static_output(self.batch_separator)
  48. return result
  49. def compare_server_default(
  50. self,
  51. inspector_column,
  52. metadata_column,
  53. rendered_metadata_default,
  54. rendered_inspector_default,
  55. ):
  56. if rendered_metadata_default is not None:
  57. rendered_metadata_default = re.sub(
  58. r"^\((.+)\)$", r"\1", rendered_metadata_default
  59. )
  60. rendered_metadata_default = re.sub(
  61. r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
  62. )
  63. if rendered_inspector_default is not None:
  64. rendered_inspector_default = re.sub(
  65. r"^\((.+)\)$", r"\1", rendered_inspector_default
  66. )
  67. rendered_inspector_default = re.sub(
  68. r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
  69. )
  70. rendered_inspector_default = rendered_inspector_default.strip()
  71. return rendered_inspector_default != rendered_metadata_default
  72. def emit_begin(self) -> None:
  73. self._exec("SET TRANSACTION READ WRITE")
  74. def emit_commit(self) -> None:
  75. self._exec("COMMIT")
  76. @compiles(AddColumn, "oracle")
  77. def visit_add_column(
  78. element: AddColumn, compiler: OracleDDLCompiler, **kw
  79. ) -> str:
  80. return "%s %s" % (
  81. alter_table(compiler, element.table_name, element.schema),
  82. add_column(compiler, element.column, **kw),
  83. )
  84. @compiles(ColumnNullable, "oracle")
  85. def visit_column_nullable(
  86. element: ColumnNullable, compiler: OracleDDLCompiler, **kw
  87. ) -> str:
  88. return "%s %s %s" % (
  89. alter_table(compiler, element.table_name, element.schema),
  90. alter_column(compiler, element.column_name),
  91. "NULL" if element.nullable else "NOT NULL",
  92. )
  93. @compiles(ColumnType, "oracle")
  94. def visit_column_type(
  95. element: ColumnType, compiler: OracleDDLCompiler, **kw
  96. ) -> str:
  97. return "%s %s %s" % (
  98. alter_table(compiler, element.table_name, element.schema),
  99. alter_column(compiler, element.column_name),
  100. "%s" % format_type(compiler, element.type_),
  101. )
  102. @compiles(ColumnName, "oracle")
  103. def visit_column_name(
  104. element: ColumnName, compiler: OracleDDLCompiler, **kw
  105. ) -> str:
  106. return "%s RENAME COLUMN %s TO %s" % (
  107. alter_table(compiler, element.table_name, element.schema),
  108. format_column_name(compiler, element.column_name),
  109. format_column_name(compiler, element.newname),
  110. )
  111. @compiles(ColumnDefault, "oracle")
  112. def visit_column_default(
  113. element: ColumnDefault, compiler: OracleDDLCompiler, **kw
  114. ) -> str:
  115. return "%s %s %s" % (
  116. alter_table(compiler, element.table_name, element.schema),
  117. alter_column(compiler, element.column_name),
  118. (
  119. "DEFAULT %s" % format_server_default(compiler, element.default)
  120. if element.default is not None
  121. else "DEFAULT NULL"
  122. ),
  123. )
  124. @compiles(ColumnComment, "oracle")
  125. def visit_column_comment(
  126. element: ColumnComment, compiler: OracleDDLCompiler, **kw
  127. ) -> str:
  128. ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
  129. comment = compiler.sql_compiler.render_literal_value(
  130. (element.comment if element.comment is not None else ""),
  131. sqltypes.String(),
  132. )
  133. return ddl.format(
  134. table_name=element.table_name,
  135. column_name=element.column_name,
  136. comment=comment,
  137. )
  138. @compiles(RenameTable, "oracle")
  139. def visit_rename_table(
  140. element: RenameTable, compiler: OracleDDLCompiler, **kw
  141. ) -> str:
  142. return "%s RENAME TO %s" % (
  143. alter_table(compiler, element.table_name, element.schema),
  144. format_table_name(compiler, element.new_table_name, None),
  145. )
  146. def alter_column(compiler: OracleDDLCompiler, name: str) -> str:
  147. return "MODIFY %s" % format_column_name(compiler, name)
  148. def add_column(compiler: OracleDDLCompiler, column: Column[Any], **kw) -> str:
  149. return "ADD %s" % compiler.get_column_specification(column, **kw)
  150. @compiles(IdentityColumnDefault, "oracle")
  151. def visit_identity_column(
  152. element: IdentityColumnDefault, compiler: OracleDDLCompiler, **kw
  153. ):
  154. text = "%s %s " % (
  155. alter_table(compiler, element.table_name, element.schema),
  156. alter_column(compiler, element.column_name),
  157. )
  158. if element.default is None:
  159. # drop identity
  160. text += "DROP IDENTITY"
  161. return text
  162. else:
  163. text += compiler.visit_identity_column(element.default)
  164. return text