provision.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # dialects/sqlite/provision.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. import os
  9. import re
  10. from ... import exc
  11. from ...engine import url as sa_url
  12. from ...testing.provision import create_db
  13. from ...testing.provision import drop_db
  14. from ...testing.provision import follower_url_from_main
  15. from ...testing.provision import generate_driver_url
  16. from ...testing.provision import log
  17. from ...testing.provision import post_configure_engine
  18. from ...testing.provision import run_reap_dbs
  19. from ...testing.provision import stop_test_class_outside_fixtures
  20. from ...testing.provision import temp_table_keyword_args
  21. from ...testing.provision import upsert
  22. # TODO: I can't get this to build dynamically with pytest-xdist procs
  23. _drivernames = {
  24. "pysqlite",
  25. "aiosqlite",
  26. "pysqlcipher",
  27. "pysqlite_numeric",
  28. "pysqlite_dollar",
  29. }
  30. def _format_url(url, driver, ident):
  31. """given a sqlite url + desired driver + ident, make a canonical
  32. URL out of it
  33. """
  34. url = sa_url.make_url(url)
  35. if driver is None:
  36. driver = url.get_driver_name()
  37. filename = url.database
  38. needs_enc = driver == "pysqlcipher"
  39. name_token = None
  40. if filename and filename != ":memory:":
  41. assert "test_schema" not in filename
  42. tokens = re.split(r"[_\.]", filename)
  43. for token in tokens:
  44. if token in _drivernames:
  45. if driver is None:
  46. driver = token
  47. continue
  48. elif token in ("db", "enc"):
  49. continue
  50. elif name_token is None:
  51. name_token = token.strip("_")
  52. assert name_token, f"sqlite filename has no name token: {url.database}"
  53. new_filename = f"{name_token}_{driver}"
  54. if ident:
  55. new_filename += f"_{ident}"
  56. new_filename += ".db"
  57. if needs_enc:
  58. new_filename += ".enc"
  59. url = url.set(database=new_filename)
  60. if needs_enc:
  61. url = url.set(password="test")
  62. url = url.set(drivername="sqlite+%s" % (driver,))
  63. return url
  64. @generate_driver_url.for_db("sqlite")
  65. def generate_driver_url(url, driver, query_str):
  66. url = _format_url(url, driver, None)
  67. try:
  68. url.get_dialect()
  69. except exc.NoSuchModuleError:
  70. return None
  71. else:
  72. return url
  73. @follower_url_from_main.for_db("sqlite")
  74. def _sqlite_follower_url_from_main(url, ident):
  75. return _format_url(url, None, ident)
  76. @post_configure_engine.for_db("sqlite")
  77. def _sqlite_post_configure_engine(url, engine, follower_ident):
  78. from sqlalchemy import event
  79. if follower_ident:
  80. attach_path = f"{follower_ident}_{engine.driver}_test_schema.db"
  81. else:
  82. attach_path = f"{engine.driver}_test_schema.db"
  83. @event.listens_for(engine, "connect")
  84. def connect(dbapi_connection, connection_record):
  85. # use file DBs in all cases, memory acts kind of strangely
  86. # as an attached
  87. # NOTE! this has to be done *per connection*. New sqlite connection,
  88. # as we get with say, QueuePool, the attaches are gone.
  89. # so schemes to delete those attached files have to be done at the
  90. # filesystem level and not rely upon what attachments are in a
  91. # particular SQLite connection
  92. dbapi_connection.execute(
  93. f'ATTACH DATABASE "{attach_path}" AS test_schema'
  94. )
  95. @event.listens_for(engine, "engine_disposed")
  96. def dispose(engine):
  97. """most databases should be dropped using
  98. stop_test_class_outside_fixtures
  99. however a few tests like AttachedDBTest might not get triggered on
  100. that main hook
  101. """
  102. if os.path.exists(attach_path):
  103. os.remove(attach_path)
  104. filename = engine.url.database
  105. if filename and filename != ":memory:" and os.path.exists(filename):
  106. os.remove(filename)
  107. @create_db.for_db("sqlite")
  108. def _sqlite_create_db(cfg, eng, ident):
  109. pass
  110. @drop_db.for_db("sqlite")
  111. def _sqlite_drop_db(cfg, eng, ident):
  112. _drop_dbs_w_ident(eng.url.database, eng.driver, ident)
  113. def _drop_dbs_w_ident(databasename, driver, ident):
  114. for path in os.listdir("."):
  115. fname, ext = os.path.split(path)
  116. if ident in fname and ext in [".db", ".db.enc"]:
  117. log.info("deleting SQLite database file: %s", path)
  118. os.remove(path)
  119. @stop_test_class_outside_fixtures.for_db("sqlite")
  120. def stop_test_class_outside_fixtures(config, db, cls):
  121. db.dispose()
  122. @temp_table_keyword_args.for_db("sqlite")
  123. def _sqlite_temp_table_keyword_args(cfg, eng):
  124. return {"prefixes": ["TEMPORARY"]}
  125. @run_reap_dbs.for_db("sqlite")
  126. def _reap_sqlite_dbs(url, idents):
  127. log.info("db reaper connecting to %r", url)
  128. log.info("identifiers in file: %s", ", ".join(idents))
  129. url = sa_url.make_url(url)
  130. for ident in idents:
  131. for drivername in _drivernames:
  132. _drop_dbs_w_ident(url.database, drivername, ident)
  133. @upsert.for_db("sqlite")
  134. def _upsert(
  135. cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
  136. ):
  137. from sqlalchemy.dialects.sqlite import insert
  138. stmt = insert(table)
  139. if set_lambda:
  140. stmt = stmt.on_conflict_do_update(set_=set_lambda(stmt.excluded))
  141. else:
  142. stmt = stmt.on_conflict_do_nothing()
  143. stmt = stmt.returning(
  144. *returning, sort_by_parameter_order=sort_by_parameter_order
  145. )
  146. return stmt