test_rowcount.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # testing/suite/test_rowcount.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. from sqlalchemy import bindparam
  9. from sqlalchemy import Column
  10. from sqlalchemy import Integer
  11. from sqlalchemy import MetaData
  12. from sqlalchemy import select
  13. from sqlalchemy import String
  14. from sqlalchemy import Table
  15. from sqlalchemy import testing
  16. from sqlalchemy import text
  17. from sqlalchemy.testing import eq_
  18. from sqlalchemy.testing import fixtures
  19. class RowCountTest(fixtures.TablesTest):
  20. """test rowcount functionality"""
  21. __requires__ = ("sane_rowcount",)
  22. __backend__ = True
  23. @classmethod
  24. def define_tables(cls, metadata):
  25. Table(
  26. "employees",
  27. metadata,
  28. Column(
  29. "employee_id",
  30. Integer,
  31. autoincrement=False,
  32. primary_key=True,
  33. ),
  34. Column("name", String(50)),
  35. Column("department", String(1)),
  36. )
  37. @classmethod
  38. def insert_data(cls, connection):
  39. cls.data = data = [
  40. ("Angela", "A"),
  41. ("Andrew", "A"),
  42. ("Anand", "A"),
  43. ("Bob", "B"),
  44. ("Bobette", "B"),
  45. ("Buffy", "B"),
  46. ("Charlie", "C"),
  47. ("Cynthia", "C"),
  48. ("Chris", "C"),
  49. ]
  50. employees_table = cls.tables.employees
  51. connection.execute(
  52. employees_table.insert(),
  53. [
  54. {"employee_id": i, "name": n, "department": d}
  55. for i, (n, d) in enumerate(data)
  56. ],
  57. )
  58. def test_basic(self, connection):
  59. employees_table = self.tables.employees
  60. s = select(
  61. employees_table.c.name, employees_table.c.department
  62. ).order_by(employees_table.c.employee_id)
  63. rows = connection.execute(s).fetchall()
  64. eq_(rows, self.data)
  65. @testing.variation("statement", ["update", "delete", "insert", "select"])
  66. @testing.variation("close_first", [True, False])
  67. def test_non_rowcount_scenarios_no_raise(
  68. self, connection, statement, close_first
  69. ):
  70. employees_table = self.tables.employees
  71. # WHERE matches 3, 3 rows changed
  72. department = employees_table.c.department
  73. if statement.update:
  74. r = connection.execute(
  75. employees_table.update().where(department == "C"),
  76. {"department": "Z"},
  77. )
  78. elif statement.delete:
  79. r = connection.execute(
  80. employees_table.delete().where(department == "C"),
  81. {"department": "Z"},
  82. )
  83. elif statement.insert:
  84. r = connection.execute(
  85. employees_table.insert(),
  86. [
  87. {"employee_id": 25, "name": "none 1", "department": "X"},
  88. {"employee_id": 26, "name": "none 2", "department": "Z"},
  89. {"employee_id": 27, "name": "none 3", "department": "Z"},
  90. ],
  91. )
  92. elif statement.select:
  93. s = select(
  94. employees_table.c.name, employees_table.c.department
  95. ).where(employees_table.c.department == "C")
  96. r = connection.execute(s)
  97. r.all()
  98. else:
  99. statement.fail()
  100. if close_first:
  101. r.close()
  102. assert r.rowcount in (-1, 3)
  103. def test_update_rowcount1(self, connection):
  104. employees_table = self.tables.employees
  105. # WHERE matches 3, 3 rows changed
  106. department = employees_table.c.department
  107. r = connection.execute(
  108. employees_table.update().where(department == "C"),
  109. {"department": "Z"},
  110. )
  111. assert r.rowcount == 3
  112. def test_update_rowcount2(self, connection):
  113. employees_table = self.tables.employees
  114. # WHERE matches 3, 0 rows changed
  115. department = employees_table.c.department
  116. r = connection.execute(
  117. employees_table.update().where(department == "C"),
  118. {"department": "C"},
  119. )
  120. eq_(r.rowcount, 3)
  121. @testing.variation("implicit_returning", [True, False])
  122. @testing.variation(
  123. "dml",
  124. [
  125. ("update", testing.requires.update_returning),
  126. ("delete", testing.requires.delete_returning),
  127. ],
  128. )
  129. def test_update_delete_rowcount_return_defaults(
  130. self, connection, implicit_returning, dml
  131. ):
  132. """note this test should succeed for all RETURNING backends
  133. as of 2.0. In
  134. Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use
  135. len(rows) when we have implicit returning
  136. """
  137. if implicit_returning:
  138. employees_table = self.tables.employees
  139. else:
  140. employees_table = Table(
  141. "employees",
  142. MetaData(),
  143. Column(
  144. "employee_id",
  145. Integer,
  146. autoincrement=False,
  147. primary_key=True,
  148. ),
  149. Column("name", String(50)),
  150. Column("department", String(1)),
  151. implicit_returning=False,
  152. )
  153. department = employees_table.c.department
  154. if dml.update:
  155. stmt = (
  156. employees_table.update()
  157. .where(department == "C")
  158. .values(name=employees_table.c.department + "Z")
  159. .return_defaults()
  160. )
  161. elif dml.delete:
  162. stmt = (
  163. employees_table.delete()
  164. .where(department == "C")
  165. .return_defaults()
  166. )
  167. else:
  168. dml.fail()
  169. r = connection.execute(stmt)
  170. eq_(r.rowcount, 3)
  171. def test_raw_sql_rowcount(self, connection):
  172. # test issue #3622, make sure eager rowcount is called for text
  173. result = connection.exec_driver_sql(
  174. "update employees set department='Z' where department='C'"
  175. )
  176. eq_(result.rowcount, 3)
  177. def test_text_rowcount(self, connection):
  178. # test issue #3622, make sure eager rowcount is called for text
  179. result = connection.execute(
  180. text("update employees set department='Z' where department='C'")
  181. )
  182. eq_(result.rowcount, 3)
  183. def test_delete_rowcount(self, connection):
  184. employees_table = self.tables.employees
  185. # WHERE matches 3, 3 rows deleted
  186. department = employees_table.c.department
  187. r = connection.execute(
  188. employees_table.delete().where(department == "C")
  189. )
  190. eq_(r.rowcount, 3)
  191. @testing.requires.sane_multi_rowcount
  192. def test_multi_update_rowcount(self, connection):
  193. employees_table = self.tables.employees
  194. stmt = (
  195. employees_table.update()
  196. .where(employees_table.c.name == bindparam("emp_name"))
  197. .values(department="C")
  198. )
  199. r = connection.execute(
  200. stmt,
  201. [
  202. {"emp_name": "Bob"},
  203. {"emp_name": "Cynthia"},
  204. {"emp_name": "nonexistent"},
  205. ],
  206. )
  207. eq_(r.rowcount, 2)
  208. @testing.requires.sane_multi_rowcount
  209. def test_multi_delete_rowcount(self, connection):
  210. employees_table = self.tables.employees
  211. stmt = employees_table.delete().where(
  212. employees_table.c.name == bindparam("emp_name")
  213. )
  214. r = connection.execute(
  215. stmt,
  216. [
  217. {"emp_name": "Bob"},
  218. {"emp_name": "Cynthia"},
  219. {"emp_name": "nonexistent"},
  220. ],
  221. )
  222. eq_(r.rowcount, 2)