provision.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. # testing/provision.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. from __future__ import annotations
  9. import collections
  10. import logging
  11. from . import config
  12. from . import engines
  13. from . import util
  14. from .. import exc
  15. from .. import inspect
  16. from ..engine import url as sa_url
  17. from ..sql import ddl
  18. from ..sql import schema
  19. log = logging.getLogger(__name__)
  20. FOLLOWER_IDENT = None
  21. class register:
  22. def __init__(self, decorator=None):
  23. self.fns = {}
  24. self.decorator = decorator
  25. @classmethod
  26. def init(cls, fn):
  27. return register().for_db("*")(fn)
  28. @classmethod
  29. def init_decorator(cls, decorator):
  30. return register(decorator).for_db("*")
  31. def for_db(self, *dbnames):
  32. def decorate(fn):
  33. if self.decorator:
  34. fn = self.decorator(fn)
  35. for dbname in dbnames:
  36. self.fns[dbname] = fn
  37. return self
  38. return decorate
  39. def __call__(self, cfg, *arg, **kw):
  40. if isinstance(cfg, str):
  41. url = sa_url.make_url(cfg)
  42. elif isinstance(cfg, sa_url.URL):
  43. url = cfg
  44. else:
  45. url = cfg.db.url
  46. backend = url.get_backend_name()
  47. if backend in self.fns:
  48. return self.fns[backend](cfg, *arg, **kw)
  49. else:
  50. return self.fns["*"](cfg, *arg, **kw)
  51. def create_follower_db(follower_ident):
  52. for cfg in _configs_for_db_operation():
  53. log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
  54. create_db(cfg, cfg.db, follower_ident)
  55. def setup_config(db_url, options, file_config, follower_ident):
  56. # load the dialect, which should also have it set up its provision
  57. # hooks
  58. dialect = sa_url.make_url(db_url).get_dialect()
  59. dialect.load_provisioning()
  60. if follower_ident:
  61. db_url = follower_url_from_main(db_url, follower_ident)
  62. db_opts = {}
  63. update_db_opts(db_url, db_opts, options)
  64. db_opts["scope"] = "global"
  65. eng = engines.testing_engine(db_url, db_opts)
  66. post_configure_engine(db_url, eng, follower_ident)
  67. eng.connect().close()
  68. cfg = config.Config.register(eng, db_opts, options, file_config)
  69. # a symbolic name that tests can use if they need to disambiguate
  70. # names across databases
  71. if follower_ident:
  72. config.ident = follower_ident
  73. if follower_ident:
  74. configure_follower(cfg, follower_ident)
  75. return cfg
  76. def drop_follower_db(follower_ident):
  77. for cfg in _configs_for_db_operation():
  78. log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
  79. drop_db(cfg, cfg.db, follower_ident)
  80. def generate_db_urls(db_urls, extra_drivers):
  81. """Generate a set of URLs to test given configured URLs plus additional
  82. driver names.
  83. Given:
  84. .. sourcecode:: text
  85. --dburi postgresql://db1 \
  86. --dburi postgresql://db2 \
  87. --dburi postgresql://db2 \
  88. --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
  89. Noting that the default postgresql driver is psycopg2, the output
  90. would be:
  91. .. sourcecode:: text
  92. postgresql+psycopg2://db1
  93. postgresql+asyncpg://db1
  94. postgresql+psycopg2://db2
  95. postgresql+psycopg2://db3
  96. That is, for the driver in a --dburi, we want to keep that and use that
  97. driver for each URL it's part of . For a driver that is only
  98. in --dbdrivers, we want to use it just once for one of the URLs.
  99. for a driver that is both coming from --dburi as well as --dbdrivers,
  100. we want to keep it in that dburi.
  101. Driver specific query options can be specified by added them to the
  102. driver name. For example, to enable the async fallback option for
  103. asyncpg::
  104. .. sourcecode:: text
  105. --dburi postgresql://db1 \
  106. --dbdriver=asyncpg?async_fallback=true
  107. """
  108. urls = set()
  109. backend_to_driver_we_already_have = collections.defaultdict(set)
  110. urls_plus_dialects = [
  111. (url_obj, url_obj.get_dialect())
  112. for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
  113. ]
  114. for url_obj, dialect in urls_plus_dialects:
  115. # use get_driver_name instead of dialect.driver to account for
  116. # "_async" virtual drivers like oracledb and psycopg
  117. driver_name = url_obj.get_driver_name()
  118. backend_to_driver_we_already_have[dialect.name].add(driver_name)
  119. backend_to_driver_we_need = {}
  120. for url_obj, dialect in urls_plus_dialects:
  121. backend = dialect.name
  122. dialect.load_provisioning()
  123. if backend not in backend_to_driver_we_need:
  124. backend_to_driver_we_need[backend] = extra_per_backend = set(
  125. extra_drivers
  126. ).difference(backend_to_driver_we_already_have[backend])
  127. else:
  128. extra_per_backend = backend_to_driver_we_need[backend]
  129. for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
  130. if driver_url in urls:
  131. continue
  132. urls.add(driver_url)
  133. yield driver_url
  134. def _generate_driver_urls(url, extra_drivers):
  135. main_driver = url.get_driver_name()
  136. extra_drivers.discard(main_driver)
  137. url = generate_driver_url(url, main_driver, "")
  138. yield url
  139. for drv in list(extra_drivers):
  140. if "?" in drv:
  141. driver_only, query_str = drv.split("?", 1)
  142. else:
  143. driver_only = drv
  144. query_str = None
  145. new_url = generate_driver_url(url, driver_only, query_str)
  146. if new_url:
  147. extra_drivers.remove(drv)
  148. yield new_url
  149. @register.init
  150. def generate_driver_url(url, driver, query_str):
  151. backend = url.get_backend_name()
  152. new_url = url.set(
  153. drivername="%s+%s" % (backend, driver),
  154. )
  155. if query_str:
  156. new_url = new_url.update_query_string(query_str)
  157. try:
  158. new_url.get_dialect()
  159. except exc.NoSuchModuleError:
  160. return None
  161. else:
  162. return new_url
  163. def _configs_for_db_operation():
  164. hosts = set()
  165. for cfg in config.Config.all_configs():
  166. cfg.db.dispose()
  167. for cfg in config.Config.all_configs():
  168. url = cfg.db.url
  169. backend = url.get_backend_name()
  170. host_conf = (backend, url.username, url.host, url.database)
  171. if host_conf not in hosts:
  172. yield cfg
  173. hosts.add(host_conf)
  174. for cfg in config.Config.all_configs():
  175. cfg.db.dispose()
  176. @register.init
  177. def drop_all_schema_objects_pre_tables(cfg, eng):
  178. pass
  179. @register.init
  180. def drop_all_schema_objects_post_tables(cfg, eng):
  181. pass
  182. def drop_all_schema_objects(cfg, eng):
  183. drop_all_schema_objects_pre_tables(cfg, eng)
  184. drop_views(cfg, eng)
  185. if config.requirements.materialized_views.enabled:
  186. drop_materialized_views(cfg, eng)
  187. inspector = inspect(eng)
  188. consider_schemas = (None,)
  189. if config.requirements.schemas.enabled_for_config(cfg):
  190. consider_schemas += (cfg.test_schema, cfg.test_schema_2)
  191. util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas)
  192. drop_all_schema_objects_post_tables(cfg, eng)
  193. if config.requirements.sequences.enabled_for_config(cfg):
  194. with eng.begin() as conn:
  195. for seq in inspector.get_sequence_names():
  196. conn.execute(ddl.DropSequence(schema.Sequence(seq)))
  197. if config.requirements.schemas.enabled_for_config(cfg):
  198. for schema_name in [cfg.test_schema, cfg.test_schema_2]:
  199. for seq in inspector.get_sequence_names(
  200. schema=schema_name
  201. ):
  202. conn.execute(
  203. ddl.DropSequence(
  204. schema.Sequence(seq, schema=schema_name)
  205. )
  206. )
  207. def drop_views(cfg, eng):
  208. inspector = inspect(eng)
  209. try:
  210. view_names = inspector.get_view_names()
  211. except NotImplementedError:
  212. pass
  213. else:
  214. with eng.begin() as conn:
  215. for vname in view_names:
  216. conn.execute(
  217. ddl._DropView(schema.Table(vname, schema.MetaData()))
  218. )
  219. if config.requirements.schemas.enabled_for_config(cfg):
  220. try:
  221. view_names = inspector.get_view_names(schema=cfg.test_schema)
  222. except NotImplementedError:
  223. pass
  224. else:
  225. with eng.begin() as conn:
  226. for vname in view_names:
  227. conn.execute(
  228. ddl._DropView(
  229. schema.Table(
  230. vname,
  231. schema.MetaData(),
  232. schema=cfg.test_schema,
  233. )
  234. )
  235. )
  236. def drop_materialized_views(cfg, eng):
  237. inspector = inspect(eng)
  238. mview_names = inspector.get_materialized_view_names()
  239. with eng.begin() as conn:
  240. for vname in mview_names:
  241. conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}")
  242. if config.requirements.schemas.enabled_for_config(cfg):
  243. mview_names = inspector.get_materialized_view_names(
  244. schema=cfg.test_schema
  245. )
  246. with eng.begin() as conn:
  247. for vname in mview_names:
  248. conn.exec_driver_sql(
  249. f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}"
  250. )
  251. @register.init
  252. def create_db(cfg, eng, ident):
  253. """Dynamically create a database for testing.
  254. Used when a test run will employ multiple processes, e.g., when run
  255. via `tox` or `pytest -n4`.
  256. """
  257. raise NotImplementedError(
  258. "no DB creation routine for cfg: %s" % (eng.url,)
  259. )
  260. @register.init
  261. def drop_db(cfg, eng, ident):
  262. """Drop a database that we dynamically created for testing."""
  263. raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,))
  264. def _adapt_update_db_opts(fn):
  265. insp = util.inspect_getfullargspec(fn)
  266. if len(insp.args) == 3:
  267. return fn
  268. else:
  269. return lambda db_url, db_opts, _options: fn(db_url, db_opts)
  270. @register.init_decorator(_adapt_update_db_opts)
  271. def update_db_opts(db_url, db_opts, options):
  272. """Set database options (db_opts) for a test database that we created."""
  273. @register.init
  274. def post_configure_engine(url, engine, follower_ident):
  275. """Perform extra steps after configuring an engine for testing.
  276. (For the internal dialects, currently only used by sqlite, oracle, mssql)
  277. """
  278. @register.init
  279. def follower_url_from_main(url, ident):
  280. """Create a connection URL for a dynamically-created test database.
  281. :param url: the connection URL specified when the test run was invoked
  282. :param ident: the pytest-xdist "worker identifier" to be used as the
  283. database name
  284. """
  285. url = sa_url.make_url(url)
  286. return url.set(database=ident)
  287. @register.init
  288. def configure_follower(cfg, ident):
  289. """Create dialect-specific config settings for a follower database."""
  290. pass
  291. @register.init
  292. def run_reap_dbs(url, ident):
  293. """Remove databases that were created during the test process, after the
  294. process has ended.
  295. This is an optional step that is invoked for certain backends that do not
  296. reliably release locks on the database as long as a process is still in
  297. use. For the internal dialects, this is currently only necessary for
  298. mssql and oracle.
  299. """
  300. def reap_dbs(idents_file):
  301. log.info("Reaping databases...")
  302. urls = collections.defaultdict(set)
  303. idents = collections.defaultdict(set)
  304. dialects = {}
  305. with open(idents_file) as file_:
  306. for line in file_:
  307. line = line.strip()
  308. db_name, db_url = line.split(" ")
  309. url_obj = sa_url.make_url(db_url)
  310. if db_name not in dialects:
  311. dialects[db_name] = url_obj.get_dialect()
  312. dialects[db_name].load_provisioning()
  313. url_key = (url_obj.get_backend_name(), url_obj.host)
  314. urls[url_key].add(db_url)
  315. idents[url_key].add(db_name)
  316. for url_key in urls:
  317. url = list(urls[url_key])[0]
  318. ident = idents[url_key]
  319. run_reap_dbs(url, ident)
  320. @register.init
  321. def temp_table_keyword_args(cfg, eng):
  322. """Specify keyword arguments for creating a temporary Table.
  323. Dialect-specific implementations of this method will return the
  324. kwargs that are passed to the Table method when creating a temporary
  325. table for testing, e.g., in the define_temp_tables method of the
  326. ComponentReflectionTest class in suite/test_reflection.py
  327. """
  328. raise NotImplementedError(
  329. "no temp table keyword args routine for cfg: %s" % (eng.url,)
  330. )
  331. @register.init
  332. def prepare_for_drop_tables(config, connection):
  333. pass
  334. @register.init
  335. def stop_test_class_outside_fixtures(config, db, testcls):
  336. pass
  337. @register.init
  338. def get_temp_table_name(cfg, eng, base_name):
  339. """Specify table name for creating a temporary Table.
  340. Dialect-specific implementations of this method will return the
  341. name to use when creating a temporary table for testing,
  342. e.g., in the define_temp_tables method of the
  343. ComponentReflectionTest class in suite/test_reflection.py
  344. Default to just the base name since that's what most dialects will
  345. use. The mssql dialect's implementation will need a "#" prepended.
  346. """
  347. return base_name
  348. @register.init
  349. def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
  350. raise NotImplementedError(
  351. "backend does not implement a schema name set function: %s"
  352. % (cfg.db.url,)
  353. )
  354. @register.init
  355. def upsert(
  356. cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
  357. ):
  358. """return the backends insert..on conflict / on dupe etc. construct.
  359. while we should add a backend-neutral upsert construct as well, such as
  360. insert().upsert(), it's important that we continue to test the
  361. backend-specific insert() constructs since if we do implement
  362. insert().upsert(), that would be using a different codepath for the things
  363. we need to test like insertmanyvalues, etc.
  364. """
  365. raise NotImplementedError(
  366. f"backend does not include an upsert implementation: {cfg.db.url}"
  367. )
  368. @register.init
  369. def normalize_sequence(cfg, sequence):
  370. """Normalize sequence parameters for dialect that don't start with 1
  371. by default.
  372. The default implementation does nothing
  373. """
  374. return sequence