orm.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # testing/fixtures/orm.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. from __future__ import annotations
  9. from typing import Any
  10. import sqlalchemy as sa
  11. from .base import TestBase
  12. from .sql import TablesTest
  13. from .. import assertions
  14. from .. import config
  15. from .. import schema
  16. from ..entities import BasicEntity
  17. from ..entities import ComparableEntity
  18. from ..util import adict
  19. from ... import orm
  20. from ...orm import DeclarativeBase
  21. from ...orm import events as orm_events
  22. from ...orm import registry
  23. class ORMTest(TestBase):
  24. @config.fixture
  25. def fixture_session(self):
  26. return fixture_session()
  27. class MappedTest(ORMTest, TablesTest, assertions.AssertsExecutionResults):
  28. # 'once', 'each', None
  29. run_setup_classes = "once"
  30. # 'once', 'each', None
  31. run_setup_mappers = "each"
  32. classes: Any = None
  33. @config.fixture(autouse=True, scope="class")
  34. def _setup_tables_test_class(self):
  35. cls = self.__class__
  36. cls._init_class()
  37. if cls.classes is None:
  38. cls.classes = adict()
  39. cls._setup_once_tables()
  40. cls._setup_once_classes()
  41. cls._setup_once_mappers()
  42. cls._setup_once_inserts()
  43. yield
  44. cls._teardown_once_class()
  45. cls._teardown_once_metadata_bind()
  46. @config.fixture(autouse=True, scope="function")
  47. def _setup_tables_test_instance(self):
  48. self._setup_each_tables()
  49. self._setup_each_classes()
  50. self._setup_each_mappers()
  51. self._setup_each_inserts()
  52. yield
  53. orm.session.close_all_sessions()
  54. self._teardown_each_mappers()
  55. self._teardown_each_classes()
  56. self._teardown_each_tables()
  57. @classmethod
  58. def _teardown_once_class(cls):
  59. cls.classes.clear()
  60. @classmethod
  61. def _setup_once_classes(cls):
  62. if cls.run_setup_classes == "once":
  63. cls._with_register_classes(cls.setup_classes)
  64. @classmethod
  65. def _setup_once_mappers(cls):
  66. if cls.run_setup_mappers == "once":
  67. cls.mapper_registry, cls.mapper = cls._generate_registry()
  68. cls._with_register_classes(cls.setup_mappers)
  69. def _setup_each_mappers(self):
  70. if self.run_setup_mappers != "once":
  71. (
  72. self.__class__.mapper_registry,
  73. self.__class__.mapper,
  74. ) = self._generate_registry()
  75. if self.run_setup_mappers == "each":
  76. self._with_register_classes(self.setup_mappers)
  77. def _setup_each_classes(self):
  78. if self.run_setup_classes == "each":
  79. self._with_register_classes(self.setup_classes)
  80. @classmethod
  81. def _generate_registry(cls):
  82. decl = registry(metadata=cls._tables_metadata)
  83. return decl, decl.map_imperatively
  84. @classmethod
  85. def _with_register_classes(cls, fn):
  86. """Run a setup method, framing the operation with a Base class
  87. that will catch new subclasses to be established within
  88. the "classes" registry.
  89. """
  90. cls_registry = cls.classes
  91. class _Base:
  92. def __init_subclass__(cls) -> None:
  93. assert cls_registry is not None
  94. cls_registry[cls.__name__] = cls
  95. super().__init_subclass__()
  96. class Basic(BasicEntity, _Base):
  97. pass
  98. class Comparable(ComparableEntity, _Base):
  99. pass
  100. cls.Basic = Basic
  101. cls.Comparable = Comparable
  102. fn()
  103. def _teardown_each_mappers(self):
  104. # some tests create mappers in the test bodies
  105. # and will define setup_mappers as None -
  106. # clear mappers in any case
  107. if self.run_setup_mappers != "once":
  108. orm.clear_mappers()
  109. def _teardown_each_classes(self):
  110. if self.run_setup_classes != "once":
  111. self.classes.clear()
  112. @classmethod
  113. def setup_classes(cls):
  114. pass
  115. @classmethod
  116. def setup_mappers(cls):
  117. pass
  118. class DeclarativeMappedTest(MappedTest):
  119. run_setup_classes = "once"
  120. run_setup_mappers = "once"
  121. @classmethod
  122. def _setup_once_tables(cls):
  123. pass
  124. @classmethod
  125. def _with_register_classes(cls, fn):
  126. cls_registry = cls.classes
  127. class _DeclBase(DeclarativeBase):
  128. __table_cls__ = schema.Table
  129. metadata = cls._tables_metadata
  130. type_annotation_map = {
  131. str: sa.String().with_variant(
  132. sa.String(50), "mysql", "mariadb", "oracle"
  133. )
  134. }
  135. def __init_subclass__(cls, **kw) -> None:
  136. assert cls_registry is not None
  137. cls_registry[cls.__name__] = cls
  138. super().__init_subclass__(**kw)
  139. cls.DeclarativeBasic = _DeclBase
  140. # sets up cls.Basic which is helpful for things like composite
  141. # classes
  142. super()._with_register_classes(fn)
  143. if cls._tables_metadata.tables and cls.run_create_tables:
  144. cls._tables_metadata.create_all(config.db)
  145. class RemoveORMEventsGlobally:
  146. @config.fixture(autouse=True)
  147. def _remove_listeners(self):
  148. yield
  149. orm_events.MapperEvents._clear()
  150. orm_events.InstanceEvents._clear()
  151. orm_events.SessionEvents._clear()
  152. orm_events.InstrumentationEvents._clear()
  153. orm_events.QueryEvents._clear()
  154. _fixture_sessions = set()
  155. def fixture_session(**kw):
  156. kw.setdefault("autoflush", True)
  157. kw.setdefault("expire_on_commit", True)
  158. bind = kw.pop("bind", config.db)
  159. sess = orm.Session(bind, **kw)
  160. _fixture_sessions.add(sess)
  161. return sess
  162. def close_all_sessions():
  163. # will close all still-referenced sessions
  164. orm.close_all_sessions()
  165. _fixture_sessions.clear()
  166. def stop_test_class_inside_fixtures(cls):
  167. close_all_sessions()
  168. orm.clear_mappers()
  169. def after_test():
  170. if _fixture_sessions:
  171. close_all_sessions()