| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557 |
- import importlib.machinery
- import os
- from pathlib import Path
- import shutil
- import textwrap
- from sqlalchemy.testing import config
- from sqlalchemy.testing import provision
- from . import util as testing_util
- from .. import command
- from .. import script
- from .. import util
- from ..script import Script
- from ..script import ScriptDirectory
- def _get_staging_directory():
- if provision.FOLLOWER_IDENT:
- return f"scratch_{provision.FOLLOWER_IDENT}"
- else:
- return "scratch"
- def staging_env(create=True, template="generic", sourceless=False):
- cfg = _testing_config()
- if create:
- path = _join_path(_get_staging_directory(), "scripts")
- assert not os.path.exists(path), (
- "staging directory %s already exists; poor cleanup?" % path
- )
- command.init(cfg, path, template=template)
- if sourceless:
- try:
- # do an import so that a .pyc/.pyo is generated.
- util.load_python_file(path, "env.py")
- except AttributeError:
- # we don't have the migration context set up yet
- # so running the .env py throws this exception.
- # theoretically we could be using py_compiler here to
- # generate .pyc/.pyo without importing but not really
- # worth it.
- pass
- assert sourceless in (
- "pep3147_envonly",
- "simple",
- "pep3147_everything",
- ), sourceless
- make_sourceless(
- _join_path(path, "env.py"),
- "pep3147" if "pep3147" in sourceless else "simple",
- )
- sc = script.ScriptDirectory.from_config(cfg)
- return sc
- def clear_staging_env():
- from sqlalchemy.testing import engines
- engines.testing_reaper.close_all()
- shutil.rmtree(_get_staging_directory(), True)
- def script_file_fixture(txt):
- dir_ = _join_path(_get_staging_directory(), "scripts")
- path = _join_path(dir_, "script.py.mako")
- with open(path, "w") as f:
- f.write(txt)
- def env_file_fixture(txt):
- dir_ = _join_path(_get_staging_directory(), "scripts")
- txt = (
- """
- from alembic import context
- config = context.config
- """
- + txt
- )
- path = _join_path(dir_, "env.py")
- pyc_path = util.pyc_file_from_path(path)
- if pyc_path:
- os.unlink(pyc_path)
- with open(path, "w") as f:
- f.write(txt)
- def _sqlite_file_db(tempname="foo.db", future=False, scope=None, **options):
- dir_ = _join_path(_get_staging_directory(), "scripts")
- url = "sqlite:///%s/%s" % (dir_, tempname)
- if scope:
- options["scope"] = scope
- return testing_util.testing_engine(url=url, future=future, options=options)
- def _sqlite_testing_config(sourceless=False, future=False):
- dir_ = _join_path(_get_staging_directory(), "scripts")
- url = f"sqlite:///{dir_}/foo.db"
- sqlalchemy_future = future or ("future" in config.db.__class__.__module__)
- return _write_config_file(
- f"""
- [alembic]
- script_location = {dir_}
- sqlalchemy.url = {url}
- sourceless = {"true" if sourceless else "false"}
- {"sqlalchemy.future = true" if sqlalchemy_future else ""}
- [loggers]
- keys = root,sqlalchemy
- [handlers]
- keys = console
- [logger_root]
- level = WARNING
- handlers = console
- qualname =
- [logger_sqlalchemy]
- level = DEBUG
- handlers =
- qualname = sqlalchemy.engine
- [handler_console]
- class = StreamHandler
- args = (sys.stderr,)
- level = NOTSET
- formatter = generic
- [formatters]
- keys = generic
- [formatter_generic]
- format = %%(levelname)-5.5s [%%(name)s] %%(message)s
- datefmt = %%H:%%M:%%S
- """
- )
- def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
- dir_ = _join_path(_get_staging_directory(), "scripts")
- sqlalchemy_future = "future" in config.db.__class__.__module__
- url = "sqlite:///%s/foo.db" % dir_
- return _write_config_file(
- f"""
- [alembic]
- script_location = {dir_}
- sqlalchemy.url = {url}
- sqlalchemy.future = {"true" if sqlalchemy_future else "false"}
- sourceless = {"true" if sourceless else "false"}
- path_separator = space
- version_locations = %(here)s/model1/ %(here)s/model2/ %(here)s/model3/ \
- {extra_version_location}
- [loggers]
- keys = root
- [handlers]
- keys = console
- [logger_root]
- level = WARNING
- handlers = console
- qualname =
- [handler_console]
- class = StreamHandler
- args = (sys.stderr,)
- level = NOTSET
- formatter = generic
- [formatters]
- keys = generic
- [formatter_generic]
- format = %%(levelname)-5.5s [%%(name)s] %%(message)s
- datefmt = %%H:%%M:%%S
- """
- )
- def _no_sql_pyproject_config(dialect="postgresql", directives=""):
- """use a postgresql url with no host so that
- connections guaranteed to fail"""
- dir_ = _join_path(_get_staging_directory(), "scripts")
- return _write_toml_config(
- f"""
- [tool.alembic]
- script_location ="{dir_}"
- {textwrap.dedent(directives)}
- """,
- f"""
- [alembic]
- sqlalchemy.url = {dialect}://
- [loggers]
- keys = root
- [handlers]
- keys = console
- [logger_root]
- level = WARNING
- handlers = console
- qualname =
- [handler_console]
- class = StreamHandler
- args = (sys.stderr,)
- level = NOTSET
- formatter = generic
- [formatters]
- keys = generic
- [formatter_generic]
- format = %%(levelname)-5.5s [%%(name)s] %%(message)s
- datefmt = %%H:%%M:%%S
- """,
- )
- def _no_sql_testing_config(dialect="postgresql", directives=""):
- """use a postgresql url with no host so that
- connections guaranteed to fail"""
- dir_ = _join_path(_get_staging_directory(), "scripts")
- return _write_config_file(
- f"""
- [alembic]
- script_location ={dir_}
- sqlalchemy.url = {dialect}://
- {directives}
- [loggers]
- keys = root
- [handlers]
- keys = console
- [logger_root]
- level = WARNING
- handlers = console
- qualname =
- [handler_console]
- class = StreamHandler
- args = (sys.stderr,)
- level = NOTSET
- formatter = generic
- [formatters]
- keys = generic
- [formatter_generic]
- format = %%(levelname)-5.5s [%%(name)s] %%(message)s
- datefmt = %%H:%%M:%%S
- """
- )
- def _write_toml_config(tomltext, initext):
- cfg = _write_config_file(initext)
- with open(cfg.toml_file_name, "w") as f:
- f.write(tomltext)
- return cfg
- def _write_config_file(text):
- cfg = _testing_config()
- with open(cfg.config_file_name, "w") as f:
- f.write(text)
- return cfg
- def _testing_config():
- from alembic.config import Config
- if not os.access(_get_staging_directory(), os.F_OK):
- os.mkdir(_get_staging_directory())
- return Config(
- _join_path(_get_staging_directory(), "test_alembic.ini"),
- _join_path(_get_staging_directory(), "pyproject.toml"),
- )
- def write_script(
- scriptdir, rev_id, content, encoding="ascii", sourceless=False
- ):
- old = scriptdir.revision_map.get_revision(rev_id)
- path = old.path
- content = textwrap.dedent(content)
- if encoding:
- content = content.encode(encoding)
- with open(path, "wb") as fp:
- fp.write(content)
- pyc_path = util.pyc_file_from_path(path)
- if pyc_path:
- os.unlink(pyc_path)
- script = Script._from_path(scriptdir, path)
- old = scriptdir.revision_map.get_revision(script.revision)
- if old.down_revision != script.down_revision:
- raise Exception("Can't change down_revision on a refresh operation.")
- scriptdir.revision_map.add_revision(script, _replace=True)
- if sourceless:
- make_sourceless(
- path, "pep3147" if sourceless == "pep3147_everything" else "simple"
- )
- def make_sourceless(path, style):
- import py_compile
- py_compile.compile(path)
- if style == "simple":
- pyc_path = util.pyc_file_from_path(path)
- suffix = importlib.machinery.BYTECODE_SUFFIXES[0]
- filepath, ext = os.path.splitext(path)
- simple_pyc_path = filepath + suffix
- shutil.move(pyc_path, simple_pyc_path)
- pyc_path = simple_pyc_path
- else:
- assert style in ("pep3147", "simple")
- pyc_path = util.pyc_file_from_path(path)
- assert os.access(pyc_path, os.F_OK)
- os.unlink(path)
- def three_rev_fixture(cfg):
- a = util.rev_id()
- b = util.rev_id()
- c = util.rev_id()
- script = ScriptDirectory.from_config(cfg)
- script.generate_revision(a, "revision a", refresh=True, head="base")
- write_script(
- script,
- a,
- f"""\
- "Rev A"
- revision = '{a}'
- down_revision = None
- from alembic import op
- def upgrade():
- op.execute("CREATE STEP 1")
- def downgrade():
- op.execute("DROP STEP 1")
- """,
- )
- script.generate_revision(b, "revision b", refresh=True, head=a)
- write_script(
- script,
- b,
- f"""# coding: utf-8
- "Rev B, méil, %3"
- revision = '{b}'
- down_revision = '{a}'
- from alembic import op
- def upgrade():
- op.execute("CREATE STEP 2")
- def downgrade():
- op.execute("DROP STEP 2")
- """,
- encoding="utf-8",
- )
- script.generate_revision(c, "revision c", refresh=True, head=b)
- write_script(
- script,
- c,
- f"""\
- "Rev C"
- revision = '{c}'
- down_revision = '{b}'
- from alembic import op
- def upgrade():
- op.execute("CREATE STEP 3")
- def downgrade():
- op.execute("DROP STEP 3")
- """,
- )
- return a, b, c
- def multi_heads_fixture(cfg, a, b, c):
- """Create a multiple head fixture from the three-revs fixture"""
- # a->b->c
- # -> d -> e
- # -> f
- d = util.rev_id()
- e = util.rev_id()
- f = util.rev_id()
- script = ScriptDirectory.from_config(cfg)
- script.generate_revision(
- d, "revision d from b", head=b, splice=True, refresh=True
- )
- write_script(
- script,
- d,
- f"""\
- "Rev D"
- revision = '{d}'
- down_revision = '{b}'
- from alembic import op
- def upgrade():
- op.execute("CREATE STEP 4")
- def downgrade():
- op.execute("DROP STEP 4")
- """,
- )
- script.generate_revision(
- e, "revision e from d", head=d, splice=True, refresh=True
- )
- write_script(
- script,
- e,
- f"""\
- "Rev E"
- revision = '{e}'
- down_revision = '{d}'
- from alembic import op
- def upgrade():
- op.execute("CREATE STEP 5")
- def downgrade():
- op.execute("DROP STEP 5")
- """,
- )
- script.generate_revision(
- f, "revision f from b", head=b, splice=True, refresh=True
- )
- write_script(
- script,
- f,
- f"""\
- "Rev F"
- revision = '{f}'
- down_revision = '{b}'
- from alembic import op
- def upgrade():
- op.execute("CREATE STEP 6")
- def downgrade():
- op.execute("DROP STEP 6")
- """,
- )
- return d, e, f
- def _multidb_testing_config(engines):
- """alembic.ini fixture to work exactly with the 'multidb' template"""
- dir_ = _join_path(_get_staging_directory(), "scripts")
- sqlalchemy_future = "future" in config.db.__class__.__module__
- databases = ", ".join(engines.keys())
- engines = "\n\n".join(
- f"[{key}]\nsqlalchemy.url = {value.url}"
- for key, value in engines.items()
- )
- return _write_config_file(
- f"""
- [alembic]
- script_location = {dir_}
- sourceless = false
- sqlalchemy.future = {"true" if sqlalchemy_future else "false"}
- databases = {databases}
- {engines}
- [loggers]
- keys = root
- [handlers]
- keys = console
- [logger_root]
- level = WARNING
- handlers = console
- qualname =
- [handler_console]
- class = StreamHandler
- args = (sys.stderr,)
- level = NOTSET
- formatter = generic
- [formatters]
- keys = generic
- [formatter_generic]
- format = %%(levelname)-5.5s [%%(name)s] %%(message)s
- datefmt = %%H:%%M:%%S
- """
- )
- def _join_path(base: str, *more: str):
- return str(Path(base).joinpath(*more).as_posix())
|