fixtures.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. from __future__ import annotations
  2. import configparser
  3. from contextlib import contextmanager
  4. import io
  5. import os
  6. import re
  7. import shutil
  8. from typing import Any
  9. from typing import Dict
  10. from sqlalchemy import Column
  11. from sqlalchemy import create_mock_engine
  12. from sqlalchemy import inspect
  13. from sqlalchemy import MetaData
  14. from sqlalchemy import String
  15. from sqlalchemy import Table
  16. from sqlalchemy import testing
  17. from sqlalchemy import text
  18. from sqlalchemy.testing import config
  19. from sqlalchemy.testing import mock
  20. from sqlalchemy.testing.assertions import eq_
  21. from sqlalchemy.testing.fixtures import FutureEngineMixin
  22. from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
  23. from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
  24. import alembic
  25. from .assertions import _get_dialect
  26. from .env import _get_staging_directory
  27. from ..environment import EnvironmentContext
  28. from ..migration import MigrationContext
  29. from ..operations import Operations
  30. from ..util import sqla_compat
  31. from ..util.sqla_compat import sqla_2
  32. testing_config = configparser.ConfigParser()
  33. testing_config.read(["test.cfg"])
  34. class TestBase(SQLAlchemyTestBase):
  35. is_sqlalchemy_future = sqla_2
  36. @testing.fixture()
  37. def clear_staging_dir(self):
  38. yield
  39. location = _get_staging_directory()
  40. for filename in os.listdir(location):
  41. file_path = os.path.join(location, filename)
  42. if os.path.isfile(file_path) or os.path.islink(file_path):
  43. os.unlink(file_path)
  44. elif os.path.isdir(file_path):
  45. shutil.rmtree(file_path)
  46. @contextmanager
  47. def pushd(self, dirname):
  48. current_dir = os.getcwd()
  49. try:
  50. os.chdir(dirname)
  51. yield
  52. finally:
  53. os.chdir(current_dir)
  54. @testing.fixture()
  55. def pop_alembic_config_env(self):
  56. yield
  57. os.environ.pop("ALEMBIC_CONFIG", None)
  58. @testing.fixture()
  59. def ops_context(self, migration_context):
  60. with migration_context.begin_transaction(_per_migration=True):
  61. yield Operations(migration_context)
  62. @testing.fixture
  63. def migration_context(self, connection):
  64. return MigrationContext.configure(
  65. connection, opts=dict(transaction_per_migration=True)
  66. )
  67. @testing.fixture
  68. def as_sql_migration_context(self, connection):
  69. return MigrationContext.configure(
  70. connection, opts=dict(transaction_per_migration=True, as_sql=True)
  71. )
  72. @testing.fixture
  73. def connection(self):
  74. with config.db.connect() as conn:
  75. yield conn
  76. class TablesTest(TestBase, SQLAlchemyTablesTest):
  77. pass
  78. FutureEngineMixin.is_sqlalchemy_future = True
  79. def capture_db(dialect="postgresql://"):
  80. buf = []
  81. def dump(sql, *multiparams, **params):
  82. buf.append(str(sql.compile(dialect=engine.dialect)))
  83. engine = create_mock_engine(dialect, dump)
  84. return engine, buf
  85. _engs: Dict[Any, Any] = {}
  86. @contextmanager
  87. def capture_context_buffer(**kw):
  88. if kw.pop("bytes_io", False):
  89. buf = io.BytesIO()
  90. else:
  91. buf = io.StringIO()
  92. kw.update({"dialect_name": "sqlite", "output_buffer": buf})
  93. conf = EnvironmentContext.configure
  94. def configure(*arg, **opt):
  95. opt.update(**kw)
  96. return conf(*arg, **opt)
  97. with mock.patch.object(EnvironmentContext, "configure", configure):
  98. yield buf
  99. @contextmanager
  100. def capture_engine_context_buffer(**kw):
  101. from .env import _sqlite_file_db
  102. from sqlalchemy import event
  103. buf = io.StringIO()
  104. eng = _sqlite_file_db()
  105. conn = eng.connect()
  106. @event.listens_for(conn, "before_cursor_execute")
  107. def bce(conn, cursor, statement, parameters, context, executemany):
  108. buf.write(statement + "\n")
  109. kw.update({"connection": conn})
  110. conf = EnvironmentContext.configure
  111. def configure(*arg, **opt):
  112. opt.update(**kw)
  113. return conf(*arg, **opt)
  114. with mock.patch.object(EnvironmentContext, "configure", configure):
  115. yield buf
  116. def op_fixture(
  117. dialect="default",
  118. as_sql=False,
  119. naming_convention=None,
  120. literal_binds=False,
  121. native_boolean=None,
  122. ):
  123. opts = {}
  124. if naming_convention:
  125. opts["target_metadata"] = MetaData(naming_convention=naming_convention)
  126. class buffer_:
  127. def __init__(self):
  128. self.lines = []
  129. def write(self, msg):
  130. msg = msg.strip()
  131. msg = re.sub(r"[\n\t]", "", msg)
  132. if as_sql:
  133. # the impl produces soft tabs,
  134. # so search for blocks of 4 spaces
  135. msg = re.sub(r" ", "", msg)
  136. msg = re.sub(r"\;\n*$", "", msg)
  137. self.lines.append(msg)
  138. def flush(self):
  139. pass
  140. buf = buffer_()
  141. class ctx(MigrationContext):
  142. def get_buf(self):
  143. return buf
  144. def clear_assertions(self):
  145. buf.lines[:] = []
  146. def assert_(self, *sql):
  147. # TODO: make this more flexible about
  148. # whitespace and such
  149. eq_(buf.lines, [re.sub(r"[\n\t]", "", s) for s in sql])
  150. def assert_contains(self, sql):
  151. for stmt in buf.lines:
  152. if re.sub(r"[\n\t]", "", sql) in stmt:
  153. return
  154. else:
  155. assert False, "Could not locate fragment %r in %r" % (
  156. sql,
  157. buf.lines,
  158. )
  159. if as_sql:
  160. opts["as_sql"] = as_sql
  161. if literal_binds:
  162. opts["literal_binds"] = literal_binds
  163. ctx_dialect = _get_dialect(dialect)
  164. if native_boolean is not None:
  165. ctx_dialect.supports_native_boolean = native_boolean
  166. # this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
  167. # which breaks assumptions in the alembic test suite
  168. ctx_dialect.non_native_boolean_check_constraint = True
  169. if not as_sql:
  170. def execute(stmt, *multiparam, **param):
  171. if isinstance(stmt, str):
  172. stmt = text(stmt)
  173. assert stmt.supports_execution
  174. sql = str(stmt.compile(dialect=ctx_dialect))
  175. buf.write(sql)
  176. connection = mock.Mock(dialect=ctx_dialect, execute=execute)
  177. else:
  178. opts["output_buffer"] = buf
  179. connection = None
  180. context = ctx(ctx_dialect, connection, opts)
  181. alembic.op._proxy = Operations(context)
  182. return context
  183. class AlterColRoundTripFixture:
  184. # since these tests are about syntax, use more recent SQLAlchemy as some of
  185. # the type / server default compare logic might not work on older
  186. # SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle
  187. __requires__ = ("alter_column",)
  188. def setUp(self):
  189. self.conn = config.db.connect()
  190. self.ctx = MigrationContext.configure(self.conn)
  191. self.op = Operations(self.ctx)
  192. self.metadata = MetaData()
  193. def _compare_type(self, t1, t2):
  194. c1 = Column("q", t1)
  195. c2 = Column("q", t2)
  196. assert not self.ctx.impl.compare_type(
  197. c1, c2
  198. ), "Type objects %r and %r didn't compare as equivalent" % (t1, t2)
  199. def _compare_server_default(self, t1, s1, t2, s2):
  200. c1 = Column("q", t1, server_default=s1)
  201. c2 = Column("q", t2, server_default=s2)
  202. assert not self.ctx.impl.compare_server_default(
  203. c1, c2, s2, s1
  204. ), "server defaults %r and %r didn't compare as equivalent" % (s1, s2)
  205. def tearDown(self):
  206. sqla_compat._safe_rollback_connection_transaction(self.conn)
  207. with self.conn.begin():
  208. self.metadata.drop_all(self.conn)
  209. self.conn.close()
  210. def _run_alter_col(self, from_, to_, compare=None):
  211. column = Column(
  212. from_.get("name", "colname"),
  213. from_.get("type", String(10)),
  214. nullable=from_.get("nullable", True),
  215. server_default=from_.get("server_default", None),
  216. # comment=from_.get("comment", None)
  217. )
  218. t = Table("x", self.metadata, column)
  219. with sqla_compat._ensure_scope_for_ddl(self.conn):
  220. t.create(self.conn)
  221. insp = inspect(self.conn)
  222. old_col = insp.get_columns("x")[0]
  223. # TODO: conditional comment support
  224. self.op.alter_column(
  225. "x",
  226. column.name,
  227. existing_type=column.type,
  228. existing_server_default=(
  229. column.server_default
  230. if column.server_default is not None
  231. else False
  232. ),
  233. existing_nullable=True if column.nullable else False,
  234. # existing_comment=column.comment,
  235. nullable=to_.get("nullable", None),
  236. # modify_comment=False,
  237. server_default=to_.get("server_default", False),
  238. new_column_name=to_.get("name", None),
  239. type_=to_.get("type", None),
  240. )
  241. insp = inspect(self.conn)
  242. new_col = insp.get_columns("x")[0]
  243. if compare is None:
  244. compare = to_
  245. eq_(
  246. new_col["name"],
  247. compare["name"] if "name" in compare else column.name,
  248. )
  249. self._compare_type(
  250. new_col["type"], compare.get("type", old_col["type"])
  251. )
  252. eq_(new_col["nullable"], compare.get("nullable", column.nullable))
  253. self._compare_server_default(
  254. new_col["type"],
  255. new_col.get("default", None),
  256. compare.get("type", old_col["type"]),
  257. (
  258. compare["server_default"].text
  259. if "server_default" in compare
  260. else (
  261. column.server_default.arg.text
  262. if column.server_default is not None
  263. else None
  264. )
  265. ),
  266. )