track_modifications.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from __future__ import annotations
  2. import typing as t
  3. import sqlalchemy as sa
  4. import sqlalchemy.event as sa_event
  5. import sqlalchemy.orm as sa_orm
  6. from flask import current_app
  7. from flask import has_app_context
  8. from flask.signals import Namespace # type: ignore[attr-defined]
  9. if t.TYPE_CHECKING:
  10. from .session import Session
  11. _signals = Namespace()
  12. models_committed = _signals.signal("models-committed")
  13. """This Blinker signal is sent after the session is committed if there were changed
  14. models in the session.
  15. The sender is the application that emitted the changes. The receiver is passed the
  16. ``changes`` argument with a list of tuples in the form ``(instance, operation)``.
  17. The operations are ``"insert"``, ``"update"``, and ``"delete"``.
  18. """
  19. before_models_committed = _signals.signal("before-models-committed")
  20. """This signal works exactly like :data:`models_committed` but is emitted before the
  21. commit takes place.
  22. """
  23. def _listen(session: sa_orm.scoped_session[Session]) -> None:
  24. sa_event.listen(session, "before_flush", _record_ops, named=True)
  25. sa_event.listen(session, "before_commit", _record_ops, named=True)
  26. sa_event.listen(session, "before_commit", _before_commit)
  27. sa_event.listen(session, "after_commit", _after_commit)
  28. sa_event.listen(session, "after_rollback", _after_rollback)
  29. def _record_ops(session: Session, **kwargs: t.Any) -> None:
  30. if not has_app_context():
  31. return
  32. if not current_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]:
  33. return
  34. for targets, operation in (
  35. (session.new, "insert"),
  36. (session.dirty, "update"),
  37. (session.deleted, "delete"),
  38. ):
  39. for target in targets:
  40. state = sa.inspect(target)
  41. key = state.identity_key if state.has_identity else id(target)
  42. session._model_changes[key] = (target, operation)
  43. def _before_commit(session: Session) -> None:
  44. if not has_app_context():
  45. return
  46. app = current_app._get_current_object() # type: ignore[attr-defined]
  47. if not app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]:
  48. return
  49. if session._model_changes:
  50. changes = list(session._model_changes.values())
  51. before_models_committed.send(app, changes=changes)
  52. def _after_commit(session: Session) -> None:
  53. if not has_app_context():
  54. return
  55. app = current_app._get_current_object() # type: ignore[attr-defined]
  56. if not app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]:
  57. return
  58. if session._model_changes:
  59. changes = list(session._model_changes.values())
  60. models_committed.send(app, changes=changes)
  61. session._model_changes.clear()
  62. def _after_rollback(session: Session) -> None:
  63. session._model_changes.clear()