__init__.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import argparse
  2. from functools import wraps
  3. import logging
  4. import os
  5. import sys
  6. from flask import current_app, g
  7. from alembic import __version__ as __alembic_version__
  8. from alembic.config import Config as AlembicConfig
  9. from alembic import command
  10. from alembic.util import CommandError
  11. alembic_version = tuple([int(v) for v in __alembic_version__.split('.')[0:3]])
  12. log = logging.getLogger(__name__)
  13. class _MigrateConfig(object):
  14. def __init__(self, migrate, db, **kwargs):
  15. self.migrate = migrate
  16. self.db = db
  17. self.directory = migrate.directory
  18. self.configure_args = kwargs
  19. @property
  20. def metadata(self):
  21. """
  22. Backwards compatibility, in old releases app.extensions['migrate']
  23. was set to db, and env.py accessed app.extensions['migrate'].metadata
  24. """
  25. return self.db.metadata
  26. class Config(AlembicConfig):
  27. def __init__(self, *args, **kwargs):
  28. self.template_directory = kwargs.pop('template_directory', None)
  29. super().__init__(*args, **kwargs)
  30. def get_template_directory(self):
  31. if self.template_directory:
  32. return self.template_directory
  33. package_dir = os.path.abspath(os.path.dirname(__file__))
  34. return os.path.join(package_dir, 'templates')
  35. class Migrate(object):
  36. def __init__(self, app=None, db=None, directory='migrations', command='db',
  37. compare_type=True, render_as_batch=True, **kwargs):
  38. self.configure_callbacks = []
  39. self.db = db
  40. self.command = command
  41. self.directory = str(directory)
  42. self.alembic_ctx_kwargs = kwargs
  43. self.alembic_ctx_kwargs['compare_type'] = compare_type
  44. self.alembic_ctx_kwargs['render_as_batch'] = render_as_batch
  45. if app is not None and db is not None:
  46. self.init_app(app, db, directory)
  47. def init_app(self, app, db=None, directory=None, command=None,
  48. compare_type=None, render_as_batch=None, **kwargs):
  49. self.db = db or self.db
  50. self.command = command or self.command
  51. self.directory = str(directory or self.directory)
  52. self.alembic_ctx_kwargs.update(kwargs)
  53. if compare_type is not None:
  54. self.alembic_ctx_kwargs['compare_type'] = compare_type
  55. if render_as_batch is not None:
  56. self.alembic_ctx_kwargs['render_as_batch'] = render_as_batch
  57. if not hasattr(app, 'extensions'):
  58. app.extensions = {}
  59. app.extensions['migrate'] = _MigrateConfig(
  60. self, self.db, **self.alembic_ctx_kwargs)
  61. from flask_migrate.cli import db as db_cli_group
  62. app.cli.add_command(db_cli_group, name=self.command)
  63. def configure(self, f):
  64. self.configure_callbacks.append(f)
  65. return f
  66. def call_configure_callbacks(self, config):
  67. for f in self.configure_callbacks:
  68. config = f(config)
  69. return config
  70. def get_config(self, directory=None, x_arg=None, opts=None):
  71. if directory is None:
  72. directory = self.directory
  73. directory = str(directory)
  74. config = Config(os.path.join(directory, 'alembic.ini'))
  75. config.set_main_option('script_location', directory)
  76. if config.cmd_opts is None:
  77. config.cmd_opts = argparse.Namespace()
  78. for opt in opts or []:
  79. setattr(config.cmd_opts, opt, True)
  80. if not hasattr(config.cmd_opts, 'x'):
  81. setattr(config.cmd_opts, 'x', [])
  82. for x in getattr(g, 'x_arg', []):
  83. config.cmd_opts.x.append(x)
  84. if x_arg is not None:
  85. if isinstance(x_arg, list) or isinstance(x_arg, tuple):
  86. for x in x_arg:
  87. config.cmd_opts.x.append(x)
  88. else:
  89. config.cmd_opts.x.append(x_arg)
  90. return self.call_configure_callbacks(config)
  91. def catch_errors(f):
  92. @wraps(f)
  93. def wrapped(*args, **kwargs):
  94. try:
  95. f(*args, **kwargs)
  96. except (CommandError, RuntimeError) as exc:
  97. log.error('Error: ' + str(exc))
  98. sys.exit(1)
  99. return wrapped
  100. @catch_errors
  101. def list_templates():
  102. """List available templates."""
  103. config = Config()
  104. config.print_stdout("Available templates:\n")
  105. for tempname in sorted(os.listdir(config.get_template_directory())):
  106. with open(
  107. os.path.join(config.get_template_directory(), tempname, "README")
  108. ) as readme:
  109. synopsis = next(readme).strip()
  110. config.print_stdout("%s - %s", tempname, synopsis)
  111. @catch_errors
  112. def init(directory=None, multidb=False, template=None, package=False):
  113. """Creates a new migration repository"""
  114. if directory is None:
  115. directory = current_app.extensions['migrate'].directory
  116. template_directory = None
  117. if template is not None and ('/' in template or '\\' in template):
  118. template_directory, template = os.path.split(template)
  119. config = Config(template_directory=template_directory)
  120. config.set_main_option('script_location', directory)
  121. config.config_file_name = os.path.join(directory, 'alembic.ini')
  122. config = current_app.extensions['migrate'].\
  123. migrate.call_configure_callbacks(config)
  124. if multidb and template is None:
  125. template = 'flask-multidb'
  126. elif template is None:
  127. template = 'flask'
  128. command.init(config, directory, template=template, package=package)
  129. @catch_errors
  130. def revision(directory=None, message=None, autogenerate=False, sql=False,
  131. head='head', splice=False, branch_label=None, version_path=None,
  132. rev_id=None):
  133. """Create a new revision file."""
  134. opts = ['autogenerate'] if autogenerate else None
  135. config = current_app.extensions['migrate'].migrate.get_config(
  136. directory, opts=opts)
  137. command.revision(config, message, autogenerate=autogenerate, sql=sql,
  138. head=head, splice=splice, branch_label=branch_label,
  139. version_path=version_path, rev_id=rev_id)
  140. @catch_errors
  141. def migrate(directory=None, message=None, sql=False, head='head', splice=False,
  142. branch_label=None, version_path=None, rev_id=None, x_arg=None):
  143. """Alias for 'revision --autogenerate'"""
  144. config = current_app.extensions['migrate'].migrate.get_config(
  145. directory, opts=['autogenerate'], x_arg=x_arg)
  146. command.revision(config, message, autogenerate=True, sql=sql,
  147. head=head, splice=splice, branch_label=branch_label,
  148. version_path=version_path, rev_id=rev_id)
  149. @catch_errors
  150. def edit(directory=None, revision='current'):
  151. """Edit current revision."""
  152. if alembic_version >= (0, 8, 0):
  153. config = current_app.extensions['migrate'].migrate.get_config(
  154. directory)
  155. command.edit(config, revision)
  156. else:
  157. raise RuntimeError('Alembic 0.8.0 or greater is required')
  158. @catch_errors
  159. def merge(directory=None, revisions='', message=None, branch_label=None,
  160. rev_id=None):
  161. """Merge two revisions together. Creates a new migration file"""
  162. config = current_app.extensions['migrate'].migrate.get_config(directory)
  163. command.merge(config, revisions, message=message,
  164. branch_label=branch_label, rev_id=rev_id)
  165. @catch_errors
  166. def upgrade(directory=None, revision='head', sql=False, tag=None, x_arg=None):
  167. """Upgrade to a later version"""
  168. config = current_app.extensions['migrate'].migrate.get_config(directory,
  169. x_arg=x_arg)
  170. command.upgrade(config, revision, sql=sql, tag=tag)
  171. @catch_errors
  172. def downgrade(directory=None, revision='-1', sql=False, tag=None, x_arg=None):
  173. """Revert to a previous version"""
  174. config = current_app.extensions['migrate'].migrate.get_config(directory,
  175. x_arg=x_arg)
  176. if sql and revision == '-1':
  177. revision = 'head:-1'
  178. command.downgrade(config, revision, sql=sql, tag=tag)
  179. @catch_errors
  180. def show(directory=None, revision='head'):
  181. """Show the revision denoted by the given symbol."""
  182. config = current_app.extensions['migrate'].migrate.get_config(directory)
  183. command.show(config, revision)
  184. @catch_errors
  185. def history(directory=None, rev_range=None, verbose=False,
  186. indicate_current=False):
  187. """List changeset scripts in chronological order."""
  188. config = current_app.extensions['migrate'].migrate.get_config(directory)
  189. if alembic_version >= (0, 9, 9):
  190. command.history(config, rev_range, verbose=verbose,
  191. indicate_current=indicate_current)
  192. else:
  193. command.history(config, rev_range, verbose=verbose)
  194. @catch_errors
  195. def heads(directory=None, verbose=False, resolve_dependencies=False):
  196. """Show current available heads in the script directory"""
  197. config = current_app.extensions['migrate'].migrate.get_config(directory)
  198. command.heads(config, verbose=verbose,
  199. resolve_dependencies=resolve_dependencies)
  200. @catch_errors
  201. def branches(directory=None, verbose=False):
  202. """Show current branch points"""
  203. config = current_app.extensions['migrate'].migrate.get_config(directory)
  204. command.branches(config, verbose=verbose)
  205. @catch_errors
  206. def current(directory=None, verbose=False):
  207. """Display the current revision for each database."""
  208. config = current_app.extensions['migrate'].migrate.get_config(directory)
  209. command.current(config, verbose=verbose)
  210. @catch_errors
  211. def stamp(directory=None, revision='head', sql=False, tag=None, purge=False):
  212. """'stamp' the revision table with the given revision; don't run any
  213. migrations"""
  214. config = current_app.extensions['migrate'].migrate.get_config(directory)
  215. command.stamp(config, revision, sql=sql, tag=tag, purge=purge)
  216. @catch_errors
  217. def check(directory=None):
  218. """Check if there are any new operations to migrate"""
  219. config = current_app.extensions['migrate'].migrate.get_config(directory)
  220. command.check(config)