test_cte.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # testing/suite/test_cte.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. from .. import fixtures
  9. from ..assertions import eq_
  10. from ..schema import Column
  11. from ..schema import Table
  12. from ... import column
  13. from ... import ForeignKey
  14. from ... import Integer
  15. from ... import select
  16. from ... import String
  17. from ... import testing
  18. from ... import values
  19. class CTETest(fixtures.TablesTest):
  20. __backend__ = True
  21. __requires__ = ("ctes",)
  22. run_inserts = "each"
  23. run_deletes = "each"
  24. @classmethod
  25. def define_tables(cls, metadata):
  26. Table(
  27. "some_table",
  28. metadata,
  29. Column("id", Integer, primary_key=True),
  30. Column("data", String(50)),
  31. Column("parent_id", ForeignKey("some_table.id")),
  32. )
  33. Table(
  34. "some_other_table",
  35. metadata,
  36. Column("id", Integer, primary_key=True),
  37. Column("data", String(50)),
  38. Column("parent_id", Integer),
  39. )
  40. @classmethod
  41. def insert_data(cls, connection):
  42. connection.execute(
  43. cls.tables.some_table.insert(),
  44. [
  45. {"id": 1, "data": "d1", "parent_id": None},
  46. {"id": 2, "data": "d2", "parent_id": 1},
  47. {"id": 3, "data": "d3", "parent_id": 1},
  48. {"id": 4, "data": "d4", "parent_id": 3},
  49. {"id": 5, "data": "d5", "parent_id": 3},
  50. ],
  51. )
  52. def test_select_nonrecursive_round_trip(self, connection):
  53. some_table = self.tables.some_table
  54. cte = (
  55. select(some_table)
  56. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  57. .cte("some_cte")
  58. )
  59. result = connection.execute(
  60. select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
  61. )
  62. eq_(result.fetchall(), [("d4",)])
  63. def test_select_recursive_round_trip(self, connection):
  64. some_table = self.tables.some_table
  65. cte = (
  66. select(some_table)
  67. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  68. .cte("some_cte", recursive=True)
  69. )
  70. cte_alias = cte.alias("c1")
  71. st1 = some_table.alias()
  72. # note that SQL Server requires this to be UNION ALL,
  73. # can't be UNION
  74. cte = cte.union_all(
  75. select(st1).where(st1.c.id == cte_alias.c.parent_id)
  76. )
  77. result = connection.execute(
  78. select(cte.c.data)
  79. .where(cte.c.data != "d2")
  80. .order_by(cte.c.data.desc())
  81. )
  82. eq_(
  83. result.fetchall(),
  84. [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
  85. )
  86. def test_insert_from_select_round_trip(self, connection):
  87. some_table = self.tables.some_table
  88. some_other_table = self.tables.some_other_table
  89. cte = (
  90. select(some_table)
  91. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  92. .cte("some_cte")
  93. )
  94. connection.execute(
  95. some_other_table.insert().from_select(
  96. ["id", "data", "parent_id"], select(cte)
  97. )
  98. )
  99. eq_(
  100. connection.execute(
  101. select(some_other_table).order_by(some_other_table.c.id)
  102. ).fetchall(),
  103. [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
  104. )
  105. @testing.requires.ctes_with_update_delete
  106. @testing.requires.update_from
  107. def test_update_from_round_trip(self, connection):
  108. some_table = self.tables.some_table
  109. some_other_table = self.tables.some_other_table
  110. connection.execute(
  111. some_other_table.insert().from_select(
  112. ["id", "data", "parent_id"], select(some_table)
  113. )
  114. )
  115. cte = (
  116. select(some_table)
  117. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  118. .cte("some_cte")
  119. )
  120. connection.execute(
  121. some_other_table.update()
  122. .values(parent_id=5)
  123. .where(some_other_table.c.data == cte.c.data)
  124. )
  125. eq_(
  126. connection.execute(
  127. select(some_other_table).order_by(some_other_table.c.id)
  128. ).fetchall(),
  129. [
  130. (1, "d1", None),
  131. (2, "d2", 5),
  132. (3, "d3", 5),
  133. (4, "d4", 5),
  134. (5, "d5", 3),
  135. ],
  136. )
  137. @testing.requires.ctes_with_update_delete
  138. @testing.requires.delete_from
  139. def test_delete_from_round_trip(self, connection):
  140. some_table = self.tables.some_table
  141. some_other_table = self.tables.some_other_table
  142. connection.execute(
  143. some_other_table.insert().from_select(
  144. ["id", "data", "parent_id"], select(some_table)
  145. )
  146. )
  147. cte = (
  148. select(some_table)
  149. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  150. .cte("some_cte")
  151. )
  152. connection.execute(
  153. some_other_table.delete().where(
  154. some_other_table.c.data == cte.c.data
  155. )
  156. )
  157. eq_(
  158. connection.execute(
  159. select(some_other_table).order_by(some_other_table.c.id)
  160. ).fetchall(),
  161. [(1, "d1", None), (5, "d5", 3)],
  162. )
  163. @testing.requires.ctes_with_update_delete
  164. def test_delete_scalar_subq_round_trip(self, connection):
  165. some_table = self.tables.some_table
  166. some_other_table = self.tables.some_other_table
  167. connection.execute(
  168. some_other_table.insert().from_select(
  169. ["id", "data", "parent_id"], select(some_table)
  170. )
  171. )
  172. cte = (
  173. select(some_table)
  174. .where(some_table.c.data.in_(["d2", "d3", "d4"]))
  175. .cte("some_cte")
  176. )
  177. connection.execute(
  178. some_other_table.delete().where(
  179. some_other_table.c.data
  180. == select(cte.c.data)
  181. .where(cte.c.id == some_other_table.c.id)
  182. .scalar_subquery()
  183. )
  184. )
  185. eq_(
  186. connection.execute(
  187. select(some_other_table).order_by(some_other_table.c.id)
  188. ).fetchall(),
  189. [(1, "d1", None), (5, "d5", 3)],
  190. )
  191. @testing.variation("values_named", [True, False])
  192. @testing.variation("cte_named", [True, False])
  193. @testing.variation("literal_binds", [True, False])
  194. @testing.requires.ctes_with_values
  195. def test_values_named_via_cte(
  196. self, connection, values_named, cte_named, literal_binds
  197. ):
  198. cte1 = (
  199. values(
  200. column("col1", String),
  201. column("col2", Integer),
  202. literal_binds=bool(literal_binds),
  203. name="some name" if values_named else None,
  204. )
  205. .data([("a", 2), ("b", 3)])
  206. .cte("cte1" if cte_named else None)
  207. )
  208. stmt = select(cte1)
  209. rows = connection.execute(stmt).all()
  210. eq_(rows, [("a", 2), ("b", 3)])