asyncio.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. # connectors/asyncio.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. """generic asyncio-adapted versions of DBAPI connection and cursor"""
  8. from __future__ import annotations
  9. import asyncio
  10. import collections
  11. import sys
  12. from typing import Any
  13. from typing import AsyncIterator
  14. from typing import Deque
  15. from typing import Iterator
  16. from typing import NoReturn
  17. from typing import Optional
  18. from typing import Sequence
  19. from typing import TYPE_CHECKING
  20. from ..engine import AdaptedConnection
  21. from ..util.concurrency import await_fallback
  22. from ..util.concurrency import await_only
  23. from ..util.typing import Protocol
  24. if TYPE_CHECKING:
  25. from ..engine.interfaces import _DBAPICursorDescription
  26. from ..engine.interfaces import _DBAPIMultiExecuteParams
  27. from ..engine.interfaces import _DBAPISingleExecuteParams
  28. from ..engine.interfaces import DBAPIModule
  29. from ..util.typing import Self
  30. class AsyncIODBAPIConnection(Protocol):
  31. """protocol representing an async adapted version of a
  32. :pep:`249` database connection.
  33. """
  34. # note that async DBAPIs dont agree if close() should be awaitable,
  35. # so it is omitted here and picked up by the __getattr__ hook below
  36. async def commit(self) -> None: ...
  37. def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ...
  38. async def rollback(self) -> None: ...
  39. def __getattr__(self, key: str) -> Any: ...
  40. def __setattr__(self, key: str, value: Any) -> None: ...
  41. class AsyncIODBAPICursor(Protocol):
  42. """protocol representing an async adapted version
  43. of a :pep:`249` database cursor.
  44. """
  45. def __aenter__(self) -> Any: ...
  46. @property
  47. def description(
  48. self,
  49. ) -> _DBAPICursorDescription:
  50. """The description attribute of the Cursor."""
  51. ...
  52. @property
  53. def rowcount(self) -> int: ...
  54. arraysize: int
  55. lastrowid: int
  56. async def close(self) -> None: ...
  57. async def execute(
  58. self,
  59. operation: Any,
  60. parameters: Optional[_DBAPISingleExecuteParams] = None,
  61. ) -> Any: ...
  62. async def executemany(
  63. self,
  64. operation: Any,
  65. parameters: _DBAPIMultiExecuteParams,
  66. ) -> Any: ...
  67. async def fetchone(self) -> Optional[Any]: ...
  68. async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ...
  69. async def fetchall(self) -> Sequence[Any]: ...
  70. async def setinputsizes(self, sizes: Sequence[Any]) -> None: ...
  71. def setoutputsize(self, size: Any, column: Any) -> None: ...
  72. async def callproc(
  73. self, procname: str, parameters: Sequence[Any] = ...
  74. ) -> Any: ...
  75. async def nextset(self) -> Optional[bool]: ...
  76. def __aiter__(self) -> AsyncIterator[Any]: ...
  77. class AsyncAdapt_dbapi_module:
  78. if TYPE_CHECKING:
  79. Error = DBAPIModule.Error
  80. OperationalError = DBAPIModule.OperationalError
  81. InterfaceError = DBAPIModule.InterfaceError
  82. IntegrityError = DBAPIModule.IntegrityError
  83. def __getattr__(self, key: str) -> Any: ...
  84. class AsyncAdapt_dbapi_cursor:
  85. server_side = False
  86. __slots__ = (
  87. "_adapt_connection",
  88. "_connection",
  89. "await_",
  90. "_cursor",
  91. "_rows",
  92. )
  93. _cursor: AsyncIODBAPICursor
  94. _adapt_connection: AsyncAdapt_dbapi_connection
  95. _connection: AsyncIODBAPIConnection
  96. _rows: Deque[Any]
  97. def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
  98. self._adapt_connection = adapt_connection
  99. self._connection = adapt_connection._connection
  100. self.await_ = adapt_connection.await_
  101. cursor = self._make_new_cursor(self._connection)
  102. self._cursor = self._aenter_cursor(cursor)
  103. if not self.server_side:
  104. self._rows = collections.deque()
  105. def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
  106. return self.await_(cursor.__aenter__()) # type: ignore[no-any-return]
  107. def _make_new_cursor(
  108. self, connection: AsyncIODBAPIConnection
  109. ) -> AsyncIODBAPICursor:
  110. return connection.cursor()
  111. @property
  112. def description(self) -> Optional[_DBAPICursorDescription]:
  113. return self._cursor.description
  114. @property
  115. def rowcount(self) -> int:
  116. return self._cursor.rowcount
  117. @property
  118. def arraysize(self) -> int:
  119. return self._cursor.arraysize
  120. @arraysize.setter
  121. def arraysize(self, value: int) -> None:
  122. self._cursor.arraysize = value
  123. @property
  124. def lastrowid(self) -> int:
  125. return self._cursor.lastrowid
  126. def close(self) -> None:
  127. # note we aren't actually closing the cursor here,
  128. # we are just letting GC do it. see notes in aiomysql dialect
  129. self._rows.clear()
  130. def execute(
  131. self,
  132. operation: Any,
  133. parameters: Optional[_DBAPISingleExecuteParams] = None,
  134. ) -> Any:
  135. try:
  136. return self.await_(self._execute_async(operation, parameters))
  137. except Exception as error:
  138. self._adapt_connection._handle_exception(error)
  139. def executemany(
  140. self,
  141. operation: Any,
  142. seq_of_parameters: _DBAPIMultiExecuteParams,
  143. ) -> Any:
  144. try:
  145. return self.await_(
  146. self._executemany_async(operation, seq_of_parameters)
  147. )
  148. except Exception as error:
  149. self._adapt_connection._handle_exception(error)
  150. async def _execute_async(
  151. self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams]
  152. ) -> Any:
  153. async with self._adapt_connection._execute_mutex:
  154. if parameters is None:
  155. result = await self._cursor.execute(operation)
  156. else:
  157. result = await self._cursor.execute(operation, parameters)
  158. if self._cursor.description and not self.server_side:
  159. self._rows = collections.deque(await self._cursor.fetchall())
  160. return result
  161. async def _executemany_async(
  162. self,
  163. operation: Any,
  164. seq_of_parameters: _DBAPIMultiExecuteParams,
  165. ) -> Any:
  166. async with self._adapt_connection._execute_mutex:
  167. return await self._cursor.executemany(operation, seq_of_parameters)
  168. def nextset(self) -> None:
  169. self.await_(self._cursor.nextset())
  170. if self._cursor.description and not self.server_side:
  171. self._rows = collections.deque(
  172. self.await_(self._cursor.fetchall())
  173. )
  174. def setinputsizes(self, *inputsizes: Any) -> None:
  175. # NOTE: this is overrridden in aioodbc due to
  176. # see https://github.com/aio-libs/aioodbc/issues/451
  177. # right now
  178. return self.await_(self._cursor.setinputsizes(*inputsizes))
  179. def __enter__(self) -> Self:
  180. return self
  181. def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
  182. self.close()
  183. def __iter__(self) -> Iterator[Any]:
  184. while self._rows:
  185. yield self._rows.popleft()
  186. def fetchone(self) -> Optional[Any]:
  187. if self._rows:
  188. return self._rows.popleft()
  189. else:
  190. return None
  191. def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
  192. if size is None:
  193. size = self.arraysize
  194. rr = self._rows
  195. return [rr.popleft() for _ in range(min(size, len(rr)))]
  196. def fetchall(self) -> Sequence[Any]:
  197. retval = list(self._rows)
  198. self._rows.clear()
  199. return retval
  200. class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
  201. __slots__ = ()
  202. server_side = True
  203. def close(self) -> None:
  204. if self._cursor is not None:
  205. self.await_(self._cursor.close())
  206. self._cursor = None # type: ignore
  207. def fetchone(self) -> Optional[Any]:
  208. return self.await_(self._cursor.fetchone())
  209. def fetchmany(self, size: Optional[int] = None) -> Any:
  210. return self.await_(self._cursor.fetchmany(size=size))
  211. def fetchall(self) -> Sequence[Any]:
  212. return self.await_(self._cursor.fetchall())
  213. def __iter__(self) -> Iterator[Any]:
  214. iterator = self._cursor.__aiter__()
  215. while True:
  216. try:
  217. yield self.await_(iterator.__anext__())
  218. except StopAsyncIteration:
  219. break
  220. class AsyncAdapt_dbapi_connection(AdaptedConnection):
  221. _cursor_cls = AsyncAdapt_dbapi_cursor
  222. _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
  223. await_ = staticmethod(await_only)
  224. __slots__ = ("dbapi", "_execute_mutex")
  225. _connection: AsyncIODBAPIConnection
  226. def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
  227. self.dbapi = dbapi
  228. self._connection = connection
  229. self._execute_mutex = asyncio.Lock()
  230. def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor:
  231. if server_side:
  232. return self._ss_cursor_cls(self)
  233. else:
  234. return self._cursor_cls(self)
  235. def execute(
  236. self,
  237. operation: Any,
  238. parameters: Optional[_DBAPISingleExecuteParams] = None,
  239. ) -> Any:
  240. """lots of DBAPIs seem to provide this, so include it"""
  241. cursor = self.cursor()
  242. cursor.execute(operation, parameters)
  243. return cursor
  244. def _handle_exception(self, error: Exception) -> NoReturn:
  245. exc_info = sys.exc_info()
  246. raise error.with_traceback(exc_info[2])
  247. def rollback(self) -> None:
  248. try:
  249. self.await_(self._connection.rollback())
  250. except Exception as error:
  251. self._handle_exception(error)
  252. def commit(self) -> None:
  253. try:
  254. self.await_(self._connection.commit())
  255. except Exception as error:
  256. self._handle_exception(error)
  257. def close(self) -> None:
  258. self.await_(self._connection.close())
  259. class AsyncAdaptFallback_dbapi_connection(AsyncAdapt_dbapi_connection):
  260. __slots__ = ()
  261. await_ = staticmethod(await_fallback)