Refactor --pyi and --py36 into FileMode

This commit is contained in:
Łukasz Langa 2018-05-29 01:53:54 -07:00
parent ad01a51868
commit 023e61a254
2 changed files with 92 additions and 99 deletions

117
black.py
View File

@ -2,7 +2,7 @@
import pickle import pickle
from asyncio.base_events import BaseEventLoop from asyncio.base_events import BaseEventLoop
from concurrent.futures import Executor, ProcessPoolExecutor from concurrent.futures import Executor, ProcessPoolExecutor
from enum import Enum from enum import Enum, Flag
from functools import partial, wraps from functools import partial, wraps
import keyword import keyword
import logging import logging
@ -122,6 +122,12 @@ class Changed(Enum):
YES = 2 YES = 2
class FileMode(Flag):
AUTO_DETECT = 0
PYTHON36 = 1
PYI = 2
@click.command() @click.command()
@click.option( @click.option(
"-l", "-l",
@ -216,6 +222,11 @@ def main(
write_back = WriteBack.DIFF write_back = WriteBack.DIFF
else: else:
write_back = WriteBack.YES write_back = WriteBack.YES
mode = FileMode.AUTO_DETECT
if py36:
mode |= FileMode.PYTHON36
if pyi:
mode |= FileMode.PYI
report = Report(check=check, quiet=quiet) report = Report(check=check, quiet=quiet)
if len(sources) == 0: if len(sources) == 0:
out("No paths given. Nothing to do 😴") out("No paths given. Nothing to do 😴")
@ -227,9 +238,8 @@ def main(
src=sources[0], src=sources[0],
line_length=line_length, line_length=line_length,
fast=fast, fast=fast,
pyi=pyi,
py36=py36,
write_back=write_back, write_back=write_back,
mode=mode,
report=report, report=report,
) )
else: else:
@ -241,9 +251,8 @@ def main(
sources=sources, sources=sources,
line_length=line_length, line_length=line_length,
fast=fast, fast=fast,
pyi=pyi,
py36=py36,
write_back=write_back, write_back=write_back,
mode=mode,
report=report, report=report,
loop=loop, loop=loop,
executor=executor, executor=executor,
@ -261,9 +270,8 @@ def reformat_one(
src: Path, src: Path,
line_length: int, line_length: int,
fast: bool, fast: bool,
pyi: bool,
py36: bool,
write_back: WriteBack, write_back: WriteBack,
mode: FileMode,
report: "Report", report: "Report",
) -> None: ) -> None:
"""Reformat a single file under `src` without spawning child processes. """Reformat a single file under `src` without spawning child processes.
@ -276,17 +284,13 @@ def reformat_one(
changed = Changed.NO changed = Changed.NO
if not src.is_file() and str(src) == "-": if not src.is_file() and str(src) == "-":
if format_stdin_to_stdout( if format_stdin_to_stdout(
line_length=line_length, line_length=line_length, fast=fast, write_back=write_back, mode=mode
fast=fast,
is_pyi=pyi,
force_py36=py36,
write_back=write_back,
): ):
changed = Changed.YES changed = Changed.YES
else: else:
cache: Cache = {} cache: Cache = {}
if write_back != WriteBack.DIFF: if write_back != WriteBack.DIFF:
cache = read_cache(line_length, pyi, py36) cache = read_cache(line_length, mode)
src = src.resolve() src = src.resolve()
if src in cache and cache[src] == get_cache_info(src): if src in cache and cache[src] == get_cache_info(src):
changed = Changed.CACHED changed = Changed.CACHED
@ -294,13 +298,12 @@ def reformat_one(
src, src,
line_length=line_length, line_length=line_length,
fast=fast, fast=fast,
force_pyi=pyi,
force_py36=py36,
write_back=write_back, write_back=write_back,
mode=mode,
): ):
changed = Changed.YES changed = Changed.YES
if write_back == WriteBack.YES and changed is not Changed.NO: if write_back == WriteBack.YES and changed is not Changed.NO:
write_cache(cache, [src], line_length, pyi, py36) write_cache(cache, [src], line_length, mode)
report.done(src, changed) report.done(src, changed)
except Exception as exc: except Exception as exc:
report.failed(src, str(exc)) report.failed(src, str(exc))
@ -310,9 +313,8 @@ async def schedule_formatting(
sources: List[Path], sources: List[Path],
line_length: int, line_length: int,
fast: bool, fast: bool,
pyi: bool,
py36: bool,
write_back: WriteBack, write_back: WriteBack,
mode: FileMode,
report: "Report", report: "Report",
loop: BaseEventLoop, loop: BaseEventLoop,
executor: Executor, executor: Executor,
@ -326,7 +328,7 @@ async def schedule_formatting(
""" """
cache: Cache = {} cache: Cache = {}
if write_back != WriteBack.DIFF: if write_back != WriteBack.DIFF:
cache = read_cache(line_length, pyi, py36) cache = read_cache(line_length, mode)
sources, cached = filter_cached(cache, sources) sources, cached = filter_cached(cache, sources)
for src in cached: for src in cached:
report.done(src, Changed.CACHED) report.done(src, Changed.CACHED)
@ -346,9 +348,8 @@ async def schedule_formatting(
src, src,
line_length, line_length,
fast, fast,
pyi,
py36,
write_back, write_back,
mode,
lock, lock,
): src ): src
for src in sorted(sources) for src in sorted(sources)
@ -374,16 +375,15 @@ async def schedule_formatting(
if cancelled: if cancelled:
await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
if write_back == WriteBack.YES and formatted: if write_back == WriteBack.YES and formatted:
write_cache(cache, formatted, line_length, pyi, py36) write_cache(cache, formatted, line_length, mode)
def format_file_in_place( def format_file_in_place(
src: Path, src: Path,
line_length: int, line_length: int,
fast: bool, fast: bool,
force_pyi: bool = False,
force_py36: bool = False,
write_back: WriteBack = WriteBack.NO, write_back: WriteBack = WriteBack.NO,
mode: FileMode = FileMode.AUTO_DETECT,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
) -> bool: ) -> bool:
"""Format file under `src` path. Return True if changed. """Format file under `src` path. Return True if changed.
@ -391,17 +391,13 @@ def format_file_in_place(
If `write_back` is True, write reformatted code back to stdout. If `write_back` is True, write reformatted code back to stdout.
`line_length` and `fast` options are passed to :func:`format_file_contents`. `line_length` and `fast` options are passed to :func:`format_file_contents`.
""" """
is_pyi = force_pyi or src.suffix == ".pyi" if src.suffix == ".pyi":
mode |= FileMode.PYI
with tokenize.open(src) as src_buffer: with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read() src_contents = src_buffer.read()
try: try:
dst_contents = format_file_contents( dst_contents = format_file_contents(
src_contents, src_contents, line_length=line_length, fast=fast, mode=mode
line_length=line_length,
fast=fast,
is_pyi=is_pyi,
force_py36=force_py36,
) )
except NothingChanged: except NothingChanged:
return False return False
@ -426,9 +422,8 @@ def format_file_in_place(
def format_stdin_to_stdout( def format_stdin_to_stdout(
line_length: int, line_length: int,
fast: bool, fast: bool,
is_pyi: bool = False,
force_py36: bool = False,
write_back: WriteBack = WriteBack.NO, write_back: WriteBack = WriteBack.NO,
mode: FileMode = FileMode.AUTO_DETECT,
) -> bool: ) -> bool:
"""Format file on stdin. Return True if changed. """Format file on stdin. Return True if changed.
@ -439,13 +434,7 @@ def format_stdin_to_stdout(
src = sys.stdin.read() src = sys.stdin.read()
dst = src dst = src
try: try:
dst = format_file_contents( dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
src,
line_length=line_length,
fast=fast,
is_pyi=is_pyi,
force_py36=force_py36,
)
return True return True
except NothingChanged: except NothingChanged:
@ -465,8 +454,7 @@ def format_file_contents(
*, *,
line_length: int, line_length: int,
fast: bool, fast: bool,
is_pyi: bool = False, mode: FileMode = FileMode.AUTO_DETECT,
force_py36: bool = False,
) -> FileContent: ) -> FileContent:
"""Reformat contents a file and return new contents. """Reformat contents a file and return new contents.
@ -477,30 +465,18 @@ def format_file_contents(
if src_contents.strip() == "": if src_contents.strip() == "":
raise NothingChanged raise NothingChanged
dst_contents = format_str( dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
src_contents, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
)
if src_contents == dst_contents: if src_contents == dst_contents:
raise NothingChanged raise NothingChanged
if not fast: if not fast:
assert_equivalent(src_contents, dst_contents) assert_equivalent(src_contents, dst_contents)
assert_stable( assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
src_contents,
dst_contents,
line_length=line_length,
is_pyi=is_pyi,
force_py36=force_py36,
)
return dst_contents return dst_contents
def format_str( def format_str(
src_contents: str, src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
line_length: int,
*,
is_pyi: bool = False,
force_py36: bool = False,
) -> FileContent: ) -> FileContent:
"""Reformat a string and return new contents. """Reformat a string and return new contents.
@ -509,11 +485,12 @@ def format_str(
src_node = lib2to3_parse(src_contents) src_node = lib2to3_parse(src_contents)
dst_contents = "" dst_contents = ""
future_imports = get_future_imports(src_node) future_imports = get_future_imports(src_node)
elt = EmptyLineTracker(is_pyi=is_pyi) is_pyi = bool(mode & FileMode.PYI)
py36 = force_py36 or is_python36(src_node) py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
lines = LineGenerator( lines = LineGenerator(
remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
) )
elt = EmptyLineTracker(is_pyi=is_pyi)
empty_line = Line() empty_line = Line()
after = 0 after = 0
for current_line in lines.visit(src_node): for current_line in lines.visit(src_node):
@ -2932,12 +2909,10 @@ def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
def assert_stable( def assert_stable(
src: str, dst: str, line_length: int, is_pyi: bool = False, force_py36: bool = False src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
) -> None: ) -> None:
"""Raise AssertionError if `dst` reformats differently the second time.""" """Raise AssertionError if `dst` reformats differently the second time."""
newdst = format_str( newdst = format_str(dst, line_length=line_length, mode=mode)
dst, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
)
if dst != newdst: if dst != newdst:
log = dump_to_file( log = dump_to_file(
diff(src, dst, "source", "first pass"), diff(src, dst, "source", "first pass"),
@ -3148,19 +3123,21 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
return False return False
def get_cache_file(line_length: int, pyi: bool = False, py36: bool = False) -> Path: def get_cache_file(line_length: int, mode: FileMode) -> Path:
pyi = bool(mode & FileMode.PYI)
py36 = bool(mode & FileMode.PYTHON36)
return ( return (
CACHE_DIR CACHE_DIR
/ f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle" / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
) )
def read_cache(line_length: int, pyi: bool = False, py36: bool = False) -> Cache: def read_cache(line_length: int, mode: FileMode) -> Cache:
"""Read the cache if it exists and is well formed. """Read the cache if it exists and is well formed.
If it is not well formed, the call to write_cache later should resolve the issue. If it is not well formed, the call to write_cache later should resolve the issue.
""" """
cache_file = get_cache_file(line_length, pyi, py36) cache_file = get_cache_file(line_length, mode)
if not cache_file.exists(): if not cache_file.exists():
return {} return {}
@ -3198,14 +3175,10 @@ def filter_cached(
def write_cache( def write_cache(
cache: Cache, cache: Cache, sources: List[Path], line_length: int, mode: FileMode
sources: List[Path],
line_length: int,
pyi: bool = False,
py36: bool = False,
) -> None: ) -> None:
"""Update the cache file.""" """Update the cache file."""
cache_file = get_cache_file(line_length, pyi, py36) cache_file = get_cache_file(line_length, mode)
try: try:
if not CACHE_DIR.exists(): if not CACHE_DIR.exists():
CACHE_DIR.mkdir(parents=True) CACHE_DIR.mkdir(parents=True)

View File

@ -342,10 +342,11 @@ def test_python2_unicode_literals(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_stub(self) -> None: def test_stub(self) -> None:
mode = black.FileMode.PYI
source, expected = read_data("stub.pyi") source, expected = read_data("stub.pyi")
actual = fs(source, is_pyi=True) actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll, is_pyi=True) black.assert_stable(source, actual, line_length=ll, mode=mode)
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff(self) -> None: def test_fmtonoff(self) -> None:
@ -566,25 +567,27 @@ def err(msg: str, **kwargs: Any) -> None:
self.assertEqual("".join(err_lines), "") self.assertEqual("".join(err_lines), "")
def test_cache_broken_file(self) -> None: def test_cache_broken_file(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir() as workspace: with cache_dir() as workspace:
cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH) cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
with cache_file.open("w") as fobj: with cache_file.open("w") as fobj:
fobj.write("this is not a pickle") fobj.write("this is not a pickle")
self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {}) self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
src = (workspace / "test.py").resolve() src = (workspace / "test.py").resolve()
with src.open("w") as fobj: with src.open("w") as fobj:
fobj.write("print('hello')") fobj.write("print('hello')")
result = CliRunner().invoke(black.main, [str(src)]) result = CliRunner().invoke(black.main, [str(src)])
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
cache = black.read_cache(black.DEFAULT_LINE_LENGTH) cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
self.assertIn(src, cache) self.assertIn(src, cache)
def test_cache_single_file_already_cached(self) -> None: def test_cache_single_file_already_cached(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir() as workspace: with cache_dir() as workspace:
src = (workspace / "test.py").resolve() src = (workspace / "test.py").resolve()
with src.open("w") as fobj: with src.open("w") as fobj:
fobj.write("print('hello')") fobj.write("print('hello')")
black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH) black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
result = CliRunner().invoke(black.main, [str(src)]) result = CliRunner().invoke(black.main, [str(src)])
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
with src.open("r") as fobj: with src.open("r") as fobj:
@ -592,6 +595,7 @@ def test_cache_single_file_already_cached(self) -> None:
@event_loop(close=False) @event_loop(close=False)
def test_cache_multiple_files(self) -> None: def test_cache_multiple_files(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir() as workspace, patch( with cache_dir() as workspace, patch(
"black.ProcessPoolExecutor", new=ThreadPoolExecutor "black.ProcessPoolExecutor", new=ThreadPoolExecutor
): ):
@ -601,44 +605,48 @@ def test_cache_multiple_files(self) -> None:
two = (workspace / "two.py").resolve() two = (workspace / "two.py").resolve()
with two.open("w") as fobj: with two.open("w") as fobj:
fobj.write("print('hello')") fobj.write("print('hello')")
black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH) black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
result = CliRunner().invoke(black.main, [str(workspace)]) result = CliRunner().invoke(black.main, [str(workspace)])
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
with one.open("r") as fobj: with one.open("r") as fobj:
self.assertEqual(fobj.read(), "print('hello')") self.assertEqual(fobj.read(), "print('hello')")
with two.open("r") as fobj: with two.open("r") as fobj:
self.assertEqual(fobj.read(), 'print("hello")\n') self.assertEqual(fobj.read(), 'print("hello")\n')
cache = black.read_cache(black.DEFAULT_LINE_LENGTH) cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
self.assertIn(one, cache) self.assertIn(one, cache)
self.assertIn(two, cache) self.assertIn(two, cache)
def test_no_cache_when_writeback_diff(self) -> None: def test_no_cache_when_writeback_diff(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir() as workspace: with cache_dir() as workspace:
src = (workspace / "test.py").resolve() src = (workspace / "test.py").resolve()
with src.open("w") as fobj: with src.open("w") as fobj:
fobj.write("print('hello')") fobj.write("print('hello')")
result = CliRunner().invoke(black.main, [str(src), "--diff"]) result = CliRunner().invoke(black.main, [str(src), "--diff"])
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH) cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
self.assertFalse(cache_file.exists()) self.assertFalse(cache_file.exists())
def test_no_cache_when_stdin(self) -> None: def test_no_cache_when_stdin(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir(): with cache_dir():
result = CliRunner().invoke(black.main, ["-"], input="print('hello')") result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH) cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
self.assertFalse(cache_file.exists()) self.assertFalse(cache_file.exists())
def test_read_cache_no_cachefile(self) -> None: def test_read_cache_no_cachefile(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir(): with cache_dir():
self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {}) self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
def test_write_cache_read_cache(self) -> None: def test_write_cache_read_cache(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir() as workspace: with cache_dir() as workspace:
src = (workspace / "test.py").resolve() src = (workspace / "test.py").resolve()
src.touch() src.touch()
black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH) black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
cache = black.read_cache(black.DEFAULT_LINE_LENGTH) cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
self.assertIn(src, cache) self.assertIn(src, cache)
self.assertEqual(cache[src], black.get_cache_info(src)) self.assertEqual(cache[src], black.get_cache_info(src))
@ -659,13 +667,15 @@ def test_filter_cached(self) -> None:
self.assertEqual(done, [cached]) self.assertEqual(done, [cached])
def test_write_cache_creates_directory_if_needed(self) -> None: def test_write_cache_creates_directory_if_needed(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir(exists=False) as workspace: with cache_dir(exists=False) as workspace:
self.assertFalse(workspace.exists()) self.assertFalse(workspace.exists())
black.write_cache({}, [], black.DEFAULT_LINE_LENGTH) black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
self.assertTrue(workspace.exists()) self.assertTrue(workspace.exists())
@event_loop(close=False) @event_loop(close=False)
def test_failed_formatting_does_not_get_cached(self) -> None: def test_failed_formatting_does_not_get_cached(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir() as workspace, patch( with cache_dir() as workspace, patch(
"black.ProcessPoolExecutor", new=ThreadPoolExecutor "black.ProcessPoolExecutor", new=ThreadPoolExecutor
): ):
@ -677,14 +687,15 @@ def test_failed_formatting_does_not_get_cached(self) -> None:
fobj.write('print("hello")\n') fobj.write('print("hello")\n')
result = CliRunner().invoke(black.main, [str(workspace)]) result = CliRunner().invoke(black.main, [str(workspace)])
self.assertEqual(result.exit_code, 123) self.assertEqual(result.exit_code, 123)
cache = black.read_cache(black.DEFAULT_LINE_LENGTH) cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
self.assertNotIn(failing, cache) self.assertNotIn(failing, cache)
self.assertIn(clean, cache) self.assertIn(clean, cache)
def test_write_cache_write_fail(self) -> None: def test_write_cache_write_fail(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir(), patch.object(Path, "open") as mock: with cache_dir(), patch.object(Path, "open") as mock:
mock.side_effect = OSError mock.side_effect = OSError
black.write_cache({}, [], black.DEFAULT_LINE_LENGTH) black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
@event_loop(close=False) @event_loop(close=False)
def test_check_diff_use_together(self) -> None: def test_check_diff_use_together(self) -> None:
@ -719,16 +730,19 @@ def test_broken_symlink(self) -> None:
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
def test_read_cache_line_lengths(self) -> None: def test_read_cache_line_lengths(self) -> None:
mode = black.FileMode.AUTO_DETECT
with cache_dir() as workspace: with cache_dir() as workspace:
path = (workspace / "file.py").resolve() path = (workspace / "file.py").resolve()
path.touch() path.touch()
black.write_cache({}, [path], 1) black.write_cache({}, [path], 1, mode)
one = black.read_cache(1) one = black.read_cache(1, mode)
self.assertIn(path, one) self.assertIn(path, one)
two = black.read_cache(2) two = black.read_cache(2, mode)
self.assertNotIn(path, two) self.assertNotIn(path, two)
def test_single_file_force_pyi(self) -> None: def test_single_file_force_pyi(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT
pyi_mode = black.FileMode.PYI
contents, expected = read_data("force_pyi") contents, expected = read_data("force_pyi")
with cache_dir() as workspace: with cache_dir() as workspace:
path = (workspace / "file.py").resolve() path = (workspace / "file.py").resolve()
@ -739,14 +753,16 @@ def test_single_file_force_pyi(self) -> None:
with open(path, "r") as fh: with open(path, "r") as fh:
actual = fh.read() actual = fh.read()
# verify cache with --pyi is separate # verify cache with --pyi is separate
pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi=True) pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
self.assertIn(path, pyi_cache) self.assertIn(path, pyi_cache)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH) normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
self.assertNotIn(path, normal_cache) self.assertNotIn(path, normal_cache)
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
@event_loop(close=False) @event_loop(close=False)
def test_multi_file_force_pyi(self) -> None: def test_multi_file_force_pyi(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT
pyi_mode = black.FileMode.PYI
contents, expected = read_data("force_pyi") contents, expected = read_data("force_pyi")
with cache_dir() as workspace: with cache_dir() as workspace:
paths = [ paths = [
@ -763,8 +779,8 @@ def test_multi_file_force_pyi(self) -> None:
actual = fh.read() actual = fh.read()
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
# verify cache with --pyi is separate # verify cache with --pyi is separate
pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi=True) pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH) normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
for path in paths: for path in paths:
self.assertIn(path, pyi_cache) self.assertIn(path, pyi_cache)
self.assertNotIn(path, normal_cache) self.assertNotIn(path, normal_cache)
@ -777,6 +793,8 @@ def test_pipe_force_pyi(self) -> None:
self.assertFormatEqual(actual, expected) self.assertFormatEqual(actual, expected)
def test_single_file_force_py36(self) -> None: def test_single_file_force_py36(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT
py36_mode = black.FileMode.PYTHON36
source, expected = read_data("force_py36") source, expected = read_data("force_py36")
with cache_dir() as workspace: with cache_dir() as workspace:
path = (workspace / "file.py").resolve() path = (workspace / "file.py").resolve()
@ -787,14 +805,16 @@ def test_single_file_force_py36(self) -> None:
with open(path, "r") as fh: with open(path, "r") as fh:
actual = fh.read() actual = fh.read()
# verify cache with --py36 is separate # verify cache with --py36 is separate
py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36=True) py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
self.assertIn(path, py36_cache) self.assertIn(path, py36_cache)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH) normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
self.assertNotIn(path, normal_cache) self.assertNotIn(path, normal_cache)
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
@event_loop(close=False) @event_loop(close=False)
def test_multi_file_force_py36(self) -> None: def test_multi_file_force_py36(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT
py36_mode = black.FileMode.PYTHON36
source, expected = read_data("force_py36") source, expected = read_data("force_py36")
with cache_dir() as workspace: with cache_dir() as workspace:
paths = [ paths = [
@ -813,8 +833,8 @@ def test_multi_file_force_py36(self) -> None:
actual = fh.read() actual = fh.read()
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
# verify cache with --py36 is separate # verify cache with --py36 is separate
pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36=True) pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH) normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
for path in paths: for path in paths:
self.assertIn(path, pyi_cache) self.assertIn(path, pyi_cache)
self.assertNotIn(path, normal_cache) self.assertNotIn(path, normal_cache)