provision.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # dialects/postgresql/provision.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. import time
  9. from ... import exc
  10. from ... import inspect
  11. from ... import text
  12. from ...testing import warn_test_suite
  13. from ...testing.provision import create_db
  14. from ...testing.provision import drop_all_schema_objects_post_tables
  15. from ...testing.provision import drop_all_schema_objects_pre_tables
  16. from ...testing.provision import drop_db
  17. from ...testing.provision import log
  18. from ...testing.provision import post_configure_engine
  19. from ...testing.provision import prepare_for_drop_tables
  20. from ...testing.provision import set_default_schema_on_connection
  21. from ...testing.provision import temp_table_keyword_args
  22. from ...testing.provision import upsert
  23. @create_db.for_db("postgresql")
  24. def _pg_create_db(cfg, eng, ident):
  25. template_db = cfg.options.postgresql_templatedb
  26. with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn:
  27. if not template_db:
  28. template_db = conn.exec_driver_sql(
  29. "select current_database()"
  30. ).scalar()
  31. attempt = 0
  32. while True:
  33. try:
  34. conn.exec_driver_sql(
  35. "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)
  36. )
  37. except exc.OperationalError as err:
  38. attempt += 1
  39. if attempt >= 3:
  40. raise
  41. if "accessed by other users" in str(err):
  42. log.info(
  43. "Waiting to create %s, URI %r, "
  44. "template DB %s is in use sleeping for .5",
  45. ident,
  46. eng.url,
  47. template_db,
  48. )
  49. time.sleep(0.5)
  50. except:
  51. raise
  52. else:
  53. break
  54. @drop_db.for_db("postgresql")
  55. def _pg_drop_db(cfg, eng, ident):
  56. with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
  57. with conn.begin():
  58. conn.execute(
  59. text(
  60. "select pg_terminate_backend(pid) from pg_stat_activity "
  61. "where usename=current_user and pid != pg_backend_pid() "
  62. "and datname=:dname"
  63. ),
  64. dict(dname=ident),
  65. )
  66. conn.exec_driver_sql("DROP DATABASE %s" % ident)
  67. @temp_table_keyword_args.for_db("postgresql")
  68. def _postgresql_temp_table_keyword_args(cfg, eng):
  69. return {"prefixes": ["TEMPORARY"]}
  70. @set_default_schema_on_connection.for_db("postgresql")
  71. def _postgresql_set_default_schema_on_connection(
  72. cfg, dbapi_connection, schema_name
  73. ):
  74. existing_autocommit = dbapi_connection.autocommit
  75. dbapi_connection.autocommit = True
  76. cursor = dbapi_connection.cursor()
  77. cursor.execute("SET SESSION search_path='%s'" % schema_name)
  78. cursor.close()
  79. dbapi_connection.autocommit = existing_autocommit
  80. @drop_all_schema_objects_pre_tables.for_db("postgresql")
  81. def drop_all_schema_objects_pre_tables(cfg, eng):
  82. with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
  83. for xid in conn.exec_driver_sql(
  84. "select gid from pg_prepared_xacts"
  85. ).scalars():
  86. conn.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid)
  87. @drop_all_schema_objects_post_tables.for_db("postgresql")
  88. def drop_all_schema_objects_post_tables(cfg, eng):
  89. from sqlalchemy.dialects import postgresql
  90. inspector = inspect(eng)
  91. with eng.begin() as conn:
  92. for enum in inspector.get_enums("*"):
  93. conn.execute(
  94. postgresql.DropEnumType(
  95. postgresql.ENUM(name=enum["name"], schema=enum["schema"])
  96. )
  97. )
  98. @prepare_for_drop_tables.for_db("postgresql")
  99. def prepare_for_drop_tables(config, connection):
  100. """Ensure there are no locks on the current username/database."""
  101. result = connection.exec_driver_sql(
  102. "select pid, state, wait_event_type, query "
  103. # "select pg_terminate_backend(pid), state, wait_event_type "
  104. "from pg_stat_activity where "
  105. "usename=current_user "
  106. "and datname=current_database() and state='idle in transaction' "
  107. "and pid != pg_backend_pid()"
  108. )
  109. rows = result.all() # noqa
  110. if rows:
  111. warn_test_suite(
  112. "PostgreSQL may not be able to DROP tables due to "
  113. "idle in transaction: %s"
  114. % ("; ".join(row._mapping["query"] for row in rows))
  115. )
  116. @upsert.for_db("postgresql")
  117. def _upsert(
  118. cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
  119. ):
  120. from sqlalchemy.dialects.postgresql import insert
  121. stmt = insert(table)
  122. table_pk = inspect(table).selectable
  123. if set_lambda:
  124. stmt = stmt.on_conflict_do_update(
  125. index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded)
  126. )
  127. else:
  128. stmt = stmt.on_conflict_do_nothing()
  129. stmt = stmt.returning(
  130. *returning, sort_by_parameter_order=sort_by_parameter_order
  131. )
  132. return stmt
  133. _extensions = [
  134. ("citext", (13,)),
  135. ("hstore", (13,)),
  136. ]
  137. @post_configure_engine.for_db("postgresql")
  138. def _create_citext_extension(url, engine, follower_ident):
  139. with engine.connect() as conn:
  140. for extension, min_version in _extensions:
  141. if conn.dialect.server_version_info >= min_version:
  142. conn.execute(
  143. text(f"CREATE EXTENSION IF NOT EXISTS {extension}")
  144. )
  145. conn.commit()