base.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055
  1. from __future__ import annotations
  2. from contextlib import contextmanager
  3. import datetime
  4. import os
  5. from pathlib import Path
  6. import re
  7. import shutil
  8. import sys
  9. from types import ModuleType
  10. from typing import Any
  11. from typing import cast
  12. from typing import Iterator
  13. from typing import List
  14. from typing import Optional
  15. from typing import Sequence
  16. from typing import Set
  17. from typing import Tuple
  18. from typing import TYPE_CHECKING
  19. from typing import Union
  20. from . import revision
  21. from . import write_hooks
  22. from .. import util
  23. from ..runtime import migration
  24. from ..util import compat
  25. from ..util import not_none
  26. from ..util.pyfiles import _preserving_path_as_str
  27. if TYPE_CHECKING:
  28. from .revision import _GetRevArg
  29. from .revision import _RevIdType
  30. from .revision import Revision
  31. from ..config import Config
  32. from ..config import MessagingOptions
  33. from ..config import PostWriteHookConfig
  34. from ..runtime.migration import RevisionStep
  35. from ..runtime.migration import StampStep
  36. try:
  37. if compat.py39:
  38. from zoneinfo import ZoneInfo
  39. from zoneinfo import ZoneInfoNotFoundError
  40. else:
  41. from backports.zoneinfo import ZoneInfo # type: ignore[import-not-found,no-redef] # noqa: E501
  42. from backports.zoneinfo import ZoneInfoNotFoundError # type: ignore[no-redef] # noqa: E501
  43. except ImportError:
  44. ZoneInfo = None # type: ignore[assignment, misc]
  45. _sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$")
  46. _only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$")
  47. _legacy_rev = re.compile(r"([a-f0-9]+)\.py$")
  48. _slug_re = re.compile(r"\w+")
  49. _default_file_template = "%(rev)s_%(slug)s"
  50. class ScriptDirectory:
  51. """Provides operations upon an Alembic script directory.
  52. This object is useful to get information as to current revisions,
  53. most notably being able to get at the "head" revision, for schemes
  54. that want to test if the current revision in the database is the most
  55. recent::
  56. from alembic.script import ScriptDirectory
  57. from alembic.config import Config
  58. config = Config()
  59. config.set_main_option("script_location", "myapp:migrations")
  60. script = ScriptDirectory.from_config(config)
  61. head_revision = script.get_current_head()
  62. """
  63. def __init__(
  64. self,
  65. dir: Union[str, os.PathLike[str]], # noqa: A002
  66. file_template: str = _default_file_template,
  67. truncate_slug_length: Optional[int] = 40,
  68. version_locations: Optional[
  69. Sequence[Union[str, os.PathLike[str]]]
  70. ] = None,
  71. sourceless: bool = False,
  72. output_encoding: str = "utf-8",
  73. timezone: Optional[str] = None,
  74. hooks: list[PostWriteHookConfig] = [],
  75. recursive_version_locations: bool = False,
  76. messaging_opts: MessagingOptions = cast(
  77. "MessagingOptions", util.EMPTY_DICT
  78. ),
  79. ) -> None:
  80. self.dir = _preserving_path_as_str(dir)
  81. self.version_locations = [
  82. _preserving_path_as_str(p) for p in version_locations or ()
  83. ]
  84. self.file_template = file_template
  85. self.truncate_slug_length = truncate_slug_length or 40
  86. self.sourceless = sourceless
  87. self.output_encoding = output_encoding
  88. self.revision_map = revision.RevisionMap(self._load_revisions)
  89. self.timezone = timezone
  90. self.hooks = hooks
  91. self.recursive_version_locations = recursive_version_locations
  92. self.messaging_opts = messaging_opts
  93. if not os.access(dir, os.F_OK):
  94. raise util.CommandError(
  95. f"Path doesn't exist: {dir}. Please use "
  96. "the 'init' command to create a new "
  97. "scripts folder."
  98. )
  99. @property
  100. def versions(self) -> str:
  101. """return a single version location based on the sole path passed
  102. within version_locations.
  103. If multiple version locations are configured, an error is raised.
  104. """
  105. return str(self._singular_version_location)
  106. @util.memoized_property
  107. def _singular_version_location(self) -> Path:
  108. loc = self._version_locations
  109. if len(loc) > 1:
  110. raise util.CommandError("Multiple version_locations present")
  111. else:
  112. return loc[0]
  113. @util.memoized_property
  114. def _version_locations(self) -> Sequence[Path]:
  115. if self.version_locations:
  116. return [
  117. util.coerce_resource_to_filename(location).absolute()
  118. for location in self.version_locations
  119. ]
  120. else:
  121. return [Path(self.dir, "versions").absolute()]
  122. def _load_revisions(self) -> Iterator[Script]:
  123. paths = [vers for vers in self._version_locations if vers.exists()]
  124. dupes = set()
  125. for vers in paths:
  126. for file_path in Script._list_py_dir(self, vers):
  127. real_path = file_path.resolve()
  128. if real_path in dupes:
  129. util.warn(
  130. f"File {real_path} loaded twice! ignoring. "
  131. "Please ensure version_locations is unique."
  132. )
  133. continue
  134. dupes.add(real_path)
  135. script = Script._from_path(self, real_path)
  136. if script is None:
  137. continue
  138. yield script
  139. @classmethod
  140. def from_config(cls, config: Config) -> ScriptDirectory:
  141. """Produce a new :class:`.ScriptDirectory` given a :class:`.Config`
  142. instance.
  143. The :class:`.Config` need only have the ``script_location`` key
  144. present.
  145. """
  146. script_location = config.get_alembic_option("script_location")
  147. if script_location is None:
  148. raise util.CommandError(
  149. "No 'script_location' key found in configuration."
  150. )
  151. truncate_slug_length: Optional[int]
  152. tsl = config.get_alembic_option("truncate_slug_length")
  153. if tsl is not None:
  154. truncate_slug_length = int(tsl)
  155. else:
  156. truncate_slug_length = None
  157. prepend_sys_path = config.get_prepend_sys_paths_list()
  158. if prepend_sys_path:
  159. sys.path[:0] = prepend_sys_path
  160. rvl = config.get_alembic_boolean_option("recursive_version_locations")
  161. return ScriptDirectory(
  162. util.coerce_resource_to_filename(script_location),
  163. file_template=config.get_alembic_option(
  164. "file_template", _default_file_template
  165. ),
  166. truncate_slug_length=truncate_slug_length,
  167. sourceless=config.get_alembic_boolean_option("sourceless"),
  168. output_encoding=config.get_alembic_option(
  169. "output_encoding", "utf-8"
  170. ),
  171. version_locations=config.get_version_locations_list(),
  172. timezone=config.get_alembic_option("timezone"),
  173. hooks=config.get_hooks_list(),
  174. recursive_version_locations=rvl,
  175. messaging_opts=config.messaging_opts,
  176. )
  177. @contextmanager
  178. def _catch_revision_errors(
  179. self,
  180. ancestor: Optional[str] = None,
  181. multiple_heads: Optional[str] = None,
  182. start: Optional[str] = None,
  183. end: Optional[str] = None,
  184. resolution: Optional[str] = None,
  185. ) -> Iterator[None]:
  186. try:
  187. yield
  188. except revision.RangeNotAncestorError as rna:
  189. if start is None:
  190. start = cast(Any, rna.lower)
  191. if end is None:
  192. end = cast(Any, rna.upper)
  193. if not ancestor:
  194. ancestor = (
  195. "Requested range %(start)s:%(end)s does not refer to "
  196. "ancestor/descendant revisions along the same branch"
  197. )
  198. ancestor = ancestor % {"start": start, "end": end}
  199. raise util.CommandError(ancestor) from rna
  200. except revision.MultipleHeads as mh:
  201. if not multiple_heads:
  202. multiple_heads = (
  203. "Multiple head revisions are present for given "
  204. "argument '%(head_arg)s'; please "
  205. "specify a specific target revision, "
  206. "'<branchname>@%(head_arg)s' to "
  207. "narrow to a specific head, or 'heads' for all heads"
  208. )
  209. multiple_heads = multiple_heads % {
  210. "head_arg": end or mh.argument,
  211. "heads": util.format_as_comma(mh.heads),
  212. }
  213. raise util.CommandError(multiple_heads) from mh
  214. except revision.ResolutionError as re:
  215. if resolution is None:
  216. resolution = "Can't locate revision identified by '%s'" % (
  217. re.argument
  218. )
  219. raise util.CommandError(resolution) from re
  220. except revision.RevisionError as err:
  221. raise util.CommandError(err.args[0]) from err
  222. def walk_revisions(
  223. self, base: str = "base", head: str = "heads"
  224. ) -> Iterator[Script]:
  225. """Iterate through all revisions.
  226. :param base: the base revision, or "base" to start from the
  227. empty revision.
  228. :param head: the head revision; defaults to "heads" to indicate
  229. all head revisions. May also be "head" to indicate a single
  230. head revision.
  231. """
  232. with self._catch_revision_errors(start=base, end=head):
  233. for rev in self.revision_map.iterate_revisions(
  234. head, base, inclusive=True, assert_relative_length=False
  235. ):
  236. yield cast(Script, rev)
  237. def get_revisions(self, id_: _GetRevArg) -> Tuple[Script, ...]:
  238. """Return the :class:`.Script` instance with the given rev identifier,
  239. symbolic name, or sequence of identifiers.
  240. """
  241. with self._catch_revision_errors():
  242. return cast(
  243. Tuple[Script, ...],
  244. self.revision_map.get_revisions(id_),
  245. )
  246. def get_all_current(self, id_: Tuple[str, ...]) -> Set[Script]:
  247. with self._catch_revision_errors():
  248. return cast(Set[Script], self.revision_map._get_all_current(id_))
  249. def get_revision(self, id_: str) -> Script:
  250. """Return the :class:`.Script` instance with the given rev id.
  251. .. seealso::
  252. :meth:`.ScriptDirectory.get_revisions`
  253. """
  254. with self._catch_revision_errors():
  255. return cast(Script, self.revision_map.get_revision(id_))
  256. def as_revision_number(
  257. self, id_: Optional[str]
  258. ) -> Optional[Union[str, Tuple[str, ...]]]:
  259. """Convert a symbolic revision, i.e. 'head' or 'base', into
  260. an actual revision number."""
  261. with self._catch_revision_errors():
  262. rev, branch_name = self.revision_map._resolve_revision_number(id_)
  263. if not rev:
  264. # convert () to None
  265. return None
  266. elif id_ == "heads":
  267. return rev
  268. else:
  269. return rev[0]
  270. def iterate_revisions(
  271. self,
  272. upper: Union[str, Tuple[str, ...], None],
  273. lower: Union[str, Tuple[str, ...], None],
  274. **kw: Any,
  275. ) -> Iterator[Script]:
  276. """Iterate through script revisions, starting at the given
  277. upper revision identifier and ending at the lower.
  278. The traversal uses strictly the `down_revision`
  279. marker inside each migration script, so
  280. it is a requirement that upper >= lower,
  281. else you'll get nothing back.
  282. The iterator yields :class:`.Script` objects.
  283. .. seealso::
  284. :meth:`.RevisionMap.iterate_revisions`
  285. """
  286. return cast(
  287. Iterator[Script],
  288. self.revision_map.iterate_revisions(upper, lower, **kw),
  289. )
  290. def get_current_head(self) -> Optional[str]:
  291. """Return the current head revision.
  292. If the script directory has multiple heads
  293. due to branching, an error is raised;
  294. :meth:`.ScriptDirectory.get_heads` should be
  295. preferred.
  296. :return: a string revision number.
  297. .. seealso::
  298. :meth:`.ScriptDirectory.get_heads`
  299. """
  300. with self._catch_revision_errors(
  301. multiple_heads=(
  302. "The script directory has multiple heads (due to branching)."
  303. "Please use get_heads(), or merge the branches using "
  304. "alembic merge."
  305. )
  306. ):
  307. return self.revision_map.get_current_head()
  308. def get_heads(self) -> List[str]:
  309. """Return all "versioned head" revisions as strings.
  310. This is normally a list of length one,
  311. unless branches are present. The
  312. :meth:`.ScriptDirectory.get_current_head()` method
  313. can be used normally when a script directory
  314. has only one head.
  315. :return: a tuple of string revision numbers.
  316. """
  317. return list(self.revision_map.heads)
  318. def get_base(self) -> Optional[str]:
  319. """Return the "base" revision as a string.
  320. This is the revision number of the script that
  321. has a ``down_revision`` of None.
  322. If the script directory has multiple bases, an error is raised;
  323. :meth:`.ScriptDirectory.get_bases` should be
  324. preferred.
  325. """
  326. bases = self.get_bases()
  327. if len(bases) > 1:
  328. raise util.CommandError(
  329. "The script directory has multiple bases. "
  330. "Please use get_bases()."
  331. )
  332. elif bases:
  333. return bases[0]
  334. else:
  335. return None
  336. def get_bases(self) -> List[str]:
  337. """return all "base" revisions as strings.
  338. This is the revision number of all scripts that
  339. have a ``down_revision`` of None.
  340. """
  341. return list(self.revision_map.bases)
  342. def _upgrade_revs(
  343. self, destination: str, current_rev: str
  344. ) -> List[RevisionStep]:
  345. with self._catch_revision_errors(
  346. ancestor="Destination %(end)s is not a valid upgrade "
  347. "target from current head(s)",
  348. end=destination,
  349. ):
  350. revs = self.iterate_revisions(
  351. destination, current_rev, implicit_base=True
  352. )
  353. return [
  354. migration.MigrationStep.upgrade_from_script(
  355. self.revision_map, script
  356. )
  357. for script in reversed(list(revs))
  358. ]
  359. def _downgrade_revs(
  360. self, destination: str, current_rev: Optional[str]
  361. ) -> List[RevisionStep]:
  362. with self._catch_revision_errors(
  363. ancestor="Destination %(end)s is not a valid downgrade "
  364. "target from current head(s)",
  365. end=destination,
  366. ):
  367. revs = self.iterate_revisions(
  368. current_rev, destination, select_for_downgrade=True
  369. )
  370. return [
  371. migration.MigrationStep.downgrade_from_script(
  372. self.revision_map, script
  373. )
  374. for script in revs
  375. ]
  376. def _stamp_revs(
  377. self, revision: _RevIdType, heads: _RevIdType
  378. ) -> List[StampStep]:
  379. with self._catch_revision_errors(
  380. multiple_heads="Multiple heads are present; please specify a "
  381. "single target revision"
  382. ):
  383. heads_revs = self.get_revisions(heads)
  384. steps = []
  385. if not revision:
  386. revision = "base"
  387. filtered_heads: List[Script] = []
  388. for rev in util.to_tuple(revision):
  389. if rev:
  390. filtered_heads.extend(
  391. self.revision_map.filter_for_lineage(
  392. cast(Sequence[Script], heads_revs),
  393. rev,
  394. include_dependencies=True,
  395. )
  396. )
  397. filtered_heads = util.unique_list(filtered_heads)
  398. dests = self.get_revisions(revision) or [None]
  399. for dest in dests:
  400. if dest is None:
  401. # dest is 'base'. Return a "delete branch" migration
  402. # for all applicable heads.
  403. steps.extend(
  404. [
  405. migration.StampStep(
  406. head.revision,
  407. None,
  408. False,
  409. True,
  410. self.revision_map,
  411. )
  412. for head in filtered_heads
  413. ]
  414. )
  415. continue
  416. elif dest in filtered_heads:
  417. # the dest is already in the version table, do nothing.
  418. continue
  419. # figure out if the dest is a descendant or an
  420. # ancestor of the selected nodes
  421. descendants = set(
  422. self.revision_map._get_descendant_nodes([dest])
  423. )
  424. ancestors = set(self.revision_map._get_ancestor_nodes([dest]))
  425. if descendants.intersection(filtered_heads):
  426. # heads are above the target, so this is a downgrade.
  427. # we can treat them as a "merge", single step.
  428. assert not ancestors.intersection(filtered_heads)
  429. todo_heads = [head.revision for head in filtered_heads]
  430. step = migration.StampStep(
  431. todo_heads,
  432. dest.revision,
  433. False,
  434. False,
  435. self.revision_map,
  436. )
  437. steps.append(step)
  438. continue
  439. elif ancestors.intersection(filtered_heads):
  440. # heads are below the target, so this is an upgrade.
  441. # we can treat them as a "merge", single step.
  442. todo_heads = [head.revision for head in filtered_heads]
  443. step = migration.StampStep(
  444. todo_heads,
  445. dest.revision,
  446. True,
  447. False,
  448. self.revision_map,
  449. )
  450. steps.append(step)
  451. continue
  452. else:
  453. # destination is in a branch not represented,
  454. # treat it as new branch
  455. step = migration.StampStep(
  456. (), dest.revision, True, True, self.revision_map
  457. )
  458. steps.append(step)
  459. continue
  460. return steps
  461. def run_env(self) -> None:
  462. """Run the script environment.
  463. This basically runs the ``env.py`` script present
  464. in the migration environment. It is called exclusively
  465. by the command functions in :mod:`alembic.command`.
  466. """
  467. util.load_python_file(self.dir, "env.py")
  468. @property
  469. def env_py_location(self) -> str:
  470. return str(Path(self.dir, "env.py"))
  471. def _append_template(self, src: Path, dest: Path, **kw: Any) -> None:
  472. with util.status(
  473. f"Appending to existing {dest.absolute()}",
  474. **self.messaging_opts,
  475. ):
  476. util.template_to_file(
  477. src,
  478. dest,
  479. self.output_encoding,
  480. append_with_newlines=True,
  481. **kw,
  482. )
  483. def _generate_template(self, src: Path, dest: Path, **kw: Any) -> None:
  484. with util.status(
  485. f"Generating {dest.absolute()}", **self.messaging_opts
  486. ):
  487. util.template_to_file(src, dest, self.output_encoding, **kw)
  488. def _copy_file(self, src: Path, dest: Path) -> None:
  489. with util.status(
  490. f"Generating {dest.absolute()}", **self.messaging_opts
  491. ):
  492. shutil.copy(src, dest)
  493. def _ensure_directory(self, path: Path) -> None:
  494. path = path.absolute()
  495. if not path.exists():
  496. with util.status(
  497. f"Creating directory {path}", **self.messaging_opts
  498. ):
  499. os.makedirs(path)
  500. def _generate_create_date(self) -> datetime.datetime:
  501. if self.timezone is not None:
  502. if ZoneInfo is None:
  503. raise util.CommandError(
  504. "Python >= 3.9 is required for timezone support or "
  505. "the 'backports.zoneinfo' package must be installed."
  506. )
  507. # First, assume correct capitalization
  508. try:
  509. tzinfo = ZoneInfo(self.timezone)
  510. except ZoneInfoNotFoundError:
  511. tzinfo = None
  512. if tzinfo is None:
  513. try:
  514. tzinfo = ZoneInfo(self.timezone.upper())
  515. except ZoneInfoNotFoundError:
  516. raise util.CommandError(
  517. "Can't locate timezone: %s" % self.timezone
  518. ) from None
  519. create_date = datetime.datetime.now(
  520. tz=datetime.timezone.utc
  521. ).astimezone(tzinfo)
  522. else:
  523. create_date = datetime.datetime.now()
  524. return create_date
  525. def generate_revision(
  526. self,
  527. revid: str,
  528. message: Optional[str],
  529. head: Optional[_RevIdType] = None,
  530. splice: Optional[bool] = False,
  531. branch_labels: Optional[_RevIdType] = None,
  532. version_path: Union[str, os.PathLike[str], None] = None,
  533. file_template: Optional[str] = None,
  534. depends_on: Optional[_RevIdType] = None,
  535. **kw: Any,
  536. ) -> Optional[Script]:
  537. """Generate a new revision file.
  538. This runs the ``script.py.mako`` template, given
  539. template arguments, and creates a new file.
  540. :param revid: String revision id. Typically this
  541. comes from ``alembic.util.rev_id()``.
  542. :param message: the revision message, the one passed
  543. by the -m argument to the ``revision`` command.
  544. :param head: the head revision to generate against. Defaults
  545. to the current "head" if no branches are present, else raises
  546. an exception.
  547. :param splice: if True, allow the "head" version to not be an
  548. actual head; otherwise, the selected head must be a head
  549. (e.g. endpoint) revision.
  550. """
  551. if head is None:
  552. head = "head"
  553. try:
  554. Script.verify_rev_id(revid)
  555. except revision.RevisionError as err:
  556. raise util.CommandError(err.args[0]) from err
  557. with self._catch_revision_errors(
  558. multiple_heads=(
  559. "Multiple heads are present; please specify the head "
  560. "revision on which the new revision should be based, "
  561. "or perform a merge."
  562. )
  563. ):
  564. heads = cast(
  565. Tuple[Optional["Revision"], ...],
  566. self.revision_map.get_revisions(head),
  567. )
  568. for h in heads:
  569. assert h != "base" # type: ignore[comparison-overlap]
  570. if len(set(heads)) != len(heads):
  571. raise util.CommandError("Duplicate head revisions specified")
  572. create_date = self._generate_create_date()
  573. if version_path is None:
  574. if len(self._version_locations) > 1:
  575. for head_ in heads:
  576. if head_ is not None:
  577. assert isinstance(head_, Script)
  578. version_path = head_._script_path.parent
  579. break
  580. else:
  581. raise util.CommandError(
  582. "Multiple version locations present, "
  583. "please specify --version-path"
  584. )
  585. else:
  586. version_path = self._singular_version_location
  587. else:
  588. version_path = Path(version_path)
  589. assert isinstance(version_path, Path)
  590. norm_path = version_path.absolute()
  591. for vers_path in self._version_locations:
  592. if vers_path.absolute() == norm_path:
  593. break
  594. else:
  595. raise util.CommandError(
  596. f"Path {version_path} is not represented in current "
  597. "version locations"
  598. )
  599. if self.version_locations:
  600. self._ensure_directory(version_path)
  601. path = self._rev_path(version_path, revid, message, create_date)
  602. if not splice:
  603. for head_ in heads:
  604. if head_ is not None and not head_.is_head:
  605. raise util.CommandError(
  606. "Revision %s is not a head revision; please specify "
  607. "--splice to create a new branch from this revision"
  608. % head_.revision
  609. )
  610. resolved_depends_on: Optional[List[str]]
  611. if depends_on:
  612. with self._catch_revision_errors():
  613. resolved_depends_on = [
  614. (
  615. dep
  616. if dep in rev.branch_labels # maintain branch labels
  617. else rev.revision
  618. ) # resolve partial revision identifiers
  619. for rev, dep in [
  620. (not_none(self.revision_map.get_revision(dep)), dep)
  621. for dep in util.to_list(depends_on)
  622. ]
  623. ]
  624. else:
  625. resolved_depends_on = None
  626. self._generate_template(
  627. Path(self.dir, "script.py.mako"),
  628. path,
  629. up_revision=str(revid),
  630. down_revision=revision.tuple_rev_as_scalar(
  631. tuple(h.revision if h is not None else None for h in heads)
  632. ),
  633. branch_labels=util.to_tuple(branch_labels),
  634. depends_on=revision.tuple_rev_as_scalar(resolved_depends_on),
  635. create_date=create_date,
  636. comma=util.format_as_comma,
  637. message=message if message is not None else ("empty message"),
  638. **kw,
  639. )
  640. post_write_hooks = self.hooks
  641. if post_write_hooks:
  642. write_hooks._run_hooks(path, post_write_hooks)
  643. try:
  644. script = Script._from_path(self, path)
  645. except revision.RevisionError as err:
  646. raise util.CommandError(err.args[0]) from err
  647. if script is None:
  648. return None
  649. if branch_labels and not script.branch_labels:
  650. raise util.CommandError(
  651. "Version %s specified branch_labels %s, however the "
  652. "migration file %s does not have them; have you upgraded "
  653. "your script.py.mako to include the "
  654. "'branch_labels' section?"
  655. % (script.revision, branch_labels, script.path)
  656. )
  657. self.revision_map.add_revision(script)
  658. return script
  659. def _rev_path(
  660. self,
  661. path: Union[str, os.PathLike[str]],
  662. rev_id: str,
  663. message: Optional[str],
  664. create_date: datetime.datetime,
  665. ) -> Path:
  666. epoch = int(create_date.timestamp())
  667. slug = "_".join(_slug_re.findall(message or "")).lower()
  668. if len(slug) > self.truncate_slug_length:
  669. slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_"
  670. filename = "%s.py" % (
  671. self.file_template
  672. % {
  673. "rev": rev_id,
  674. "slug": slug,
  675. "epoch": epoch,
  676. "year": create_date.year,
  677. "month": create_date.month,
  678. "day": create_date.day,
  679. "hour": create_date.hour,
  680. "minute": create_date.minute,
  681. "second": create_date.second,
  682. }
  683. )
  684. return Path(path) / filename
  685. class Script(revision.Revision):
  686. """Represent a single revision file in a ``versions/`` directory.
  687. The :class:`.Script` instance is returned by methods
  688. such as :meth:`.ScriptDirectory.iterate_revisions`.
  689. """
  690. def __init__(
  691. self,
  692. module: ModuleType,
  693. rev_id: str,
  694. path: Union[str, os.PathLike[str]],
  695. ):
  696. self.module = module
  697. self.path = _preserving_path_as_str(path)
  698. super().__init__(
  699. rev_id,
  700. module.down_revision,
  701. branch_labels=util.to_tuple(
  702. getattr(module, "branch_labels", None), default=()
  703. ),
  704. dependencies=util.to_tuple(
  705. getattr(module, "depends_on", None), default=()
  706. ),
  707. )
  708. module: ModuleType
  709. """The Python module representing the actual script itself."""
  710. path: str
  711. """Filesystem path of the script."""
  712. @property
  713. def _script_path(self) -> Path:
  714. return Path(self.path)
  715. _db_current_indicator: Optional[bool] = None
  716. """Utility variable which when set will cause string output to indicate
  717. this is a "current" version in some database"""
  718. @property
  719. def doc(self) -> str:
  720. """Return the docstring given in the script."""
  721. return re.split("\n\n", self.longdoc)[0]
  722. @property
  723. def longdoc(self) -> str:
  724. """Return the docstring given in the script."""
  725. doc = self.module.__doc__
  726. if doc:
  727. if hasattr(self.module, "_alembic_source_encoding"):
  728. doc = doc.decode( # type: ignore[attr-defined]
  729. self.module._alembic_source_encoding
  730. )
  731. return doc.strip()
  732. else:
  733. return ""
  734. @property
  735. def log_entry(self) -> str:
  736. entry = "Rev: %s%s%s%s%s\n" % (
  737. self.revision,
  738. " (head)" if self.is_head else "",
  739. " (branchpoint)" if self.is_branch_point else "",
  740. " (mergepoint)" if self.is_merge_point else "",
  741. " (current)" if self._db_current_indicator else "",
  742. )
  743. if self.is_merge_point:
  744. entry += "Merges: %s\n" % (self._format_down_revision(),)
  745. else:
  746. entry += "Parent: %s\n" % (self._format_down_revision(),)
  747. if self.dependencies:
  748. entry += "Also depends on: %s\n" % (
  749. util.format_as_comma(self.dependencies)
  750. )
  751. if self.is_branch_point:
  752. entry += "Branches into: %s\n" % (
  753. util.format_as_comma(self.nextrev)
  754. )
  755. if self.branch_labels:
  756. entry += "Branch names: %s\n" % (
  757. util.format_as_comma(self.branch_labels),
  758. )
  759. entry += "Path: %s\n" % (self.path,)
  760. entry += "\n%s\n" % (
  761. "\n".join(" %s" % para for para in self.longdoc.splitlines())
  762. )
  763. return entry
  764. def __str__(self) -> str:
  765. return "%s -> %s%s%s%s, %s" % (
  766. self._format_down_revision(),
  767. self.revision,
  768. " (head)" if self.is_head else "",
  769. " (branchpoint)" if self.is_branch_point else "",
  770. " (mergepoint)" if self.is_merge_point else "",
  771. self.doc,
  772. )
  773. def _head_only(
  774. self,
  775. include_branches: bool = False,
  776. include_doc: bool = False,
  777. include_parents: bool = False,
  778. tree_indicators: bool = True,
  779. head_indicators: bool = True,
  780. ) -> str:
  781. text = self.revision
  782. if include_parents:
  783. if self.dependencies:
  784. text = "%s (%s) -> %s" % (
  785. self._format_down_revision(),
  786. util.format_as_comma(self.dependencies),
  787. text,
  788. )
  789. else:
  790. text = "%s -> %s" % (self._format_down_revision(), text)
  791. assert text is not None
  792. if include_branches and self.branch_labels:
  793. text += " (%s)" % util.format_as_comma(self.branch_labels)
  794. if head_indicators or tree_indicators:
  795. text += "%s%s%s" % (
  796. " (head)" if self._is_real_head else "",
  797. (
  798. " (effective head)"
  799. if self.is_head and not self._is_real_head
  800. else ""
  801. ),
  802. " (current)" if self._db_current_indicator else "",
  803. )
  804. if tree_indicators:
  805. text += "%s%s" % (
  806. " (branchpoint)" if self.is_branch_point else "",
  807. " (mergepoint)" if self.is_merge_point else "",
  808. )
  809. if include_doc:
  810. text += ", %s" % self.doc
  811. return text
  812. def cmd_format(
  813. self,
  814. verbose: bool,
  815. include_branches: bool = False,
  816. include_doc: bool = False,
  817. include_parents: bool = False,
  818. tree_indicators: bool = True,
  819. ) -> str:
  820. if verbose:
  821. return self.log_entry
  822. else:
  823. return self._head_only(
  824. include_branches, include_doc, include_parents, tree_indicators
  825. )
  826. def _format_down_revision(self) -> str:
  827. if not self.down_revision:
  828. return "<base>"
  829. else:
  830. return util.format_as_comma(self._versioned_down_revisions)
  831. @classmethod
  832. def _list_py_dir(
  833. cls, scriptdir: ScriptDirectory, path: Path
  834. ) -> List[Path]:
  835. paths = []
  836. for root, dirs, files in compat.path_walk(path, top_down=True):
  837. if root.name.endswith("__pycache__"):
  838. # a special case - we may include these files
  839. # if a `sourceless` option is specified
  840. continue
  841. for filename in sorted(files):
  842. paths.append(root / filename)
  843. if scriptdir.sourceless:
  844. # look for __pycache__
  845. py_cache_path = root / "__pycache__"
  846. if py_cache_path.exists():
  847. # add all files from __pycache__ whose filename is not
  848. # already in the names we got from the version directory.
  849. # add as relative paths including __pycache__ token
  850. names = {
  851. Path(filename).name.split(".")[0] for filename in files
  852. }
  853. paths.extend(
  854. py_cache_path / pyc
  855. for pyc in py_cache_path.iterdir()
  856. if pyc.name.split(".")[0] not in names
  857. )
  858. if not scriptdir.recursive_version_locations:
  859. break
  860. # the real script order is defined by revision,
  861. # but it may be undefined if there are many files with a same
  862. # `down_revision`, for a better user experience (ex. debugging),
  863. # we use a deterministic order
  864. dirs.sort()
  865. return paths
  866. @classmethod
  867. def _from_path(
  868. cls, scriptdir: ScriptDirectory, path: Union[str, os.PathLike[str]]
  869. ) -> Optional[Script]:
  870. path = Path(path)
  871. dir_, filename = path.parent, path.name
  872. if scriptdir.sourceless:
  873. py_match = _sourceless_rev_file.match(filename)
  874. else:
  875. py_match = _only_source_rev_file.match(filename)
  876. if not py_match:
  877. return None
  878. py_filename = py_match.group(1)
  879. if scriptdir.sourceless:
  880. is_c = py_match.group(2) == "c"
  881. is_o = py_match.group(2) == "o"
  882. else:
  883. is_c = is_o = False
  884. if is_o or is_c:
  885. py_exists = (dir_ / py_filename).exists()
  886. pyc_exists = (dir_ / (py_filename + "c")).exists()
  887. # prefer .py over .pyc because we'd like to get the
  888. # source encoding; prefer .pyc over .pyo because we'd like to
  889. # have the docstrings which a -OO file would not have
  890. if py_exists or is_o and pyc_exists:
  891. return None
  892. module = util.load_python_file(dir_, filename)
  893. if not hasattr(module, "revision"):
  894. # attempt to get the revision id from the script name,
  895. # this for legacy only
  896. m = _legacy_rev.match(filename)
  897. if not m:
  898. raise util.CommandError(
  899. "Could not determine revision id from "
  900. f"filename {filename}. "
  901. "Be sure the 'revision' variable is "
  902. "declared inside the script (please see 'Upgrading "
  903. "from Alembic 0.1 to 0.2' in the documentation)."
  904. )
  905. else:
  906. revision = m.group(1)
  907. else:
  908. revision = module.revision
  909. return Script(module, revision, dir_ / filename)