| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503 |
- # testing/fixtures/sql.py
- # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- # mypy: ignore-errors
- from __future__ import annotations
- import itertools
- import random
- import re
- import sys
- import sqlalchemy as sa
- from .base import TestBase
- from .. import config
- from .. import mock
- from ..assertions import eq_
- from ..assertions import ne_
- from ..util import adict
- from ..util import drop_all_tables_from_metadata
- from ... import event
- from ... import util
- from ...schema import sort_tables_and_constraints
- from ...sql import visitors
- from ...sql.elements import ClauseElement
- class TablesTest(TestBase):
- # 'once', None
- run_setup_bind = "once"
- # 'once', 'each', None
- run_define_tables = "once"
- # 'once', 'each', None
- run_create_tables = "once"
- # 'once', 'each', None
- run_inserts = "each"
- # 'each', None
- run_deletes = "each"
- # 'once', None
- run_dispose_bind = None
- bind = None
- _tables_metadata = None
- tables = None
- other = None
- sequences = None
- @config.fixture(autouse=True, scope="class")
- def _setup_tables_test_class(self):
- cls = self.__class__
- cls._init_class()
- cls._setup_once_tables()
- cls._setup_once_inserts()
- yield
- cls._teardown_once_metadata_bind()
- @config.fixture(autouse=True, scope="function")
- def _setup_tables_test_instance(self):
- self._setup_each_tables()
- self._setup_each_inserts()
- yield
- self._teardown_each_tables()
- @property
- def tables_test_metadata(self):
- return self._tables_metadata
- @classmethod
- def _init_class(cls):
- if cls.run_define_tables == "each":
- if cls.run_create_tables == "once":
- cls.run_create_tables = "each"
- assert cls.run_inserts in ("each", None)
- cls.other = adict()
- cls.tables = adict()
- cls.sequences = adict()
- cls.bind = cls.setup_bind()
- cls._tables_metadata = sa.MetaData()
- @classmethod
- def _setup_once_inserts(cls):
- if cls.run_inserts == "once":
- cls._load_fixtures()
- with cls.bind.begin() as conn:
- cls.insert_data(conn)
- @classmethod
- def _setup_once_tables(cls):
- if cls.run_define_tables == "once":
- cls.define_tables(cls._tables_metadata)
- if cls.run_create_tables == "once":
- cls._tables_metadata.create_all(cls.bind)
- cls.tables.update(cls._tables_metadata.tables)
- cls.sequences.update(cls._tables_metadata._sequences)
- def _setup_each_tables(self):
- if self.run_define_tables == "each":
- self.define_tables(self._tables_metadata)
- if self.run_create_tables == "each":
- self._tables_metadata.create_all(self.bind)
- self.tables.update(self._tables_metadata.tables)
- self.sequences.update(self._tables_metadata._sequences)
- elif self.run_create_tables == "each":
- self._tables_metadata.create_all(self.bind)
- def _setup_each_inserts(self):
- if self.run_inserts == "each":
- self._load_fixtures()
- with self.bind.begin() as conn:
- self.insert_data(conn)
- def _teardown_each_tables(self):
- if self.run_define_tables == "each":
- self.tables.clear()
- if self.run_create_tables == "each":
- drop_all_tables_from_metadata(self._tables_metadata, self.bind)
- self._tables_metadata.clear()
- elif self.run_create_tables == "each":
- drop_all_tables_from_metadata(self._tables_metadata, self.bind)
- savepoints = getattr(config.requirements, "savepoints", False)
- if savepoints:
- savepoints = savepoints.enabled
- # no need to run deletes if tables are recreated on setup
- if (
- self.run_define_tables != "each"
- and self.run_create_tables != "each"
- and self.run_deletes == "each"
- ):
- with self.bind.begin() as conn:
- for table in reversed(
- [
- t
- for (t, fks) in sort_tables_and_constraints(
- self._tables_metadata.tables.values()
- )
- if t is not None
- ]
- ):
- try:
- if savepoints:
- with conn.begin_nested():
- conn.execute(table.delete())
- else:
- conn.execute(table.delete())
- except sa.exc.DBAPIError as ex:
- print(
- ("Error emptying table %s: %r" % (table, ex)),
- file=sys.stderr,
- )
- @classmethod
- def _teardown_once_metadata_bind(cls):
- if cls.run_create_tables:
- drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
- if cls.run_dispose_bind == "once":
- cls.dispose_bind(cls.bind)
- cls._tables_metadata.bind = None
- if cls.run_setup_bind is not None:
- cls.bind = None
- @classmethod
- def setup_bind(cls):
- return config.db
- @classmethod
- def dispose_bind(cls, bind):
- if hasattr(bind, "dispose"):
- bind.dispose()
- elif hasattr(bind, "close"):
- bind.close()
- @classmethod
- def define_tables(cls, metadata):
- pass
- @classmethod
- def fixtures(cls):
- return {}
- @classmethod
- def insert_data(cls, connection):
- pass
- def sql_count_(self, count, fn):
- self.assert_sql_count(self.bind, fn, count)
- def sql_eq_(self, callable_, statements):
- self.assert_sql(self.bind, callable_, statements)
- @classmethod
- def _load_fixtures(cls):
- """Insert rows as represented by the fixtures() method."""
- headers, rows = {}, {}
- for table, data in cls.fixtures().items():
- if len(data) < 2:
- continue
- if isinstance(table, str):
- table = cls.tables[table]
- headers[table] = data[0]
- rows[table] = data[1:]
- for table, fks in sort_tables_and_constraints(
- cls._tables_metadata.tables.values()
- ):
- if table is None:
- continue
- if table not in headers:
- continue
- with cls.bind.begin() as conn:
- conn.execute(
- table.insert(),
- [
- dict(zip(headers[table], column_values))
- for column_values in rows[table]
- ],
- )
- class NoCache:
- @config.fixture(autouse=True, scope="function")
- def _disable_cache(self):
- _cache = config.db._compiled_cache
- config.db._compiled_cache = None
- yield
- config.db._compiled_cache = _cache
- class RemovesEvents:
- @util.memoized_property
- def _event_fns(self):
- return set()
- def event_listen(self, target, name, fn, **kw):
- self._event_fns.add((target, name, fn))
- event.listen(target, name, fn, **kw)
- @config.fixture(autouse=True, scope="function")
- def _remove_events(self):
- yield
- for key in self._event_fns:
- event.remove(*key)
- class ComputedReflectionFixtureTest(TablesTest):
- run_inserts = run_deletes = None
- __backend__ = True
- __requires__ = ("computed_columns", "table_reflection")
- regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
- def normalize(self, text):
- return self.regexp.sub("", text).lower()
- @classmethod
- def define_tables(cls, metadata):
- from ... import Integer
- from ... import testing
- from ...schema import Column
- from ...schema import Computed
- from ...schema import Table
- Table(
- "computed_default_table",
- metadata,
- Column("id", Integer, primary_key=True),
- Column("normal", Integer),
- Column("computed_col", Integer, Computed("normal + 42")),
- Column("with_default", Integer, server_default="42"),
- )
- t = Table(
- "computed_column_table",
- metadata,
- Column("id", Integer, primary_key=True),
- Column("normal", Integer),
- Column("computed_no_flag", Integer, Computed("normal + 42")),
- )
- if testing.requires.schemas.enabled:
- t2 = Table(
- "computed_column_table",
- metadata,
- Column("id", Integer, primary_key=True),
- Column("normal", Integer),
- Column("computed_no_flag", Integer, Computed("normal / 42")),
- schema=config.test_schema,
- )
- if testing.requires.computed_columns_virtual.enabled:
- t.append_column(
- Column(
- "computed_virtual",
- Integer,
- Computed("normal + 2", persisted=False),
- )
- )
- if testing.requires.schemas.enabled:
- t2.append_column(
- Column(
- "computed_virtual",
- Integer,
- Computed("normal / 2", persisted=False),
- )
- )
- if testing.requires.computed_columns_stored.enabled:
- t.append_column(
- Column(
- "computed_stored",
- Integer,
- Computed("normal - 42", persisted=True),
- )
- )
- if testing.requires.schemas.enabled:
- t2.append_column(
- Column(
- "computed_stored",
- Integer,
- Computed("normal * 42", persisted=True),
- )
- )
- class CacheKeyFixture:
- def _compare_equal(self, a, b, compare_values):
- a_key = a._generate_cache_key()
- b_key = b._generate_cache_key()
- if a_key is None:
- assert a._annotations.get("nocache")
- assert b_key is None
- else:
- eq_(a_key.key, b_key.key)
- eq_(hash(a_key.key), hash(b_key.key))
- for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
- assert a_param.compare(b_param, compare_values=compare_values)
- return a_key, b_key
- def _run_cache_key_fixture(self, fixture, compare_values):
- case_a = fixture()
- case_b = fixture()
- for a, b in itertools.combinations_with_replacement(
- range(len(case_a)), 2
- ):
- if a == b:
- a_key, b_key = self._compare_equal(
- case_a[a], case_b[b], compare_values
- )
- if a_key is None:
- continue
- else:
- a_key = case_a[a]._generate_cache_key()
- b_key = case_b[b]._generate_cache_key()
- if a_key is None or b_key is None:
- if a_key is None:
- assert case_a[a]._annotations.get("nocache")
- if b_key is None:
- assert case_b[b]._annotations.get("nocache")
- continue
- if a_key.key == b_key.key:
- for a_param, b_param in zip(
- a_key.bindparams, b_key.bindparams
- ):
- if not a_param.compare(
- b_param, compare_values=compare_values
- ):
- break
- else:
- # this fails unconditionally since we could not
- # find bound parameter values that differed.
- # Usually we intended to get two distinct keys here
- # so the failure will be more descriptive using the
- # ne_() assertion.
- ne_(a_key.key, b_key.key)
- else:
- ne_(a_key.key, b_key.key)
- # ClauseElement-specific test to ensure the cache key
- # collected all the bound parameters that aren't marked
- # as "literal execute"
- if isinstance(case_a[a], ClauseElement) and isinstance(
- case_b[b], ClauseElement
- ):
- assert_a_params = []
- assert_b_params = []
- for elem in visitors.iterate(case_a[a]):
- if elem.__visit_name__ == "bindparam":
- assert_a_params.append(elem)
- for elem in visitors.iterate(case_b[b]):
- if elem.__visit_name__ == "bindparam":
- assert_b_params.append(elem)
- # note we're asserting the order of the params as well as
- # if there are dupes or not. ordering has to be
- # deterministic and matches what a traversal would provide.
- eq_(
- sorted(a_key.bindparams, key=lambda b: b.key),
- sorted(
- util.unique_list(assert_a_params), key=lambda b: b.key
- ),
- )
- eq_(
- sorted(b_key.bindparams, key=lambda b: b.key),
- sorted(
- util.unique_list(assert_b_params), key=lambda b: b.key
- ),
- )
- def _run_cache_key_equal_fixture(self, fixture, compare_values):
- case_a = fixture()
- case_b = fixture()
- for a, b in itertools.combinations_with_replacement(
- range(len(case_a)), 2
- ):
- self._compare_equal(case_a[a], case_b[b], compare_values)
- def insertmanyvalues_fixture(
- connection, randomize_rows=False, warn_on_downgraded=False
- ):
- dialect = connection.dialect
- orig_dialect = dialect._deliver_insertmanyvalues_batches
- orig_conn = connection._exec_insertmany_context
- class RandomCursor:
- __slots__ = ("cursor",)
- def __init__(self, cursor):
- self.cursor = cursor
- # only this method is called by the deliver method.
- # by not having the other methods we assert that those aren't being
- # used
- @property
- def description(self):
- return self.cursor.description
- def fetchall(self):
- rows = self.cursor.fetchall()
- rows = list(rows)
- random.shuffle(rows)
- return rows
- def _deliver_insertmanyvalues_batches(
- connection,
- cursor,
- statement,
- parameters,
- generic_setinputsizes,
- context,
- ):
- if randomize_rows:
- cursor = RandomCursor(cursor)
- for batch in orig_dialect(
- connection,
- cursor,
- statement,
- parameters,
- generic_setinputsizes,
- context,
- ):
- if warn_on_downgraded and batch.is_downgraded:
- util.warn("Batches were downgraded for sorted INSERT")
- yield batch
- def _exec_insertmany_context(dialect, context):
- with mock.patch.object(
- dialect,
- "_deliver_insertmanyvalues_batches",
- new=_deliver_insertmanyvalues_batches,
- ):
- return orig_conn(dialect, context)
- connection._exec_insertmany_context = _exec_insertmany_context
|