session.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from __future__ import annotations
  2. import typing as t
  3. import sqlalchemy as sa
  4. import sqlalchemy.exc as sa_exc
  5. import sqlalchemy.orm as sa_orm
  6. from flask.globals import app_ctx
  7. if t.TYPE_CHECKING:
  8. from .extension import SQLAlchemy
  9. class Session(sa_orm.Session):
  10. """A SQLAlchemy :class:`~sqlalchemy.orm.Session` class that chooses what engine to
  11. use based on the bind key associated with the metadata associated with the thing
  12. being queried.
  13. To customize ``db.session``, subclass this and pass it as the ``class_`` key in the
  14. ``session_options`` to :class:`.SQLAlchemy`.
  15. .. versionchanged:: 3.0
  16. Renamed from ``SignallingSession``.
  17. """
  18. def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None:
  19. super().__init__(**kwargs)
  20. self._db = db
  21. self._model_changes: dict[object, tuple[t.Any, str]] = {}
  22. def get_bind(
  23. self,
  24. mapper: t.Any | None = None,
  25. clause: t.Any | None = None,
  26. bind: sa.engine.Engine | sa.engine.Connection | None = None,
  27. **kwargs: t.Any,
  28. ) -> sa.engine.Engine | sa.engine.Connection:
  29. """Select an engine based on the ``bind_key`` of the metadata associated with
  30. the model or table being queried. If no bind key is set, uses the default bind.
  31. .. versionchanged:: 3.0.3
  32. Fix finding the bind for a joined inheritance model.
  33. .. versionchanged:: 3.0
  34. The implementation more closely matches the base SQLAlchemy implementation.
  35. .. versionchanged:: 2.1
  36. Support joining an external transaction.
  37. """
  38. if bind is not None:
  39. return bind
  40. engines = self._db.engines
  41. if mapper is not None:
  42. try:
  43. mapper = sa.inspect(mapper)
  44. except sa_exc.NoInspectionAvailable as e:
  45. if isinstance(mapper, type):
  46. raise sa_orm.exc.UnmappedClassError(mapper) from e
  47. raise
  48. engine = _clause_to_engine(mapper.local_table, engines)
  49. if engine is not None:
  50. return engine
  51. if clause is not None:
  52. engine = _clause_to_engine(clause, engines)
  53. if engine is not None:
  54. return engine
  55. if None in engines:
  56. return engines[None]
  57. return super().get_bind(mapper=mapper, clause=clause, bind=bind, **kwargs)
  58. def _clause_to_engine(
  59. clause: sa.ClauseElement | None,
  60. engines: t.Mapping[str | None, sa.engine.Engine],
  61. ) -> sa.engine.Engine | None:
  62. """If the clause is a table, return the engine associated with the table's
  63. metadata's bind key.
  64. """
  65. table = None
  66. if clause is not None:
  67. if isinstance(clause, sa.Table):
  68. table = clause
  69. elif isinstance(clause, sa.UpdateBase) and isinstance(clause.table, sa.Table):
  70. table = clause.table
  71. if table is not None and "bind_key" in table.metadata.info:
  72. key = table.metadata.info["bind_key"]
  73. if key not in engines:
  74. raise sa_exc.UnboundExecutionError(
  75. f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config."
  76. )
  77. return engines[key]
  78. return None
  79. def _app_ctx_id() -> int:
  80. """Get the id of the current Flask application context for the session scope."""
  81. return id(app_ctx._get_current_object()) # type: ignore[attr-defined]