| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135 |
- # orm/bulk_persistence.py
- # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- # mypy: ignore-errors
- """additional ORM persistence classes related to "bulk" operations,
- specifically outside of the flush() process.
- """
- from __future__ import annotations
- from typing import Any
- from typing import cast
- from typing import Dict
- from typing import Iterable
- from typing import Optional
- from typing import overload
- from typing import TYPE_CHECKING
- from typing import TypeVar
- from typing import Union
- from . import attributes
- from . import context
- from . import evaluator
- from . import exc as orm_exc
- from . import loading
- from . import persistence
- from .base import NO_VALUE
- from .context import AbstractORMCompileState
- from .context import FromStatement
- from .context import ORMFromStatementCompileState
- from .context import QueryContext
- from .. import exc as sa_exc
- from .. import util
- from ..engine import Dialect
- from ..engine import result as _result
- from ..sql import coercions
- from ..sql import dml
- from ..sql import expression
- from ..sql import roles
- from ..sql import select
- from ..sql import sqltypes
- from ..sql.base import _entity_namespace_key
- from ..sql.base import CompileState
- from ..sql.base import Options
- from ..sql.dml import DeleteDMLState
- from ..sql.dml import InsertDMLState
- from ..sql.dml import UpdateDMLState
- from ..util import EMPTY_DICT
- from ..util.typing import Literal
- if TYPE_CHECKING:
- from ._typing import DMLStrategyArgument
- from ._typing import OrmExecuteOptionsParameter
- from ._typing import SynchronizeSessionArgument
- from .mapper import Mapper
- from .session import _BindArguments
- from .session import ORMExecuteState
- from .session import Session
- from .session import SessionTransaction
- from .state import InstanceState
- from ..engine import Connection
- from ..engine import cursor
- from ..engine.interfaces import _CoreAnyExecuteParams
- _O = TypeVar("_O", bound=object)
- @overload
- def _bulk_insert(
- mapper: Mapper[_O],
- mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
- session_transaction: SessionTransaction,
- *,
- isstates: bool,
- return_defaults: bool,
- render_nulls: bool,
- use_orm_insert_stmt: Literal[None] = ...,
- execution_options: Optional[OrmExecuteOptionsParameter] = ...,
- ) -> None: ...
- @overload
- def _bulk_insert(
- mapper: Mapper[_O],
- mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
- session_transaction: SessionTransaction,
- *,
- isstates: bool,
- return_defaults: bool,
- render_nulls: bool,
- use_orm_insert_stmt: Optional[dml.Insert] = ...,
- execution_options: Optional[OrmExecuteOptionsParameter] = ...,
- ) -> cursor.CursorResult[Any]: ...
- def _bulk_insert(
- mapper: Mapper[_O],
- mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
- session_transaction: SessionTransaction,
- *,
- isstates: bool,
- return_defaults: bool,
- render_nulls: bool,
- use_orm_insert_stmt: Optional[dml.Insert] = None,
- execution_options: Optional[OrmExecuteOptionsParameter] = None,
- ) -> Optional[cursor.CursorResult[Any]]:
- base_mapper = mapper.base_mapper
- if session_transaction.session.connection_callable:
- raise NotImplementedError(
- "connection_callable / per-instance sharding "
- "not supported in bulk_insert()"
- )
- if isstates:
- if TYPE_CHECKING:
- mappings = cast(Iterable[InstanceState[_O]], mappings)
- if return_defaults:
- # list of states allows us to attach .key for return_defaults case
- states = [(state, state.dict) for state in mappings]
- mappings = [dict_ for (state, dict_) in states]
- else:
- mappings = [state.dict for state in mappings]
- else:
- if TYPE_CHECKING:
- mappings = cast(Iterable[Dict[str, Any]], mappings)
- if return_defaults:
- # use dictionaries given, so that newly populated defaults
- # can be delivered back to the caller (see #11661). This is **not**
- # compatible with other use cases such as a session-executed
- # insert() construct, as this will confuse the case of
- # insert-per-subclass for joined inheritance cases (see
- # test_bulk_statements.py::BulkDMLReturningJoinedInhTest).
- #
- # So in this conditional, we have **only** called
- # session.bulk_insert_mappings() which does not have this
- # requirement
- mappings = list(mappings)
- else:
- # for all other cases we need to establish a local dictionary
- # so that the incoming dictionaries aren't mutated
- mappings = [dict(m) for m in mappings]
- _expand_composites(mapper, mappings)
- connection = session_transaction.connection(base_mapper)
- return_result: Optional[cursor.CursorResult[Any]] = None
- mappers_to_run = [
- (table, mp)
- for table, mp in base_mapper._sorted_tables.items()
- if table in mapper._pks_by_table
- ]
- if return_defaults:
- # not used by new-style bulk inserts, only used for legacy
- bookkeeping = True
- elif len(mappers_to_run) > 1:
- # if we have more than one table, mapper to run where we will be
- # either horizontally splicing, or copying values between tables,
- # we need the "bookkeeping" / deterministic returning order
- bookkeeping = True
- else:
- bookkeeping = False
- for table, super_mapper in mappers_to_run:
- # find bindparams in the statement. For bulk, we don't really know if
- # a key in the params applies to a different table since we are
- # potentially inserting for multiple tables here; looking at the
- # bindparam() is a lot more direct. in most cases this will
- # use _generate_cache_key() which is memoized, although in practice
- # the ultimate statement that's executed is probably not the same
- # object so that memoization might not matter much.
- extra_bp_names = (
- [
- b.key
- for b in use_orm_insert_stmt._get_embedded_bindparams()
- if b.key in mappings[0]
- ]
- if use_orm_insert_stmt is not None
- else ()
- )
- records = (
- (
- None,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_pks,
- has_all_defaults,
- )
- for (
- state,
- state_dict,
- params,
- mp,
- conn,
- value_params,
- has_all_pks,
- has_all_defaults,
- ) in persistence._collect_insert_commands(
- table,
- ((None, mapping, mapper, connection) for mapping in mappings),
- bulk=True,
- return_defaults=bookkeeping,
- render_nulls=render_nulls,
- include_bulk_keys=extra_bp_names,
- )
- )
- result = persistence._emit_insert_statements(
- base_mapper,
- None,
- super_mapper,
- table,
- records,
- bookkeeping=bookkeeping,
- use_orm_insert_stmt=use_orm_insert_stmt,
- execution_options=execution_options,
- )
- if use_orm_insert_stmt is not None:
- if not use_orm_insert_stmt._returning or return_result is None:
- return_result = result
- elif result.returns_rows:
- assert bookkeeping
- return_result = return_result.splice_horizontally(result)
- if return_defaults and isstates:
- identity_cls = mapper._identity_class
- identity_props = [p.key for p in mapper._identity_key_props]
- for state, dict_ in states:
- state.key = (
- identity_cls,
- tuple([dict_[key] for key in identity_props]),
- None,
- )
- if use_orm_insert_stmt is not None:
- assert return_result is not None
- return return_result
- @overload
- def _bulk_update(
- mapper: Mapper[Any],
- mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
- session_transaction: SessionTransaction,
- *,
- isstates: bool,
- update_changed_only: bool,
- use_orm_update_stmt: Literal[None] = ...,
- enable_check_rowcount: bool = True,
- ) -> None: ...
- @overload
- def _bulk_update(
- mapper: Mapper[Any],
- mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
- session_transaction: SessionTransaction,
- *,
- isstates: bool,
- update_changed_only: bool,
- use_orm_update_stmt: Optional[dml.Update] = ...,
- enable_check_rowcount: bool = True,
- ) -> _result.Result[Any]: ...
- def _bulk_update(
- mapper: Mapper[Any],
- mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
- session_transaction: SessionTransaction,
- *,
- isstates: bool,
- update_changed_only: bool,
- use_orm_update_stmt: Optional[dml.Update] = None,
- enable_check_rowcount: bool = True,
- ) -> Optional[_result.Result[Any]]:
- base_mapper = mapper.base_mapper
- search_keys = mapper._primary_key_propkeys
- if mapper._version_id_prop:
- search_keys = {mapper._version_id_prop.key}.union(search_keys)
- def _changed_dict(mapper, state):
- return {
- k: v
- for k, v in state.dict.items()
- if k in state.committed_state or k in search_keys
- }
- if isstates:
- if update_changed_only:
- mappings = [_changed_dict(mapper, state) for state in mappings]
- else:
- mappings = [state.dict for state in mappings]
- else:
- mappings = [dict(m) for m in mappings]
- _expand_composites(mapper, mappings)
- if session_transaction.session.connection_callable:
- raise NotImplementedError(
- "connection_callable / per-instance sharding "
- "not supported in bulk_update()"
- )
- connection = session_transaction.connection(base_mapper)
- # find bindparams in the statement. see _bulk_insert for similar
- # notes for the insert case
- extra_bp_names = (
- [
- b.key
- for b in use_orm_update_stmt._get_embedded_bindparams()
- if b.key in mappings[0]
- ]
- if use_orm_update_stmt is not None
- else ()
- )
- for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
- continue
- records = persistence._collect_update_commands(
- None,
- table,
- (
- (
- None,
- mapping,
- mapper,
- connection,
- (
- mapping[mapper._version_id_prop.key]
- if mapper._version_id_prop
- else None
- ),
- )
- for mapping in mappings
- ),
- bulk=True,
- use_orm_update_stmt=use_orm_update_stmt,
- include_bulk_keys=extra_bp_names,
- )
- persistence._emit_update_statements(
- base_mapper,
- None,
- super_mapper,
- table,
- records,
- bookkeeping=False,
- use_orm_update_stmt=use_orm_update_stmt,
- enable_check_rowcount=enable_check_rowcount,
- )
- if use_orm_update_stmt is not None:
- return _result.null_result()
- def _expand_composites(mapper, mappings):
- composite_attrs = mapper.composites
- if not composite_attrs:
- return
- composite_keys = set(composite_attrs.keys())
- populators = {
- key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
- for key in composite_keys
- }
- for mapping in mappings:
- for key in composite_keys.intersection(mapping):
- populators[key](mapping)
- class ORMDMLState(AbstractORMCompileState):
- is_dml_returning = True
- from_statement_ctx: Optional[ORMFromStatementCompileState] = None
- @classmethod
- def _get_orm_crud_kv_pairs(
- cls, mapper, statement, kv_iterator, needs_to_be_cacheable
- ):
- core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
- for k, v in kv_iterator:
- k = coercions.expect(roles.DMLColumnRole, k)
- if isinstance(k, str):
- desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
- if desc is NO_VALUE:
- yield (
- coercions.expect(roles.DMLColumnRole, k),
- (
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- )
- if needs_to_be_cacheable
- else v
- ),
- )
- else:
- yield from core_get_crud_kv_pairs(
- statement,
- desc._bulk_update_tuples(v),
- needs_to_be_cacheable,
- )
- elif "entity_namespace" in k._annotations:
- k_anno = k._annotations
- attr = _entity_namespace_key(
- k_anno["entity_namespace"], k_anno["proxy_key"]
- )
- yield from core_get_crud_kv_pairs(
- statement,
- attr._bulk_update_tuples(v),
- needs_to_be_cacheable,
- )
- else:
- yield (
- k,
- (
- v
- if not needs_to_be_cacheable
- else coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- )
- ),
- )
- @classmethod
- def _get_dml_plugin_subject(cls, statement):
- plugin_subject = statement.table._propagate_attrs.get("plugin_subject")
- if (
- not plugin_subject
- or not plugin_subject.mapper
- or plugin_subject
- is not statement._propagate_attrs["plugin_subject"]
- ):
- return None
- return plugin_subject
- @classmethod
- def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
- plugin_subject = cls._get_dml_plugin_subject(statement)
- if not plugin_subject:
- return UpdateDMLState._get_multi_crud_kv_pairs(
- statement, kv_iterator
- )
- return [
- dict(
- cls._get_orm_crud_kv_pairs(
- plugin_subject.mapper, statement, value_dict.items(), False
- )
- )
- for value_dict in kv_iterator
- ]
- @classmethod
- def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
- assert (
- needs_to_be_cacheable
- ), "no test coverage for needs_to_be_cacheable=False"
- plugin_subject = cls._get_dml_plugin_subject(statement)
- if not plugin_subject:
- return UpdateDMLState._get_crud_kv_pairs(
- statement, kv_iterator, needs_to_be_cacheable
- )
- return list(
- cls._get_orm_crud_kv_pairs(
- plugin_subject.mapper,
- statement,
- kv_iterator,
- needs_to_be_cacheable,
- )
- )
- @classmethod
- def get_entity_description(cls, statement):
- ext_info = statement.table._annotations["parententity"]
- mapper = ext_info.mapper
- if ext_info.is_aliased_class:
- _label_name = ext_info.name
- else:
- _label_name = mapper.class_.__name__
- return {
- "name": _label_name,
- "type": mapper.class_,
- "expr": ext_info.entity,
- "entity": ext_info.entity,
- "table": mapper.local_table,
- }
- @classmethod
- def get_returning_column_descriptions(cls, statement):
- def _ent_for_col(c):
- return c._annotations.get("parententity", None)
- def _attr_for_col(c, ent):
- if ent is None:
- return c
- proxy_key = c._annotations.get("proxy_key", None)
- if not proxy_key:
- return c
- else:
- return getattr(ent.entity, proxy_key, c)
- return [
- {
- "name": c.key,
- "type": c.type,
- "expr": _attr_for_col(c, ent),
- "aliased": ent.is_aliased_class,
- "entity": ent.entity,
- }
- for c, ent in [
- (c, _ent_for_col(c)) for c in statement._all_selected_columns
- ]
- ]
- def _setup_orm_returning(
- self,
- compiler,
- orm_level_statement,
- dml_level_statement,
- dml_mapper,
- *,
- use_supplemental_cols=True,
- ):
- """establish ORM column handlers for an INSERT, UPDATE, or DELETE
- which uses explicit returning().
- called within compilation level create_for_statement.
- The _return_orm_returning() method then receives the Result
- after the statement was executed, and applies ORM loading to the
- state that we first established here.
- """
- if orm_level_statement._returning:
- fs = FromStatement(
- orm_level_statement._returning,
- dml_level_statement,
- _adapt_on_names=False,
- )
- fs = fs.execution_options(**orm_level_statement._execution_options)
- fs = fs.options(*orm_level_statement._with_options)
- self.select_statement = fs
- self.from_statement_ctx = fsc = (
- ORMFromStatementCompileState.create_for_statement(fs, compiler)
- )
- fsc.setup_dml_returning_compile_state(dml_mapper)
- dml_level_statement = dml_level_statement._generate()
- dml_level_statement._returning = ()
- cols_to_return = [c for c in fsc.primary_columns if c is not None]
- # since we are splicing result sets together, make sure there
- # are columns of some kind returned in each result set
- if not cols_to_return:
- cols_to_return.extend(dml_mapper.primary_key)
- if use_supplemental_cols:
- dml_level_statement = dml_level_statement.return_defaults(
- # this is a little weird looking, but by passing
- # primary key as the main list of cols, this tells
- # return_defaults to omit server-default cols (and
- # actually all cols, due to some weird thing we should
- # clean up in crud.py).
- # Since we have cols_to_return, just return what we asked
- # for (plus primary key, which ORM persistence needs since
- # we likely set bookkeeping=True here, which is another
- # whole thing...). We dont want to clutter the
- # statement up with lots of other cols the user didn't
- # ask for. see #9685
- *dml_mapper.primary_key,
- supplemental_cols=cols_to_return,
- )
- else:
- dml_level_statement = dml_level_statement.returning(
- *cols_to_return
- )
- return dml_level_statement
- @classmethod
- def _return_orm_returning(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- ):
- execution_context = result.context
- compile_state = execution_context.compiled.compile_state
- if (
- compile_state.from_statement_ctx
- and not compile_state.from_statement_ctx.compile_options._is_star
- ):
- load_options = execution_options.get(
- "_sa_orm_load_options", QueryContext.default_load_options
- )
- querycontext = QueryContext(
- compile_state.from_statement_ctx,
- compile_state.select_statement,
- statement,
- params,
- session,
- load_options,
- execution_options,
- bind_arguments,
- )
- return loading.instances(result, querycontext)
- else:
- return result
- class BulkUDCompileState(ORMDMLState):
- class default_update_options(Options):
- _dml_strategy: DMLStrategyArgument = "auto"
- _synchronize_session: SynchronizeSessionArgument = "auto"
- _can_use_returning: bool = False
- _is_delete_using: bool = False
- _is_update_from: bool = False
- _autoflush: bool = True
- _subject_mapper: Optional[Mapper[Any]] = None
- _resolved_values = EMPTY_DICT
- _eval_condition = None
- _matched_rows = None
- _identity_token = None
- _populate_existing: bool = False
- @classmethod
- def can_use_returning(
- cls,
- dialect: Dialect,
- mapper: Mapper[Any],
- *,
- is_multitable: bool = False,
- is_update_from: bool = False,
- is_delete_using: bool = False,
- is_executemany: bool = False,
- ) -> bool:
- raise NotImplementedError()
- @classmethod
- def orm_pre_session_exec(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- is_pre_event,
- ):
- (
- update_options,
- execution_options,
- ) = BulkUDCompileState.default_update_options.from_execution_options(
- "_sa_orm_update_options",
- {
- "synchronize_session",
- "autoflush",
- "populate_existing",
- "identity_token",
- "is_delete_using",
- "is_update_from",
- "dml_strategy",
- },
- execution_options,
- statement._execution_options,
- )
- bind_arguments["clause"] = statement
- try:
- plugin_subject = statement._propagate_attrs["plugin_subject"]
- except KeyError:
- assert False, "statement had 'orm' plugin but no plugin_subject"
- else:
- if plugin_subject:
- bind_arguments["mapper"] = plugin_subject.mapper
- update_options += {"_subject_mapper": plugin_subject.mapper}
- if "parententity" not in statement.table._annotations:
- update_options += {"_dml_strategy": "core_only"}
- elif not isinstance(params, list):
- if update_options._dml_strategy == "auto":
- update_options += {"_dml_strategy": "orm"}
- elif update_options._dml_strategy == "bulk":
- raise sa_exc.InvalidRequestError(
- 'Can\'t use "bulk" ORM insert strategy without '
- "passing separate parameters"
- )
- else:
- if update_options._dml_strategy == "auto":
- update_options += {"_dml_strategy": "bulk"}
- sync = update_options._synchronize_session
- if sync is not None:
- if sync not in ("auto", "evaluate", "fetch", False):
- raise sa_exc.ArgumentError(
- "Valid strategies for session synchronization "
- "are 'auto', 'evaluate', 'fetch', False"
- )
- if update_options._dml_strategy == "bulk" and sync == "fetch":
- raise sa_exc.InvalidRequestError(
- "The 'fetch' synchronization strategy is not available "
- "for 'bulk' ORM updates (i.e. multiple parameter sets)"
- )
- if not is_pre_event:
- if update_options._autoflush:
- session._autoflush()
- if update_options._dml_strategy == "orm":
- if update_options._synchronize_session == "auto":
- update_options = cls._do_pre_synchronize_auto(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- elif update_options._synchronize_session == "evaluate":
- update_options = cls._do_pre_synchronize_evaluate(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- elif update_options._synchronize_session == "fetch":
- update_options = cls._do_pre_synchronize_fetch(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- elif update_options._dml_strategy == "bulk":
- if update_options._synchronize_session == "auto":
- update_options += {"_synchronize_session": "evaluate"}
- # indicators from the "pre exec" step that are then
- # added to the DML statement, which will also be part of the cache
- # key. The compile level create_for_statement() method will then
- # consume these at compiler time.
- statement = statement._annotate(
- {
- "synchronize_session": update_options._synchronize_session,
- "is_delete_using": update_options._is_delete_using,
- "is_update_from": update_options._is_update_from,
- "dml_strategy": update_options._dml_strategy,
- "can_use_returning": update_options._can_use_returning,
- }
- )
- return (
- statement,
- util.immutabledict(execution_options).union(
- {"_sa_orm_update_options": update_options}
- ),
- )
- @classmethod
- def orm_setup_cursor_result(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- ):
- # this stage of the execution is called after the
- # do_orm_execute event hook. meaning for an extension like
- # horizontal sharding, this step happens *within* the horizontal
- # sharding event handler which calls session.execute() re-entrantly
- # and will occur for each backend individually.
- # the sharding extension then returns its own merged result from the
- # individual ones we return here.
- update_options = execution_options["_sa_orm_update_options"]
- if update_options._dml_strategy == "orm":
- if update_options._synchronize_session == "evaluate":
- cls._do_post_synchronize_evaluate(
- session, statement, result, update_options
- )
- elif update_options._synchronize_session == "fetch":
- cls._do_post_synchronize_fetch(
- session, statement, result, update_options
- )
- elif update_options._dml_strategy == "bulk":
- if update_options._synchronize_session == "evaluate":
- cls._do_post_synchronize_bulk_evaluate(
- session, params, result, update_options
- )
- return result
- return cls._return_orm_returning(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- )
- @classmethod
- def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
- """Apply extra criteria filtering.
- For all distinct single-table-inheritance mappers represented in the
- table being updated or deleted, produce additional WHERE criteria such
- that only the appropriate subtypes are selected from the total results.
- Additionally, add WHERE criteria originating from LoaderCriteriaOptions
- collected from the statement.
- """
- return_crit = ()
- adapter = ext_info._adapter if ext_info.is_aliased_class else None
- if (
- "additional_entity_criteria",
- ext_info.mapper,
- ) in global_attributes:
- return_crit += tuple(
- ae._resolve_where_criteria(ext_info)
- for ae in global_attributes[
- ("additional_entity_criteria", ext_info.mapper)
- ]
- if ae.include_aliases or ae.entity is ext_info
- )
- if ext_info.mapper._single_table_criterion is not None:
- return_crit += (ext_info.mapper._single_table_criterion,)
- if adapter:
- return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
- return return_crit
- @classmethod
- def _interpret_returning_rows(cls, result, mapper, rows):
- """return rows that indicate PK cols in mapper.primary_key position
- for RETURNING rows.
- Prior to 2.0.36, this method seemed to be written for some kind of
- inheritance scenario but the scenario was unused for actual joined
- inheritance, and the function instead seemed to perform some kind of
- partial translation that would remove non-PK cols if the PK cols
- happened to be first in the row, but not otherwise. The joined
- inheritance walk feature here seems to have never been used as it was
- always skipped by the "local_table" check.
- As of 2.0.36 the function strips away non-PK cols and provides the
- PK cols for the table in mapper PK order.
- """
- try:
- if mapper.local_table is not mapper.base_mapper.local_table:
- # TODO: dive more into how a local table PK is used for fetch
- # sync, not clear if this is correct as it depends on the
- # downstream routine to fetch rows using
- # local_table.primary_key order
- pk_keys = result._tuple_getter(mapper.local_table.primary_key)
- else:
- pk_keys = result._tuple_getter(mapper.primary_key)
- except KeyError:
- # can't use these rows, they don't have PK cols in them
- # this is an unusual case where the user would have used
- # .return_defaults()
- return []
- return [pk_keys(row) for row in rows]
- @classmethod
- def _get_matched_objects_on_criteria(cls, update_options, states):
- mapper = update_options._subject_mapper
- eval_condition = update_options._eval_condition
- raw_data = [
- (state.obj(), state, state.dict)
- for state in states
- if state.mapper.isa(mapper) and not state.expired
- ]
- identity_token = update_options._identity_token
- if identity_token is not None:
- raw_data = [
- (obj, state, dict_)
- for obj, state, dict_ in raw_data
- if state.identity_token == identity_token
- ]
- result = []
- for obj, state, dict_ in raw_data:
- evaled_condition = eval_condition(obj)
- # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
- # evaluates as True for all comparisons
- if (
- evaled_condition is True
- or evaled_condition is evaluator._EXPIRED_OBJECT
- ):
- result.append(
- (
- obj,
- state,
- dict_,
- evaled_condition is evaluator._EXPIRED_OBJECT,
- )
- )
- return result
- @classmethod
- def _eval_condition_from_statement(cls, update_options, statement):
- mapper = update_options._subject_mapper
- target_cls = mapper.class_
- evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
- crit = ()
- if statement._where_criteria:
- crit += statement._where_criteria
- global_attributes = {}
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(global_attributes)
- if global_attributes:
- crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
- if crit:
- eval_condition = evaluator_compiler.process(*crit)
- else:
- # workaround for mypy https://github.com/python/mypy/issues/14027
- def _eval_condition(obj):
- return True
- eval_condition = _eval_condition
- return eval_condition
- @classmethod
- def _do_pre_synchronize_auto(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- ):
- """setup auto sync strategy
- "auto" checks if we can use "evaluate" first, then falls back
- to "fetch"
- evaluate is vastly more efficient for the common case
- where session is empty, only has a few objects, and the UPDATE
- statement can potentially match thousands/millions of rows.
- OTOH more complex criteria that fails to work with "evaluate"
- we would hope usually correlates with fewer net rows.
- """
- try:
- eval_condition = cls._eval_condition_from_statement(
- update_options, statement
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- return update_options + {
- "_eval_condition": eval_condition,
- "_synchronize_session": "evaluate",
- }
- update_options += {"_synchronize_session": "fetch"}
- return cls._do_pre_synchronize_fetch(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- @classmethod
- def _do_pre_synchronize_evaluate(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- ):
- try:
- eval_condition = cls._eval_condition_from_statement(
- update_options, statement
- )
- except evaluator.UnevaluatableError as err:
- raise sa_exc.InvalidRequestError(
- 'Could not evaluate current criteria in Python: "%s". '
- "Specify 'fetch' or False for the "
- "synchronize_session execution option." % err
- ) from err
- return update_options + {
- "_eval_condition": eval_condition,
- }
- @classmethod
- def _get_resolved_values(cls, mapper, statement):
- if statement._multi_values:
- return []
- elif statement._ordered_values:
- return list(statement._ordered_values)
- elif statement._values:
- return list(statement._values.items())
- else:
- return []
- @classmethod
- def _resolved_keys_as_propnames(cls, mapper, resolved_values):
- values = []
- for k, v in resolved_values:
- if mapper and isinstance(k, expression.ColumnElement):
- try:
- attr = mapper._columntoproperty[k]
- except orm_exc.UnmappedColumnError:
- pass
- else:
- values.append((attr.key, v))
- else:
- raise sa_exc.InvalidRequestError(
- "Attribute name not found, can't be "
- "synchronized back to objects: %r" % k
- )
- return values
- @classmethod
- def _do_pre_synchronize_fetch(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- ):
- mapper = update_options._subject_mapper
- select_stmt = (
- select(*(mapper.primary_key + (mapper.select_identity_token,)))
- .select_from(mapper)
- .options(*statement._with_options)
- )
- select_stmt._where_criteria = statement._where_criteria
- # conditionally run the SELECT statement for pre-fetch, testing the
- # "bind" for if we can use RETURNING or not using the do_orm_execute
- # event. If RETURNING is available, the do_orm_execute event
- # will cancel the SELECT from being actually run.
- #
- # The way this is organized seems strange, why don't we just
- # call can_use_returning() before invoking the statement and get
- # answer?, why does this go through the whole execute phase using an
- # event? Answer: because we are integrating with extensions such
- # as the horizontal sharding extention that "multiplexes" an individual
- # statement run through multiple engines, and it uses
- # do_orm_execute() to do that.
- can_use_returning = None
- def skip_for_returning(orm_context: ORMExecuteState) -> Any:
- bind = orm_context.session.get_bind(**orm_context.bind_arguments)
- nonlocal can_use_returning
- per_bind_result = cls.can_use_returning(
- bind.dialect,
- mapper,
- is_update_from=update_options._is_update_from,
- is_delete_using=update_options._is_delete_using,
- is_executemany=orm_context.is_executemany,
- )
- if can_use_returning is not None:
- if can_use_returning != per_bind_result:
- raise sa_exc.InvalidRequestError(
- "For synchronize_session='fetch', can't mix multiple "
- "backends where some support RETURNING and others "
- "don't"
- )
- elif orm_context.is_executemany and not per_bind_result:
- raise sa_exc.InvalidRequestError(
- "For synchronize_session='fetch', can't use multiple "
- "parameter sets in ORM mode, which this backend does not "
- "support with RETURNING"
- )
- else:
- can_use_returning = per_bind_result
- if per_bind_result:
- return _result.null_result()
- else:
- return None
- result = session.execute(
- select_stmt,
- params,
- execution_options=execution_options,
- bind_arguments=bind_arguments,
- _add_event=skip_for_returning,
- )
- matched_rows = result.fetchall()
- return update_options + {
- "_matched_rows": matched_rows,
- "_can_use_returning": can_use_returning,
- }
- @CompileState.plugin_for("orm", "insert")
- class BulkORMInsert(ORMDMLState, InsertDMLState):
- class default_insert_options(Options):
- _dml_strategy: DMLStrategyArgument = "auto"
- _render_nulls: bool = False
- _return_defaults: bool = False
- _subject_mapper: Optional[Mapper[Any]] = None
- _autoflush: bool = True
- _populate_existing: bool = False
- select_statement: Optional[FromStatement] = None
- @classmethod
- def orm_pre_session_exec(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- is_pre_event,
- ):
- (
- insert_options,
- execution_options,
- ) = BulkORMInsert.default_insert_options.from_execution_options(
- "_sa_orm_insert_options",
- {"dml_strategy", "autoflush", "populate_existing", "render_nulls"},
- execution_options,
- statement._execution_options,
- )
- bind_arguments["clause"] = statement
- try:
- plugin_subject = statement._propagate_attrs["plugin_subject"]
- except KeyError:
- assert False, "statement had 'orm' plugin but no plugin_subject"
- else:
- if plugin_subject:
- bind_arguments["mapper"] = plugin_subject.mapper
- insert_options += {"_subject_mapper": plugin_subject.mapper}
- if not params:
- if insert_options._dml_strategy == "auto":
- insert_options += {"_dml_strategy": "orm"}
- elif insert_options._dml_strategy == "bulk":
- raise sa_exc.InvalidRequestError(
- 'Can\'t use "bulk" ORM insert strategy without '
- "passing separate parameters"
- )
- else:
- if insert_options._dml_strategy == "auto":
- insert_options += {"_dml_strategy": "bulk"}
- if insert_options._dml_strategy != "raw":
- # for ORM object loading, like ORMContext, we have to disable
- # result set adapt_to_context, because we will be generating a
- # new statement with specific columns that's cached inside of
- # an ORMFromStatementCompileState, which we will re-use for
- # each result.
- if not execution_options:
- execution_options = context._orm_load_exec_options
- else:
- execution_options = execution_options.union(
- context._orm_load_exec_options
- )
- if not is_pre_event and insert_options._autoflush:
- session._autoflush()
- statement = statement._annotate(
- {"dml_strategy": insert_options._dml_strategy}
- )
- return (
- statement,
- util.immutabledict(execution_options).union(
- {"_sa_orm_insert_options": insert_options}
- ),
- )
- @classmethod
- def orm_execute_statement(
- cls,
- session: Session,
- statement: dml.Insert,
- params: _CoreAnyExecuteParams,
- execution_options: OrmExecuteOptionsParameter,
- bind_arguments: _BindArguments,
- conn: Connection,
- ) -> _result.Result:
- insert_options = execution_options.get(
- "_sa_orm_insert_options", cls.default_insert_options
- )
- if insert_options._dml_strategy not in (
- "raw",
- "bulk",
- "orm",
- "auto",
- ):
- raise sa_exc.ArgumentError(
- "Valid strategies for ORM insert strategy "
- "are 'raw', 'orm', 'bulk', 'auto"
- )
- result: _result.Result[Any]
- if insert_options._dml_strategy == "raw":
- result = conn.execute(
- statement, params or {}, execution_options=execution_options
- )
- return result
- if insert_options._dml_strategy == "bulk":
- mapper = insert_options._subject_mapper
- if (
- statement._post_values_clause is not None
- and mapper._multiple_persistence_tables
- ):
- raise sa_exc.InvalidRequestError(
- "bulk INSERT with a 'post values' clause "
- "(typically upsert) not supported for multi-table "
- f"mapper {mapper}"
- )
- assert mapper is not None
- assert session._transaction is not None
- result = _bulk_insert(
- mapper,
- cast(
- "Iterable[Dict[str, Any]]",
- [params] if isinstance(params, dict) else params,
- ),
- session._transaction,
- isstates=False,
- return_defaults=insert_options._return_defaults,
- render_nulls=insert_options._render_nulls,
- use_orm_insert_stmt=statement,
- execution_options=execution_options,
- )
- elif insert_options._dml_strategy == "orm":
- result = conn.execute(
- statement, params or {}, execution_options=execution_options
- )
- else:
- raise AssertionError()
- if not bool(statement._returning):
- return result
- if insert_options._populate_existing:
- load_options = execution_options.get(
- "_sa_orm_load_options", QueryContext.default_load_options
- )
- load_options += {"_populate_existing": True}
- execution_options = execution_options.union(
- {"_sa_orm_load_options": load_options}
- )
- return cls._return_orm_returning(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- )
- @classmethod
- def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
- self = cast(
- BulkORMInsert,
- super().create_for_statement(statement, compiler, **kw),
- )
- if compiler is not None:
- toplevel = not compiler.stack
- else:
- toplevel = True
- if not toplevel:
- return self
- mapper = statement._propagate_attrs["plugin_subject"]
- dml_strategy = statement._annotations.get("dml_strategy", "raw")
- if dml_strategy == "bulk":
- self._setup_for_bulk_insert(compiler)
- elif dml_strategy == "orm":
- self._setup_for_orm_insert(compiler, mapper)
- return self
- @classmethod
- def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
- return {
- col.key if col is not None else k: v
- for col, k, v in (
- (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
- )
- }
- def _setup_for_orm_insert(self, compiler, mapper):
- statement = orm_level_statement = cast(dml.Insert, self.statement)
- statement = self._setup_orm_returning(
- compiler,
- orm_level_statement,
- statement,
- dml_mapper=mapper,
- use_supplemental_cols=False,
- )
- self.statement = statement
- def _setup_for_bulk_insert(self, compiler):
- """establish an INSERT statement within the context of
- bulk insert.
- This method will be within the "conn.execute()" call that is invoked
- by persistence._emit_insert_statement().
- """
- statement = orm_level_statement = cast(dml.Insert, self.statement)
- an = statement._annotations
- emit_insert_table, emit_insert_mapper = (
- an["_emit_insert_table"],
- an["_emit_insert_mapper"],
- )
- statement = statement._clone()
- statement.table = emit_insert_table
- if self._dict_parameters:
- self._dict_parameters = {
- col: val
- for col, val in self._dict_parameters.items()
- if col.table is emit_insert_table
- }
- statement = self._setup_orm_returning(
- compiler,
- orm_level_statement,
- statement,
- dml_mapper=emit_insert_mapper,
- use_supplemental_cols=True,
- )
- if (
- self.from_statement_ctx is not None
- and self.from_statement_ctx.compile_options._is_star
- ):
- raise sa_exc.CompileError(
- "Can't use RETURNING * with bulk ORM INSERT. "
- "Please use a different INSERT form, such as INSERT..VALUES "
- "or INSERT with a Core Connection"
- )
- self.statement = statement
- @CompileState.plugin_for("orm", "update")
- class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
- @classmethod
- def create_for_statement(cls, statement, compiler, **kw):
- self = cls.__new__(cls)
- dml_strategy = statement._annotations.get(
- "dml_strategy", "unspecified"
- )
- toplevel = not compiler.stack
- if toplevel and dml_strategy == "bulk":
- self._setup_for_bulk_update(statement, compiler)
- elif (
- dml_strategy == "core_only"
- or dml_strategy == "unspecified"
- and "parententity" not in statement.table._annotations
- ):
- UpdateDMLState.__init__(self, statement, compiler, **kw)
- elif not toplevel or dml_strategy in ("orm", "unspecified"):
- self._setup_for_orm_update(statement, compiler)
- return self
- def _setup_for_orm_update(self, statement, compiler, **kw):
- orm_level_statement = statement
- toplevel = not compiler.stack
- ext_info = statement.table._annotations["parententity"]
- self.mapper = mapper = ext_info.mapper
- self._resolved_values = self._get_resolved_values(mapper, statement)
- self._init_global_attributes(
- statement,
- compiler,
- toplevel=toplevel,
- process_criteria_for_toplevel=toplevel,
- )
- if statement._values:
- self._resolved_values = dict(self._resolved_values)
- new_stmt = statement._clone()
- if new_stmt.table._annotations["parententity"] is mapper:
- new_stmt.table = mapper.local_table
- # note if the statement has _multi_values, these
- # are passed through to the new statement, which will then raise
- # InvalidRequestError because UPDATE doesn't support multi_values
- # right now.
- if statement._ordered_values:
- new_stmt._ordered_values = self._resolved_values
- elif statement._values:
- new_stmt._values = self._resolved_values
- new_crit = self._adjust_for_extra_criteria(
- self.global_attributes, mapper
- )
- if new_crit:
- new_stmt = new_stmt.where(*new_crit)
- # if we are against a lambda statement we might not be the
- # topmost object that received per-execute annotations
- # do this first as we need to determine if there is
- # UPDATE..FROM
- UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
- use_supplemental_cols = False
- if not toplevel:
- synchronize_session = None
- else:
- synchronize_session = compiler._annotations.get(
- "synchronize_session", None
- )
- can_use_returning = compiler._annotations.get(
- "can_use_returning", None
- )
- if can_use_returning is not False:
- # even though pre_exec has determined basic
- # can_use_returning for the dialect, if we are to use
- # RETURNING we need to run can_use_returning() at this level
- # unconditionally because is_delete_using was not known
- # at the pre_exec level
- can_use_returning = (
- synchronize_session == "fetch"
- and self.can_use_returning(
- compiler.dialect, mapper, is_multitable=self.is_multitable
- )
- )
- if synchronize_session == "fetch" and can_use_returning:
- use_supplemental_cols = True
- # NOTE: we might want to RETURNING the actual columns to be
- # synchronized also. however this is complicated and difficult
- # to align against the behavior of "evaluate". Additionally,
- # in a large number (if not the majority) of cases, we have the
- # "evaluate" answer, usually a fixed value, in memory already and
- # there's no need to re-fetch the same value
- # over and over again. so perhaps if it could be RETURNING just
- # the elements that were based on a SQL expression and not
- # a constant. For now it doesn't quite seem worth it
- new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
- if toplevel:
- new_stmt = self._setup_orm_returning(
- compiler,
- orm_level_statement,
- new_stmt,
- dml_mapper=mapper,
- use_supplemental_cols=use_supplemental_cols,
- )
- self.statement = new_stmt
- def _setup_for_bulk_update(self, statement, compiler, **kw):
- """establish an UPDATE statement within the context of
- bulk insert.
- This method will be within the "conn.execute()" call that is invoked
- by persistence._emit_update_statement().
- """
- statement = cast(dml.Update, statement)
- an = statement._annotations
- emit_update_table, _ = (
- an["_emit_update_table"],
- an["_emit_update_mapper"],
- )
- statement = statement._clone()
- statement.table = emit_update_table
- UpdateDMLState.__init__(self, statement, compiler, **kw)
- if self._ordered_values:
- raise sa_exc.InvalidRequestError(
- "bulk ORM UPDATE does not support ordered_values() for "
- "custom UPDATE statements with bulk parameter sets. Use a "
- "non-bulk UPDATE statement or use values()."
- )
- if self._dict_parameters:
- self._dict_parameters = {
- col: val
- for col, val in self._dict_parameters.items()
- if col.table is emit_update_table
- }
- self.statement = statement
- @classmethod
- def orm_execute_statement(
- cls,
- session: Session,
- statement: dml.Update,
- params: _CoreAnyExecuteParams,
- execution_options: OrmExecuteOptionsParameter,
- bind_arguments: _BindArguments,
- conn: Connection,
- ) -> _result.Result:
- update_options = execution_options.get(
- "_sa_orm_update_options", cls.default_update_options
- )
- if update_options._populate_existing:
- load_options = execution_options.get(
- "_sa_orm_load_options", QueryContext.default_load_options
- )
- load_options += {"_populate_existing": True}
- execution_options = execution_options.union(
- {"_sa_orm_load_options": load_options}
- )
- if update_options._dml_strategy not in (
- "orm",
- "auto",
- "bulk",
- "core_only",
- ):
- raise sa_exc.ArgumentError(
- "Valid strategies for ORM UPDATE strategy "
- "are 'orm', 'auto', 'bulk', 'core_only'"
- )
- result: _result.Result[Any]
- if update_options._dml_strategy == "bulk":
- enable_check_rowcount = not statement._where_criteria
- assert update_options._synchronize_session != "fetch"
- if (
- statement._where_criteria
- and update_options._synchronize_session == "evaluate"
- ):
- raise sa_exc.InvalidRequestError(
- "bulk synchronize of persistent objects not supported "
- "when using bulk update with additional WHERE "
- "criteria right now. add synchronize_session=None "
- "execution option to bypass synchronize of persistent "
- "objects."
- )
- mapper = update_options._subject_mapper
- assert mapper is not None
- assert session._transaction is not None
- result = _bulk_update(
- mapper,
- cast(
- "Iterable[Dict[str, Any]]",
- [params] if isinstance(params, dict) else params,
- ),
- session._transaction,
- isstates=False,
- update_changed_only=False,
- use_orm_update_stmt=statement,
- enable_check_rowcount=enable_check_rowcount,
- )
- return cls.orm_setup_cursor_result(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- )
- else:
- return super().orm_execute_statement(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- conn,
- )
- @classmethod
- def can_use_returning(
- cls,
- dialect: Dialect,
- mapper: Mapper[Any],
- *,
- is_multitable: bool = False,
- is_update_from: bool = False,
- is_delete_using: bool = False,
- is_executemany: bool = False,
- ) -> bool:
- # normal answer for "should we use RETURNING" at all.
- normal_answer = (
- dialect.update_returning and mapper.local_table.implicit_returning
- )
- if not normal_answer:
- return False
- if is_executemany:
- return dialect.update_executemany_returning
- # these workarounds are currently hypothetical for UPDATE,
- # unlike DELETE where they impact MariaDB
- if is_update_from:
- return dialect.update_returning_multifrom
- elif is_multitable and not dialect.update_returning_multifrom:
- raise sa_exc.CompileError(
- f'Dialect "{dialect.name}" does not support RETURNING '
- "with UPDATE..FROM; for synchronize_session='fetch', "
- "please add the additional execution option "
- "'is_update_from=True' to the statement to indicate that "
- "a separate SELECT should be used for this backend."
- )
- return True
- @classmethod
- def _do_post_synchronize_bulk_evaluate(
- cls, session, params, result, update_options
- ):
- if not params:
- return
- mapper = update_options._subject_mapper
- pk_keys = [prop.key for prop in mapper._identity_key_props]
- identity_map = session.identity_map
- for param in params:
- identity_key = mapper.identity_key_from_primary_key(
- (param[key] for key in pk_keys),
- update_options._identity_token,
- )
- state = identity_map.fast_get_state(identity_key)
- if not state:
- continue
- evaluated_keys = set(param).difference(pk_keys)
- dict_ = state.dict
- # only evaluate unmodified attributes
- to_evaluate = state.unmodified.intersection(evaluated_keys)
- for key in to_evaluate:
- if key in dict_:
- dict_[key] = param[key]
- state.manager.dispatch.refresh(state, None, to_evaluate)
- state._commit(dict_, list(to_evaluate))
- # attributes that were formerly modified instead get expired.
- # this only gets hit if the session had pending changes
- # and autoflush were set to False.
- to_expire = evaluated_keys.intersection(dict_).difference(
- to_evaluate
- )
- if to_expire:
- state._expire_attributes(dict_, to_expire)
- @classmethod
- def _do_post_synchronize_evaluate(
- cls, session, statement, result, update_options
- ):
- matched_objects = cls._get_matched_objects_on_criteria(
- update_options,
- session.identity_map.all_states(),
- )
- cls._apply_update_set_values_to_objects(
- session,
- update_options,
- statement,
- result.context.compiled_parameters[0],
- [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
- result.prefetch_cols(),
- result.postfetch_cols(),
- )
- @classmethod
- def _do_post_synchronize_fetch(
- cls, session, statement, result, update_options
- ):
- target_mapper = update_options._subject_mapper
- returned_defaults_rows = result.returned_defaults_rows
- if returned_defaults_rows:
- pk_rows = cls._interpret_returning_rows(
- result, target_mapper, returned_defaults_rows
- )
- matched_rows = [
- tuple(row) + (update_options._identity_token,)
- for row in pk_rows
- ]
- else:
- matched_rows = update_options._matched_rows
- objs = [
- session.identity_map[identity_key]
- for identity_key in [
- target_mapper.identity_key_from_primary_key(
- list(primary_key),
- identity_token=identity_token,
- )
- for primary_key, identity_token in [
- (row[0:-1], row[-1]) for row in matched_rows
- ]
- if update_options._identity_token is None
- or identity_token == update_options._identity_token
- ]
- if identity_key in session.identity_map
- ]
- if not objs:
- return
- cls._apply_update_set_values_to_objects(
- session,
- update_options,
- statement,
- result.context.compiled_parameters[0],
- [
- (
- obj,
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
- )
- for obj in objs
- ],
- result.prefetch_cols(),
- result.postfetch_cols(),
- )
- @classmethod
- def _apply_update_set_values_to_objects(
- cls,
- session,
- update_options,
- statement,
- effective_params,
- matched_objects,
- prefetch_cols,
- postfetch_cols,
- ):
- """apply values to objects derived from an update statement, e.g.
- UPDATE..SET <values>
- """
- mapper = update_options._subject_mapper
- target_cls = mapper.class_
- evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
- resolved_values = cls._get_resolved_values(mapper, statement)
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- value_evaluators = {}
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
- evaluated_keys = list(value_evaluators.keys())
- attrib = {k for k, v in resolved_keys_as_propnames}
- states = set()
- to_prefetch = {
- c
- for c in prefetch_cols
- if c.key in effective_params
- and c in mapper._columntoproperty
- and c.key not in evaluated_keys
- }
- to_expire = {
- mapper._columntoproperty[c].key
- for c in postfetch_cols
- if c in mapper._columntoproperty
- }.difference(evaluated_keys)
- prefetch_transfer = [
- (mapper._columntoproperty[c].key, c.key) for c in to_prefetch
- ]
- for obj, state, dict_ in matched_objects:
- dict_.update(
- {
- col_to_prop: effective_params[c_key]
- for col_to_prop, c_key in prefetch_transfer
- }
- )
- state._expire_attributes(state.dict, to_expire)
- to_evaluate = state.unmodified.intersection(evaluated_keys)
- for key in to_evaluate:
- if key in dict_:
- # only run eval for attributes that are present.
- dict_[key] = value_evaluators[key](obj)
- state.manager.dispatch.refresh(state, None, to_evaluate)
- state._commit(dict_, list(to_evaluate))
- # attributes that were formerly modified instead get expired.
- # this only gets hit if the session had pending changes
- # and autoflush were set to False.
- to_expire = attrib.intersection(dict_).difference(to_evaluate)
- if to_expire:
- state._expire_attributes(dict_, to_expire)
- states.add(state)
- session._register_altered(states)
- @CompileState.plugin_for("orm", "delete")
- class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
- @classmethod
- def create_for_statement(cls, statement, compiler, **kw):
- self = cls.__new__(cls)
- dml_strategy = statement._annotations.get(
- "dml_strategy", "unspecified"
- )
- if (
- dml_strategy == "core_only"
- or dml_strategy == "unspecified"
- and "parententity" not in statement.table._annotations
- ):
- DeleteDMLState.__init__(self, statement, compiler, **kw)
- return self
- toplevel = not compiler.stack
- orm_level_statement = statement
- ext_info = statement.table._annotations["parententity"]
- self.mapper = mapper = ext_info.mapper
- self._init_global_attributes(
- statement,
- compiler,
- toplevel=toplevel,
- process_criteria_for_toplevel=toplevel,
- )
- new_stmt = statement._clone()
- if new_stmt.table._annotations["parententity"] is mapper:
- new_stmt.table = mapper.local_table
- new_crit = cls._adjust_for_extra_criteria(
- self.global_attributes, mapper
- )
- if new_crit:
- new_stmt = new_stmt.where(*new_crit)
- # do this first as we need to determine if there is
- # DELETE..FROM
- DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
- use_supplemental_cols = False
- if not toplevel:
- synchronize_session = None
- else:
- synchronize_session = compiler._annotations.get(
- "synchronize_session", None
- )
- can_use_returning = compiler._annotations.get(
- "can_use_returning", None
- )
- if can_use_returning is not False:
- # even though pre_exec has determined basic
- # can_use_returning for the dialect, if we are to use
- # RETURNING we need to run can_use_returning() at this level
- # unconditionally because is_delete_using was not known
- # at the pre_exec level
- can_use_returning = (
- synchronize_session == "fetch"
- and self.can_use_returning(
- compiler.dialect,
- mapper,
- is_multitable=self.is_multitable,
- is_delete_using=compiler._annotations.get(
- "is_delete_using", False
- ),
- )
- )
- if can_use_returning:
- use_supplemental_cols = True
- new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
- if toplevel:
- new_stmt = self._setup_orm_returning(
- compiler,
- orm_level_statement,
- new_stmt,
- dml_mapper=mapper,
- use_supplemental_cols=use_supplemental_cols,
- )
- self.statement = new_stmt
- return self
- @classmethod
- def orm_execute_statement(
- cls,
- session: Session,
- statement: dml.Delete,
- params: _CoreAnyExecuteParams,
- execution_options: OrmExecuteOptionsParameter,
- bind_arguments: _BindArguments,
- conn: Connection,
- ) -> _result.Result:
- update_options = execution_options.get(
- "_sa_orm_update_options", cls.default_update_options
- )
- if update_options._dml_strategy == "bulk":
- raise sa_exc.InvalidRequestError(
- "Bulk ORM DELETE not supported right now. "
- "Statement may be invoked at the "
- "Core level using "
- "session.connection().execute(stmt, parameters)"
- )
- if update_options._dml_strategy not in ("orm", "auto", "core_only"):
- raise sa_exc.ArgumentError(
- "Valid strategies for ORM DELETE strategy are 'orm', 'auto', "
- "'core_only'"
- )
- return super().orm_execute_statement(
- session, statement, params, execution_options, bind_arguments, conn
- )
- @classmethod
- def can_use_returning(
- cls,
- dialect: Dialect,
- mapper: Mapper[Any],
- *,
- is_multitable: bool = False,
- is_update_from: bool = False,
- is_delete_using: bool = False,
- is_executemany: bool = False,
- ) -> bool:
- # normal answer for "should we use RETURNING" at all.
- normal_answer = (
- dialect.delete_returning and mapper.local_table.implicit_returning
- )
- if not normal_answer:
- return False
- # now get into special workarounds because MariaDB supports
- # DELETE...RETURNING but not DELETE...USING...RETURNING.
- if is_delete_using:
- # is_delete_using hint was passed. use
- # additional dialect feature (True for PG, False for MariaDB)
- return dialect.delete_returning_multifrom
- elif is_multitable and not dialect.delete_returning_multifrom:
- # is_delete_using hint was not passed, but we determined
- # at compile time that this is in fact a DELETE..USING.
- # it's too late to continue since we did not pre-SELECT.
- # raise that we need that hint up front.
- raise sa_exc.CompileError(
- f'Dialect "{dialect.name}" does not support RETURNING '
- "with DELETE..USING; for synchronize_session='fetch', "
- "please add the additional execution option "
- "'is_delete_using=True' to the statement to indicate that "
- "a separate SELECT should be used for this backend."
- )
- return True
- @classmethod
- def _do_post_synchronize_evaluate(
- cls, session, statement, result, update_options
- ):
- matched_objects = cls._get_matched_objects_on_criteria(
- update_options,
- session.identity_map.all_states(),
- )
- to_delete = []
- for _, state, dict_, is_partially_expired in matched_objects:
- if is_partially_expired:
- state._expire(dict_, session.identity_map._modified)
- else:
- to_delete.append(state)
- if to_delete:
- session._remove_newly_deleted(to_delete)
- @classmethod
- def _do_post_synchronize_fetch(
- cls, session, statement, result, update_options
- ):
- target_mapper = update_options._subject_mapper
- returned_defaults_rows = result.returned_defaults_rows
- if returned_defaults_rows:
- pk_rows = cls._interpret_returning_rows(
- result, target_mapper, returned_defaults_rows
- )
- matched_rows = [
- tuple(row) + (update_options._identity_token,)
- for row in pk_rows
- ]
- else:
- matched_rows = update_options._matched_rows
- for row in matched_rows:
- primary_key = row[0:-1]
- identity_token = row[-1]
- # TODO: inline this and call remove_newly_deleted
- # once
- identity_key = target_mapper.identity_key_from_primary_key(
- list(primary_key),
- identity_token=identity_token,
- )
- if identity_key in session.identity_map:
- session._remove_newly_deleted(
- [
- attributes.instance_state(
- session.identity_map[identity_key]
- )
- ]
- )
|