testing.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. from __future__ import annotations
  2. import collections.abc as cabc
  3. import contextlib
  4. import io
  5. import os
  6. import shlex
  7. import sys
  8. import tempfile
  9. import typing as t
  10. from types import TracebackType
  11. from . import _compat
  12. from . import formatting
  13. from . import termui
  14. from . import utils
  15. from ._compat import _find_binary_reader
  16. if t.TYPE_CHECKING:
  17. from _typeshed import ReadableBuffer
  18. from .core import Command
  19. class EchoingStdin:
  20. def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None:
  21. self._input = input
  22. self._output = output
  23. self._paused = False
  24. def __getattr__(self, x: str) -> t.Any:
  25. return getattr(self._input, x)
  26. def _echo(self, rv: bytes) -> bytes:
  27. if not self._paused:
  28. self._output.write(rv)
  29. return rv
  30. def read(self, n: int = -1) -> bytes:
  31. return self._echo(self._input.read(n))
  32. def read1(self, n: int = -1) -> bytes:
  33. return self._echo(self._input.read1(n)) # type: ignore
  34. def readline(self, n: int = -1) -> bytes:
  35. return self._echo(self._input.readline(n))
  36. def readlines(self) -> list[bytes]:
  37. return [self._echo(x) for x in self._input.readlines()]
  38. def __iter__(self) -> cabc.Iterator[bytes]:
  39. return iter(self._echo(x) for x in self._input)
  40. def __repr__(self) -> str:
  41. return repr(self._input)
  42. @contextlib.contextmanager
  43. def _pause_echo(stream: EchoingStdin | None) -> cabc.Iterator[None]:
  44. if stream is None:
  45. yield
  46. else:
  47. stream._paused = True
  48. yield
  49. stream._paused = False
  50. class BytesIOCopy(io.BytesIO):
  51. """Patch ``io.BytesIO`` to let the written stream be copied to another.
  52. .. versionadded:: 8.2
  53. """
  54. def __init__(self, copy_to: io.BytesIO) -> None:
  55. super().__init__()
  56. self.copy_to = copy_to
  57. def flush(self) -> None:
  58. super().flush()
  59. self.copy_to.flush()
  60. def write(self, b: ReadableBuffer) -> int:
  61. self.copy_to.write(b)
  62. return super().write(b)
  63. class StreamMixer:
  64. """Mixes `<stdout>` and `<stderr>` streams.
  65. The result is available in the ``output`` attribute.
  66. .. versionadded:: 8.2
  67. """
  68. def __init__(self) -> None:
  69. self.output: io.BytesIO = io.BytesIO()
  70. self.stdout: io.BytesIO = BytesIOCopy(copy_to=self.output)
  71. self.stderr: io.BytesIO = BytesIOCopy(copy_to=self.output)
  72. def __del__(self) -> None:
  73. """
  74. Guarantee that embedded file-like objects are closed in a
  75. predictable order, protecting against races between
  76. self.output being closed and other streams being flushed on close
  77. .. versionadded:: 8.2.2
  78. """
  79. self.stderr.close()
  80. self.stdout.close()
  81. self.output.close()
  82. class _NamedTextIOWrapper(io.TextIOWrapper):
  83. def __init__(
  84. self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any
  85. ) -> None:
  86. super().__init__(buffer, **kwargs)
  87. self._name = name
  88. self._mode = mode
  89. @property
  90. def name(self) -> str:
  91. return self._name
  92. @property
  93. def mode(self) -> str:
  94. return self._mode
  95. def make_input_stream(
  96. input: str | bytes | t.IO[t.Any] | None, charset: str
  97. ) -> t.BinaryIO:
  98. # Is already an input stream.
  99. if hasattr(input, "read"):
  100. rv = _find_binary_reader(t.cast("t.IO[t.Any]", input))
  101. if rv is not None:
  102. return rv
  103. raise TypeError("Could not find binary reader for input stream.")
  104. if input is None:
  105. input = b""
  106. elif isinstance(input, str):
  107. input = input.encode(charset)
  108. return io.BytesIO(input)
  109. class Result:
  110. """Holds the captured result of an invoked CLI script.
  111. :param runner: The runner that created the result
  112. :param stdout_bytes: The standard output as bytes.
  113. :param stderr_bytes: The standard error as bytes.
  114. :param output_bytes: A mix of ``stdout_bytes`` and ``stderr_bytes``, as the
  115. user would see it in its terminal.
  116. :param return_value: The value returned from the invoked command.
  117. :param exit_code: The exit code as integer.
  118. :param exception: The exception that happened if one did.
  119. :param exc_info: Exception information (exception type, exception instance,
  120. traceback type).
  121. .. versionchanged:: 8.2
  122. ``stderr_bytes`` no longer optional, ``output_bytes`` introduced and
  123. ``mix_stderr`` has been removed.
  124. .. versionadded:: 8.0
  125. Added ``return_value``.
  126. """
  127. def __init__(
  128. self,
  129. runner: CliRunner,
  130. stdout_bytes: bytes,
  131. stderr_bytes: bytes,
  132. output_bytes: bytes,
  133. return_value: t.Any,
  134. exit_code: int,
  135. exception: BaseException | None,
  136. exc_info: tuple[type[BaseException], BaseException, TracebackType]
  137. | None = None,
  138. ):
  139. self.runner = runner
  140. self.stdout_bytes = stdout_bytes
  141. self.stderr_bytes = stderr_bytes
  142. self.output_bytes = output_bytes
  143. self.return_value = return_value
  144. self.exit_code = exit_code
  145. self.exception = exception
  146. self.exc_info = exc_info
  147. @property
  148. def output(self) -> str:
  149. """The terminal output as unicode string, as the user would see it.
  150. .. versionchanged:: 8.2
  151. No longer a proxy for ``self.stdout``. Now has its own independent stream
  152. that is mixing `<stdout>` and `<stderr>`, in the order they were written.
  153. """
  154. return self.output_bytes.decode(self.runner.charset, "replace").replace(
  155. "\r\n", "\n"
  156. )
  157. @property
  158. def stdout(self) -> str:
  159. """The standard output as unicode string."""
  160. return self.stdout_bytes.decode(self.runner.charset, "replace").replace(
  161. "\r\n", "\n"
  162. )
  163. @property
  164. def stderr(self) -> str:
  165. """The standard error as unicode string.
  166. .. versionchanged:: 8.2
  167. No longer raise an exception, always returns the `<stderr>` string.
  168. """
  169. return self.stderr_bytes.decode(self.runner.charset, "replace").replace(
  170. "\r\n", "\n"
  171. )
  172. def __repr__(self) -> str:
  173. exc_str = repr(self.exception) if self.exception else "okay"
  174. return f"<{type(self).__name__} {exc_str}>"
  175. class CliRunner:
  176. """The CLI runner provides functionality to invoke a Click command line
  177. script for unittesting purposes in a isolated environment. This only
  178. works in single-threaded systems without any concurrency as it changes the
  179. global interpreter state.
  180. :param charset: the character set for the input and output data.
  181. :param env: a dictionary with environment variables for overriding.
  182. :param echo_stdin: if this is set to `True`, then reading from `<stdin>` writes
  183. to `<stdout>`. This is useful for showing examples in
  184. some circumstances. Note that regular prompts
  185. will automatically echo the input.
  186. :param catch_exceptions: Whether to catch any exceptions other than
  187. ``SystemExit`` when running :meth:`~CliRunner.invoke`.
  188. .. versionchanged:: 8.2
  189. Added the ``catch_exceptions`` parameter.
  190. .. versionchanged:: 8.2
  191. ``mix_stderr`` parameter has been removed.
  192. """
  193. def __init__(
  194. self,
  195. charset: str = "utf-8",
  196. env: cabc.Mapping[str, str | None] | None = None,
  197. echo_stdin: bool = False,
  198. catch_exceptions: bool = True,
  199. ) -> None:
  200. self.charset = charset
  201. self.env: cabc.Mapping[str, str | None] = env or {}
  202. self.echo_stdin = echo_stdin
  203. self.catch_exceptions = catch_exceptions
  204. def get_default_prog_name(self, cli: Command) -> str:
  205. """Given a command object it will return the default program name
  206. for it. The default is the `name` attribute or ``"root"`` if not
  207. set.
  208. """
  209. return cli.name or "root"
  210. def make_env(
  211. self, overrides: cabc.Mapping[str, str | None] | None = None
  212. ) -> cabc.Mapping[str, str | None]:
  213. """Returns the environment overrides for invoking a script."""
  214. rv = dict(self.env)
  215. if overrides:
  216. rv.update(overrides)
  217. return rv
  218. @contextlib.contextmanager
  219. def isolation(
  220. self,
  221. input: str | bytes | t.IO[t.Any] | None = None,
  222. env: cabc.Mapping[str, str | None] | None = None,
  223. color: bool = False,
  224. ) -> cabc.Iterator[tuple[io.BytesIO, io.BytesIO, io.BytesIO]]:
  225. """A context manager that sets up the isolation for invoking of a
  226. command line tool. This sets up `<stdin>` with the given input data
  227. and `os.environ` with the overrides from the given dictionary.
  228. This also rebinds some internals in Click to be mocked (like the
  229. prompt functionality).
  230. This is automatically done in the :meth:`invoke` method.
  231. :param input: the input stream to put into `sys.stdin`.
  232. :param env: the environment overrides as dictionary.
  233. :param color: whether the output should contain color codes. The
  234. application can still override this explicitly.
  235. .. versionadded:: 8.2
  236. An additional output stream is returned, which is a mix of
  237. `<stdout>` and `<stderr>` streams.
  238. .. versionchanged:: 8.2
  239. Always returns the `<stderr>` stream.
  240. .. versionchanged:: 8.0
  241. `<stderr>` is opened with ``errors="backslashreplace"``
  242. instead of the default ``"strict"``.
  243. .. versionchanged:: 4.0
  244. Added the ``color`` parameter.
  245. """
  246. bytes_input = make_input_stream(input, self.charset)
  247. echo_input = None
  248. old_stdin = sys.stdin
  249. old_stdout = sys.stdout
  250. old_stderr = sys.stderr
  251. old_forced_width = formatting.FORCED_WIDTH
  252. formatting.FORCED_WIDTH = 80
  253. env = self.make_env(env)
  254. stream_mixer = StreamMixer()
  255. if self.echo_stdin:
  256. bytes_input = echo_input = t.cast(
  257. t.BinaryIO, EchoingStdin(bytes_input, stream_mixer.stdout)
  258. )
  259. sys.stdin = text_input = _NamedTextIOWrapper(
  260. bytes_input, encoding=self.charset, name="<stdin>", mode="r"
  261. )
  262. if self.echo_stdin:
  263. # Force unbuffered reads, otherwise TextIOWrapper reads a
  264. # large chunk which is echoed early.
  265. text_input._CHUNK_SIZE = 1 # type: ignore
  266. sys.stdout = _NamedTextIOWrapper(
  267. stream_mixer.stdout, encoding=self.charset, name="<stdout>", mode="w"
  268. )
  269. sys.stderr = _NamedTextIOWrapper(
  270. stream_mixer.stderr,
  271. encoding=self.charset,
  272. name="<stderr>",
  273. mode="w",
  274. errors="backslashreplace",
  275. )
  276. @_pause_echo(echo_input) # type: ignore
  277. def visible_input(prompt: str | None = None) -> str:
  278. sys.stdout.write(prompt or "")
  279. try:
  280. val = next(text_input).rstrip("\r\n")
  281. except StopIteration as e:
  282. raise EOFError() from e
  283. sys.stdout.write(f"{val}\n")
  284. sys.stdout.flush()
  285. return val
  286. @_pause_echo(echo_input) # type: ignore
  287. def hidden_input(prompt: str | None = None) -> str:
  288. sys.stdout.write(f"{prompt or ''}\n")
  289. sys.stdout.flush()
  290. try:
  291. return next(text_input).rstrip("\r\n")
  292. except StopIteration as e:
  293. raise EOFError() from e
  294. @_pause_echo(echo_input) # type: ignore
  295. def _getchar(echo: bool) -> str:
  296. char = sys.stdin.read(1)
  297. if echo:
  298. sys.stdout.write(char)
  299. sys.stdout.flush()
  300. return char
  301. default_color = color
  302. def should_strip_ansi(
  303. stream: t.IO[t.Any] | None = None, color: bool | None = None
  304. ) -> bool:
  305. if color is None:
  306. return not default_color
  307. return not color
  308. old_visible_prompt_func = termui.visible_prompt_func
  309. old_hidden_prompt_func = termui.hidden_prompt_func
  310. old__getchar_func = termui._getchar
  311. old_should_strip_ansi = utils.should_strip_ansi # type: ignore
  312. old__compat_should_strip_ansi = _compat.should_strip_ansi
  313. termui.visible_prompt_func = visible_input
  314. termui.hidden_prompt_func = hidden_input
  315. termui._getchar = _getchar
  316. utils.should_strip_ansi = should_strip_ansi # type: ignore
  317. _compat.should_strip_ansi = should_strip_ansi
  318. old_env = {}
  319. try:
  320. for key, value in env.items():
  321. old_env[key] = os.environ.get(key)
  322. if value is None:
  323. try:
  324. del os.environ[key]
  325. except Exception:
  326. pass
  327. else:
  328. os.environ[key] = value
  329. yield (stream_mixer.stdout, stream_mixer.stderr, stream_mixer.output)
  330. finally:
  331. for key, value in old_env.items():
  332. if value is None:
  333. try:
  334. del os.environ[key]
  335. except Exception:
  336. pass
  337. else:
  338. os.environ[key] = value
  339. sys.stdout = old_stdout
  340. sys.stderr = old_stderr
  341. sys.stdin = old_stdin
  342. termui.visible_prompt_func = old_visible_prompt_func
  343. termui.hidden_prompt_func = old_hidden_prompt_func
  344. termui._getchar = old__getchar_func
  345. utils.should_strip_ansi = old_should_strip_ansi # type: ignore
  346. _compat.should_strip_ansi = old__compat_should_strip_ansi
  347. formatting.FORCED_WIDTH = old_forced_width
  348. def invoke(
  349. self,
  350. cli: Command,
  351. args: str | cabc.Sequence[str] | None = None,
  352. input: str | bytes | t.IO[t.Any] | None = None,
  353. env: cabc.Mapping[str, str | None] | None = None,
  354. catch_exceptions: bool | None = None,
  355. color: bool = False,
  356. **extra: t.Any,
  357. ) -> Result:
  358. """Invokes a command in an isolated environment. The arguments are
  359. forwarded directly to the command line script, the `extra` keyword
  360. arguments are passed to the :meth:`~clickpkg.Command.main` function of
  361. the command.
  362. This returns a :class:`Result` object.
  363. :param cli: the command to invoke
  364. :param args: the arguments to invoke. It may be given as an iterable
  365. or a string. When given as string it will be interpreted
  366. as a Unix shell command. More details at
  367. :func:`shlex.split`.
  368. :param input: the input data for `sys.stdin`.
  369. :param env: the environment overrides.
  370. :param catch_exceptions: Whether to catch any other exceptions than
  371. ``SystemExit``. If :data:`None`, the value
  372. from :class:`CliRunner` is used.
  373. :param extra: the keyword arguments to pass to :meth:`main`.
  374. :param color: whether the output should contain color codes. The
  375. application can still override this explicitly.
  376. .. versionadded:: 8.2
  377. The result object has the ``output_bytes`` attribute with
  378. the mix of ``stdout_bytes`` and ``stderr_bytes``, as the user would
  379. see it in its terminal.
  380. .. versionchanged:: 8.2
  381. The result object always returns the ``stderr_bytes`` stream.
  382. .. versionchanged:: 8.0
  383. The result object has the ``return_value`` attribute with
  384. the value returned from the invoked command.
  385. .. versionchanged:: 4.0
  386. Added the ``color`` parameter.
  387. .. versionchanged:: 3.0
  388. Added the ``catch_exceptions`` parameter.
  389. .. versionchanged:: 3.0
  390. The result object has the ``exc_info`` attribute with the
  391. traceback if available.
  392. """
  393. exc_info = None
  394. if catch_exceptions is None:
  395. catch_exceptions = self.catch_exceptions
  396. with self.isolation(input=input, env=env, color=color) as outstreams:
  397. return_value = None
  398. exception: BaseException | None = None
  399. exit_code = 0
  400. if isinstance(args, str):
  401. args = shlex.split(args)
  402. try:
  403. prog_name = extra.pop("prog_name")
  404. except KeyError:
  405. prog_name = self.get_default_prog_name(cli)
  406. try:
  407. return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
  408. except SystemExit as e:
  409. exc_info = sys.exc_info()
  410. e_code = t.cast("int | t.Any | None", e.code)
  411. if e_code is None:
  412. e_code = 0
  413. if e_code != 0:
  414. exception = e
  415. if not isinstance(e_code, int):
  416. sys.stdout.write(str(e_code))
  417. sys.stdout.write("\n")
  418. e_code = 1
  419. exit_code = e_code
  420. except Exception as e:
  421. if not catch_exceptions:
  422. raise
  423. exception = e
  424. exit_code = 1
  425. exc_info = sys.exc_info()
  426. finally:
  427. sys.stdout.flush()
  428. sys.stderr.flush()
  429. stdout = outstreams[0].getvalue()
  430. stderr = outstreams[1].getvalue()
  431. output = outstreams[2].getvalue()
  432. return Result(
  433. runner=self,
  434. stdout_bytes=stdout,
  435. stderr_bytes=stderr,
  436. output_bytes=output,
  437. return_value=return_value,
  438. exit_code=exit_code,
  439. exception=exception,
  440. exc_info=exc_info, # type: ignore
  441. )
  442. @contextlib.contextmanager
  443. def isolated_filesystem(
  444. self, temp_dir: str | os.PathLike[str] | None = None
  445. ) -> cabc.Iterator[str]:
  446. """A context manager that creates a temporary directory and
  447. changes the current working directory to it. This isolates tests
  448. that affect the contents of the CWD to prevent them from
  449. interfering with each other.
  450. :param temp_dir: Create the temporary directory under this
  451. directory. If given, the created directory is not removed
  452. when exiting.
  453. .. versionchanged:: 8.0
  454. Added the ``temp_dir`` parameter.
  455. """
  456. cwd = os.getcwd()
  457. dt = tempfile.mkdtemp(dir=temp_dir)
  458. os.chdir(dt)
  459. try:
  460. yield dt
  461. finally:
  462. os.chdir(cwd)
  463. if temp_dir is None:
  464. import shutil
  465. try:
  466. shutil.rmtree(dt)
  467. except OSError:
  468. pass