enumerated.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # dialects/mysql/enumerated.py
  2. # Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. from __future__ import annotations
  8. import enum
  9. import re
  10. from typing import Any
  11. from typing import Dict
  12. from typing import Optional
  13. from typing import Set
  14. from typing import Type
  15. from typing import TYPE_CHECKING
  16. from typing import Union
  17. from .types import _StringType
  18. from ... import exc
  19. from ... import sql
  20. from ... import util
  21. from ...sql import sqltypes
  22. from ...sql import type_api
  23. if TYPE_CHECKING:
  24. from ...engine.interfaces import Dialect
  25. from ...sql.elements import ColumnElement
  26. from ...sql.type_api import _BindProcessorType
  27. from ...sql.type_api import _ResultProcessorType
  28. from ...sql.type_api import TypeEngine
  29. from ...sql.type_api import TypeEngineMixin
  30. class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
  31. """MySQL ENUM type."""
  32. __visit_name__ = "ENUM"
  33. native_enum = True
  34. def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None:
  35. """Construct an ENUM.
  36. E.g.::
  37. Column("myenum", ENUM("foo", "bar", "baz"))
  38. :param enums: The range of valid values for this ENUM. Values in
  39. enums are not quoted, they will be escaped and surrounded by single
  40. quotes when generating the schema. This object may also be a
  41. PEP-435-compliant enumerated type.
  42. .. versionadded: 1.1 added support for PEP-435-compliant enumerated
  43. types.
  44. :param strict: This flag has no effect.
  45. .. versionchanged:: The MySQL ENUM type as well as the base Enum
  46. type now validates all Python data values.
  47. :param charset: Optional, a column-level character set for this string
  48. value. Takes precedence to 'ascii' or 'unicode' short-hand.
  49. :param collation: Optional, a column-level collation for this string
  50. value. Takes precedence to 'binary' short-hand.
  51. :param ascii: Defaults to False: short-hand for the ``latin1``
  52. character set, generates ASCII in schema.
  53. :param unicode: Defaults to False: short-hand for the ``ucs2``
  54. character set, generates UNICODE in schema.
  55. :param binary: Defaults to False: short-hand, pick the binary
  56. collation type that matches the column's character set. Generates
  57. BINARY in schema. This does not affect the type of data stored,
  58. only the collation of character data.
  59. """
  60. kw.pop("strict", None)
  61. self._enum_init(enums, kw) # type: ignore[arg-type]
  62. _StringType.__init__(self, length=self.length, **kw)
  63. @classmethod
  64. def adapt_emulated_to_native(
  65. cls,
  66. impl: Union[TypeEngine[Any], TypeEngineMixin],
  67. **kw: Any,
  68. ) -> ENUM:
  69. """Produce a MySQL native :class:`.mysql.ENUM` from plain
  70. :class:`.Enum`.
  71. """
  72. if TYPE_CHECKING:
  73. assert isinstance(impl, ENUM)
  74. kw.setdefault("validate_strings", impl.validate_strings)
  75. kw.setdefault("values_callable", impl.values_callable)
  76. kw.setdefault("omit_aliases", impl._omit_aliases)
  77. return cls(**kw)
  78. def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]:
  79. # mysql sends back a blank string for any value that
  80. # was persisted that was not in the enums; that is, it does no
  81. # validation on the incoming data, it "truncates" it to be
  82. # the blank string. Return it straight.
  83. if elem == "":
  84. return elem
  85. else:
  86. return super()._object_value_for_elem(elem)
  87. def __repr__(self) -> str:
  88. return util.generic_repr(
  89. self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
  90. )
  91. # TODO: SET is a string as far as configuration but does not act like
  92. # a string at the python level. We either need to make a py-type agnostic
  93. # version of String as a base to be used for this, make this some kind of
  94. # TypeDecorator, or just vendor it out as its own type.
  95. class SET(_StringType):
  96. """MySQL SET type."""
  97. __visit_name__ = "SET"
  98. def __init__(self, *values: str, **kw: Any):
  99. """Construct a SET.
  100. E.g.::
  101. Column("myset", SET("foo", "bar", "baz"))
  102. The list of potential values is required in the case that this
  103. set will be used to generate DDL for a table, or if the
  104. :paramref:`.SET.retrieve_as_bitwise` flag is set to True.
  105. :param values: The range of valid values for this SET. The values
  106. are not quoted, they will be escaped and surrounded by single
  107. quotes when generating the schema.
  108. :param convert_unicode: Same flag as that of
  109. :paramref:`.String.convert_unicode`.
  110. :param collation: same as that of :paramref:`.String.collation`
  111. :param charset: same as that of :paramref:`.VARCHAR.charset`.
  112. :param ascii: same as that of :paramref:`.VARCHAR.ascii`.
  113. :param unicode: same as that of :paramref:`.VARCHAR.unicode`.
  114. :param binary: same as that of :paramref:`.VARCHAR.binary`.
  115. :param retrieve_as_bitwise: if True, the data for the set type will be
  116. persisted and selected using an integer value, where a set is coerced
  117. into a bitwise mask for persistence. MySQL allows this mode which
  118. has the advantage of being able to store values unambiguously,
  119. such as the blank string ``''``. The datatype will appear
  120. as the expression ``col + 0`` in a SELECT statement, so that the
  121. value is coerced into an integer value in result sets.
  122. This flag is required if one wishes
  123. to persist a set that can store the blank string ``''`` as a value.
  124. .. warning::
  125. When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is
  126. essential that the list of set values is expressed in the
  127. **exact same order** as exists on the MySQL database.
  128. """
  129. self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False)
  130. self.values = tuple(values)
  131. if not self.retrieve_as_bitwise and "" in values:
  132. raise exc.ArgumentError(
  133. "Can't use the blank value '' in a SET without "
  134. "setting retrieve_as_bitwise=True"
  135. )
  136. if self.retrieve_as_bitwise:
  137. self._inversed_bitmap: Dict[str, int] = {
  138. value: 2**idx for idx, value in enumerate(self.values)
  139. }
  140. self._bitmap: Dict[int, str] = {
  141. 2**idx: value for idx, value in enumerate(self.values)
  142. }
  143. length = max([len(v) for v in values] + [0])
  144. kw.setdefault("length", length)
  145. super().__init__(**kw)
  146. def column_expression(
  147. self, colexpr: ColumnElement[Any]
  148. ) -> ColumnElement[Any]:
  149. if self.retrieve_as_bitwise:
  150. return sql.type_coerce(
  151. sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
  152. )
  153. else:
  154. return colexpr
  155. def result_processor(
  156. self, dialect: Dialect, coltype: Any
  157. ) -> Optional[_ResultProcessorType[Any]]:
  158. if self.retrieve_as_bitwise:
  159. def process(value: Union[str, int, None]) -> Optional[Set[str]]:
  160. if value is not None:
  161. value = int(value)
  162. return set(util.map_bits(self._bitmap.__getitem__, value))
  163. else:
  164. return None
  165. else:
  166. super_convert = super().result_processor(dialect, coltype)
  167. def process(value: Union[str, Set[str], None]) -> Optional[Set[str]]: # type: ignore[misc] # noqa: E501
  168. if isinstance(value, str):
  169. # MySQLdb returns a string, let's parse
  170. if super_convert:
  171. value = super_convert(value)
  172. assert value is not None
  173. if TYPE_CHECKING:
  174. assert isinstance(value, str)
  175. return set(re.findall(r"[^,]+", value))
  176. else:
  177. # mysql-connector-python does a naive
  178. # split(",") which throws in an empty string
  179. if value is not None:
  180. value.discard("")
  181. return value
  182. return process
  183. def bind_processor(
  184. self, dialect: Dialect
  185. ) -> _BindProcessorType[Union[str, int]]:
  186. super_convert = super().bind_processor(dialect)
  187. if self.retrieve_as_bitwise:
  188. def process(
  189. value: Union[str, int, set[str], None],
  190. ) -> Union[str, int, None]:
  191. if value is None:
  192. return None
  193. elif isinstance(value, (int, str)):
  194. if super_convert:
  195. return super_convert(value) # type: ignore[arg-type, no-any-return] # noqa: E501
  196. else:
  197. return value
  198. else:
  199. int_value = 0
  200. for v in value:
  201. int_value |= self._inversed_bitmap[v]
  202. return int_value
  203. else:
  204. def process(
  205. value: Union[str, int, set[str], None],
  206. ) -> Union[str, int, None]:
  207. # accept strings and int (actually bitflag) values directly
  208. if value is not None and not isinstance(value, (int, str)):
  209. value = ",".join(value)
  210. if super_convert:
  211. return super_convert(value) # type: ignore
  212. else:
  213. return value
  214. return process
  215. def adapt(self, cls: type, **kw: Any) -> Any:
  216. kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
  217. return util.constructor_copy(self, cls, *self.values, **kw)
  218. def __repr__(self) -> str:
  219. return util.generic_repr(
  220. self,
  221. to_inspect=[SET, _StringType],
  222. additional_kw=[
  223. ("retrieve_as_bitwise", False),
  224. ],
  225. )