Support --diff for both files and stdin

Fixes #87
This commit is contained in:
Łukasz Langa 2018-03-31 02:24:01 -07:00
parent a42aef7806
commit a20a3eeb0f
5 changed files with 371 additions and 22 deletions

View File

@ -1,3 +1,3 @@
include *.rst *.md LICENSE
recursive-include blib2to3 *.txt *.py
recursive-include tests *.txt *.out *.py
recursive-include tests *.txt *.out *.diff *.py

View File

@ -53,11 +53,13 @@ black [OPTIONS] [SRC]...
Options:
-l, --line-length INTEGER Where to wrap around. [default: 88]
--check Don't write back the files, just return the
--check Don't write the files back, just return the
status. Return code 0 means nothing would
change. Return code 1 means some files would be
reformatted. Return code 123 means there was an
internal error.
--diff Don't write the files back, just output a diff
for each file on stdout.
--fast / --safe If --fast given, skip temporary sanity checks.
[default: --safe]
--version Show the version and exit.
@ -394,8 +396,10 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md).
### 18.3a5 (unreleased)
* added `--diff` (#87)
* add line breaks before all delimiters, except in cases like commas, to better
comply with PEP8 (#73)
comply with PEP 8 (#73)
* fixed handling of standalone comments within nested bracketed
expressions; Black will no longer produce super long lines or put all

View File

@ -3,15 +3,18 @@
import asyncio
from asyncio.base_events import BaseEventLoop
from concurrent.futures import Executor, ProcessPoolExecutor
from enum import Enum
from functools import partial, wraps
import keyword
import logging
from multiprocessing import Manager
import os
from pathlib import Path
import tokenize
import signal
import sys
from typing import (
Any,
Callable,
Dict,
Generic,
@ -92,6 +95,12 @@ class FormatOff(FormatError):
"""Found a comment like `# fmt: off` in the file."""
class WriteBack(Enum):
NO = 0
YES = 1
DIFF = 2
@click.command()
@click.option(
"-l",
@ -105,11 +114,16 @@ class FormatOff(FormatError):
"--check",
is_flag=True,
help=(
"Don't write back the files, just return the status. Return code 0 "
"Don't write the files back, just return the status. Return code 0 "
"means nothing would change. Return code 1 means some files would be "
"reformatted. Return code 123 means there was an internal error."
),
)
@click.option(
"--diff",
is_flag=True,
help="Don't write the files back, just output a diff for each file on stdout.",
)
@click.option(
"--fast/--safe",
is_flag=True,
@ -125,7 +139,12 @@ class FormatOff(FormatError):
)
@click.pass_context
def main(
ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
ctx: click.Context,
line_length: int,
check: bool,
diff: bool,
fast: bool,
src: List[str],
) -> None:
"""The uncompromising code formatter."""
sources: List[Path] = []
@ -140,6 +159,17 @@ def main(
sources.append(Path("-"))
else:
err(f"invalid path: {s}")
if check and diff:
exc = click.ClickException("Options --check and --diff are mutually exclusive")
exc.exit_code = 2
raise exc
if check:
write_back = WriteBack.NO
elif diff:
write_back = WriteBack.DIFF
else:
write_back = WriteBack.YES
if len(sources) == 0:
ctx.exit(0)
elif len(sources) == 1:
@ -148,11 +178,11 @@ def main(
try:
if not p.is_file() and str(p) == "-":
changed = format_stdin_to_stdout(
line_length=line_length, fast=fast, write_back=not check
line_length=line_length, fast=fast, write_back=write_back
)
else:
changed = format_file_in_place(
p, line_length=line_length, fast=fast, write_back=not check
p, line_length=line_length, fast=fast, write_back=write_back
)
report.done(p, changed)
except Exception as exc:
@ -165,7 +195,7 @@ def main(
try:
return_code = loop.run_until_complete(
schedule_formatting(
sources, line_length, not check, fast, loop, executor
sources, line_length, write_back, fast, loop, executor
)
)
finally:
@ -176,7 +206,7 @@ def main(
async def schedule_formatting(
sources: List[Path],
line_length: int,
write_back: bool,
write_back: WriteBack,
fast: bool,
loop: BaseEventLoop,
executor: Executor,
@ -188,9 +218,15 @@ async def schedule_formatting(
`line_length`, `write_back`, and `fast` options are passed to
:func:`format_file_in_place`.
"""
lock = None
if write_back == WriteBack.DIFF:
# For diff output, we need locks to ensure we don't interleave output
# from different processes.
manager = Manager()
lock = manager.Lock()
tasks = {
src: loop.run_in_executor(
executor, format_file_in_place, src, line_length, fast, write_back
executor, format_file_in_place, src, line_length, fast, write_back, lock
)
for src in sources
}
@ -220,7 +256,11 @@ async def schedule_formatting(
def format_file_in_place(
src: Path, line_length: int, fast: bool, write_back: bool = False
src: Path,
line_length: int,
fast: bool,
write_back: WriteBack = WriteBack.NO,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
) -> bool:
"""Format file under `src` path. Return True if changed.
@ -230,37 +270,53 @@ def format_file_in_place(
with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read()
try:
contents = format_file_contents(
dst_contents = format_file_contents(
src_contents, line_length=line_length, fast=fast
)
except NothingChanged:
return False
if write_back:
if write_back == write_back.YES:
with open(src, "w", encoding=src_buffer.encoding) as f:
f.write(contents)
f.write(dst_contents)
elif write_back == write_back.DIFF:
src_name = f"{src.name} (original)"
dst_name = f"{src.name} (formatted)"
diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
if lock:
lock.acquire()
try:
sys.stdout.write(diff_contents)
finally:
if lock:
lock.release()
return True
def format_stdin_to_stdout(
line_length: int, fast: bool, write_back: bool = False
line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
) -> bool:
"""Format file on stdin. Return True if changed.
If `write_back` is True, write reformatted code back to stdout.
`line_length` and `fast` arguments are passed to :func:`format_file_contents`.
"""
contents = sys.stdin.read()
src = sys.stdin.read()
try:
contents = format_file_contents(contents, line_length=line_length, fast=fast)
dst = format_file_contents(src, line_length=line_length, fast=fast)
return True
except NothingChanged:
dst = src
return False
finally:
if write_back:
sys.stdout.write(contents)
if write_back == WriteBack.YES:
sys.stdout.write(dst)
elif write_back == WriteBack.DIFF:
src_name = "<stdin> (original)"
dst_name = "<stdin> (formatted)"
sys.stdout.write(diff(src, dst, src_name, dst_name))
def format_file_contents(
@ -2064,7 +2120,8 @@ def dump_to_file(*output: str) -> str:
) as f:
for lines in output:
f.write(lines)
f.write("\n")
if lines and lines[-1] != "\n":
f.write("\n")
return f.name

238
tests/expression.diff Normal file
View File

@ -0,0 +1,238 @@
--- <stdin> (original)
+++ <stdin> (formatted)
@@ -1,8 +1,8 @@
...
-'some_string'
-b'\\xa3'
+"some_string"
+b"\\xa3"
Name
None
True
False
1
@@ -29,65 +29,74 @@
~great
+value
-1
~int and not v1 ^ 123 + v2 | True
(~int) and (not ((v1 ^ (123 + v2)) | True))
-flags & ~ select.EPOLLIN and waiters.write_task is not None
+flags & ~select.EPOLLIN and waiters.write_task is not None
lambda arg: None
lambda a=True: a
lambda a, b, c=True: a
-lambda a, b, c=True, *, d=(1 << v2), e='str': a
-lambda a, b, c=True, *vararg, d=(v1 << 2), e='str', **kwargs: a + b
+lambda a, b, c=True, *, d=(1 << v2), e="str": a
+lambda a, b, c=True, *vararg, d=(v1 << 2), e="str", **kwargs: a + b
1 if True else 2
str or None if True else str or bytes or None
(str or None) if True else (str or bytes or None)
str or None if (1 if True else 2) else str or bytes or None
(str or None) if (1 if True else 2) else (str or bytes or None)
-{'2.7': dead, '3.7': (long_live or die_hard)}
-{'2.7': dead, '3.7': (long_live or die_hard), **{'3.6': verygood}}
+{"2.7": dead, "3.7": (long_live or die_hard)}
+{"2.7": dead, "3.7": (long_live or die_hard), **{"3.6": verygood}}
{**a, **b, **c}
-{'2.7', '3.6', '3.7', '3.8', '3.9', ('4.0' if gilectomy else '3.10')}
-({'a': 'b'}, (True or False), (+value), 'string', b'bytes') or None
+{"2.7", "3.6", "3.7", "3.8", "3.9", ("4.0" if gilectomy else "3.10")}
+({"a": "b"}, (True or False), (+value), "string", b"bytes") or None
()
(1,)
(1, 2)
(1, 2, 3)
[]
[1, 2, 3, 4, 5, 6, 7, 8, 9, (10 or A), (11 or B), (12 or C)]
-[1, 2, 3,]
+[1, 2, 3]
{i for i in (1, 2, 3)}
{(i ** 2) for i in (1, 2, 3)}
-{(i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))}
+{(i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))}
{((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)}
[i for i in (1, 2, 3)]
[(i ** 2) for i in (1, 2, 3)]
-[(i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))]
+[(i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))]
[((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)]
{i: 0 for i in (1, 2, 3)}
-{i: j for i, j in ((1, 'a'), (2, 'b'), (3, 'c'))}
+{i: j for i, j in ((1, "a"), (2, "b"), (3, "c"))}
{a: b * 2 for a, b in dictionary.items()}
{a: b * -2 for a, b in dictionary.items()}
-{k: v for k, v in this_is_a_very_long_variable_which_will_cause_a_trailing_comma_which_breaks_the_comprehension}
+{
+ k: v
+ for k, v in this_is_a_very_long_variable_which_will_cause_a_trailing_comma_which_breaks_the_comprehension
+}
Python3 > Python2 > COBOL
Life is Life
call()
call(arg)
-call(kwarg='hey')
-call(arg, kwarg='hey')
-call(arg, another, kwarg='hey', **kwargs)
-call(this_is_a_very_long_variable_which_will_force_a_delimiter_split, arg, another, kwarg='hey', **kwargs) # note: no trailing comma pre-3.6
+call(kwarg="hey")
+call(arg, kwarg="hey")
+call(arg, another, kwarg="hey", **kwargs)
+call(
+ this_is_a_very_long_variable_which_will_force_a_delimiter_split,
+ arg,
+ another,
+ kwarg="hey",
+ **kwargs
+) # note: no trailing comma pre-3.6
call(*gidgets[:2])
call(**self.screen_kwargs)
lukasz.langa.pl
call.me(maybe)
1 .real
1.0 .real
....__class__
list[str]
dict[str, int]
tuple[str, ...]
-tuple[str, int, float, dict[str, int],]
+tuple[str, int, float, dict[str, int]]
very_long_variable_name_filters: t.List[
t.Tuple[str, t.Union[str, t.List[t.Optional[str]]]],
]
slice[0]
slice[0:1]
@@ -114,71 +123,90 @@
numpy[-(c + 1):, d]
numpy[:, l[-2]]
numpy[:, ::-1]
numpy[np.newaxis, :]
(str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None)
-{'2.7': dead, '3.7': long_live or die_hard}
-{'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'}
+{"2.7": dead, "3.7": long_live or die_hard}
+{"2.7", "3.6", "3.7", "3.8", "3.9", "4.0" if gilectomy else "3.10"}
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C]
(SomeName)
SomeName
(Good, Bad, Ugly)
(i for i in (1, 2, 3))
((i ** 2) for i in (1, 2, 3))
-((i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c')))
+((i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c")))
(((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3))
(*starred)
a = (1,)
b = 1,
c = 1
d = (1,) + a + (2,)
e = (1,).count(1)
-what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set(vars_to_remove)
-what_is_up_with_those_new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove)
-result = session.query(models.Customer.id).filter(models.Customer.account_id == account_id, models.Customer.email == email_address).order_by(models.Customer.id.asc(),).all()
+what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set(
+ vars_to_remove
+)
+what_is_up_with_those_new_coord_names = (coord_names | set(vars_to_create)) - set(
+ vars_to_remove
+)
+result = session.query(models.Customer.id).filter(
+ models.Customer.account_id == account_id, models.Customer.email == email_address
+).order_by(
+ models.Customer.id.asc()
+).all()
+
def gen():
yield from outside_of_generator
+
a = (yield)
+
async def f():
await some.complicated[0].call(with_args=(True or (1 is not 1)))
-if (
- threading.current_thread() != threading.main_thread() and
- threading.current_thread() != threading.main_thread() or
- signal.getsignal(signal.SIGINT) != signal.default_int_handler
-):
- return True
-if (
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa |
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
-):
- return True
-if (
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa &
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
-):
- return True
-if (
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
-):
- return True
-if (
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
-):
- return True
-if (
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa *
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
-):
- return True
-if (
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa /
- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
-):
- return True
+
+if (
+ threading.current_thread() != threading.main_thread()
+ and threading.current_thread() != threading.main_thread()
+ or signal.getsignal(signal.SIGINT) != signal.default_int_handler
+):
+ return True
+
+if (
+ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+ | aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+):
+ return True
+
+if (
+ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+ & aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+):
+ return True
+
+if (
+ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+ + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+):
+ return True
+
+if (
+ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+ - aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+):
+ return True
+
+if (
+ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+ * aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+):
+ return True
+
+if (
+ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+ / aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+):
+ return True
+
last_call()
# standalone comment at ENDMARKER

View File

@ -26,7 +26,7 @@ def dump_to_stderr(*output: str) -> str:
def read_data(name: str) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'"""
if not name.endswith((".py", ".out")):
if not name.endswith((".py", ".out", ".diff")):
name += ".py"
_input: List[str] = []
_output: List[str] = []
@ -92,7 +92,9 @@ def test_piping(self) -> None:
try:
sys.stdin, sys.stdout = StringIO(source), StringIO()
sys.stdin.name = "<stdin>"
black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True)
black.format_stdin_to_stdout(
line_length=ll, fast=True, write_back=black.WriteBack.YES
)
sys.stdout.seek(0)
actual = sys.stdout.read()
finally:
@ -101,6 +103,23 @@ def test_piping(self) -> None:
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
def test_piping_diff(self) -> None:
source, _ = read_data("expression.py")
expected, _ = read_data("expression.diff")
hold_stdin, hold_stdout = sys.stdin, sys.stdout
try:
sys.stdin, sys.stdout = StringIO(source), StringIO()
sys.stdin.name = "<stdin>"
black.format_stdin_to_stdout(
line_length=ll, fast=True, write_back=black.WriteBack.DIFF
)
sys.stdout.seek(0)
actual = sys.stdout.read()
finally:
sys.stdin, sys.stdout = hold_stdin, hold_stdout
actual = actual.rstrip() + "\n" # the diff output has a trailing space
self.assertEqual(expected, actual)
@patch("black.dump_to_file", dump_to_stderr)
def test_setup(self) -> None:
source, expected = read_data("../setup")
@ -126,6 +145,37 @@ def test_expression(self) -> None:
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
def test_expression_ff(self) -> None:
source, expected = read_data("expression")
tmp_file = Path(black.dump_to_file(source))
try:
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
with open(tmp_file) as f:
actual = f.read()
finally:
os.unlink(tmp_file)
self.assertFormatEqual(expected, actual)
with patch("black.dump_to_file", dump_to_stderr):
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
def test_expression_diff(self) -> None:
source, _ = read_data("expression.py")
expected, _ = read_data("expression.diff")
tmp_file = Path(black.dump_to_file(source))
hold_stdout = sys.stdout
try:
sys.stdout = StringIO()
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
sys.stdout.seek(0)
actual = sys.stdout.read()
actual = actual.replace(tmp_file.name, "<stdin>")
finally:
sys.stdout = hold_stdout
os.unlink(tmp_file)
actual = actual.rstrip() + "\n" # the diff output has a trailing space
self.assertEqual(expected, actual)
@patch("black.dump_to_file", dump_to_stderr)
def test_fstring(self) -> None:
source, expected = read_data("fstring")