_compat.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. from __future__ import annotations
  2. import codecs
  3. import collections.abc as cabc
  4. import io
  5. import os
  6. import re
  7. import sys
  8. import typing as t
  9. from types import TracebackType
  10. from weakref import WeakKeyDictionary
  11. CYGWIN = sys.platform.startswith("cygwin")
  12. WIN = sys.platform.startswith("win")
  13. auto_wrap_for_ansi: t.Callable[[t.TextIO], t.TextIO] | None = None
  14. _ansi_re = re.compile(r"\033\[[;?0-9]*[a-zA-Z]")
  15. def _make_text_stream(
  16. stream: t.BinaryIO,
  17. encoding: str | None,
  18. errors: str | None,
  19. force_readable: bool = False,
  20. force_writable: bool = False,
  21. ) -> t.TextIO:
  22. if encoding is None:
  23. encoding = get_best_encoding(stream)
  24. if errors is None:
  25. errors = "replace"
  26. return _NonClosingTextIOWrapper(
  27. stream,
  28. encoding,
  29. errors,
  30. line_buffering=True,
  31. force_readable=force_readable,
  32. force_writable=force_writable,
  33. )
  34. def is_ascii_encoding(encoding: str) -> bool:
  35. """Checks if a given encoding is ascii."""
  36. try:
  37. return codecs.lookup(encoding).name == "ascii"
  38. except LookupError:
  39. return False
  40. def get_best_encoding(stream: t.IO[t.Any]) -> str:
  41. """Returns the default stream encoding if not found."""
  42. rv = getattr(stream, "encoding", None) or sys.getdefaultencoding()
  43. if is_ascii_encoding(rv):
  44. return "utf-8"
  45. return rv
  46. class _NonClosingTextIOWrapper(io.TextIOWrapper):
  47. def __init__(
  48. self,
  49. stream: t.BinaryIO,
  50. encoding: str | None,
  51. errors: str | None,
  52. force_readable: bool = False,
  53. force_writable: bool = False,
  54. **extra: t.Any,
  55. ) -> None:
  56. self._stream = stream = t.cast(
  57. t.BinaryIO, _FixupStream(stream, force_readable, force_writable)
  58. )
  59. super().__init__(stream, encoding, errors, **extra)
  60. def __del__(self) -> None:
  61. try:
  62. self.detach()
  63. except Exception:
  64. pass
  65. def isatty(self) -> bool:
  66. # https://bitbucket.org/pypy/pypy/issue/1803
  67. return self._stream.isatty()
  68. class _FixupStream:
  69. """The new io interface needs more from streams than streams
  70. traditionally implement. As such, this fix-up code is necessary in
  71. some circumstances.
  72. The forcing of readable and writable flags are there because some tools
  73. put badly patched objects on sys (one such offender are certain version
  74. of jupyter notebook).
  75. """
  76. def __init__(
  77. self,
  78. stream: t.BinaryIO,
  79. force_readable: bool = False,
  80. force_writable: bool = False,
  81. ):
  82. self._stream = stream
  83. self._force_readable = force_readable
  84. self._force_writable = force_writable
  85. def __getattr__(self, name: str) -> t.Any:
  86. return getattr(self._stream, name)
  87. def read1(self, size: int) -> bytes:
  88. f = getattr(self._stream, "read1", None)
  89. if f is not None:
  90. return t.cast(bytes, f(size))
  91. return self._stream.read(size)
  92. def readable(self) -> bool:
  93. if self._force_readable:
  94. return True
  95. x = getattr(self._stream, "readable", None)
  96. if x is not None:
  97. return t.cast(bool, x())
  98. try:
  99. self._stream.read(0)
  100. except Exception:
  101. return False
  102. return True
  103. def writable(self) -> bool:
  104. if self._force_writable:
  105. return True
  106. x = getattr(self._stream, "writable", None)
  107. if x is not None:
  108. return t.cast(bool, x())
  109. try:
  110. self._stream.write(b"")
  111. except Exception:
  112. try:
  113. self._stream.write(b"")
  114. except Exception:
  115. return False
  116. return True
  117. def seekable(self) -> bool:
  118. x = getattr(self._stream, "seekable", None)
  119. if x is not None:
  120. return t.cast(bool, x())
  121. try:
  122. self._stream.seek(self._stream.tell())
  123. except Exception:
  124. return False
  125. return True
  126. def _is_binary_reader(stream: t.IO[t.Any], default: bool = False) -> bool:
  127. try:
  128. return isinstance(stream.read(0), bytes)
  129. except Exception:
  130. return default
  131. # This happens in some cases where the stream was already
  132. # closed. In this case, we assume the default.
  133. def _is_binary_writer(stream: t.IO[t.Any], default: bool = False) -> bool:
  134. try:
  135. stream.write(b"")
  136. except Exception:
  137. try:
  138. stream.write("")
  139. return False
  140. except Exception:
  141. pass
  142. return default
  143. return True
  144. def _find_binary_reader(stream: t.IO[t.Any]) -> t.BinaryIO | None:
  145. # We need to figure out if the given stream is already binary.
  146. # This can happen because the official docs recommend detaching
  147. # the streams to get binary streams. Some code might do this, so
  148. # we need to deal with this case explicitly.
  149. if _is_binary_reader(stream, False):
  150. return t.cast(t.BinaryIO, stream)
  151. buf = getattr(stream, "buffer", None)
  152. # Same situation here; this time we assume that the buffer is
  153. # actually binary in case it's closed.
  154. if buf is not None and _is_binary_reader(buf, True):
  155. return t.cast(t.BinaryIO, buf)
  156. return None
  157. def _find_binary_writer(stream: t.IO[t.Any]) -> t.BinaryIO | None:
  158. # We need to figure out if the given stream is already binary.
  159. # This can happen because the official docs recommend detaching
  160. # the streams to get binary streams. Some code might do this, so
  161. # we need to deal with this case explicitly.
  162. if _is_binary_writer(stream, False):
  163. return t.cast(t.BinaryIO, stream)
  164. buf = getattr(stream, "buffer", None)
  165. # Same situation here; this time we assume that the buffer is
  166. # actually binary in case it's closed.
  167. if buf is not None and _is_binary_writer(buf, True):
  168. return t.cast(t.BinaryIO, buf)
  169. return None
  170. def _stream_is_misconfigured(stream: t.TextIO) -> bool:
  171. """A stream is misconfigured if its encoding is ASCII."""
  172. # If the stream does not have an encoding set, we assume it's set
  173. # to ASCII. This appears to happen in certain unittest
  174. # environments. It's not quite clear what the correct behavior is
  175. # but this at least will force Click to recover somehow.
  176. return is_ascii_encoding(getattr(stream, "encoding", None) or "ascii")
  177. def _is_compat_stream_attr(stream: t.TextIO, attr: str, value: str | None) -> bool:
  178. """A stream attribute is compatible if it is equal to the
  179. desired value or the desired value is unset and the attribute
  180. has a value.
  181. """
  182. stream_value = getattr(stream, attr, None)
  183. return stream_value == value or (value is None and stream_value is not None)
  184. def _is_compatible_text_stream(
  185. stream: t.TextIO, encoding: str | None, errors: str | None
  186. ) -> bool:
  187. """Check if a stream's encoding and errors attributes are
  188. compatible with the desired values.
  189. """
  190. return _is_compat_stream_attr(
  191. stream, "encoding", encoding
  192. ) and _is_compat_stream_attr(stream, "errors", errors)
  193. def _force_correct_text_stream(
  194. text_stream: t.IO[t.Any],
  195. encoding: str | None,
  196. errors: str | None,
  197. is_binary: t.Callable[[t.IO[t.Any], bool], bool],
  198. find_binary: t.Callable[[t.IO[t.Any]], t.BinaryIO | None],
  199. force_readable: bool = False,
  200. force_writable: bool = False,
  201. ) -> t.TextIO:
  202. if is_binary(text_stream, False):
  203. binary_reader = t.cast(t.BinaryIO, text_stream)
  204. else:
  205. text_stream = t.cast(t.TextIO, text_stream)
  206. # If the stream looks compatible, and won't default to a
  207. # misconfigured ascii encoding, return it as-is.
  208. if _is_compatible_text_stream(text_stream, encoding, errors) and not (
  209. encoding is None and _stream_is_misconfigured(text_stream)
  210. ):
  211. return text_stream
  212. # Otherwise, get the underlying binary reader.
  213. possible_binary_reader = find_binary(text_stream)
  214. # If that's not possible, silently use the original reader
  215. # and get mojibake instead of exceptions.
  216. if possible_binary_reader is None:
  217. return text_stream
  218. binary_reader = possible_binary_reader
  219. # Default errors to replace instead of strict in order to get
  220. # something that works.
  221. if errors is None:
  222. errors = "replace"
  223. # Wrap the binary stream in a text stream with the correct
  224. # encoding parameters.
  225. return _make_text_stream(
  226. binary_reader,
  227. encoding,
  228. errors,
  229. force_readable=force_readable,
  230. force_writable=force_writable,
  231. )
  232. def _force_correct_text_reader(
  233. text_reader: t.IO[t.Any],
  234. encoding: str | None,
  235. errors: str | None,
  236. force_readable: bool = False,
  237. ) -> t.TextIO:
  238. return _force_correct_text_stream(
  239. text_reader,
  240. encoding,
  241. errors,
  242. _is_binary_reader,
  243. _find_binary_reader,
  244. force_readable=force_readable,
  245. )
  246. def _force_correct_text_writer(
  247. text_writer: t.IO[t.Any],
  248. encoding: str | None,
  249. errors: str | None,
  250. force_writable: bool = False,
  251. ) -> t.TextIO:
  252. return _force_correct_text_stream(
  253. text_writer,
  254. encoding,
  255. errors,
  256. _is_binary_writer,
  257. _find_binary_writer,
  258. force_writable=force_writable,
  259. )
  260. def get_binary_stdin() -> t.BinaryIO:
  261. reader = _find_binary_reader(sys.stdin)
  262. if reader is None:
  263. raise RuntimeError("Was not able to determine binary stream for sys.stdin.")
  264. return reader
  265. def get_binary_stdout() -> t.BinaryIO:
  266. writer = _find_binary_writer(sys.stdout)
  267. if writer is None:
  268. raise RuntimeError("Was not able to determine binary stream for sys.stdout.")
  269. return writer
  270. def get_binary_stderr() -> t.BinaryIO:
  271. writer = _find_binary_writer(sys.stderr)
  272. if writer is None:
  273. raise RuntimeError("Was not able to determine binary stream for sys.stderr.")
  274. return writer
  275. def get_text_stdin(encoding: str | None = None, errors: str | None = None) -> t.TextIO:
  276. rv = _get_windows_console_stream(sys.stdin, encoding, errors)
  277. if rv is not None:
  278. return rv
  279. return _force_correct_text_reader(sys.stdin, encoding, errors, force_readable=True)
  280. def get_text_stdout(encoding: str | None = None, errors: str | None = None) -> t.TextIO:
  281. rv = _get_windows_console_stream(sys.stdout, encoding, errors)
  282. if rv is not None:
  283. return rv
  284. return _force_correct_text_writer(sys.stdout, encoding, errors, force_writable=True)
  285. def get_text_stderr(encoding: str | None = None, errors: str | None = None) -> t.TextIO:
  286. rv = _get_windows_console_stream(sys.stderr, encoding, errors)
  287. if rv is not None:
  288. return rv
  289. return _force_correct_text_writer(sys.stderr, encoding, errors, force_writable=True)
  290. def _wrap_io_open(
  291. file: str | os.PathLike[str] | int,
  292. mode: str,
  293. encoding: str | None,
  294. errors: str | None,
  295. ) -> t.IO[t.Any]:
  296. """Handles not passing ``encoding`` and ``errors`` in binary mode."""
  297. if "b" in mode:
  298. return open(file, mode)
  299. return open(file, mode, encoding=encoding, errors=errors)
  300. def open_stream(
  301. filename: str | os.PathLike[str],
  302. mode: str = "r",
  303. encoding: str | None = None,
  304. errors: str | None = "strict",
  305. atomic: bool = False,
  306. ) -> tuple[t.IO[t.Any], bool]:
  307. binary = "b" in mode
  308. filename = os.fspath(filename)
  309. # Standard streams first. These are simple because they ignore the
  310. # atomic flag. Use fsdecode to handle Path("-").
  311. if os.fsdecode(filename) == "-":
  312. if any(m in mode for m in ["w", "a", "x"]):
  313. if binary:
  314. return get_binary_stdout(), False
  315. return get_text_stdout(encoding=encoding, errors=errors), False
  316. if binary:
  317. return get_binary_stdin(), False
  318. return get_text_stdin(encoding=encoding, errors=errors), False
  319. # Non-atomic writes directly go out through the regular open functions.
  320. if not atomic:
  321. return _wrap_io_open(filename, mode, encoding, errors), True
  322. # Some usability stuff for atomic writes
  323. if "a" in mode:
  324. raise ValueError(
  325. "Appending to an existing file is not supported, because that"
  326. " would involve an expensive `copy`-operation to a temporary"
  327. " file. Open the file in normal `w`-mode and copy explicitly"
  328. " if that's what you're after."
  329. )
  330. if "x" in mode:
  331. raise ValueError("Use the `overwrite`-parameter instead.")
  332. if "w" not in mode:
  333. raise ValueError("Atomic writes only make sense with `w`-mode.")
  334. # Atomic writes are more complicated. They work by opening a file
  335. # as a proxy in the same folder and then using the fdopen
  336. # functionality to wrap it in a Python file. Then we wrap it in an
  337. # atomic file that moves the file over on close.
  338. import errno
  339. import random
  340. try:
  341. perm: int | None = os.stat(filename).st_mode
  342. except OSError:
  343. perm = None
  344. flags = os.O_RDWR | os.O_CREAT | os.O_EXCL
  345. if binary:
  346. flags |= getattr(os, "O_BINARY", 0)
  347. while True:
  348. tmp_filename = os.path.join(
  349. os.path.dirname(filename),
  350. f".__atomic-write{random.randrange(1 << 32):08x}",
  351. )
  352. try:
  353. fd = os.open(tmp_filename, flags, 0o666 if perm is None else perm)
  354. break
  355. except OSError as e:
  356. if e.errno == errno.EEXIST or (
  357. os.name == "nt"
  358. and e.errno == errno.EACCES
  359. and os.path.isdir(e.filename)
  360. and os.access(e.filename, os.W_OK)
  361. ):
  362. continue
  363. raise
  364. if perm is not None:
  365. os.chmod(tmp_filename, perm) # in case perm includes bits in umask
  366. f = _wrap_io_open(fd, mode, encoding, errors)
  367. af = _AtomicFile(f, tmp_filename, os.path.realpath(filename))
  368. return t.cast(t.IO[t.Any], af), True
  369. class _AtomicFile:
  370. def __init__(self, f: t.IO[t.Any], tmp_filename: str, real_filename: str) -> None:
  371. self._f = f
  372. self._tmp_filename = tmp_filename
  373. self._real_filename = real_filename
  374. self.closed = False
  375. @property
  376. def name(self) -> str:
  377. return self._real_filename
  378. def close(self, delete: bool = False) -> None:
  379. if self.closed:
  380. return
  381. self._f.close()
  382. os.replace(self._tmp_filename, self._real_filename)
  383. self.closed = True
  384. def __getattr__(self, name: str) -> t.Any:
  385. return getattr(self._f, name)
  386. def __enter__(self) -> _AtomicFile:
  387. return self
  388. def __exit__(
  389. self,
  390. exc_type: type[BaseException] | None,
  391. exc_value: BaseException | None,
  392. tb: TracebackType | None,
  393. ) -> None:
  394. self.close(delete=exc_type is not None)
  395. def __repr__(self) -> str:
  396. return repr(self._f)
  397. def strip_ansi(value: str) -> str:
  398. return _ansi_re.sub("", value)
  399. def _is_jupyter_kernel_output(stream: t.IO[t.Any]) -> bool:
  400. while isinstance(stream, (_FixupStream, _NonClosingTextIOWrapper)):
  401. stream = stream._stream
  402. return stream.__class__.__module__.startswith("ipykernel.")
  403. def should_strip_ansi(
  404. stream: t.IO[t.Any] | None = None, color: bool | None = None
  405. ) -> bool:
  406. if color is None:
  407. if stream is None:
  408. stream = sys.stdin
  409. return not isatty(stream) and not _is_jupyter_kernel_output(stream)
  410. return not color
  411. # On Windows, wrap the output streams with colorama to support ANSI
  412. # color codes.
  413. # NOTE: double check is needed so mypy does not analyze this on Linux
  414. if sys.platform.startswith("win") and WIN:
  415. from ._winconsole import _get_windows_console_stream
  416. def _get_argv_encoding() -> str:
  417. import locale
  418. return locale.getpreferredencoding()
  419. _ansi_stream_wrappers: cabc.MutableMapping[t.TextIO, t.TextIO] = WeakKeyDictionary()
  420. def auto_wrap_for_ansi(stream: t.TextIO, color: bool | None = None) -> t.TextIO:
  421. """Support ANSI color and style codes on Windows by wrapping a
  422. stream with colorama.
  423. """
  424. try:
  425. cached = _ansi_stream_wrappers.get(stream)
  426. except Exception:
  427. cached = None
  428. if cached is not None:
  429. return cached
  430. import colorama
  431. strip = should_strip_ansi(stream, color)
  432. ansi_wrapper = colorama.AnsiToWin32(stream, strip=strip)
  433. rv = t.cast(t.TextIO, ansi_wrapper.stream)
  434. _write = rv.write
  435. def _safe_write(s: str) -> int:
  436. try:
  437. return _write(s)
  438. except BaseException:
  439. ansi_wrapper.reset_all()
  440. raise
  441. rv.write = _safe_write # type: ignore[method-assign]
  442. try:
  443. _ansi_stream_wrappers[stream] = rv
  444. except Exception:
  445. pass
  446. return rv
  447. else:
  448. def _get_argv_encoding() -> str:
  449. return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
  450. def _get_windows_console_stream(
  451. f: t.TextIO, encoding: str | None, errors: str | None
  452. ) -> t.TextIO | None:
  453. return None
  454. def term_len(x: str) -> int:
  455. return len(strip_ansi(x))
  456. def isatty(stream: t.IO[t.Any]) -> bool:
  457. try:
  458. return stream.isatty()
  459. except Exception:
  460. return False
  461. def _make_cached_stream_func(
  462. src_func: t.Callable[[], t.TextIO | None],
  463. wrapper_func: t.Callable[[], t.TextIO],
  464. ) -> t.Callable[[], t.TextIO | None]:
  465. cache: cabc.MutableMapping[t.TextIO, t.TextIO] = WeakKeyDictionary()
  466. def func() -> t.TextIO | None:
  467. stream = src_func()
  468. if stream is None:
  469. return None
  470. try:
  471. rv = cache.get(stream)
  472. except Exception:
  473. rv = None
  474. if rv is not None:
  475. return rv
  476. rv = wrapper_func()
  477. try:
  478. cache[stream] = rv
  479. except Exception:
  480. pass
  481. return rv
  482. return func
  483. _default_text_stdin = _make_cached_stream_func(lambda: sys.stdin, get_text_stdin)
  484. _default_text_stdout = _make_cached_stream_func(lambda: sys.stdout, get_text_stdout)
  485. _default_text_stderr = _make_cached_stream_func(lambda: sys.stderr, get_text_stderr)
  486. binary_streams: cabc.Mapping[str, t.Callable[[], t.BinaryIO]] = {
  487. "stdin": get_binary_stdin,
  488. "stdout": get_binary_stdout,
  489. "stderr": get_binary_stderr,
  490. }
  491. text_streams: cabc.Mapping[str, t.Callable[[str | None, str | None], t.TextIO]] = {
  492. "stdin": get_text_stdin,
  493. "stdout": get_text_stdout,
  494. "stderr": get_text_stderr,
  495. }