black/tests/util.py
Łukasz Langa f2ea461e9e
Refactor src/black/__init__.py into many files (#2206)
* Move string-related utility to functions to strings.py, const.py
* Move Leaf/Node-related functionality to nodes.py
* Move comment-related functions to comments.py
* Move caching to cache.py and Mode/TargetVersion/Feature to mode.py
* Move some leftover functions to nodes.py, comments.py, strings.py
* Add missing files to source list for test runs
* Move line-related functionality into lines.py, brackets into brackets.py
* Move transformers to trans.py
* Move file handling, output, parsing, concurrency, debug, and report
* Move two more functions to nodes.py
* Add CHANGES
* Add numeric.py
* Add linegen.py
* More docstrings
* Include new files in tests

Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
2021-05-08 11:29:47 +02:00

75 lines
2.4 KiB
Python

import os
import unittest
from pathlib import Path
from typing import List, Tuple, Any
from functools import partial
import black
from black.output import out, err
from black.debug import DebugVisitor
THIS_DIR = Path(__file__).parent
PROJECT_ROOT = THIS_DIR.parent
EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
DETERMINISTIC_HEADER = "[Deterministic header]"
DEFAULT_MODE = black.Mode()
ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
fs = partial(black.format_str, mode=DEFAULT_MODE)
def dump_to_stderr(*output: str) -> str:
return "\n" + "\n".join(output) + "\n"
class BlackBaseTestCase(unittest.TestCase):
maxDiff = None
_diffThreshold = 2 ** 20
def assertFormatEqual(self, expected: str, actual: str) -> None:
if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
bdv: DebugVisitor[Any]
out("Expected tree:", fg="green")
try:
exp_node = black.lib2to3_parse(expected)
bdv = DebugVisitor()
list(bdv.visit(exp_node))
except Exception as ve:
err(str(ve))
out("Actual tree:", fg="red")
try:
exp_node = black.lib2to3_parse(actual)
bdv = DebugVisitor()
list(bdv.visit(exp_node))
except Exception as ve:
err(str(ve))
self.assertMultiLineEqual(expected, actual)
def read_data(name: str, data: bool = True) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'"""
if not name.endswith((".py", ".pyi", ".out", ".diff")):
name += ".py"
base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
return read_data_from_file(base_dir / name)
def read_data_from_file(file_name: Path) -> Tuple[str, str]:
with open(file_name, "r", encoding="utf8") as test:
lines = test.readlines()
_input: List[str] = []
_output: List[str] = []
result = _input
for line in lines:
line = line.replace(EMPTY_LINE, "")
if line.rstrip() == "# output":
result = _output
continue
result.append(line)
if _input and not _output:
# If there's no output marker, treat the entire file as already pre-formatted.
_output = _input[:]
return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"