sql.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. # testing/fixtures/sql.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. from __future__ import annotations
  9. import itertools
  10. import random
  11. import re
  12. import sys
  13. import sqlalchemy as sa
  14. from .base import TestBase
  15. from .. import config
  16. from .. import mock
  17. from ..assertions import eq_
  18. from ..assertions import ne_
  19. from ..util import adict
  20. from ..util import drop_all_tables_from_metadata
  21. from ... import event
  22. from ... import util
  23. from ...schema import sort_tables_and_constraints
  24. from ...sql import visitors
  25. from ...sql.elements import ClauseElement
  26. class TablesTest(TestBase):
  27. # 'once', None
  28. run_setup_bind = "once"
  29. # 'once', 'each', None
  30. run_define_tables = "once"
  31. # 'once', 'each', None
  32. run_create_tables = "once"
  33. # 'once', 'each', None
  34. run_inserts = "each"
  35. # 'each', None
  36. run_deletes = "each"
  37. # 'once', None
  38. run_dispose_bind = None
  39. bind = None
  40. _tables_metadata = None
  41. tables = None
  42. other = None
  43. sequences = None
  44. @config.fixture(autouse=True, scope="class")
  45. def _setup_tables_test_class(self):
  46. cls = self.__class__
  47. cls._init_class()
  48. cls._setup_once_tables()
  49. cls._setup_once_inserts()
  50. yield
  51. cls._teardown_once_metadata_bind()
  52. @config.fixture(autouse=True, scope="function")
  53. def _setup_tables_test_instance(self):
  54. self._setup_each_tables()
  55. self._setup_each_inserts()
  56. yield
  57. self._teardown_each_tables()
  58. @property
  59. def tables_test_metadata(self):
  60. return self._tables_metadata
  61. @classmethod
  62. def _init_class(cls):
  63. if cls.run_define_tables == "each":
  64. if cls.run_create_tables == "once":
  65. cls.run_create_tables = "each"
  66. assert cls.run_inserts in ("each", None)
  67. cls.other = adict()
  68. cls.tables = adict()
  69. cls.sequences = adict()
  70. cls.bind = cls.setup_bind()
  71. cls._tables_metadata = sa.MetaData()
  72. @classmethod
  73. def _setup_once_inserts(cls):
  74. if cls.run_inserts == "once":
  75. cls._load_fixtures()
  76. with cls.bind.begin() as conn:
  77. cls.insert_data(conn)
  78. @classmethod
  79. def _setup_once_tables(cls):
  80. if cls.run_define_tables == "once":
  81. cls.define_tables(cls._tables_metadata)
  82. if cls.run_create_tables == "once":
  83. cls._tables_metadata.create_all(cls.bind)
  84. cls.tables.update(cls._tables_metadata.tables)
  85. cls.sequences.update(cls._tables_metadata._sequences)
  86. def _setup_each_tables(self):
  87. if self.run_define_tables == "each":
  88. self.define_tables(self._tables_metadata)
  89. if self.run_create_tables == "each":
  90. self._tables_metadata.create_all(self.bind)
  91. self.tables.update(self._tables_metadata.tables)
  92. self.sequences.update(self._tables_metadata._sequences)
  93. elif self.run_create_tables == "each":
  94. self._tables_metadata.create_all(self.bind)
  95. def _setup_each_inserts(self):
  96. if self.run_inserts == "each":
  97. self._load_fixtures()
  98. with self.bind.begin() as conn:
  99. self.insert_data(conn)
  100. def _teardown_each_tables(self):
  101. if self.run_define_tables == "each":
  102. self.tables.clear()
  103. if self.run_create_tables == "each":
  104. drop_all_tables_from_metadata(self._tables_metadata, self.bind)
  105. self._tables_metadata.clear()
  106. elif self.run_create_tables == "each":
  107. drop_all_tables_from_metadata(self._tables_metadata, self.bind)
  108. savepoints = getattr(config.requirements, "savepoints", False)
  109. if savepoints:
  110. savepoints = savepoints.enabled
  111. # no need to run deletes if tables are recreated on setup
  112. if (
  113. self.run_define_tables != "each"
  114. and self.run_create_tables != "each"
  115. and self.run_deletes == "each"
  116. ):
  117. with self.bind.begin() as conn:
  118. for table in reversed(
  119. [
  120. t
  121. for (t, fks) in sort_tables_and_constraints(
  122. self._tables_metadata.tables.values()
  123. )
  124. if t is not None
  125. ]
  126. ):
  127. try:
  128. if savepoints:
  129. with conn.begin_nested():
  130. conn.execute(table.delete())
  131. else:
  132. conn.execute(table.delete())
  133. except sa.exc.DBAPIError as ex:
  134. print(
  135. ("Error emptying table %s: %r" % (table, ex)),
  136. file=sys.stderr,
  137. )
  138. @classmethod
  139. def _teardown_once_metadata_bind(cls):
  140. if cls.run_create_tables:
  141. drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
  142. if cls.run_dispose_bind == "once":
  143. cls.dispose_bind(cls.bind)
  144. cls._tables_metadata.bind = None
  145. if cls.run_setup_bind is not None:
  146. cls.bind = None
  147. @classmethod
  148. def setup_bind(cls):
  149. return config.db
  150. @classmethod
  151. def dispose_bind(cls, bind):
  152. if hasattr(bind, "dispose"):
  153. bind.dispose()
  154. elif hasattr(bind, "close"):
  155. bind.close()
  156. @classmethod
  157. def define_tables(cls, metadata):
  158. pass
  159. @classmethod
  160. def fixtures(cls):
  161. return {}
  162. @classmethod
  163. def insert_data(cls, connection):
  164. pass
  165. def sql_count_(self, count, fn):
  166. self.assert_sql_count(self.bind, fn, count)
  167. def sql_eq_(self, callable_, statements):
  168. self.assert_sql(self.bind, callable_, statements)
  169. @classmethod
  170. def _load_fixtures(cls):
  171. """Insert rows as represented by the fixtures() method."""
  172. headers, rows = {}, {}
  173. for table, data in cls.fixtures().items():
  174. if len(data) < 2:
  175. continue
  176. if isinstance(table, str):
  177. table = cls.tables[table]
  178. headers[table] = data[0]
  179. rows[table] = data[1:]
  180. for table, fks in sort_tables_and_constraints(
  181. cls._tables_metadata.tables.values()
  182. ):
  183. if table is None:
  184. continue
  185. if table not in headers:
  186. continue
  187. with cls.bind.begin() as conn:
  188. conn.execute(
  189. table.insert(),
  190. [
  191. dict(zip(headers[table], column_values))
  192. for column_values in rows[table]
  193. ],
  194. )
  195. class NoCache:
  196. @config.fixture(autouse=True, scope="function")
  197. def _disable_cache(self):
  198. _cache = config.db._compiled_cache
  199. config.db._compiled_cache = None
  200. yield
  201. config.db._compiled_cache = _cache
  202. class RemovesEvents:
  203. @util.memoized_property
  204. def _event_fns(self):
  205. return set()
  206. def event_listen(self, target, name, fn, **kw):
  207. self._event_fns.add((target, name, fn))
  208. event.listen(target, name, fn, **kw)
  209. @config.fixture(autouse=True, scope="function")
  210. def _remove_events(self):
  211. yield
  212. for key in self._event_fns:
  213. event.remove(*key)
  214. class ComputedReflectionFixtureTest(TablesTest):
  215. run_inserts = run_deletes = None
  216. __backend__ = True
  217. __requires__ = ("computed_columns", "table_reflection")
  218. regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
  219. def normalize(self, text):
  220. return self.regexp.sub("", text).lower()
  221. @classmethod
  222. def define_tables(cls, metadata):
  223. from ... import Integer
  224. from ... import testing
  225. from ...schema import Column
  226. from ...schema import Computed
  227. from ...schema import Table
  228. Table(
  229. "computed_default_table",
  230. metadata,
  231. Column("id", Integer, primary_key=True),
  232. Column("normal", Integer),
  233. Column("computed_col", Integer, Computed("normal + 42")),
  234. Column("with_default", Integer, server_default="42"),
  235. )
  236. t = Table(
  237. "computed_column_table",
  238. metadata,
  239. Column("id", Integer, primary_key=True),
  240. Column("normal", Integer),
  241. Column("computed_no_flag", Integer, Computed("normal + 42")),
  242. )
  243. if testing.requires.schemas.enabled:
  244. t2 = Table(
  245. "computed_column_table",
  246. metadata,
  247. Column("id", Integer, primary_key=True),
  248. Column("normal", Integer),
  249. Column("computed_no_flag", Integer, Computed("normal / 42")),
  250. schema=config.test_schema,
  251. )
  252. if testing.requires.computed_columns_virtual.enabled:
  253. t.append_column(
  254. Column(
  255. "computed_virtual",
  256. Integer,
  257. Computed("normal + 2", persisted=False),
  258. )
  259. )
  260. if testing.requires.schemas.enabled:
  261. t2.append_column(
  262. Column(
  263. "computed_virtual",
  264. Integer,
  265. Computed("normal / 2", persisted=False),
  266. )
  267. )
  268. if testing.requires.computed_columns_stored.enabled:
  269. t.append_column(
  270. Column(
  271. "computed_stored",
  272. Integer,
  273. Computed("normal - 42", persisted=True),
  274. )
  275. )
  276. if testing.requires.schemas.enabled:
  277. t2.append_column(
  278. Column(
  279. "computed_stored",
  280. Integer,
  281. Computed("normal * 42", persisted=True),
  282. )
  283. )
  284. class CacheKeyFixture:
  285. def _compare_equal(self, a, b, compare_values):
  286. a_key = a._generate_cache_key()
  287. b_key = b._generate_cache_key()
  288. if a_key is None:
  289. assert a._annotations.get("nocache")
  290. assert b_key is None
  291. else:
  292. eq_(a_key.key, b_key.key)
  293. eq_(hash(a_key.key), hash(b_key.key))
  294. for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
  295. assert a_param.compare(b_param, compare_values=compare_values)
  296. return a_key, b_key
  297. def _run_cache_key_fixture(self, fixture, compare_values):
  298. case_a = fixture()
  299. case_b = fixture()
  300. for a, b in itertools.combinations_with_replacement(
  301. range(len(case_a)), 2
  302. ):
  303. if a == b:
  304. a_key, b_key = self._compare_equal(
  305. case_a[a], case_b[b], compare_values
  306. )
  307. if a_key is None:
  308. continue
  309. else:
  310. a_key = case_a[a]._generate_cache_key()
  311. b_key = case_b[b]._generate_cache_key()
  312. if a_key is None or b_key is None:
  313. if a_key is None:
  314. assert case_a[a]._annotations.get("nocache")
  315. if b_key is None:
  316. assert case_b[b]._annotations.get("nocache")
  317. continue
  318. if a_key.key == b_key.key:
  319. for a_param, b_param in zip(
  320. a_key.bindparams, b_key.bindparams
  321. ):
  322. if not a_param.compare(
  323. b_param, compare_values=compare_values
  324. ):
  325. break
  326. else:
  327. # this fails unconditionally since we could not
  328. # find bound parameter values that differed.
  329. # Usually we intended to get two distinct keys here
  330. # so the failure will be more descriptive using the
  331. # ne_() assertion.
  332. ne_(a_key.key, b_key.key)
  333. else:
  334. ne_(a_key.key, b_key.key)
  335. # ClauseElement-specific test to ensure the cache key
  336. # collected all the bound parameters that aren't marked
  337. # as "literal execute"
  338. if isinstance(case_a[a], ClauseElement) and isinstance(
  339. case_b[b], ClauseElement
  340. ):
  341. assert_a_params = []
  342. assert_b_params = []
  343. for elem in visitors.iterate(case_a[a]):
  344. if elem.__visit_name__ == "bindparam":
  345. assert_a_params.append(elem)
  346. for elem in visitors.iterate(case_b[b]):
  347. if elem.__visit_name__ == "bindparam":
  348. assert_b_params.append(elem)
  349. # note we're asserting the order of the params as well as
  350. # if there are dupes or not. ordering has to be
  351. # deterministic and matches what a traversal would provide.
  352. eq_(
  353. sorted(a_key.bindparams, key=lambda b: b.key),
  354. sorted(
  355. util.unique_list(assert_a_params), key=lambda b: b.key
  356. ),
  357. )
  358. eq_(
  359. sorted(b_key.bindparams, key=lambda b: b.key),
  360. sorted(
  361. util.unique_list(assert_b_params), key=lambda b: b.key
  362. ),
  363. )
  364. def _run_cache_key_equal_fixture(self, fixture, compare_values):
  365. case_a = fixture()
  366. case_b = fixture()
  367. for a, b in itertools.combinations_with_replacement(
  368. range(len(case_a)), 2
  369. ):
  370. self._compare_equal(case_a[a], case_b[b], compare_values)
  371. def insertmanyvalues_fixture(
  372. connection, randomize_rows=False, warn_on_downgraded=False
  373. ):
  374. dialect = connection.dialect
  375. orig_dialect = dialect._deliver_insertmanyvalues_batches
  376. orig_conn = connection._exec_insertmany_context
  377. class RandomCursor:
  378. __slots__ = ("cursor",)
  379. def __init__(self, cursor):
  380. self.cursor = cursor
  381. # only this method is called by the deliver method.
  382. # by not having the other methods we assert that those aren't being
  383. # used
  384. @property
  385. def description(self):
  386. return self.cursor.description
  387. def fetchall(self):
  388. rows = self.cursor.fetchall()
  389. rows = list(rows)
  390. random.shuffle(rows)
  391. return rows
  392. def _deliver_insertmanyvalues_batches(
  393. connection,
  394. cursor,
  395. statement,
  396. parameters,
  397. generic_setinputsizes,
  398. context,
  399. ):
  400. if randomize_rows:
  401. cursor = RandomCursor(cursor)
  402. for batch in orig_dialect(
  403. connection,
  404. cursor,
  405. statement,
  406. parameters,
  407. generic_setinputsizes,
  408. context,
  409. ):
  410. if warn_on_downgraded and batch.is_downgraded:
  411. util.warn("Batches were downgraded for sorted INSERT")
  412. yield batch
  413. def _exec_insertmany_context(dialect, context):
  414. with mock.patch.object(
  415. dialect,
  416. "_deliver_insertmanyvalues_batches",
  417. new=_deliver_insertmanyvalues_batches,
  418. ):
  419. return orig_conn(dialect, context)
  420. connection._exec_insertmany_context = _exec_insertmany_context