Preserve line endings when formatting a file in place (#288)

This commit is contained in:
Zsolt Dollenstein 2018-06-05 00:52:06 +02:00 committed by Łukasz Langa
parent dbe26161fa
commit 00a302560b
4 changed files with 68 additions and 15 deletions

View File

@ -720,6 +720,8 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md).
* fixed stdin handling not working correctly if an old version of Click was * fixed stdin handling not working correctly if an old version of Click was
used (#276) used (#276)
* *Black* now preserves line endings when formatting a file in place (#258)
### 18.5b1 ### 18.5b1

View File

@ -4,6 +4,7 @@
from concurrent.futures import Executor, ProcessPoolExecutor from concurrent.futures import Executor, ProcessPoolExecutor
from enum import Enum, Flag from enum import Enum, Flag
from functools import partial, wraps from functools import partial, wraps
import io
import keyword import keyword
import logging import logging
from multiprocessing import Manager from multiprocessing import Manager
@ -465,8 +466,9 @@ def format_file_in_place(
""" """
if src.suffix == ".pyi": if src.suffix == ".pyi":
mode |= FileMode.PYI mode |= FileMode.PYI
with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read() with open(src, "rb") as buf:
newline, encoding, src_contents = prepare_input(buf.read())
try: try:
dst_contents = format_file_contents( dst_contents = format_file_contents(
src_contents, line_length=line_length, fast=fast, mode=mode src_contents, line_length=line_length, fast=fast, mode=mode
@ -475,7 +477,7 @@ def format_file_in_place(
return False return False
if write_back == write_back.YES: if write_back == write_back.YES:
with open(src, "w", encoding=src_buffer.encoding) as f: with open(src, "w", encoding=encoding, newline=newline) as f:
f.write(dst_contents) f.write(dst_contents)
elif write_back == write_back.DIFF: elif write_back == write_back.DIFF:
src_name = f"{src} (original)" src_name = f"{src} (original)"
@ -484,7 +486,14 @@ def format_file_in_place(
if lock: if lock:
lock.acquire() lock.acquire()
try: try:
sys.stdout.write(diff_contents) f = io.TextIOWrapper(
sys.stdout.buffer,
encoding=encoding,
newline=newline,
write_through=True,
)
f.write(diff_contents)
f.detach()
finally: finally:
if lock: if lock:
lock.release() lock.release()
@ -503,7 +512,7 @@ def format_stdin_to_stdout(
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
:func:`format_file_contents`. :func:`format_file_contents`.
""" """
src = sys.stdin.read() newline, encoding, src = prepare_input(sys.stdin.buffer.read())
dst = src dst = src
try: try:
dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode) dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
@ -514,11 +523,25 @@ def format_stdin_to_stdout(
finally: finally:
if write_back == WriteBack.YES: if write_back == WriteBack.YES:
sys.stdout.write(dst) f = io.TextIOWrapper(
sys.stdout.buffer,
encoding=encoding,
newline=newline,
write_through=True,
)
f.write(dst)
f.detach()
elif write_back == WriteBack.DIFF: elif write_back == WriteBack.DIFF:
src_name = "<stdin> (original)" src_name = "<stdin> (original)"
dst_name = "<stdin> (formatted)" dst_name = "<stdin> (formatted)"
sys.stdout.write(diff(src, dst, src_name, dst_name)) f = io.TextIOWrapper(
sys.stdout.buffer,
encoding=encoding,
newline=newline,
write_through=True,
)
f.write(diff(src, dst, src_name, dst_name))
f.detach()
def format_file_contents( def format_file_contents(
@ -579,6 +602,19 @@ def format_str(
return dst_contents return dst_contents
def prepare_input(src: bytes) -> Tuple[str, str, str]:
"""Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
universal newlines (i.e. only LF).
"""
srcbuf = io.BytesIO(src)
encoding, lines = tokenize.detect_encoding(srcbuf.readline)
newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
srcbuf.seek(0)
return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
GRAMMARS = [ GRAMMARS = [
pygram.python_grammar_no_print_statement_no_exec_statement, pygram.python_grammar_no_print_statement_no_exec_statement,
pygram.python_grammar_no_print_statement, pygram.python_grammar_no_print_statement,
@ -590,8 +626,7 @@ def lib2to3_parse(src_txt: str) -> Node:
"""Given a string with source, return the lib2to3 Node.""" """Given a string with source, return the lib2to3 Node."""
grammar = pygram.python_grammar_no_print_statement grammar = pygram.python_grammar_no_print_statement
if src_txt[-1] != "\n": if src_txt[-1] != "\n":
nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n" src_txt += "\n"
src_txt += nl
for grammar in GRAMMARS: for grammar in GRAMMARS:
drv = driver.Driver(grammar, pytree.convert) drv = driver.Driver(grammar, pytree.convert)
try: try:

View File

@ -61,6 +61,8 @@ Parsing
.. autofunction:: black.lib2to3_unparse .. autofunction:: black.lib2to3_unparse
.. autofunction:: black.prepare_input
Split functions Split functions
--------------- ---------------

View File

@ -3,7 +3,7 @@
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from io import StringIO from io import BytesIO, TextIOWrapper
import os import os
from pathlib import Path from pathlib import Path
import sys import sys
@ -121,8 +121,9 @@ def test_piping(self) -> None:
source, expected = read_data("../black") source, expected = read_data("../black")
hold_stdin, hold_stdout = sys.stdin, sys.stdout hold_stdin, hold_stdout = sys.stdin, sys.stdout
try: try:
sys.stdin, sys.stdout = StringIO(source), StringIO() sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
sys.stdin.name = "<stdin>" sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
sys.stdin.buffer.name = "<stdin>" # type: ignore
black.format_stdin_to_stdout( black.format_stdin_to_stdout(
line_length=ll, fast=True, write_back=black.WriteBack.YES line_length=ll, fast=True, write_back=black.WriteBack.YES
) )
@ -139,8 +140,9 @@ def test_piping_diff(self) -> None:
expected, _ = read_data("expression.diff") expected, _ = read_data("expression.diff")
hold_stdin, hold_stdout = sys.stdin, sys.stdout hold_stdin, hold_stdout = sys.stdin, sys.stdout
try: try:
sys.stdin, sys.stdout = StringIO(source), StringIO() sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
sys.stdin.name = "<stdin>" sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
sys.stdin.buffer.name = "<stdin>" # type: ignore
black.format_stdin_to_stdout( black.format_stdin_to_stdout(
line_length=ll, fast=True, write_back=black.WriteBack.DIFF line_length=ll, fast=True, write_back=black.WriteBack.DIFF
) )
@ -204,7 +206,7 @@ def test_expression_diff(self) -> None:
tmp_file = Path(black.dump_to_file(source)) tmp_file = Path(black.dump_to_file(source))
hold_stdout = sys.stdout hold_stdout = sys.stdout
try: try:
sys.stdout = StringIO() sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF)) self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
sys.stdout.seek(0) sys.stdout.seek(0)
actual = sys.stdout.read() actual = sys.stdout.read()
@ -1108,6 +1110,18 @@ def test_invalid_include_exclude(self) -> None:
result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"]) result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
self.assertEqual(result.exit_code, 2) self.assertEqual(result.exit_code, 2)
def test_preserves_line_endings(self) -> None:
with TemporaryDirectory() as workspace:
test_file = Path(workspace) / "test.py"
for nl in ["\n", "\r\n"]:
contents = nl.join(["def f( ):", " pass"])
test_file.write_bytes(contents.encode())
ff(test_file, write_back=black.WriteBack.YES)
updated_contents: bytes = test_file.read_bytes()
self.assertIn(nl.encode(), updated_contents) # type: ignore
if nl == "\n":
self.assertNotIn(b"\r\n", updated_contents) # type: ignore
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()