topological.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # util/topological.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. """Topological sorting algorithms."""
  8. from __future__ import annotations
  9. from typing import Any
  10. from typing import Collection
  11. from typing import DefaultDict
  12. from typing import Iterable
  13. from typing import Iterator
  14. from typing import Sequence
  15. from typing import Set
  16. from typing import Tuple
  17. from typing import TypeVar
  18. from .. import util
  19. from ..exc import CircularDependencyError
  20. _T = TypeVar("_T", bound=Any)
  21. __all__ = ["sort", "sort_as_subsets", "find_cycles"]
  22. def sort_as_subsets(
  23. tuples: Collection[Tuple[_T, _T]], allitems: Collection[_T]
  24. ) -> Iterator[Sequence[_T]]:
  25. edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
  26. for parent, child in tuples:
  27. edges[child].add(parent)
  28. todo = list(allitems)
  29. todo_set = set(allitems)
  30. while todo_set:
  31. output = []
  32. for node in todo:
  33. if todo_set.isdisjoint(edges[node]):
  34. output.append(node)
  35. if not output:
  36. raise CircularDependencyError(
  37. "Circular dependency detected.",
  38. find_cycles(tuples, allitems),
  39. _gen_edges(edges),
  40. )
  41. todo_set.difference_update(output)
  42. todo = [t for t in todo if t in todo_set]
  43. yield output
  44. def sort(
  45. tuples: Collection[Tuple[_T, _T]],
  46. allitems: Collection[_T],
  47. deterministic_order: bool = True,
  48. ) -> Iterator[_T]:
  49. """sort the given list of items by dependency.
  50. 'tuples' is a list of tuples representing a partial ordering.
  51. deterministic_order is no longer used, the order is now always
  52. deterministic given the order of "allitems". the flag is there
  53. for backwards compatibility with Alembic.
  54. """
  55. for set_ in sort_as_subsets(tuples, allitems):
  56. yield from set_
  57. def find_cycles(
  58. tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
  59. ) -> Set[_T]:
  60. # adapted from:
  61. # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
  62. edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
  63. for parent, child in tuples:
  64. edges[parent].add(child)
  65. nodes_to_test = set(edges)
  66. output = set()
  67. # we'd like to find all nodes that are
  68. # involved in cycles, so we do the full
  69. # pass through the whole thing for each
  70. # node in the original list.
  71. # we can go just through parent edge nodes.
  72. # if a node is only a child and never a parent,
  73. # by definition it can't be part of a cycle. same
  74. # if it's not in the edges at all.
  75. for node in nodes_to_test:
  76. stack = [node]
  77. todo = nodes_to_test.difference(stack)
  78. while stack:
  79. top = stack[-1]
  80. for node in edges[top]:
  81. if node in stack:
  82. cyc = stack[stack.index(node) :]
  83. todo.difference_update(cyc)
  84. output.update(cyc)
  85. if node in todo:
  86. stack.append(node)
  87. todo.remove(node)
  88. break
  89. else:
  90. stack.pop()
  91. return output
  92. def _gen_edges(edges: DefaultDict[_T, Set[_T]]) -> Set[Tuple[_T, _T]]:
  93. return {(right, left) for left in edges for right in edges[left]}