_winconsole.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # This module is based on the excellent work by Adam Bartoš who
  2. # provided a lot of what went into the implementation here in
  3. # the discussion to issue1602 in the Python bug tracker.
  4. #
  5. # There are some general differences in regards to how this works
  6. # compared to the original patches as we do not need to patch
  7. # the entire interpreter but just work in our little world of
  8. # echo and prompt.
  9. from __future__ import annotations
  10. import collections.abc as cabc
  11. import io
  12. import sys
  13. import time
  14. import typing as t
  15. from ctypes import Array
  16. from ctypes import byref
  17. from ctypes import c_char
  18. from ctypes import c_char_p
  19. from ctypes import c_int
  20. from ctypes import c_ssize_t
  21. from ctypes import c_ulong
  22. from ctypes import c_void_p
  23. from ctypes import POINTER
  24. from ctypes import py_object
  25. from ctypes import Structure
  26. from ctypes.wintypes import DWORD
  27. from ctypes.wintypes import HANDLE
  28. from ctypes.wintypes import LPCWSTR
  29. from ctypes.wintypes import LPWSTR
  30. from ._compat import _NonClosingTextIOWrapper
  31. assert sys.platform == "win32"
  32. import msvcrt # noqa: E402
  33. from ctypes import windll # noqa: E402
  34. from ctypes import WINFUNCTYPE # noqa: E402
  35. c_ssize_p = POINTER(c_ssize_t)
  36. kernel32 = windll.kernel32
  37. GetStdHandle = kernel32.GetStdHandle
  38. ReadConsoleW = kernel32.ReadConsoleW
  39. WriteConsoleW = kernel32.WriteConsoleW
  40. GetConsoleMode = kernel32.GetConsoleMode
  41. GetLastError = kernel32.GetLastError
  42. GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32))
  43. CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
  44. ("CommandLineToArgvW", windll.shell32)
  45. )
  46. LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32))
  47. STDIN_HANDLE = GetStdHandle(-10)
  48. STDOUT_HANDLE = GetStdHandle(-11)
  49. STDERR_HANDLE = GetStdHandle(-12)
  50. PyBUF_SIMPLE = 0
  51. PyBUF_WRITABLE = 1
  52. ERROR_SUCCESS = 0
  53. ERROR_NOT_ENOUGH_MEMORY = 8
  54. ERROR_OPERATION_ABORTED = 995
  55. STDIN_FILENO = 0
  56. STDOUT_FILENO = 1
  57. STDERR_FILENO = 2
  58. EOF = b"\x1a"
  59. MAX_BYTES_WRITTEN = 32767
  60. if t.TYPE_CHECKING:
  61. try:
  62. # Using `typing_extensions.Buffer` instead of `collections.abc`
  63. # on Windows for some reason does not have `Sized` implemented.
  64. from collections.abc import Buffer # type: ignore
  65. except ImportError:
  66. from typing_extensions import Buffer
  67. try:
  68. from ctypes import pythonapi
  69. except ImportError:
  70. # On PyPy we cannot get buffers so our ability to operate here is
  71. # severely limited.
  72. get_buffer = None
  73. else:
  74. class Py_buffer(Structure):
  75. _fields_ = [ # noqa: RUF012
  76. ("buf", c_void_p),
  77. ("obj", py_object),
  78. ("len", c_ssize_t),
  79. ("itemsize", c_ssize_t),
  80. ("readonly", c_int),
  81. ("ndim", c_int),
  82. ("format", c_char_p),
  83. ("shape", c_ssize_p),
  84. ("strides", c_ssize_p),
  85. ("suboffsets", c_ssize_p),
  86. ("internal", c_void_p),
  87. ]
  88. PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
  89. PyBuffer_Release = pythonapi.PyBuffer_Release
  90. def get_buffer(obj: Buffer, writable: bool = False) -> Array[c_char]:
  91. buf = Py_buffer()
  92. flags: int = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
  93. PyObject_GetBuffer(py_object(obj), byref(buf), flags)
  94. try:
  95. buffer_type = c_char * buf.len
  96. out: Array[c_char] = buffer_type.from_address(buf.buf)
  97. return out
  98. finally:
  99. PyBuffer_Release(byref(buf))
  100. class _WindowsConsoleRawIOBase(io.RawIOBase):
  101. def __init__(self, handle: int | None) -> None:
  102. self.handle = handle
  103. def isatty(self) -> t.Literal[True]:
  104. super().isatty()
  105. return True
  106. class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
  107. def readable(self) -> t.Literal[True]:
  108. return True
  109. def readinto(self, b: Buffer) -> int:
  110. bytes_to_be_read = len(b)
  111. if not bytes_to_be_read:
  112. return 0
  113. elif bytes_to_be_read % 2:
  114. raise ValueError(
  115. "cannot read odd number of bytes from UTF-16-LE encoded console"
  116. )
  117. buffer = get_buffer(b, writable=True)
  118. code_units_to_be_read = bytes_to_be_read // 2
  119. code_units_read = c_ulong()
  120. rv = ReadConsoleW(
  121. HANDLE(self.handle),
  122. buffer,
  123. code_units_to_be_read,
  124. byref(code_units_read),
  125. None,
  126. )
  127. if GetLastError() == ERROR_OPERATION_ABORTED:
  128. # wait for KeyboardInterrupt
  129. time.sleep(0.1)
  130. if not rv:
  131. raise OSError(f"Windows error: {GetLastError()}")
  132. if buffer[0] == EOF:
  133. return 0
  134. return 2 * code_units_read.value
  135. class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
  136. def writable(self) -> t.Literal[True]:
  137. return True
  138. @staticmethod
  139. def _get_error_message(errno: int) -> str:
  140. if errno == ERROR_SUCCESS:
  141. return "ERROR_SUCCESS"
  142. elif errno == ERROR_NOT_ENOUGH_MEMORY:
  143. return "ERROR_NOT_ENOUGH_MEMORY"
  144. return f"Windows error {errno}"
  145. def write(self, b: Buffer) -> int:
  146. bytes_to_be_written = len(b)
  147. buf = get_buffer(b)
  148. code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
  149. code_units_written = c_ulong()
  150. WriteConsoleW(
  151. HANDLE(self.handle),
  152. buf,
  153. code_units_to_be_written,
  154. byref(code_units_written),
  155. None,
  156. )
  157. bytes_written = 2 * code_units_written.value
  158. if bytes_written == 0 and bytes_to_be_written > 0:
  159. raise OSError(self._get_error_message(GetLastError()))
  160. return bytes_written
  161. class ConsoleStream:
  162. def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None:
  163. self._text_stream = text_stream
  164. self.buffer = byte_stream
  165. @property
  166. def name(self) -> str:
  167. return self.buffer.name
  168. def write(self, x: t.AnyStr) -> int:
  169. if isinstance(x, str):
  170. return self._text_stream.write(x)
  171. try:
  172. self.flush()
  173. except Exception:
  174. pass
  175. return self.buffer.write(x)
  176. def writelines(self, lines: cabc.Iterable[t.AnyStr]) -> None:
  177. for line in lines:
  178. self.write(line)
  179. def __getattr__(self, name: str) -> t.Any:
  180. return getattr(self._text_stream, name)
  181. def isatty(self) -> bool:
  182. return self.buffer.isatty()
  183. def __repr__(self) -> str:
  184. return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>"
  185. def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO:
  186. text_stream = _NonClosingTextIOWrapper(
  187. io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)),
  188. "utf-16-le",
  189. "strict",
  190. line_buffering=True,
  191. )
  192. return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
  193. def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO:
  194. text_stream = _NonClosingTextIOWrapper(
  195. io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)),
  196. "utf-16-le",
  197. "strict",
  198. line_buffering=True,
  199. )
  200. return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
  201. def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO:
  202. text_stream = _NonClosingTextIOWrapper(
  203. io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)),
  204. "utf-16-le",
  205. "strict",
  206. line_buffering=True,
  207. )
  208. return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
  209. _stream_factories: cabc.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = {
  210. 0: _get_text_stdin,
  211. 1: _get_text_stdout,
  212. 2: _get_text_stderr,
  213. }
  214. def _is_console(f: t.TextIO) -> bool:
  215. if not hasattr(f, "fileno"):
  216. return False
  217. try:
  218. fileno = f.fileno()
  219. except (OSError, io.UnsupportedOperation):
  220. return False
  221. handle = msvcrt.get_osfhandle(fileno)
  222. return bool(GetConsoleMode(handle, byref(DWORD())))
  223. def _get_windows_console_stream(
  224. f: t.TextIO, encoding: str | None, errors: str | None
  225. ) -> t.TextIO | None:
  226. if (
  227. get_buffer is None
  228. or encoding not in {"utf-16-le", None}
  229. or errors not in {"strict", None}
  230. or not _is_console(f)
  231. ):
  232. return None
  233. func = _stream_factories.get(f.fileno())
  234. if func is None:
  235. return None
  236. b = getattr(f, "buffer", None)
  237. if b is None:
  238. return None
  239. return func(b)