write_hooks.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
  2. # mypy: no-warn-return-any, allow-any-generics
  3. from __future__ import annotations
  4. import importlib.util
  5. import os
  6. import shlex
  7. import subprocess
  8. import sys
  9. from typing import Any
  10. from typing import Callable
  11. from typing import Dict
  12. from typing import List
  13. from typing import Optional
  14. from typing import TYPE_CHECKING
  15. from typing import Union
  16. from .. import util
  17. from ..util import compat
  18. from ..util.pyfiles import _preserving_path_as_str
  19. if TYPE_CHECKING:
  20. from ..config import PostWriteHookConfig
  21. REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME"
  22. _registry: dict = {}
  23. def register(name: str) -> Callable:
  24. """A function decorator that will register that function as a write hook.
  25. See the documentation linked below for an example.
  26. .. seealso::
  27. :ref:`post_write_hooks_custom`
  28. """
  29. def decorate(fn):
  30. _registry[name] = fn
  31. return fn
  32. return decorate
  33. def _invoke(
  34. name: str,
  35. revision_path: Union[str, os.PathLike[str]],
  36. options: PostWriteHookConfig,
  37. ) -> Any:
  38. """Invokes the formatter registered for the given name.
  39. :param name: The name of a formatter in the registry
  40. :param revision: string path to the revision file
  41. :param options: A dict containing kwargs passed to the
  42. specified formatter.
  43. :raises: :class:`alembic.util.CommandError`
  44. """
  45. revision_path = _preserving_path_as_str(revision_path)
  46. try:
  47. hook = _registry[name]
  48. except KeyError as ke:
  49. raise util.CommandError(
  50. f"No formatter with name '{name}' registered"
  51. ) from ke
  52. else:
  53. return hook(revision_path, options)
  54. def _run_hooks(
  55. path: Union[str, os.PathLike[str]], hooks: list[PostWriteHookConfig]
  56. ) -> None:
  57. """Invoke hooks for a generated revision."""
  58. for hook in hooks:
  59. name = hook["_hook_name"]
  60. try:
  61. type_ = hook["type"]
  62. except KeyError as ke:
  63. raise util.CommandError(
  64. f"Key '{name}.type' (or 'type' in toml) is required "
  65. f"for post write hook {name!r}"
  66. ) from ke
  67. else:
  68. with util.status(
  69. f"Running post write hook {name!r}", newline=True
  70. ):
  71. _invoke(type_, path, hook)
  72. def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]:
  73. """Parse options from a string into a list.
  74. Also substitutes the revision script token with the actual filename of
  75. the revision script.
  76. If the revision script token doesn't occur in the options string, it is
  77. automatically prepended.
  78. """
  79. if REVISION_SCRIPT_TOKEN not in cmdline_options_str:
  80. cmdline_options_str = REVISION_SCRIPT_TOKEN + " " + cmdline_options_str
  81. cmdline_options_list = shlex.split(
  82. cmdline_options_str, posix=compat.is_posix
  83. )
  84. cmdline_options_list = [
  85. option.replace(REVISION_SCRIPT_TOKEN, path)
  86. for option in cmdline_options_list
  87. ]
  88. return cmdline_options_list
  89. def _get_required_option(options: dict, name: str) -> str:
  90. try:
  91. return options[name]
  92. except KeyError as ke:
  93. raise util.CommandError(
  94. f"Key {options['_hook_name']}.{name} is required for post "
  95. f"write hook {options['_hook_name']!r}"
  96. ) from ke
  97. def _run_hook(
  98. path: str, options: dict, ignore_output: bool, command: List[str]
  99. ) -> None:
  100. cwd: Optional[str] = options.get("cwd", None)
  101. cmdline_options_str = options.get("options", "")
  102. cmdline_options_list = _parse_cmdline_options(cmdline_options_str, path)
  103. kw: Dict[str, Any] = {}
  104. if ignore_output:
  105. kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
  106. subprocess.run([*command, *cmdline_options_list], cwd=cwd, **kw)
  107. @register("console_scripts")
  108. def console_scripts(
  109. path: str, options: dict, ignore_output: bool = False
  110. ) -> None:
  111. entrypoint_name = _get_required_option(options, "entrypoint")
  112. for entry in compat.importlib_metadata_get("console_scripts"):
  113. if entry.name == entrypoint_name:
  114. impl: Any = entry
  115. break
  116. else:
  117. raise util.CommandError(
  118. f"Could not find entrypoint console_scripts.{entrypoint_name}"
  119. )
  120. command = [
  121. sys.executable,
  122. "-c",
  123. f"import {impl.module}; {impl.module}.{impl.attr}()",
  124. ]
  125. _run_hook(path, options, ignore_output, command)
  126. @register("exec")
  127. def exec_(path: str, options: dict, ignore_output: bool = False) -> None:
  128. executable = _get_required_option(options, "executable")
  129. _run_hook(path, options, ignore_output, command=[executable])
  130. @register("module")
  131. def module(path: str, options: dict, ignore_output: bool = False) -> None:
  132. module_name = _get_required_option(options, "module")
  133. if importlib.util.find_spec(module_name) is None:
  134. raise util.CommandError(f"Could not find module {module_name}")
  135. command = [sys.executable, "-m", module_name]
  136. _run_hook(path, options, ignore_output, command)