mariadbconnector.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # dialects/mysql/mariadbconnector.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. """
  8. .. dialect:: mysql+mariadbconnector
  9. :name: MariaDB Connector/Python
  10. :dbapi: mariadb
  11. :connectstring: mariadb+mariadbconnector://<user>:<password>@<host>[:<port>]/<dbname>
  12. :url: https://pypi.org/project/mariadb/
  13. Driver Status
  14. -------------
  15. MariaDB Connector/Python enables Python programs to access MariaDB and MySQL
  16. databases using an API which is compliant with the Python DB API 2.0 (PEP-249).
  17. It is written in C and uses MariaDB Connector/C client library for client server
  18. communication.
  19. Note that the default driver for a ``mariadb://`` connection URI continues to
  20. be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
  21. .. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
  22. """ # noqa
  23. from __future__ import annotations
  24. import re
  25. from typing import Any
  26. from typing import Optional
  27. from typing import Sequence
  28. from typing import Tuple
  29. from typing import TYPE_CHECKING
  30. from typing import Union
  31. from uuid import UUID as _python_UUID
  32. from .base import MySQLCompiler
  33. from .base import MySQLDialect
  34. from .base import MySQLExecutionContext
  35. from ... import sql
  36. from ... import util
  37. from ...sql import sqltypes
  38. if TYPE_CHECKING:
  39. from ...engine.base import Connection
  40. from ...engine.interfaces import ConnectArgsType
  41. from ...engine.interfaces import DBAPIConnection
  42. from ...engine.interfaces import DBAPICursor
  43. from ...engine.interfaces import DBAPIModule
  44. from ...engine.interfaces import Dialect
  45. from ...engine.interfaces import IsolationLevel
  46. from ...engine.interfaces import PoolProxiedConnection
  47. from ...engine.url import URL
  48. from ...sql.compiler import SQLCompiler
  49. from ...sql.type_api import _ResultProcessorType
  50. mariadb_cpy_minimum_version = (1, 0, 1)
  51. class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
  52. # work around JIRA issue
  53. # https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed,
  54. # this type can be removed.
  55. def result_processor(
  56. self, dialect: Dialect, coltype: object
  57. ) -> Optional[_ResultProcessorType[Any]]:
  58. if self.as_uuid:
  59. def process(value: Any) -> Any:
  60. if value is not None:
  61. if hasattr(value, "decode"):
  62. value = value.decode("ascii")
  63. value = _python_UUID(value)
  64. return value
  65. return process
  66. else:
  67. def process(value: Any) -> Any:
  68. if value is not None:
  69. if hasattr(value, "decode"):
  70. value = value.decode("ascii")
  71. value = str(_python_UUID(value))
  72. return value
  73. return process
  74. class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
  75. _lastrowid: Optional[int] = None
  76. def create_server_side_cursor(self) -> DBAPICursor:
  77. return self._dbapi_connection.cursor(buffered=False)
  78. def create_default_cursor(self) -> DBAPICursor:
  79. return self._dbapi_connection.cursor(buffered=True)
  80. def post_exec(self) -> None:
  81. super().post_exec()
  82. self._rowcount = self.cursor.rowcount
  83. if TYPE_CHECKING:
  84. assert isinstance(self.compiled, SQLCompiler)
  85. if self.isinsert and self.compiled.postfetch_lastrowid:
  86. self._lastrowid = self.cursor.lastrowid
  87. def get_lastrowid(self) -> int:
  88. if TYPE_CHECKING:
  89. assert self._lastrowid is not None
  90. return self._lastrowid
  91. class MySQLCompiler_mariadbconnector(MySQLCompiler):
  92. pass
  93. class MySQLDialect_mariadbconnector(MySQLDialect):
  94. driver = "mariadbconnector"
  95. supports_statement_cache = True
  96. # set this to True at the module level to prevent the driver from running
  97. # against a backend that server detects as MySQL. currently this appears to
  98. # be unnecessary as MariaDB client libraries have always worked against
  99. # MySQL databases. However, if this changes at some point, this can be
  100. # adjusted, but PLEASE ADD A TEST in test/dialect/mysql/test_dialect.py if
  101. # this change is made at some point to ensure the correct exception
  102. # is raised at the correct point when running the driver against
  103. # a MySQL backend.
  104. # is_mariadb = True
  105. supports_unicode_statements = True
  106. encoding = "utf8mb4"
  107. convert_unicode = True
  108. supports_sane_rowcount = True
  109. supports_sane_multi_rowcount = True
  110. supports_native_decimal = True
  111. default_paramstyle = "qmark"
  112. execution_ctx_cls = MySQLExecutionContext_mariadbconnector
  113. statement_compiler = MySQLCompiler_mariadbconnector
  114. supports_server_side_cursors = True
  115. colspecs = util.update_copy(
  116. MySQLDialect.colspecs, {sqltypes.Uuid: _MariaDBUUID}
  117. )
  118. @util.memoized_property
  119. def _dbapi_version(self) -> Tuple[int, ...]:
  120. if self.dbapi and hasattr(self.dbapi, "__version__"):
  121. return tuple(
  122. [
  123. int(x)
  124. for x in re.findall(
  125. r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
  126. )
  127. ]
  128. )
  129. else:
  130. return (99, 99, 99)
  131. def __init__(self, **kwargs: Any) -> None:
  132. super().__init__(**kwargs)
  133. self.paramstyle = "qmark"
  134. if self.dbapi is not None:
  135. if self._dbapi_version < mariadb_cpy_minimum_version:
  136. raise NotImplementedError(
  137. "The minimum required version for MariaDB "
  138. "Connector/Python is %s"
  139. % ".".join(str(x) for x in mariadb_cpy_minimum_version)
  140. )
  141. @classmethod
  142. def import_dbapi(cls) -> DBAPIModule:
  143. return __import__("mariadb")
  144. def is_disconnect(
  145. self,
  146. e: DBAPIModule.Error,
  147. connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
  148. cursor: Optional[DBAPICursor],
  149. ) -> bool:
  150. if super().is_disconnect(e, connection, cursor):
  151. return True
  152. elif isinstance(e, self.loaded_dbapi.Error):
  153. str_e = str(e).lower()
  154. return "not connected" in str_e or "isn't valid" in str_e
  155. else:
  156. return False
  157. def create_connect_args(self, url: URL) -> ConnectArgsType:
  158. opts = url.translate_connect_args()
  159. opts.update(url.query)
  160. int_params = [
  161. "connect_timeout",
  162. "read_timeout",
  163. "write_timeout",
  164. "client_flag",
  165. "port",
  166. "pool_size",
  167. ]
  168. bool_params = [
  169. "local_infile",
  170. "ssl_verify_cert",
  171. "ssl",
  172. "pool_reset_connection",
  173. "compress",
  174. ]
  175. for key in int_params:
  176. util.coerce_kw_type(opts, key, int)
  177. for key in bool_params:
  178. util.coerce_kw_type(opts, key, bool)
  179. # FOUND_ROWS must be set in CLIENT_FLAGS to enable
  180. # supports_sane_rowcount.
  181. client_flag = opts.get("client_flag", 0)
  182. if self.dbapi is not None:
  183. try:
  184. CLIENT_FLAGS = __import__(
  185. self.dbapi.__name__ + ".constants.CLIENT"
  186. ).constants.CLIENT
  187. client_flag |= CLIENT_FLAGS.FOUND_ROWS
  188. except (AttributeError, ImportError):
  189. self.supports_sane_rowcount = False
  190. opts["client_flag"] = client_flag
  191. return [], opts
  192. def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
  193. try:
  194. rc: int = exception.errno
  195. except:
  196. rc = -1
  197. return rc
  198. def _detect_charset(self, connection: Connection) -> str:
  199. return "utf8mb4"
  200. def get_isolation_level_values(
  201. self, dbapi_conn: DBAPIConnection
  202. ) -> Sequence[IsolationLevel]:
  203. return (
  204. "SERIALIZABLE",
  205. "READ UNCOMMITTED",
  206. "READ COMMITTED",
  207. "REPEATABLE READ",
  208. "AUTOCOMMIT",
  209. )
  210. def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
  211. return bool(dbapi_conn.autocommit)
  212. def set_isolation_level(
  213. self, dbapi_connection: DBAPIConnection, level: IsolationLevel
  214. ) -> None:
  215. if level == "AUTOCOMMIT":
  216. dbapi_connection.autocommit = True
  217. else:
  218. dbapi_connection.autocommit = False
  219. super().set_isolation_level(dbapi_connection, level)
  220. def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
  221. connection.execute(
  222. sql.text("XA BEGIN :xid").bindparams(
  223. sql.bindparam("xid", xid, literal_execute=True)
  224. )
  225. )
  226. def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
  227. connection.execute(
  228. sql.text("XA END :xid").bindparams(
  229. sql.bindparam("xid", xid, literal_execute=True)
  230. )
  231. )
  232. connection.execute(
  233. sql.text("XA PREPARE :xid").bindparams(
  234. sql.bindparam("xid", xid, literal_execute=True)
  235. )
  236. )
  237. def do_rollback_twophase(
  238. self,
  239. connection: Connection,
  240. xid: Any,
  241. is_prepared: bool = True,
  242. recover: bool = False,
  243. ) -> None:
  244. if not is_prepared:
  245. connection.execute(
  246. sql.text("XA END :xid").bindparams(
  247. sql.bindparam("xid", xid, literal_execute=True)
  248. )
  249. )
  250. connection.execute(
  251. sql.text("XA ROLLBACK :xid").bindparams(
  252. sql.bindparam("xid", xid, literal_execute=True)
  253. )
  254. )
  255. def do_commit_twophase(
  256. self,
  257. connection: Connection,
  258. xid: Any,
  259. is_prepared: bool = True,
  260. recover: bool = False,
  261. ) -> None:
  262. if not is_prepared:
  263. self.do_prepare_twophase(connection, xid)
  264. connection.execute(
  265. sql.text("XA COMMIT :xid").bindparams(
  266. sql.bindparam("xid", xid, literal_execute=True)
  267. )
  268. )
  269. dialect = MySQLDialect_mariadbconnector