horizontal_shard.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. # ext/horizontal_shard.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. """Horizontal sharding support.
  8. Defines a rudimental 'horizontal sharding' system which allows a Session to
  9. distribute queries and persistence operations across multiple databases.
  10. For a usage example, see the :ref:`examples_sharding` example included in
  11. the source distribution.
  12. .. deepalchemy:: The horizontal sharding extension is an advanced feature,
  13. involving a complex statement -> database interaction as well as
  14. use of semi-public APIs for non-trivial cases. Simpler approaches to
  15. refering to multiple database "shards", most commonly using a distinct
  16. :class:`_orm.Session` per "shard", should always be considered first
  17. before using this more complex and less-production-tested system.
  18. """
  19. from __future__ import annotations
  20. from typing import Any
  21. from typing import Callable
  22. from typing import Dict
  23. from typing import Iterable
  24. from typing import Optional
  25. from typing import Tuple
  26. from typing import Type
  27. from typing import TYPE_CHECKING
  28. from typing import TypeVar
  29. from typing import Union
  30. from .. import event
  31. from .. import exc
  32. from .. import inspect
  33. from .. import util
  34. from ..orm import PassiveFlag
  35. from ..orm._typing import OrmExecuteOptionsParameter
  36. from ..orm.interfaces import ORMOption
  37. from ..orm.mapper import Mapper
  38. from ..orm.query import Query
  39. from ..orm.session import _BindArguments
  40. from ..orm.session import _PKIdentityArgument
  41. from ..orm.session import Session
  42. from ..util.typing import Protocol
  43. from ..util.typing import Self
  44. if TYPE_CHECKING:
  45. from ..engine.base import Connection
  46. from ..engine.base import Engine
  47. from ..engine.base import OptionEngine
  48. from ..engine.result import IteratorResult
  49. from ..engine.result import Result
  50. from ..orm import LoaderCallableStatus
  51. from ..orm._typing import _O
  52. from ..orm.bulk_persistence import BulkUDCompileState
  53. from ..orm.context import QueryContext
  54. from ..orm.session import _EntityBindKey
  55. from ..orm.session import _SessionBind
  56. from ..orm.session import ORMExecuteState
  57. from ..orm.state import InstanceState
  58. from ..sql import Executable
  59. from ..sql._typing import _TP
  60. from ..sql.elements import ClauseElement
  61. __all__ = ["ShardedSession", "ShardedQuery"]
  62. _T = TypeVar("_T", bound=Any)
  63. ShardIdentifier = str
  64. class ShardChooser(Protocol):
  65. def __call__(
  66. self,
  67. mapper: Optional[Mapper[_T]],
  68. instance: Any,
  69. clause: Optional[ClauseElement],
  70. ) -> Any: ...
  71. class IdentityChooser(Protocol):
  72. def __call__(
  73. self,
  74. mapper: Mapper[_T],
  75. primary_key: _PKIdentityArgument,
  76. *,
  77. lazy_loaded_from: Optional[InstanceState[Any]],
  78. execution_options: OrmExecuteOptionsParameter,
  79. bind_arguments: _BindArguments,
  80. **kw: Any,
  81. ) -> Any: ...
  82. class ShardedQuery(Query[_T]):
  83. """Query class used with :class:`.ShardedSession`.
  84. .. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy
  85. :class:`.Query` class. The :class:`.ShardedSession` now supports
  86. 2.0 style execution via the :meth:`.ShardedSession.execute` method.
  87. """
  88. def __init__(self, *args: Any, **kwargs: Any) -> None:
  89. super().__init__(*args, **kwargs)
  90. assert isinstance(self.session, ShardedSession)
  91. self.identity_chooser = self.session.identity_chooser
  92. self.execute_chooser = self.session.execute_chooser
  93. self._shard_id = None
  94. def set_shard(self, shard_id: ShardIdentifier) -> Self:
  95. """Return a new query, limited to a single shard ID.
  96. All subsequent operations with the returned query will
  97. be against the single shard regardless of other state.
  98. The shard_id can be passed for a 2.0 style execution to the
  99. bind_arguments dictionary of :meth:`.Session.execute`::
  100. results = session.execute(stmt, bind_arguments={"shard_id": "my_shard"})
  101. """ # noqa: E501
  102. return self.execution_options(_sa_shard_id=shard_id)
  103. class ShardedSession(Session):
  104. shard_chooser: ShardChooser
  105. identity_chooser: IdentityChooser
  106. execute_chooser: Callable[[ORMExecuteState], Iterable[Any]]
  107. def __init__(
  108. self,
  109. shard_chooser: ShardChooser,
  110. identity_chooser: Optional[IdentityChooser] = None,
  111. execute_chooser: Optional[
  112. Callable[[ORMExecuteState], Iterable[Any]]
  113. ] = None,
  114. shards: Optional[Dict[str, Any]] = None,
  115. query_cls: Type[Query[_T]] = ShardedQuery,
  116. *,
  117. id_chooser: Optional[
  118. Callable[[Query[_T], Iterable[_T]], Iterable[Any]]
  119. ] = None,
  120. query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None,
  121. **kwargs: Any,
  122. ) -> None:
  123. """Construct a ShardedSession.
  124. :param shard_chooser: A callable which, passed a Mapper, a mapped
  125. instance, and possibly a SQL clause, returns a shard ID. This id
  126. may be based off of the attributes present within the object, or on
  127. some round-robin scheme. If the scheme is based on a selection, it
  128. should set whatever state on the instance to mark it in the future as
  129. participating in that shard.
  130. :param identity_chooser: A callable, passed a Mapper and primary key
  131. argument, which should return a list of shard ids where this
  132. primary key might reside.
  133. .. versionchanged:: 2.0 The ``identity_chooser`` parameter
  134. supersedes the ``id_chooser`` parameter.
  135. :param execute_chooser: For a given :class:`.ORMExecuteState`,
  136. returns the list of shard_ids
  137. where the query should be issued. Results from all shards returned
  138. will be combined together into a single listing.
  139. .. versionchanged:: 1.4 The ``execute_chooser`` parameter
  140. supersedes the ``query_chooser`` parameter.
  141. :param shards: A dictionary of string shard names
  142. to :class:`~sqlalchemy.engine.Engine` objects.
  143. """
  144. super().__init__(query_cls=query_cls, **kwargs)
  145. event.listen(
  146. self, "do_orm_execute", execute_and_instances, retval=True
  147. )
  148. self.shard_chooser = shard_chooser
  149. if id_chooser:
  150. _id_chooser = id_chooser
  151. util.warn_deprecated(
  152. "The ``id_chooser`` parameter is deprecated; "
  153. "please use ``identity_chooser``.",
  154. "2.0",
  155. )
  156. def _legacy_identity_chooser(
  157. mapper: Mapper[_T],
  158. primary_key: _PKIdentityArgument,
  159. *,
  160. lazy_loaded_from: Optional[InstanceState[Any]],
  161. execution_options: OrmExecuteOptionsParameter,
  162. bind_arguments: _BindArguments,
  163. **kw: Any,
  164. ) -> Any:
  165. q = self.query(mapper)
  166. if lazy_loaded_from:
  167. q = q._set_lazyload_from(lazy_loaded_from)
  168. return _id_chooser(q, primary_key)
  169. self.identity_chooser = _legacy_identity_chooser
  170. elif identity_chooser:
  171. self.identity_chooser = identity_chooser
  172. else:
  173. raise exc.ArgumentError(
  174. "identity_chooser or id_chooser is required"
  175. )
  176. if query_chooser:
  177. _query_chooser = query_chooser
  178. util.warn_deprecated(
  179. "The ``query_chooser`` parameter is deprecated; "
  180. "please use ``execute_chooser``.",
  181. "1.4",
  182. )
  183. if execute_chooser:
  184. raise exc.ArgumentError(
  185. "Can't pass query_chooser and execute_chooser "
  186. "at the same time."
  187. )
  188. def _default_execute_chooser(
  189. orm_context: ORMExecuteState,
  190. ) -> Iterable[Any]:
  191. return _query_chooser(orm_context.statement)
  192. if execute_chooser is None:
  193. execute_chooser = _default_execute_chooser
  194. if execute_chooser is None:
  195. raise exc.ArgumentError(
  196. "execute_chooser or query_chooser is required"
  197. )
  198. self.execute_chooser = execute_chooser
  199. self.__shards: Dict[ShardIdentifier, _SessionBind] = {}
  200. if shards is not None:
  201. for k in shards:
  202. self.bind_shard(k, shards[k])
  203. def _identity_lookup(
  204. self,
  205. mapper: Mapper[_O],
  206. primary_key_identity: Union[Any, Tuple[Any, ...]],
  207. identity_token: Optional[Any] = None,
  208. passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
  209. lazy_loaded_from: Optional[InstanceState[Any]] = None,
  210. execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
  211. bind_arguments: Optional[_BindArguments] = None,
  212. **kw: Any,
  213. ) -> Union[Optional[_O], LoaderCallableStatus]:
  214. """override the default :meth:`.Session._identity_lookup` method so
  215. that we search for a given non-token primary key identity across all
  216. possible identity tokens (e.g. shard ids).
  217. .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from
  218. the :class:`_query.Query` object to the :class:`.Session`.
  219. """
  220. if identity_token is not None:
  221. obj = super()._identity_lookup(
  222. mapper,
  223. primary_key_identity,
  224. identity_token=identity_token,
  225. **kw,
  226. )
  227. return obj
  228. else:
  229. for shard_id in self.identity_chooser(
  230. mapper,
  231. primary_key_identity,
  232. lazy_loaded_from=lazy_loaded_from,
  233. execution_options=execution_options,
  234. bind_arguments=dict(bind_arguments) if bind_arguments else {},
  235. ):
  236. obj2 = super()._identity_lookup(
  237. mapper,
  238. primary_key_identity,
  239. identity_token=shard_id,
  240. lazy_loaded_from=lazy_loaded_from,
  241. **kw,
  242. )
  243. if obj2 is not None:
  244. return obj2
  245. return None
  246. def _choose_shard_and_assign(
  247. self,
  248. mapper: Optional[_EntityBindKey[_O]],
  249. instance: Any,
  250. **kw: Any,
  251. ) -> Any:
  252. if instance is not None:
  253. state = inspect(instance)
  254. if state.key:
  255. token = state.key[2]
  256. assert token is not None
  257. return token
  258. elif state.identity_token:
  259. return state.identity_token
  260. assert isinstance(mapper, Mapper)
  261. shard_id = self.shard_chooser(mapper, instance, **kw)
  262. if instance is not None:
  263. state.identity_token = shard_id
  264. return shard_id
  265. def connection_callable(
  266. self,
  267. mapper: Optional[Mapper[_T]] = None,
  268. instance: Optional[Any] = None,
  269. shard_id: Optional[ShardIdentifier] = None,
  270. **kw: Any,
  271. ) -> Connection:
  272. """Provide a :class:`_engine.Connection` to use in the unit of work
  273. flush process.
  274. """
  275. if shard_id is None:
  276. shard_id = self._choose_shard_and_assign(mapper, instance)
  277. if self.in_transaction():
  278. trans = self.get_transaction()
  279. assert trans is not None
  280. return trans.connection(mapper, shard_id=shard_id)
  281. else:
  282. bind = self.get_bind(
  283. mapper=mapper, shard_id=shard_id, instance=instance
  284. )
  285. if isinstance(bind, Engine):
  286. return bind.connect(**kw)
  287. else:
  288. assert isinstance(bind, Connection)
  289. return bind
  290. def get_bind(
  291. self,
  292. mapper: Optional[_EntityBindKey[_O]] = None,
  293. *,
  294. shard_id: Optional[ShardIdentifier] = None,
  295. instance: Optional[Any] = None,
  296. clause: Optional[ClauseElement] = None,
  297. **kw: Any,
  298. ) -> _SessionBind:
  299. if shard_id is None:
  300. shard_id = self._choose_shard_and_assign(
  301. mapper, instance=instance, clause=clause
  302. )
  303. assert shard_id is not None
  304. return self.__shards[shard_id]
  305. def bind_shard(
  306. self, shard_id: ShardIdentifier, bind: Union[Engine, OptionEngine]
  307. ) -> None:
  308. self.__shards[shard_id] = bind
  309. class set_shard_id(ORMOption):
  310. """a loader option for statements to apply a specific shard id to the
  311. primary query as well as for additional relationship and column
  312. loaders.
  313. The :class:`_horizontal.set_shard_id` option may be applied using
  314. the :meth:`_sql.Executable.options` method of any executable statement::
  315. stmt = (
  316. select(MyObject)
  317. .where(MyObject.name == "some name")
  318. .options(set_shard_id("shard1"))
  319. )
  320. Above, the statement when invoked will limit to the "shard1" shard
  321. identifier for the primary query as well as for all relationship and
  322. column loading strategies, including eager loaders such as
  323. :func:`_orm.selectinload`, deferred column loaders like :func:`_orm.defer`,
  324. and the lazy relationship loader :func:`_orm.lazyload`.
  325. In this way, the :class:`_horizontal.set_shard_id` option has much wider
  326. scope than using the "shard_id" argument within the
  327. :paramref:`_orm.Session.execute.bind_arguments` dictionary.
  328. .. versionadded:: 2.0.0
  329. """
  330. __slots__ = ("shard_id", "propagate_to_loaders")
  331. def __init__(
  332. self, shard_id: ShardIdentifier, propagate_to_loaders: bool = True
  333. ):
  334. """Construct a :class:`_horizontal.set_shard_id` option.
  335. :param shard_id: shard identifier
  336. :param propagate_to_loaders: if left at its default of ``True``, the
  337. shard option will take place for lazy loaders such as
  338. :func:`_orm.lazyload` and :func:`_orm.defer`; if False, the option
  339. will not be propagated to loaded objects. Note that :func:`_orm.defer`
  340. always limits to the shard_id of the parent row in any case, so the
  341. parameter only has a net effect on the behavior of the
  342. :func:`_orm.lazyload` strategy.
  343. """
  344. self.shard_id = shard_id
  345. self.propagate_to_loaders = propagate_to_loaders
  346. def execute_and_instances(
  347. orm_context: ORMExecuteState,
  348. ) -> Union[Result[_T], IteratorResult[_TP]]:
  349. active_options: Union[
  350. None,
  351. QueryContext.default_load_options,
  352. Type[QueryContext.default_load_options],
  353. BulkUDCompileState.default_update_options,
  354. Type[BulkUDCompileState.default_update_options],
  355. ]
  356. if orm_context.is_select:
  357. active_options = orm_context.load_options
  358. elif orm_context.is_update or orm_context.is_delete:
  359. active_options = orm_context.update_delete_options
  360. else:
  361. active_options = None
  362. session = orm_context.session
  363. assert isinstance(session, ShardedSession)
  364. def iter_for_shard(
  365. shard_id: ShardIdentifier,
  366. ) -> Union[Result[_T], IteratorResult[_TP]]:
  367. bind_arguments = dict(orm_context.bind_arguments)
  368. bind_arguments["shard_id"] = shard_id
  369. orm_context.update_execution_options(identity_token=shard_id)
  370. return orm_context.invoke_statement(bind_arguments=bind_arguments)
  371. for orm_opt in orm_context._non_compile_orm_options:
  372. # TODO: if we had an ORMOption that gets applied at ORM statement
  373. # execution time, that would allow this to be more generalized.
  374. # for now just iterate and look for our options
  375. if isinstance(orm_opt, set_shard_id):
  376. shard_id = orm_opt.shard_id
  377. break
  378. else:
  379. if active_options and active_options._identity_token is not None:
  380. shard_id = active_options._identity_token
  381. elif "_sa_shard_id" in orm_context.execution_options:
  382. shard_id = orm_context.execution_options["_sa_shard_id"]
  383. elif "shard_id" in orm_context.bind_arguments:
  384. shard_id = orm_context.bind_arguments["shard_id"]
  385. else:
  386. shard_id = None
  387. if shard_id is not None:
  388. return iter_for_shard(shard_id)
  389. else:
  390. partial = []
  391. for shard_id in session.execute_chooser(orm_context):
  392. result_ = iter_for_shard(shard_id)
  393. partial.append(result_)
  394. return partial[0].merge(*partial[1:])