black/tests/util.py
Jelle Zijlstra 65abd1006b
add context manager to temporarily change the cwd (#2377)
Commit history before merge:

* add context manager to temporarily change the cwd
* Iterator, not Iterable
2021-07-16 22:21:34 -04:00

87 lines
2.7 KiB
Python

import os
import unittest
from pathlib import Path
from typing import Iterator, List, Tuple, Any
from contextlib import contextmanager
from functools import partial
import black
from black.output import out, err
from black.debug import DebugVisitor
THIS_DIR = Path(__file__).parent
PROJECT_ROOT = THIS_DIR.parent
EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
DETERMINISTIC_HEADER = "[Deterministic header]"
DEFAULT_MODE = black.Mode()
ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
fs = partial(black.format_str, mode=DEFAULT_MODE)
def dump_to_stderr(*output: str) -> str:
return "\n" + "\n".join(output) + "\n"
class BlackBaseTestCase(unittest.TestCase):
maxDiff = None
_diffThreshold = 2 ** 20
def assertFormatEqual(self, expected: str, actual: str) -> None:
if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
bdv: DebugVisitor[Any]
out("Expected tree:", fg="green")
try:
exp_node = black.lib2to3_parse(expected)
bdv = DebugVisitor()
list(bdv.visit(exp_node))
except Exception as ve:
err(str(ve))
out("Actual tree:", fg="red")
try:
exp_node = black.lib2to3_parse(actual)
bdv = DebugVisitor()
list(bdv.visit(exp_node))
except Exception as ve:
err(str(ve))
self.assertMultiLineEqual(expected, actual)
def read_data(name: str, data: bool = True) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'"""
if not name.endswith((".py", ".pyi", ".out", ".diff")):
name += ".py"
base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
return read_data_from_file(base_dir / name)
def read_data_from_file(file_name: Path) -> Tuple[str, str]:
with open(file_name, "r", encoding="utf8") as test:
lines = test.readlines()
_input: List[str] = []
_output: List[str] = []
result = _input
for line in lines:
line = line.replace(EMPTY_LINE, "")
if line.rstrip() == "# output":
result = _output
continue
result.append(line)
if _input and not _output:
# If there's no output marker, treat the entire file as already pre-formatted.
_output = _input[:]
return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
@contextmanager
def change_directory(path: Path) -> Iterator[None]:
"""Context manager to temporarily chdir to a different directory."""
previous_dir = os.getcwd()
try:
os.chdir(path)
yield
finally:
os.chdir(previous_dir)