engines.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # testing/engines.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. from __future__ import annotations
  9. import collections
  10. import re
  11. import typing
  12. from typing import Any
  13. from typing import Dict
  14. from typing import Optional
  15. import warnings
  16. import weakref
  17. from . import config
  18. from .util import decorator
  19. from .util import gc_collect
  20. from .. import event
  21. from .. import pool
  22. from ..util import await_only
  23. from ..util.typing import Literal
  24. if typing.TYPE_CHECKING:
  25. from ..engine import Engine
  26. from ..engine.url import URL
  27. from ..ext.asyncio import AsyncEngine
  28. class ConnectionKiller:
  29. def __init__(self):
  30. self.proxy_refs = weakref.WeakKeyDictionary()
  31. self.testing_engines = collections.defaultdict(set)
  32. self.dbapi_connections = set()
  33. def add_pool(self, pool):
  34. event.listen(pool, "checkout", self._add_conn)
  35. event.listen(pool, "checkin", self._remove_conn)
  36. event.listen(pool, "close", self._remove_conn)
  37. event.listen(pool, "close_detached", self._remove_conn)
  38. # note we are keeping "invalidated" here, as those are still
  39. # opened connections we would like to roll back
  40. def _add_conn(self, dbapi_con, con_record, con_proxy):
  41. self.dbapi_connections.add(dbapi_con)
  42. self.proxy_refs[con_proxy] = True
  43. def _remove_conn(self, dbapi_conn, *arg):
  44. self.dbapi_connections.discard(dbapi_conn)
  45. def add_engine(self, engine, scope):
  46. self.add_pool(engine.pool)
  47. assert scope in ("class", "global", "function", "fixture")
  48. self.testing_engines[scope].add(engine)
  49. def _safe(self, fn):
  50. try:
  51. fn()
  52. except Exception as e:
  53. warnings.warn(
  54. "testing_reaper couldn't rollback/close connection: %s" % e
  55. )
  56. def rollback_all(self):
  57. for rec in list(self.proxy_refs):
  58. if rec is not None and rec.is_valid:
  59. self._safe(rec.rollback)
  60. def checkin_all(self):
  61. # run pool.checkin() for all ConnectionFairy instances we have
  62. # tracked.
  63. for rec in list(self.proxy_refs):
  64. if rec is not None and rec.is_valid:
  65. self.dbapi_connections.discard(rec.dbapi_connection)
  66. self._safe(rec._checkin)
  67. # for fairy refs that were GCed and could not close the connection,
  68. # such as asyncio, roll back those remaining connections
  69. for con in self.dbapi_connections:
  70. self._safe(con.rollback)
  71. self.dbapi_connections.clear()
  72. def close_all(self):
  73. self.checkin_all()
  74. def prepare_for_drop_tables(self, connection):
  75. # don't do aggressive checks for third party test suites
  76. if not config.bootstrapped_as_sqlalchemy:
  77. return
  78. from . import provision
  79. provision.prepare_for_drop_tables(connection.engine.url, connection)
  80. def _drop_testing_engines(self, scope):
  81. eng = self.testing_engines[scope]
  82. for rec in list(eng):
  83. for proxy_ref in list(self.proxy_refs):
  84. if proxy_ref is not None and proxy_ref.is_valid:
  85. if (
  86. proxy_ref._pool is not None
  87. and proxy_ref._pool is rec.pool
  88. ):
  89. self._safe(proxy_ref._checkin)
  90. if hasattr(rec, "sync_engine"):
  91. await_only(rec.dispose())
  92. else:
  93. rec.dispose()
  94. eng.clear()
  95. def after_test(self):
  96. self._drop_testing_engines("function")
  97. def after_test_outside_fixtures(self, test):
  98. # don't do aggressive checks for third party test suites
  99. if not config.bootstrapped_as_sqlalchemy:
  100. return
  101. if test.__class__.__leave_connections_for_teardown__:
  102. return
  103. self.checkin_all()
  104. # on PostgreSQL, this will test for any "idle in transaction"
  105. # connections. useful to identify tests with unusual patterns
  106. # that can't be cleaned up correctly.
  107. from . import provision
  108. with config.db.connect() as conn:
  109. provision.prepare_for_drop_tables(conn.engine.url, conn)
  110. def stop_test_class_inside_fixtures(self):
  111. self.checkin_all()
  112. self._drop_testing_engines("function")
  113. self._drop_testing_engines("class")
  114. def stop_test_class_outside_fixtures(self):
  115. # ensure no refs to checked out connections at all.
  116. if pool.base._strong_ref_connection_records:
  117. gc_collect()
  118. if pool.base._strong_ref_connection_records:
  119. ln = len(pool.base._strong_ref_connection_records)
  120. pool.base._strong_ref_connection_records.clear()
  121. assert (
  122. False
  123. ), "%d connection recs not cleared after test suite" % (ln)
  124. def final_cleanup(self):
  125. self.checkin_all()
  126. for scope in self.testing_engines:
  127. self._drop_testing_engines(scope)
  128. def assert_all_closed(self):
  129. for rec in self.proxy_refs:
  130. if rec.is_valid:
  131. assert False
  132. testing_reaper = ConnectionKiller()
  133. @decorator
  134. def assert_conns_closed(fn, *args, **kw):
  135. try:
  136. fn(*args, **kw)
  137. finally:
  138. testing_reaper.assert_all_closed()
  139. @decorator
  140. def rollback_open_connections(fn, *args, **kw):
  141. """Decorator that rolls back all open connections after fn execution."""
  142. try:
  143. fn(*args, **kw)
  144. finally:
  145. testing_reaper.rollback_all()
  146. @decorator
  147. def close_first(fn, *args, **kw):
  148. """Decorator that closes all connections before fn execution."""
  149. testing_reaper.checkin_all()
  150. fn(*args, **kw)
  151. @decorator
  152. def close_open_connections(fn, *args, **kw):
  153. """Decorator that closes all connections after fn execution."""
  154. try:
  155. fn(*args, **kw)
  156. finally:
  157. testing_reaper.checkin_all()
  158. def all_dialects(exclude=None):
  159. import sqlalchemy.dialects as d
  160. for name in d.__all__:
  161. # TEMPORARY
  162. if exclude and name in exclude:
  163. continue
  164. mod = getattr(d, name, None)
  165. if not mod:
  166. mod = getattr(
  167. __import__("sqlalchemy.dialects.%s" % name).dialects, name
  168. )
  169. yield mod.dialect()
  170. class ReconnectFixture:
  171. def __init__(self, dbapi):
  172. self.dbapi = dbapi
  173. self.connections = []
  174. self.is_stopped = False
  175. def __getattr__(self, key):
  176. return getattr(self.dbapi, key)
  177. def connect(self, *args, **kwargs):
  178. conn = self.dbapi.connect(*args, **kwargs)
  179. if self.is_stopped:
  180. self._safe(conn.close)
  181. curs = conn.cursor() # should fail on Oracle etc.
  182. # should fail for everything that didn't fail
  183. # above, connection is closed
  184. curs.execute("select 1")
  185. assert False, "simulated connect failure didn't work"
  186. else:
  187. self.connections.append(conn)
  188. return conn
  189. def _safe(self, fn):
  190. try:
  191. fn()
  192. except Exception as e:
  193. warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
  194. def shutdown(self, stop=False):
  195. # TODO: this doesn't cover all cases
  196. # as nicely as we'd like, namely MySQLdb.
  197. # would need to implement R. Brewer's
  198. # proxy server idea to get better
  199. # coverage.
  200. self.is_stopped = stop
  201. for c in list(self.connections):
  202. self._safe(c.close)
  203. self.connections = []
  204. def restart(self):
  205. self.is_stopped = False
  206. def reconnecting_engine(url=None, options=None):
  207. url = url or config.db.url
  208. dbapi = config.db.dialect.dbapi
  209. if not options:
  210. options = {}
  211. options["module"] = ReconnectFixture(dbapi)
  212. engine = testing_engine(url, options)
  213. _dispose = engine.dispose
  214. def dispose():
  215. engine.dialect.dbapi.shutdown()
  216. engine.dialect.dbapi.is_stopped = False
  217. _dispose()
  218. engine.test_shutdown = engine.dialect.dbapi.shutdown
  219. engine.test_restart = engine.dialect.dbapi.restart
  220. engine.dispose = dispose
  221. return engine
  222. @typing.overload
  223. def testing_engine(
  224. url: Optional[URL] = None,
  225. options: Optional[Dict[str, Any]] = None,
  226. asyncio: Literal[False] = False,
  227. transfer_staticpool: bool = False,
  228. ) -> Engine: ...
  229. @typing.overload
  230. def testing_engine(
  231. url: Optional[URL] = None,
  232. options: Optional[Dict[str, Any]] = None,
  233. asyncio: Literal[True] = True,
  234. transfer_staticpool: bool = False,
  235. ) -> AsyncEngine: ...
  236. def testing_engine(
  237. url=None,
  238. options=None,
  239. asyncio=False,
  240. transfer_staticpool=False,
  241. share_pool=False,
  242. _sqlite_savepoint=False,
  243. ):
  244. if asyncio:
  245. assert not _sqlite_savepoint
  246. from sqlalchemy.ext.asyncio import (
  247. create_async_engine as create_engine,
  248. )
  249. else:
  250. from sqlalchemy import create_engine
  251. from sqlalchemy.engine.url import make_url
  252. if not options:
  253. use_reaper = True
  254. scope = "function"
  255. sqlite_savepoint = False
  256. else:
  257. use_reaper = options.pop("use_reaper", True)
  258. scope = options.pop("scope", "function")
  259. sqlite_savepoint = options.pop("sqlite_savepoint", False)
  260. url = url or config.db.url
  261. url = make_url(url)
  262. if (
  263. config.db is None or url.drivername == config.db.url.drivername
  264. ) and config.db_opts:
  265. use_options = config.db_opts.copy()
  266. else:
  267. use_options = {}
  268. if options is not None:
  269. use_options.update(options)
  270. engine = create_engine(url, **use_options)
  271. if sqlite_savepoint and engine.name == "sqlite":
  272. # apply SQLite savepoint workaround
  273. @event.listens_for(engine, "connect")
  274. def do_connect(dbapi_connection, connection_record):
  275. dbapi_connection.isolation_level = None
  276. @event.listens_for(engine, "begin")
  277. def do_begin(conn):
  278. conn.exec_driver_sql("BEGIN")
  279. if transfer_staticpool:
  280. from sqlalchemy.pool import StaticPool
  281. if config.db is not None and isinstance(config.db.pool, StaticPool):
  282. use_reaper = False
  283. engine.pool._transfer_from(config.db.pool)
  284. elif share_pool:
  285. engine.pool = config.db.pool
  286. if scope == "global":
  287. if asyncio:
  288. engine.sync_engine._has_events = True
  289. else:
  290. engine._has_events = (
  291. True # enable event blocks, helps with profiling
  292. )
  293. if (
  294. isinstance(engine.pool, pool.QueuePool)
  295. and "pool" not in use_options
  296. and "pool_timeout" not in use_options
  297. and "max_overflow" not in use_options
  298. ):
  299. engine.pool._timeout = 0
  300. engine.pool._max_overflow = 0
  301. if use_reaper:
  302. testing_reaper.add_engine(engine, scope)
  303. return engine
  304. def mock_engine(dialect_name=None):
  305. """Provides a mocking engine based on the current testing.db.
  306. This is normally used to test DDL generation flow as emitted
  307. by an Engine.
  308. It should not be used in other cases, as assert_compile() and
  309. assert_sql_execution() are much better choices with fewer
  310. moving parts.
  311. """
  312. from sqlalchemy import create_mock_engine
  313. if not dialect_name:
  314. dialect_name = config.db.name
  315. buffer = []
  316. def executor(sql, *a, **kw):
  317. buffer.append(sql)
  318. def assert_sql(stmts):
  319. recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
  320. assert recv == stmts, recv
  321. def print_sql():
  322. d = engine.dialect
  323. return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
  324. engine = create_mock_engine(dialect_name + "://", executor)
  325. assert not hasattr(engine, "mock")
  326. engine.mock = buffer
  327. engine.assert_sql = assert_sql
  328. engine.print_sql = print_sql
  329. return engine
  330. class DBAPIProxyCursor:
  331. """Proxy a DBAPI cursor.
  332. Tests can provide subclasses of this to intercept
  333. DBAPI-level cursor operations.
  334. """
  335. def __init__(self, engine, conn, *args, **kwargs):
  336. self.engine = engine
  337. self.connection = conn
  338. self.cursor = conn.cursor(*args, **kwargs)
  339. def execute(self, stmt, parameters=None, **kw):
  340. if parameters:
  341. return self.cursor.execute(stmt, parameters, **kw)
  342. else:
  343. return self.cursor.execute(stmt, **kw)
  344. def executemany(self, stmt, params, **kw):
  345. return self.cursor.executemany(stmt, params, **kw)
  346. def __iter__(self):
  347. return iter(self.cursor)
  348. def __getattr__(self, key):
  349. return getattr(self.cursor, key)
  350. class DBAPIProxyConnection:
  351. """Proxy a DBAPI connection.
  352. Tests can provide subclasses of this to intercept
  353. DBAPI-level connection operations.
  354. """
  355. def __init__(self, engine, conn, cursor_cls):
  356. self.conn = conn
  357. self.engine = engine
  358. self.cursor_cls = cursor_cls
  359. def cursor(self, *args, **kwargs):
  360. return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
  361. def close(self):
  362. self.conn.close()
  363. def __getattr__(self, key):
  364. return getattr(self.conn, key)