schemacompare.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from itertools import zip_longest
  2. from sqlalchemy import schema
  3. from sqlalchemy.sql.elements import ClauseList
  4. class CompareTable:
  5. def __init__(self, table):
  6. self.table = table
  7. def __eq__(self, other):
  8. if self.table.name != other.name or self.table.schema != other.schema:
  9. return False
  10. for c1, c2 in zip_longest(self.table.c, other.c):
  11. if (c1 is None and c2 is not None) or (
  12. c2 is None and c1 is not None
  13. ):
  14. return False
  15. if CompareColumn(c1) != c2:
  16. return False
  17. return True
  18. # TODO: compare constraints, indexes
  19. def __ne__(self, other):
  20. return not self.__eq__(other)
  21. class CompareColumn:
  22. def __init__(self, column):
  23. self.column = column
  24. def __eq__(self, other):
  25. return (
  26. self.column.name == other.name
  27. and self.column.nullable == other.nullable
  28. )
  29. # TODO: datatypes etc
  30. def __ne__(self, other):
  31. return not self.__eq__(other)
  32. class CompareIndex:
  33. def __init__(self, index, name_only=False):
  34. self.index = index
  35. self.name_only = name_only
  36. def __eq__(self, other):
  37. if self.name_only:
  38. return self.index.name == other.name
  39. else:
  40. return (
  41. str(schema.CreateIndex(self.index))
  42. == str(schema.CreateIndex(other))
  43. and self.index.dialect_kwargs == other.dialect_kwargs
  44. )
  45. def __ne__(self, other):
  46. return not self.__eq__(other)
  47. def __repr__(self):
  48. expr = ClauseList(*self.index.expressions)
  49. try:
  50. expr_str = expr.compile().string
  51. except Exception:
  52. expr_str = str(expr)
  53. return f"<CompareIndex {self.index.name}({expr_str})>"
  54. class CompareCheckConstraint:
  55. def __init__(self, constraint):
  56. self.constraint = constraint
  57. def __eq__(self, other):
  58. return (
  59. isinstance(other, schema.CheckConstraint)
  60. and self.constraint.name == other.name
  61. and (str(self.constraint.sqltext) == str(other.sqltext))
  62. and (other.table.name == self.constraint.table.name)
  63. and other.table.schema == self.constraint.table.schema
  64. )
  65. def __ne__(self, other):
  66. return not self.__eq__(other)
  67. class CompareForeignKey:
  68. def __init__(self, constraint):
  69. self.constraint = constraint
  70. def __eq__(self, other):
  71. r1 = (
  72. isinstance(other, schema.ForeignKeyConstraint)
  73. and self.constraint.name == other.name
  74. and (other.table.name == self.constraint.table.name)
  75. and other.table.schema == self.constraint.table.schema
  76. )
  77. if not r1:
  78. return False
  79. for c1, c2 in zip_longest(self.constraint.columns, other.columns):
  80. if (c1 is None and c2 is not None) or (
  81. c2 is None and c1 is not None
  82. ):
  83. return False
  84. if CompareColumn(c1) != c2:
  85. return False
  86. return True
  87. def __ne__(self, other):
  88. return not self.__eq__(other)
  89. class ComparePrimaryKey:
  90. def __init__(self, constraint):
  91. self.constraint = constraint
  92. def __eq__(self, other):
  93. r1 = (
  94. isinstance(other, schema.PrimaryKeyConstraint)
  95. and self.constraint.name == other.name
  96. and (other.table.name == self.constraint.table.name)
  97. and other.table.schema == self.constraint.table.schema
  98. )
  99. if not r1:
  100. return False
  101. for c1, c2 in zip_longest(self.constraint.columns, other.columns):
  102. if (c1 is None and c2 is not None) or (
  103. c2 is None and c1 is not None
  104. ):
  105. return False
  106. if CompareColumn(c1) != c2:
  107. return False
  108. return True
  109. def __ne__(self, other):
  110. return not self.__eq__(other)
  111. class CompareUniqueConstraint:
  112. def __init__(self, constraint):
  113. self.constraint = constraint
  114. def __eq__(self, other):
  115. r1 = (
  116. isinstance(other, schema.UniqueConstraint)
  117. and self.constraint.name == other.name
  118. and (other.table.name == self.constraint.table.name)
  119. and other.table.schema == self.constraint.table.schema
  120. )
  121. if not r1:
  122. return False
  123. for c1, c2 in zip_longest(self.constraint.columns, other.columns):
  124. if (c1 is None and c2 is not None) or (
  125. c2 is None and c1 is not None
  126. ):
  127. return False
  128. if CompareColumn(c1) != c2:
  129. return False
  130. return True
  131. def __ne__(self, other):
  132. return not self.__eq__(other)