persistence.py 60 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788
  1. # orm/persistence.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: ignore-errors
  8. """private module containing functions used to emit INSERT, UPDATE
  9. and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
  10. mappers.
  11. The functions here are called only by the unit of work functions
  12. in unitofwork.py.
  13. """
  14. from __future__ import annotations
  15. from itertools import chain
  16. from itertools import groupby
  17. from itertools import zip_longest
  18. import operator
  19. from . import attributes
  20. from . import exc as orm_exc
  21. from . import loading
  22. from . import sync
  23. from .base import state_str
  24. from .. import exc as sa_exc
  25. from .. import future
  26. from .. import sql
  27. from .. import util
  28. from ..engine import cursor as _cursor
  29. from ..sql import operators
  30. from ..sql.elements import BooleanClauseList
  31. from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
  32. def save_obj(base_mapper, states, uowtransaction, single=False):
  33. """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
  34. of objects.
  35. This is called within the context of a UOWTransaction during a
  36. flush operation, given a list of states to be flushed. The
  37. base mapper in an inheritance hierarchy handles the inserts/
  38. updates for all descendant mappers.
  39. """
  40. # if batch=false, call _save_obj separately for each object
  41. if not single and not base_mapper.batch:
  42. for state in _sort_states(base_mapper, states):
  43. save_obj(base_mapper, [state], uowtransaction, single=True)
  44. return
  45. states_to_update = []
  46. states_to_insert = []
  47. for (
  48. state,
  49. dict_,
  50. mapper,
  51. connection,
  52. has_identity,
  53. row_switch,
  54. update_version_id,
  55. ) in _organize_states_for_save(base_mapper, states, uowtransaction):
  56. if has_identity or row_switch:
  57. states_to_update.append(
  58. (state, dict_, mapper, connection, update_version_id)
  59. )
  60. else:
  61. states_to_insert.append((state, dict_, mapper, connection))
  62. for table, mapper in base_mapper._sorted_tables.items():
  63. if table not in mapper._pks_by_table:
  64. continue
  65. insert = _collect_insert_commands(table, states_to_insert)
  66. update = _collect_update_commands(
  67. uowtransaction, table, states_to_update
  68. )
  69. _emit_update_statements(
  70. base_mapper,
  71. uowtransaction,
  72. mapper,
  73. table,
  74. update,
  75. )
  76. _emit_insert_statements(
  77. base_mapper,
  78. uowtransaction,
  79. mapper,
  80. table,
  81. insert,
  82. )
  83. _finalize_insert_update_commands(
  84. base_mapper,
  85. uowtransaction,
  86. chain(
  87. (
  88. (state, state_dict, mapper, connection, False)
  89. for (state, state_dict, mapper, connection) in states_to_insert
  90. ),
  91. (
  92. (state, state_dict, mapper, connection, True)
  93. for (
  94. state,
  95. state_dict,
  96. mapper,
  97. connection,
  98. update_version_id,
  99. ) in states_to_update
  100. ),
  101. ),
  102. )
  103. def post_update(base_mapper, states, uowtransaction, post_update_cols):
  104. """Issue UPDATE statements on behalf of a relationship() which
  105. specifies post_update.
  106. """
  107. states_to_update = list(
  108. _organize_states_for_post_update(base_mapper, states, uowtransaction)
  109. )
  110. for table, mapper in base_mapper._sorted_tables.items():
  111. if table not in mapper._pks_by_table:
  112. continue
  113. update = (
  114. (
  115. state,
  116. state_dict,
  117. sub_mapper,
  118. connection,
  119. (
  120. mapper._get_committed_state_attr_by_column(
  121. state, state_dict, mapper.version_id_col
  122. )
  123. if mapper.version_id_col is not None
  124. else None
  125. ),
  126. )
  127. for state, state_dict, sub_mapper, connection in states_to_update
  128. if table in sub_mapper._pks_by_table
  129. )
  130. update = _collect_post_update_commands(
  131. base_mapper, uowtransaction, table, update, post_update_cols
  132. )
  133. _emit_post_update_statements(
  134. base_mapper,
  135. uowtransaction,
  136. mapper,
  137. table,
  138. update,
  139. )
  140. def delete_obj(base_mapper, states, uowtransaction):
  141. """Issue ``DELETE`` statements for a list of objects.
  142. This is called within the context of a UOWTransaction during a
  143. flush operation.
  144. """
  145. states_to_delete = list(
  146. _organize_states_for_delete(base_mapper, states, uowtransaction)
  147. )
  148. table_to_mapper = base_mapper._sorted_tables
  149. for table in reversed(list(table_to_mapper.keys())):
  150. mapper = table_to_mapper[table]
  151. if table not in mapper._pks_by_table:
  152. continue
  153. elif mapper.inherits and mapper.passive_deletes:
  154. continue
  155. delete = _collect_delete_commands(
  156. base_mapper, uowtransaction, table, states_to_delete
  157. )
  158. _emit_delete_statements(
  159. base_mapper,
  160. uowtransaction,
  161. mapper,
  162. table,
  163. delete,
  164. )
  165. for (
  166. state,
  167. state_dict,
  168. mapper,
  169. connection,
  170. update_version_id,
  171. ) in states_to_delete:
  172. mapper.dispatch.after_delete(mapper, connection, state)
  173. def _organize_states_for_save(base_mapper, states, uowtransaction):
  174. """Make an initial pass across a set of states for INSERT or
  175. UPDATE.
  176. This includes splitting out into distinct lists for
  177. each, calling before_insert/before_update, obtaining
  178. key information for each state including its dictionary,
  179. mapper, the connection to use for the execution per state,
  180. and the identity flag.
  181. """
  182. for state, dict_, mapper, connection in _connections_for_states(
  183. base_mapper, uowtransaction, states
  184. ):
  185. has_identity = bool(state.key)
  186. instance_key = state.key or mapper._identity_key_from_state(state)
  187. row_switch = update_version_id = None
  188. # call before_XXX extensions
  189. if not has_identity:
  190. mapper.dispatch.before_insert(mapper, connection, state)
  191. else:
  192. mapper.dispatch.before_update(mapper, connection, state)
  193. if mapper._validate_polymorphic_identity:
  194. mapper._validate_polymorphic_identity(mapper, state, dict_)
  195. # detect if we have a "pending" instance (i.e. has
  196. # no instance_key attached to it), and another instance
  197. # with the same identity key already exists as persistent.
  198. # convert to an UPDATE if so.
  199. if (
  200. not has_identity
  201. and instance_key in uowtransaction.session.identity_map
  202. ):
  203. instance = uowtransaction.session.identity_map[instance_key]
  204. existing = attributes.instance_state(instance)
  205. if not uowtransaction.was_already_deleted(existing):
  206. if not uowtransaction.is_deleted(existing):
  207. util.warn(
  208. "New instance %s with identity key %s conflicts "
  209. "with persistent instance %s"
  210. % (state_str(state), instance_key, state_str(existing))
  211. )
  212. else:
  213. base_mapper._log_debug(
  214. "detected row switch for identity %s. "
  215. "will update %s, remove %s from "
  216. "transaction",
  217. instance_key,
  218. state_str(state),
  219. state_str(existing),
  220. )
  221. # remove the "delete" flag from the existing element
  222. uowtransaction.remove_state_actions(existing)
  223. row_switch = existing
  224. if (has_identity or row_switch) and mapper.version_id_col is not None:
  225. update_version_id = mapper._get_committed_state_attr_by_column(
  226. row_switch if row_switch else state,
  227. row_switch.dict if row_switch else dict_,
  228. mapper.version_id_col,
  229. )
  230. yield (
  231. state,
  232. dict_,
  233. mapper,
  234. connection,
  235. has_identity,
  236. row_switch,
  237. update_version_id,
  238. )
  239. def _organize_states_for_post_update(base_mapper, states, uowtransaction):
  240. """Make an initial pass across a set of states for UPDATE
  241. corresponding to post_update.
  242. This includes obtaining key information for each state
  243. including its dictionary, mapper, the connection to use for
  244. the execution per state.
  245. """
  246. return _connections_for_states(base_mapper, uowtransaction, states)
  247. def _organize_states_for_delete(base_mapper, states, uowtransaction):
  248. """Make an initial pass across a set of states for DELETE.
  249. This includes calling out before_delete and obtaining
  250. key information for each state including its dictionary,
  251. mapper, the connection to use for the execution per state.
  252. """
  253. for state, dict_, mapper, connection in _connections_for_states(
  254. base_mapper, uowtransaction, states
  255. ):
  256. mapper.dispatch.before_delete(mapper, connection, state)
  257. if mapper.version_id_col is not None:
  258. update_version_id = mapper._get_committed_state_attr_by_column(
  259. state, dict_, mapper.version_id_col
  260. )
  261. else:
  262. update_version_id = None
  263. yield (state, dict_, mapper, connection, update_version_id)
  264. def _collect_insert_commands(
  265. table,
  266. states_to_insert,
  267. *,
  268. bulk=False,
  269. return_defaults=False,
  270. render_nulls=False,
  271. include_bulk_keys=(),
  272. ):
  273. """Identify sets of values to use in INSERT statements for a
  274. list of states.
  275. """
  276. for state, state_dict, mapper, connection in states_to_insert:
  277. if table not in mapper._pks_by_table:
  278. continue
  279. params = {}
  280. value_params = {}
  281. propkey_to_col = mapper._propkey_to_col[table]
  282. eval_none = mapper._insert_cols_evaluating_none[table]
  283. for propkey in set(propkey_to_col).intersection(state_dict):
  284. value = state_dict[propkey]
  285. col = propkey_to_col[propkey]
  286. if value is None and col not in eval_none and not render_nulls:
  287. continue
  288. elif not bulk and (
  289. hasattr(value, "__clause_element__")
  290. or isinstance(value, sql.ClauseElement)
  291. ):
  292. value_params[col] = (
  293. value.__clause_element__()
  294. if hasattr(value, "__clause_element__")
  295. else value
  296. )
  297. else:
  298. params[col.key] = value
  299. if not bulk:
  300. # for all the columns that have no default and we don't have
  301. # a value and where "None" is not a special value, add
  302. # explicit None to the INSERT. This is a legacy behavior
  303. # which might be worth removing, as it should not be necessary
  304. # and also produces confusion, given that "missing" and None
  305. # now have distinct meanings
  306. for colkey in (
  307. mapper._insert_cols_as_none[table]
  308. .difference(params)
  309. .difference([c.key for c in value_params])
  310. ):
  311. params[colkey] = None
  312. if not bulk or return_defaults:
  313. # params are in terms of Column key objects, so
  314. # compare to pk_keys_by_table
  315. has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
  316. if mapper.base_mapper._prefer_eager_defaults(
  317. connection.dialect, table
  318. ):
  319. has_all_defaults = mapper._server_default_col_keys[
  320. table
  321. ].issubset(params)
  322. else:
  323. has_all_defaults = True
  324. else:
  325. has_all_defaults = has_all_pks = True
  326. if (
  327. mapper.version_id_generator is not False
  328. and mapper.version_id_col is not None
  329. and mapper.version_id_col in mapper._cols_by_table[table]
  330. ):
  331. params[mapper.version_id_col.key] = mapper.version_id_generator(
  332. None
  333. )
  334. if bulk:
  335. if mapper._set_polymorphic_identity:
  336. params.setdefault(
  337. mapper._polymorphic_attr_key, mapper.polymorphic_identity
  338. )
  339. if include_bulk_keys:
  340. params.update((k, state_dict[k]) for k in include_bulk_keys)
  341. yield (
  342. state,
  343. state_dict,
  344. params,
  345. mapper,
  346. connection,
  347. value_params,
  348. has_all_pks,
  349. has_all_defaults,
  350. )
  351. def _collect_update_commands(
  352. uowtransaction,
  353. table,
  354. states_to_update,
  355. *,
  356. bulk=False,
  357. use_orm_update_stmt=None,
  358. include_bulk_keys=(),
  359. ):
  360. """Identify sets of values to use in UPDATE statements for a
  361. list of states.
  362. This function works intricately with the history system
  363. to determine exactly what values should be updated
  364. as well as how the row should be matched within an UPDATE
  365. statement. Includes some tricky scenarios where the primary
  366. key of an object might have been changed.
  367. """
  368. for (
  369. state,
  370. state_dict,
  371. mapper,
  372. connection,
  373. update_version_id,
  374. ) in states_to_update:
  375. if table not in mapper._pks_by_table:
  376. continue
  377. pks = mapper._pks_by_table[table]
  378. if use_orm_update_stmt is not None:
  379. # TODO: ordered values, etc
  380. value_params = use_orm_update_stmt._values
  381. else:
  382. value_params = {}
  383. propkey_to_col = mapper._propkey_to_col[table]
  384. if bulk:
  385. # keys here are mapped attribute keys, so
  386. # look at mapper attribute keys for pk
  387. params = {
  388. propkey_to_col[propkey].key: state_dict[propkey]
  389. for propkey in set(propkey_to_col)
  390. .intersection(state_dict)
  391. .difference(mapper._pk_attr_keys_by_table[table])
  392. }
  393. has_all_defaults = True
  394. else:
  395. params = {}
  396. for propkey in set(propkey_to_col).intersection(
  397. state.committed_state
  398. ):
  399. value = state_dict[propkey]
  400. col = propkey_to_col[propkey]
  401. if hasattr(value, "__clause_element__") or isinstance(
  402. value, sql.ClauseElement
  403. ):
  404. value_params[col] = (
  405. value.__clause_element__()
  406. if hasattr(value, "__clause_element__")
  407. else value
  408. )
  409. # guard against values that generate non-__nonzero__
  410. # objects for __eq__()
  411. elif (
  412. state.manager[propkey].impl.is_equal(
  413. value, state.committed_state[propkey]
  414. )
  415. is not True
  416. ):
  417. params[col.key] = value
  418. if mapper.base_mapper.eager_defaults is True:
  419. has_all_defaults = (
  420. mapper._server_onupdate_default_col_keys[table]
  421. ).issubset(params)
  422. else:
  423. has_all_defaults = True
  424. if (
  425. update_version_id is not None
  426. and mapper.version_id_col in mapper._cols_by_table[table]
  427. ):
  428. if not bulk and not (params or value_params):
  429. # HACK: check for history in other tables, in case the
  430. # history is only in a different table than the one
  431. # where the version_id_col is. This logic was lost
  432. # from 0.9 -> 1.0.0 and restored in 1.0.6.
  433. for prop in mapper._columntoproperty.values():
  434. history = state.manager[prop.key].impl.get_history(
  435. state, state_dict, attributes.PASSIVE_NO_INITIALIZE
  436. )
  437. if history.added:
  438. break
  439. else:
  440. # no net change, break
  441. continue
  442. col = mapper.version_id_col
  443. no_params = not params and not value_params
  444. params[col._label] = update_version_id
  445. if (
  446. bulk or col.key not in params
  447. ) and mapper.version_id_generator is not False:
  448. val = mapper.version_id_generator(update_version_id)
  449. params[col.key] = val
  450. elif mapper.version_id_generator is False and no_params:
  451. # no version id generator, no values set on the table,
  452. # and version id wasn't manually incremented.
  453. # set version id to itself so we get an UPDATE
  454. # statement
  455. params[col.key] = update_version_id
  456. elif not (params or value_params):
  457. continue
  458. has_all_pks = True
  459. expect_pk_cascaded = False
  460. if bulk:
  461. # keys here are mapped attribute keys, so
  462. # look at mapper attribute keys for pk
  463. pk_params = {
  464. propkey_to_col[propkey]._label: state_dict.get(propkey)
  465. for propkey in set(propkey_to_col).intersection(
  466. mapper._pk_attr_keys_by_table[table]
  467. )
  468. }
  469. if util.NONE_SET.intersection(pk_params.values()):
  470. raise sa_exc.InvalidRequestError(
  471. f"No primary key value supplied for column(s) "
  472. f"""{
  473. ', '.join(
  474. str(c) for c in pks if pk_params[c._label] is None
  475. )
  476. }; """
  477. "per-row ORM Bulk UPDATE by Primary Key requires that "
  478. "records contain primary key values",
  479. code="bupq",
  480. )
  481. else:
  482. pk_params = {}
  483. for col in pks:
  484. propkey = mapper._columntoproperty[col].key
  485. history = state.manager[propkey].impl.get_history(
  486. state, state_dict, attributes.PASSIVE_OFF
  487. )
  488. if history.added:
  489. if (
  490. not history.deleted
  491. or ("pk_cascaded", state, col)
  492. in uowtransaction.attributes
  493. ):
  494. expect_pk_cascaded = True
  495. pk_params[col._label] = history.added[0]
  496. params.pop(col.key, None)
  497. else:
  498. # else, use the old value to locate the row
  499. pk_params[col._label] = history.deleted[0]
  500. if col in value_params:
  501. has_all_pks = False
  502. else:
  503. pk_params[col._label] = history.unchanged[0]
  504. if pk_params[col._label] is None:
  505. raise orm_exc.FlushError(
  506. "Can't update table %s using NULL for primary "
  507. "key value on column %s" % (table, col)
  508. )
  509. if include_bulk_keys:
  510. params.update((k, state_dict[k]) for k in include_bulk_keys)
  511. if params or value_params:
  512. params.update(pk_params)
  513. yield (
  514. state,
  515. state_dict,
  516. params,
  517. mapper,
  518. connection,
  519. value_params,
  520. has_all_defaults,
  521. has_all_pks,
  522. )
  523. elif expect_pk_cascaded:
  524. # no UPDATE occurs on this table, but we expect that CASCADE rules
  525. # have changed the primary key of the row; propagate this event to
  526. # other columns that expect to have been modified. this normally
  527. # occurs after the UPDATE is emitted however we invoke it here
  528. # explicitly in the absence of our invoking an UPDATE
  529. for m, equated_pairs in mapper._table_to_equated[table]:
  530. sync.populate(
  531. state,
  532. m,
  533. state,
  534. m,
  535. equated_pairs,
  536. uowtransaction,
  537. mapper.passive_updates,
  538. )
  539. def _collect_post_update_commands(
  540. base_mapper, uowtransaction, table, states_to_update, post_update_cols
  541. ):
  542. """Identify sets of values to use in UPDATE statements for a
  543. list of states within a post_update operation.
  544. """
  545. for (
  546. state,
  547. state_dict,
  548. mapper,
  549. connection,
  550. update_version_id,
  551. ) in states_to_update:
  552. # assert table in mapper._pks_by_table
  553. pks = mapper._pks_by_table[table]
  554. params = {}
  555. hasdata = False
  556. for col in mapper._cols_by_table[table]:
  557. if col in pks:
  558. params[col._label] = mapper._get_state_attr_by_column(
  559. state, state_dict, col, passive=attributes.PASSIVE_OFF
  560. )
  561. elif col in post_update_cols or col.onupdate is not None:
  562. prop = mapper._columntoproperty[col]
  563. history = state.manager[prop.key].impl.get_history(
  564. state, state_dict, attributes.PASSIVE_NO_INITIALIZE
  565. )
  566. if history.added:
  567. value = history.added[0]
  568. params[col.key] = value
  569. hasdata = True
  570. if hasdata:
  571. if (
  572. update_version_id is not None
  573. and mapper.version_id_col in mapper._cols_by_table[table]
  574. ):
  575. col = mapper.version_id_col
  576. params[col._label] = update_version_id
  577. if (
  578. bool(state.key)
  579. and col.key not in params
  580. and mapper.version_id_generator is not False
  581. ):
  582. val = mapper.version_id_generator(update_version_id)
  583. params[col.key] = val
  584. yield state, state_dict, mapper, connection, params
  585. def _collect_delete_commands(
  586. base_mapper, uowtransaction, table, states_to_delete
  587. ):
  588. """Identify values to use in DELETE statements for a list of
  589. states to be deleted."""
  590. for (
  591. state,
  592. state_dict,
  593. mapper,
  594. connection,
  595. update_version_id,
  596. ) in states_to_delete:
  597. if table not in mapper._pks_by_table:
  598. continue
  599. params = {}
  600. for col in mapper._pks_by_table[table]:
  601. params[col.key] = value = (
  602. mapper._get_committed_state_attr_by_column(
  603. state, state_dict, col
  604. )
  605. )
  606. if value is None:
  607. raise orm_exc.FlushError(
  608. "Can't delete from table %s "
  609. "using NULL for primary "
  610. "key value on column %s" % (table, col)
  611. )
  612. if (
  613. update_version_id is not None
  614. and mapper.version_id_col in mapper._cols_by_table[table]
  615. ):
  616. params[mapper.version_id_col.key] = update_version_id
  617. yield params, connection
  618. def _emit_update_statements(
  619. base_mapper,
  620. uowtransaction,
  621. mapper,
  622. table,
  623. update,
  624. *,
  625. bookkeeping=True,
  626. use_orm_update_stmt=None,
  627. enable_check_rowcount=True,
  628. ):
  629. """Emit UPDATE statements corresponding to value lists collected
  630. by _collect_update_commands()."""
  631. needs_version_id = (
  632. mapper.version_id_col is not None
  633. and mapper.version_id_col in mapper._cols_by_table[table]
  634. )
  635. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  636. def update_stmt(existing_stmt=None):
  637. clauses = BooleanClauseList._construct_raw(operators.and_)
  638. for col in mapper._pks_by_table[table]:
  639. clauses._append_inplace(
  640. col == sql.bindparam(col._label, type_=col.type)
  641. )
  642. if needs_version_id:
  643. clauses._append_inplace(
  644. mapper.version_id_col
  645. == sql.bindparam(
  646. mapper.version_id_col._label,
  647. type_=mapper.version_id_col.type,
  648. )
  649. )
  650. if existing_stmt is not None:
  651. stmt = existing_stmt.where(clauses)
  652. else:
  653. stmt = table.update().where(clauses)
  654. return stmt
  655. if use_orm_update_stmt is not None:
  656. cached_stmt = update_stmt(use_orm_update_stmt)
  657. else:
  658. cached_stmt = base_mapper._memo(("update", table), update_stmt)
  659. for (
  660. (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
  661. records,
  662. ) in groupby(
  663. update,
  664. lambda rec: (
  665. rec[4], # connection
  666. set(rec[2]), # set of parameter keys
  667. bool(rec[5]), # whether or not we have "value" parameters
  668. rec[6], # has_all_defaults
  669. rec[7], # has all pks
  670. ),
  671. ):
  672. rows = 0
  673. records = list(records)
  674. statement = cached_stmt
  675. if use_orm_update_stmt is not None:
  676. statement = statement._annotate(
  677. {
  678. "_emit_update_table": table,
  679. "_emit_update_mapper": mapper,
  680. }
  681. )
  682. return_defaults = False
  683. if not has_all_pks:
  684. statement = statement.return_defaults(*mapper._pks_by_table[table])
  685. return_defaults = True
  686. if (
  687. bookkeeping
  688. and not has_all_defaults
  689. and mapper.base_mapper.eager_defaults is True
  690. # change as of #8889 - if RETURNING is not going to be used anyway,
  691. # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
  692. # we can do an executemany UPDATE which is more efficient
  693. and table.implicit_returning
  694. and connection.dialect.update_returning
  695. ):
  696. statement = statement.return_defaults(
  697. *mapper._server_onupdate_default_cols[table]
  698. )
  699. return_defaults = True
  700. if mapper._version_id_has_server_side_value:
  701. statement = statement.return_defaults(mapper.version_id_col)
  702. return_defaults = True
  703. assert_singlerow = connection.dialect.supports_sane_rowcount
  704. assert_multirow = (
  705. assert_singlerow
  706. and connection.dialect.supports_sane_multi_rowcount
  707. )
  708. # change as of #8889 - if RETURNING is not going to be used anyway,
  709. # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
  710. # we can do an executemany UPDATE which is more efficient
  711. allow_executemany = not return_defaults and not needs_version_id
  712. if hasvalue:
  713. for (
  714. state,
  715. state_dict,
  716. params,
  717. mapper,
  718. connection,
  719. value_params,
  720. has_all_defaults,
  721. has_all_pks,
  722. ) in records:
  723. c = connection.execute(
  724. statement.values(value_params),
  725. params,
  726. execution_options=execution_options,
  727. )
  728. if bookkeeping:
  729. _postfetch(
  730. mapper,
  731. uowtransaction,
  732. table,
  733. state,
  734. state_dict,
  735. c,
  736. c.context.compiled_parameters[0],
  737. value_params,
  738. True,
  739. c.returned_defaults,
  740. )
  741. rows += c.rowcount
  742. check_rowcount = enable_check_rowcount and assert_singlerow
  743. else:
  744. if not allow_executemany:
  745. check_rowcount = enable_check_rowcount and assert_singlerow
  746. for (
  747. state,
  748. state_dict,
  749. params,
  750. mapper,
  751. connection,
  752. value_params,
  753. has_all_defaults,
  754. has_all_pks,
  755. ) in records:
  756. c = connection.execute(
  757. statement, params, execution_options=execution_options
  758. )
  759. # TODO: why with bookkeeping=False?
  760. if bookkeeping:
  761. _postfetch(
  762. mapper,
  763. uowtransaction,
  764. table,
  765. state,
  766. state_dict,
  767. c,
  768. c.context.compiled_parameters[0],
  769. value_params,
  770. True,
  771. c.returned_defaults,
  772. )
  773. rows += c.rowcount
  774. else:
  775. multiparams = [rec[2] for rec in records]
  776. check_rowcount = enable_check_rowcount and (
  777. assert_multirow
  778. or (assert_singlerow and len(multiparams) == 1)
  779. )
  780. c = connection.execute(
  781. statement, multiparams, execution_options=execution_options
  782. )
  783. rows += c.rowcount
  784. for (
  785. state,
  786. state_dict,
  787. params,
  788. mapper,
  789. connection,
  790. value_params,
  791. has_all_defaults,
  792. has_all_pks,
  793. ) in records:
  794. if bookkeeping:
  795. _postfetch(
  796. mapper,
  797. uowtransaction,
  798. table,
  799. state,
  800. state_dict,
  801. c,
  802. c.context.compiled_parameters[0],
  803. value_params,
  804. True,
  805. (
  806. c.returned_defaults
  807. if not c.context.executemany
  808. else None
  809. ),
  810. )
  811. if check_rowcount:
  812. if rows != len(records):
  813. raise orm_exc.StaleDataError(
  814. "UPDATE statement on table '%s' expected to "
  815. "update %d row(s); %d were matched."
  816. % (table.description, len(records), rows)
  817. )
  818. elif needs_version_id:
  819. util.warn(
  820. "Dialect %s does not support updated rowcount "
  821. "- versioning cannot be verified."
  822. % c.dialect.dialect_description
  823. )
  824. def _emit_insert_statements(
  825. base_mapper,
  826. uowtransaction,
  827. mapper,
  828. table,
  829. insert,
  830. *,
  831. bookkeeping=True,
  832. use_orm_insert_stmt=None,
  833. execution_options=None,
  834. ):
  835. """Emit INSERT statements corresponding to value lists collected
  836. by _collect_insert_commands()."""
  837. if use_orm_insert_stmt is not None:
  838. cached_stmt = use_orm_insert_stmt
  839. exec_opt = util.EMPTY_DICT
  840. # if a user query with RETURNING was passed, we definitely need
  841. # to use RETURNING.
  842. returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
  843. deterministic_results_reqd = (
  844. returning_is_required_anyway
  845. and use_orm_insert_stmt._sort_by_parameter_order
  846. ) or bookkeeping
  847. else:
  848. returning_is_required_anyway = False
  849. deterministic_results_reqd = bookkeeping
  850. cached_stmt = base_mapper._memo(("insert", table), table.insert)
  851. exec_opt = {"compiled_cache": base_mapper._compiled_cache}
  852. if execution_options:
  853. execution_options = util.EMPTY_DICT.merge_with(
  854. exec_opt, execution_options
  855. )
  856. else:
  857. execution_options = exec_opt
  858. return_result = None
  859. for (
  860. (connection, _, hasvalue, has_all_pks, has_all_defaults),
  861. records,
  862. ) in groupby(
  863. insert,
  864. lambda rec: (
  865. rec[4], # connection
  866. set(rec[2]), # parameter keys
  867. bool(rec[5]), # whether we have "value" parameters
  868. rec[6],
  869. rec[7],
  870. ),
  871. ):
  872. statement = cached_stmt
  873. if use_orm_insert_stmt is not None:
  874. statement = statement._annotate(
  875. {
  876. "_emit_insert_table": table,
  877. "_emit_insert_mapper": mapper,
  878. }
  879. )
  880. if (
  881. (
  882. not bookkeeping
  883. or (
  884. has_all_defaults
  885. or not base_mapper._prefer_eager_defaults(
  886. connection.dialect, table
  887. )
  888. or not table.implicit_returning
  889. or not connection.dialect.insert_returning
  890. )
  891. )
  892. and not returning_is_required_anyway
  893. and has_all_pks
  894. and not hasvalue
  895. ):
  896. # the "we don't need newly generated values back" section.
  897. # here we have all the PKs, all the defaults or we don't want
  898. # to fetch them, or the dialect doesn't support RETURNING at all
  899. # so we have to post-fetch / use lastrowid anyway.
  900. records = list(records)
  901. multiparams = [rec[2] for rec in records]
  902. result = connection.execute(
  903. statement, multiparams, execution_options=execution_options
  904. )
  905. if bookkeeping:
  906. for (
  907. (
  908. state,
  909. state_dict,
  910. params,
  911. mapper_rec,
  912. conn,
  913. value_params,
  914. has_all_pks,
  915. has_all_defaults,
  916. ),
  917. last_inserted_params,
  918. ) in zip(records, result.context.compiled_parameters):
  919. if state:
  920. _postfetch(
  921. mapper_rec,
  922. uowtransaction,
  923. table,
  924. state,
  925. state_dict,
  926. result,
  927. last_inserted_params,
  928. value_params,
  929. False,
  930. (
  931. result.returned_defaults
  932. if not result.context.executemany
  933. else None
  934. ),
  935. )
  936. else:
  937. _postfetch_bulk_save(mapper_rec, state_dict, table)
  938. else:
  939. # here, we need defaults and/or pk values back or we otherwise
  940. # know that we are using RETURNING in any case
  941. records = list(records)
  942. if returning_is_required_anyway or (
  943. table.implicit_returning and not hasvalue and len(records) > 1
  944. ):
  945. if (
  946. deterministic_results_reqd
  947. and connection.dialect.insert_executemany_returning_sort_by_parameter_order # noqa: E501
  948. ) or (
  949. not deterministic_results_reqd
  950. and connection.dialect.insert_executemany_returning
  951. ):
  952. do_executemany = True
  953. elif returning_is_required_anyway:
  954. if deterministic_results_reqd:
  955. dt = " with RETURNING and sort by parameter order"
  956. else:
  957. dt = " with RETURNING"
  958. raise sa_exc.InvalidRequestError(
  959. f"Can't use explicit RETURNING for bulk INSERT "
  960. f"operation with "
  961. f"{connection.dialect.dialect_description} backend; "
  962. f"executemany{dt} is not enabled for this dialect."
  963. )
  964. else:
  965. do_executemany = False
  966. else:
  967. do_executemany = False
  968. if use_orm_insert_stmt is None:
  969. if (
  970. not has_all_defaults
  971. and base_mapper._prefer_eager_defaults(
  972. connection.dialect, table
  973. )
  974. ):
  975. statement = statement.return_defaults(
  976. *mapper._server_default_cols[table],
  977. sort_by_parameter_order=bookkeeping,
  978. )
  979. if mapper.version_id_col is not None:
  980. statement = statement.return_defaults(
  981. mapper.version_id_col,
  982. sort_by_parameter_order=bookkeeping,
  983. )
  984. elif do_executemany:
  985. statement = statement.return_defaults(
  986. *table.primary_key, sort_by_parameter_order=bookkeeping
  987. )
  988. if do_executemany:
  989. multiparams = [rec[2] for rec in records]
  990. result = connection.execute(
  991. statement, multiparams, execution_options=execution_options
  992. )
  993. if use_orm_insert_stmt is not None:
  994. if return_result is None:
  995. return_result = result
  996. else:
  997. return_result = return_result.splice_vertically(result)
  998. if bookkeeping:
  999. for (
  1000. (
  1001. state,
  1002. state_dict,
  1003. params,
  1004. mapper_rec,
  1005. conn,
  1006. value_params,
  1007. has_all_pks,
  1008. has_all_defaults,
  1009. ),
  1010. last_inserted_params,
  1011. inserted_primary_key,
  1012. returned_defaults,
  1013. ) in zip_longest(
  1014. records,
  1015. result.context.compiled_parameters,
  1016. result.inserted_primary_key_rows,
  1017. result.returned_defaults_rows or (),
  1018. ):
  1019. if inserted_primary_key is None:
  1020. # this is a real problem and means that we didn't
  1021. # get back as many PK rows. we can't continue
  1022. # since this indicates PK rows were missing, which
  1023. # means we likely mis-populated records starting
  1024. # at that point with incorrectly matched PK
  1025. # values.
  1026. raise orm_exc.FlushError(
  1027. "Multi-row INSERT statement for %s did not "
  1028. "produce "
  1029. "the correct number of INSERTed rows for "
  1030. "RETURNING. Ensure there are no triggers or "
  1031. "special driver issues preventing INSERT from "
  1032. "functioning properly." % mapper_rec
  1033. )
  1034. for pk, col in zip(
  1035. inserted_primary_key,
  1036. mapper._pks_by_table[table],
  1037. ):
  1038. prop = mapper_rec._columntoproperty[col]
  1039. if state_dict.get(prop.key) is None:
  1040. state_dict[prop.key] = pk
  1041. if state:
  1042. _postfetch(
  1043. mapper_rec,
  1044. uowtransaction,
  1045. table,
  1046. state,
  1047. state_dict,
  1048. result,
  1049. last_inserted_params,
  1050. value_params,
  1051. False,
  1052. returned_defaults,
  1053. )
  1054. else:
  1055. _postfetch_bulk_save(mapper_rec, state_dict, table)
  1056. else:
  1057. assert not returning_is_required_anyway
  1058. for (
  1059. state,
  1060. state_dict,
  1061. params,
  1062. mapper_rec,
  1063. connection,
  1064. value_params,
  1065. has_all_pks,
  1066. has_all_defaults,
  1067. ) in records:
  1068. if value_params:
  1069. result = connection.execute(
  1070. statement.values(value_params),
  1071. params,
  1072. execution_options=execution_options,
  1073. )
  1074. else:
  1075. result = connection.execute(
  1076. statement,
  1077. params,
  1078. execution_options=execution_options,
  1079. )
  1080. primary_key = result.inserted_primary_key
  1081. if primary_key is None:
  1082. raise orm_exc.FlushError(
  1083. "Single-row INSERT statement for %s "
  1084. "did not produce a "
  1085. "new primary key result "
  1086. "being invoked. Ensure there are no triggers or "
  1087. "special driver issues preventing INSERT from "
  1088. "functioning properly." % (mapper_rec,)
  1089. )
  1090. for pk, col in zip(
  1091. primary_key, mapper._pks_by_table[table]
  1092. ):
  1093. prop = mapper_rec._columntoproperty[col]
  1094. if (
  1095. col in value_params
  1096. or state_dict.get(prop.key) is None
  1097. ):
  1098. state_dict[prop.key] = pk
  1099. if bookkeeping:
  1100. if state:
  1101. _postfetch(
  1102. mapper_rec,
  1103. uowtransaction,
  1104. table,
  1105. state,
  1106. state_dict,
  1107. result,
  1108. result.context.compiled_parameters[0],
  1109. value_params,
  1110. False,
  1111. (
  1112. result.returned_defaults
  1113. if not result.context.executemany
  1114. else None
  1115. ),
  1116. )
  1117. else:
  1118. _postfetch_bulk_save(mapper_rec, state_dict, table)
  1119. if use_orm_insert_stmt is not None:
  1120. if return_result is None:
  1121. return _cursor.null_dml_result()
  1122. else:
  1123. return return_result
  1124. def _emit_post_update_statements(
  1125. base_mapper, uowtransaction, mapper, table, update
  1126. ):
  1127. """Emit UPDATE statements corresponding to value lists collected
  1128. by _collect_post_update_commands()."""
  1129. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  1130. needs_version_id = (
  1131. mapper.version_id_col is not None
  1132. and mapper.version_id_col in mapper._cols_by_table[table]
  1133. )
  1134. def update_stmt():
  1135. clauses = BooleanClauseList._construct_raw(operators.and_)
  1136. for col in mapper._pks_by_table[table]:
  1137. clauses._append_inplace(
  1138. col == sql.bindparam(col._label, type_=col.type)
  1139. )
  1140. if needs_version_id:
  1141. clauses._append_inplace(
  1142. mapper.version_id_col
  1143. == sql.bindparam(
  1144. mapper.version_id_col._label,
  1145. type_=mapper.version_id_col.type,
  1146. )
  1147. )
  1148. stmt = table.update().where(clauses)
  1149. return stmt
  1150. statement = base_mapper._memo(("post_update", table), update_stmt)
  1151. if mapper._version_id_has_server_side_value:
  1152. statement = statement.return_defaults(mapper.version_id_col)
  1153. # execute each UPDATE in the order according to the original
  1154. # list of states to guarantee row access order, but
  1155. # also group them into common (connection, cols) sets
  1156. # to support executemany().
  1157. for key, records in groupby(
  1158. update,
  1159. lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
  1160. ):
  1161. rows = 0
  1162. records = list(records)
  1163. connection = key[0]
  1164. assert_singlerow = connection.dialect.supports_sane_rowcount
  1165. assert_multirow = (
  1166. assert_singlerow
  1167. and connection.dialect.supports_sane_multi_rowcount
  1168. )
  1169. allow_executemany = not needs_version_id or assert_multirow
  1170. if not allow_executemany:
  1171. check_rowcount = assert_singlerow
  1172. for state, state_dict, mapper_rec, connection, params in records:
  1173. c = connection.execute(
  1174. statement, params, execution_options=execution_options
  1175. )
  1176. _postfetch_post_update(
  1177. mapper_rec,
  1178. uowtransaction,
  1179. table,
  1180. state,
  1181. state_dict,
  1182. c,
  1183. c.context.compiled_parameters[0],
  1184. )
  1185. rows += c.rowcount
  1186. else:
  1187. multiparams = [
  1188. params
  1189. for state, state_dict, mapper_rec, conn, params in records
  1190. ]
  1191. check_rowcount = assert_multirow or (
  1192. assert_singlerow and len(multiparams) == 1
  1193. )
  1194. c = connection.execute(
  1195. statement, multiparams, execution_options=execution_options
  1196. )
  1197. rows += c.rowcount
  1198. for i, (
  1199. state,
  1200. state_dict,
  1201. mapper_rec,
  1202. connection,
  1203. params,
  1204. ) in enumerate(records):
  1205. _postfetch_post_update(
  1206. mapper_rec,
  1207. uowtransaction,
  1208. table,
  1209. state,
  1210. state_dict,
  1211. c,
  1212. c.context.compiled_parameters[i],
  1213. )
  1214. if check_rowcount:
  1215. if rows != len(records):
  1216. raise orm_exc.StaleDataError(
  1217. "UPDATE statement on table '%s' expected to "
  1218. "update %d row(s); %d were matched."
  1219. % (table.description, len(records), rows)
  1220. )
  1221. elif needs_version_id:
  1222. util.warn(
  1223. "Dialect %s does not support updated rowcount "
  1224. "- versioning cannot be verified."
  1225. % c.dialect.dialect_description
  1226. )
  1227. def _emit_delete_statements(
  1228. base_mapper, uowtransaction, mapper, table, delete
  1229. ):
  1230. """Emit DELETE statements corresponding to value lists collected
  1231. by _collect_delete_commands()."""
  1232. need_version_id = (
  1233. mapper.version_id_col is not None
  1234. and mapper.version_id_col in mapper._cols_by_table[table]
  1235. )
  1236. def delete_stmt():
  1237. clauses = BooleanClauseList._construct_raw(operators.and_)
  1238. for col in mapper._pks_by_table[table]:
  1239. clauses._append_inplace(
  1240. col == sql.bindparam(col.key, type_=col.type)
  1241. )
  1242. if need_version_id:
  1243. clauses._append_inplace(
  1244. mapper.version_id_col
  1245. == sql.bindparam(
  1246. mapper.version_id_col.key, type_=mapper.version_id_col.type
  1247. )
  1248. )
  1249. return table.delete().where(clauses)
  1250. statement = base_mapper._memo(("delete", table), delete_stmt)
  1251. for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
  1252. del_objects = [params for params, connection in recs]
  1253. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  1254. expected = len(del_objects)
  1255. rows_matched = -1
  1256. only_warn = False
  1257. if (
  1258. need_version_id
  1259. and not connection.dialect.supports_sane_multi_rowcount
  1260. ):
  1261. if connection.dialect.supports_sane_rowcount:
  1262. rows_matched = 0
  1263. # execute deletes individually so that versioned
  1264. # rows can be verified
  1265. for params in del_objects:
  1266. c = connection.execute(
  1267. statement, params, execution_options=execution_options
  1268. )
  1269. rows_matched += c.rowcount
  1270. else:
  1271. util.warn(
  1272. "Dialect %s does not support deleted rowcount "
  1273. "- versioning cannot be verified."
  1274. % connection.dialect.dialect_description
  1275. )
  1276. connection.execute(
  1277. statement, del_objects, execution_options=execution_options
  1278. )
  1279. else:
  1280. c = connection.execute(
  1281. statement, del_objects, execution_options=execution_options
  1282. )
  1283. if not need_version_id:
  1284. only_warn = True
  1285. rows_matched = c.rowcount
  1286. if (
  1287. base_mapper.confirm_deleted_rows
  1288. and rows_matched > -1
  1289. and expected != rows_matched
  1290. and (
  1291. connection.dialect.supports_sane_multi_rowcount
  1292. or len(del_objects) == 1
  1293. )
  1294. ):
  1295. # TODO: why does this "only warn" if versioning is turned off,
  1296. # whereas the UPDATE raises?
  1297. if only_warn:
  1298. util.warn(
  1299. "DELETE statement on table '%s' expected to "
  1300. "delete %d row(s); %d were matched. Please set "
  1301. "confirm_deleted_rows=False within the mapper "
  1302. "configuration to prevent this warning."
  1303. % (table.description, expected, rows_matched)
  1304. )
  1305. else:
  1306. raise orm_exc.StaleDataError(
  1307. "DELETE statement on table '%s' expected to "
  1308. "delete %d row(s); %d were matched. Please set "
  1309. "confirm_deleted_rows=False within the mapper "
  1310. "configuration to prevent this warning."
  1311. % (table.description, expected, rows_matched)
  1312. )
  1313. def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
  1314. """finalize state on states that have been inserted or updated,
  1315. including calling after_insert/after_update events.
  1316. """
  1317. for state, state_dict, mapper, connection, has_identity in states:
  1318. if mapper._readonly_props:
  1319. readonly = state.unmodified_intersection(
  1320. [
  1321. p.key
  1322. for p in mapper._readonly_props
  1323. if (
  1324. p.expire_on_flush
  1325. and (not p.deferred or p.key in state.dict)
  1326. )
  1327. or (
  1328. not p.expire_on_flush
  1329. and not p.deferred
  1330. and p.key not in state.dict
  1331. )
  1332. ]
  1333. )
  1334. if readonly:
  1335. state._expire_attributes(state.dict, readonly)
  1336. # if eager_defaults option is enabled, load
  1337. # all expired cols. Else if we have a version_id_col, make sure
  1338. # it isn't expired.
  1339. toload_now = []
  1340. # this is specifically to emit a second SELECT for eager_defaults,
  1341. # so only if it's set to True, not "auto"
  1342. if base_mapper.eager_defaults is True:
  1343. toload_now.extend(
  1344. state._unloaded_non_object.intersection(
  1345. mapper._server_default_plus_onupdate_propkeys
  1346. )
  1347. )
  1348. if (
  1349. mapper.version_id_col is not None
  1350. and mapper.version_id_generator is False
  1351. ):
  1352. if mapper._version_id_prop.key in state.unloaded:
  1353. toload_now.extend([mapper._version_id_prop.key])
  1354. if toload_now:
  1355. state.key = base_mapper._identity_key_from_state(state)
  1356. stmt = future.select(mapper).set_label_style(
  1357. LABEL_STYLE_TABLENAME_PLUS_COL
  1358. )
  1359. loading.load_on_ident(
  1360. uowtransaction.session,
  1361. stmt,
  1362. state.key,
  1363. refresh_state=state,
  1364. only_load_props=toload_now,
  1365. )
  1366. # call after_XXX extensions
  1367. if not has_identity:
  1368. mapper.dispatch.after_insert(mapper, connection, state)
  1369. else:
  1370. mapper.dispatch.after_update(mapper, connection, state)
  1371. if (
  1372. mapper.version_id_generator is False
  1373. and mapper.version_id_col is not None
  1374. ):
  1375. if state_dict[mapper._version_id_prop.key] is None:
  1376. raise orm_exc.FlushError(
  1377. "Instance does not contain a non-NULL version value"
  1378. )
  1379. def _postfetch_post_update(
  1380. mapper, uowtransaction, table, state, dict_, result, params
  1381. ):
  1382. needs_version_id = (
  1383. mapper.version_id_col is not None
  1384. and mapper.version_id_col in mapper._cols_by_table[table]
  1385. )
  1386. if not uowtransaction.is_deleted(state):
  1387. # post updating after a regular INSERT or UPDATE, do a full postfetch
  1388. prefetch_cols = result.context.compiled.prefetch
  1389. postfetch_cols = result.context.compiled.postfetch
  1390. elif needs_version_id:
  1391. # post updating before a DELETE with a version_id_col, need to
  1392. # postfetch just version_id_col
  1393. prefetch_cols = postfetch_cols = ()
  1394. else:
  1395. # post updating before a DELETE without a version_id_col,
  1396. # don't need to postfetch
  1397. return
  1398. if needs_version_id:
  1399. prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
  1400. refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
  1401. if refresh_flush:
  1402. load_evt_attrs = []
  1403. for c in prefetch_cols:
  1404. if c.key in params and c in mapper._columntoproperty:
  1405. dict_[mapper._columntoproperty[c].key] = params[c.key]
  1406. if refresh_flush:
  1407. load_evt_attrs.append(mapper._columntoproperty[c].key)
  1408. if refresh_flush and load_evt_attrs:
  1409. mapper.class_manager.dispatch.refresh_flush(
  1410. state, uowtransaction, load_evt_attrs
  1411. )
  1412. if postfetch_cols:
  1413. state._expire_attributes(
  1414. state.dict,
  1415. [
  1416. mapper._columntoproperty[c].key
  1417. for c in postfetch_cols
  1418. if c in mapper._columntoproperty
  1419. ],
  1420. )
  1421. def _postfetch(
  1422. mapper,
  1423. uowtransaction,
  1424. table,
  1425. state,
  1426. dict_,
  1427. result,
  1428. params,
  1429. value_params,
  1430. isupdate,
  1431. returned_defaults,
  1432. ):
  1433. """Expire attributes in need of newly persisted database state,
  1434. after an INSERT or UPDATE statement has proceeded for that
  1435. state."""
  1436. prefetch_cols = result.context.compiled.prefetch
  1437. postfetch_cols = result.context.compiled.postfetch
  1438. returning_cols = result.context.compiled.effective_returning
  1439. if (
  1440. mapper.version_id_col is not None
  1441. and mapper.version_id_col in mapper._cols_by_table[table]
  1442. ):
  1443. prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
  1444. refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
  1445. if refresh_flush:
  1446. load_evt_attrs = []
  1447. if returning_cols:
  1448. row = returned_defaults
  1449. if row is not None:
  1450. for row_value, col in zip(row, returning_cols):
  1451. # pk cols returned from insert are handled
  1452. # distinctly, don't step on the values here
  1453. if col.primary_key and result.context.isinsert:
  1454. continue
  1455. # note that columns can be in the "return defaults" that are
  1456. # not mapped to this mapper, typically because they are
  1457. # "excluded", which can be specified directly or also occurs
  1458. # when using declarative w/ single table inheritance
  1459. prop = mapper._columntoproperty.get(col)
  1460. if prop:
  1461. dict_[prop.key] = row_value
  1462. if refresh_flush:
  1463. load_evt_attrs.append(prop.key)
  1464. for c in prefetch_cols:
  1465. if c.key in params and c in mapper._columntoproperty:
  1466. pkey = mapper._columntoproperty[c].key
  1467. # set prefetched value in dict and also pop from committed_state,
  1468. # since this is new database state that replaces whatever might
  1469. # have previously been fetched (see #10800). this is essentially a
  1470. # shorthand version of set_committed_value(), which could also be
  1471. # used here directly (with more overhead)
  1472. dict_[pkey] = params[c.key]
  1473. state.committed_state.pop(pkey, None)
  1474. if refresh_flush:
  1475. load_evt_attrs.append(pkey)
  1476. if refresh_flush and load_evt_attrs:
  1477. mapper.class_manager.dispatch.refresh_flush(
  1478. state, uowtransaction, load_evt_attrs
  1479. )
  1480. if isupdate and value_params:
  1481. # explicitly suit the use case specified by
  1482. # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
  1483. # database which are set to themselves in order to do a version bump.
  1484. postfetch_cols.extend(
  1485. [
  1486. col
  1487. for col in value_params
  1488. if col.primary_key and col not in returning_cols
  1489. ]
  1490. )
  1491. if postfetch_cols:
  1492. state._expire_attributes(
  1493. state.dict,
  1494. [
  1495. mapper._columntoproperty[c].key
  1496. for c in postfetch_cols
  1497. if c in mapper._columntoproperty
  1498. ],
  1499. )
  1500. # synchronize newly inserted ids from one table to the next
  1501. # TODO: this still goes a little too often. would be nice to
  1502. # have definitive list of "columns that changed" here
  1503. for m, equated_pairs in mapper._table_to_equated[table]:
  1504. sync.populate(
  1505. state,
  1506. m,
  1507. state,
  1508. m,
  1509. equated_pairs,
  1510. uowtransaction,
  1511. mapper.passive_updates,
  1512. )
  1513. def _postfetch_bulk_save(mapper, dict_, table):
  1514. for m, equated_pairs in mapper._table_to_equated[table]:
  1515. sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
  1516. def _connections_for_states(base_mapper, uowtransaction, states):
  1517. """Return an iterator of (state, state.dict, mapper, connection).
  1518. The states are sorted according to _sort_states, then paired
  1519. with the connection they should be using for the given
  1520. unit of work transaction.
  1521. """
  1522. # if session has a connection callable,
  1523. # organize individual states with the connection
  1524. # to use for update
  1525. if uowtransaction.session.connection_callable:
  1526. connection_callable = uowtransaction.session.connection_callable
  1527. else:
  1528. connection = uowtransaction.transaction.connection(base_mapper)
  1529. connection_callable = None
  1530. for state in _sort_states(base_mapper, states):
  1531. if connection_callable:
  1532. connection = connection_callable(base_mapper, state.obj())
  1533. mapper = state.manager.mapper
  1534. yield state, state.dict, mapper, connection
  1535. def _sort_states(mapper, states):
  1536. pending = set(states)
  1537. persistent = {s for s in pending if s.key is not None}
  1538. pending.difference_update(persistent)
  1539. try:
  1540. persistent_sorted = sorted(
  1541. persistent, key=mapper._persistent_sortkey_fn
  1542. )
  1543. except TypeError as err:
  1544. raise sa_exc.InvalidRequestError(
  1545. "Could not sort objects by primary key; primary key "
  1546. "values must be sortable in Python (was: %s)" % err
  1547. ) from err
  1548. return (
  1549. sorted(pending, key=operator.attrgetter("insert_order"))
  1550. + persistent_sorted
  1551. )