Preserve line endings when formatting a file in place (#288)

This commit is contained in:
Zsolt Dollenstein 2018-06-05 00:52:06 +02:00 committed by Łukasz Langa
parent dbe26161fa
commit 00a302560b
4 changed files with 68 additions and 15 deletions

View File

@ -720,6 +720,8 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md).
* fixed stdin handling not working correctly if an old version of Click was
used (#276)
* *Black* now preserves line endings when formatting a file in place (#258)
### 18.5b1

View File

@ -4,6 +4,7 @@
from concurrent.futures import Executor, ProcessPoolExecutor
from enum import Enum, Flag
from functools import partial, wraps
import io
import keyword
import logging
from multiprocessing import Manager
@ -465,8 +466,9 @@ def format_file_in_place(
"""
if src.suffix == ".pyi":
mode |= FileMode.PYI
with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read()
with open(src, "rb") as buf:
newline, encoding, src_contents = prepare_input(buf.read())
try:
dst_contents = format_file_contents(
src_contents, line_length=line_length, fast=fast, mode=mode
@ -475,7 +477,7 @@ def format_file_in_place(
return False
if write_back == write_back.YES:
with open(src, "w", encoding=src_buffer.encoding) as f:
with open(src, "w", encoding=encoding, newline=newline) as f:
f.write(dst_contents)
elif write_back == write_back.DIFF:
src_name = f"{src} (original)"
@ -484,7 +486,14 @@ def format_file_in_place(
if lock:
lock.acquire()
try:
sys.stdout.write(diff_contents)
f = io.TextIOWrapper(
sys.stdout.buffer,
encoding=encoding,
newline=newline,
write_through=True,
)
f.write(diff_contents)
f.detach()
finally:
if lock:
lock.release()
@ -503,7 +512,7 @@ def format_stdin_to_stdout(
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
:func:`format_file_contents`.
"""
src = sys.stdin.read()
newline, encoding, src = prepare_input(sys.stdin.buffer.read())
dst = src
try:
dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
@ -514,11 +523,25 @@ def format_stdin_to_stdout(
finally:
if write_back == WriteBack.YES:
sys.stdout.write(dst)
f = io.TextIOWrapper(
sys.stdout.buffer,
encoding=encoding,
newline=newline,
write_through=True,
)
f.write(dst)
f.detach()
elif write_back == WriteBack.DIFF:
src_name = "<stdin> (original)"
dst_name = "<stdin> (formatted)"
sys.stdout.write(diff(src, dst, src_name, dst_name))
f = io.TextIOWrapper(
sys.stdout.buffer,
encoding=encoding,
newline=newline,
write_through=True,
)
f.write(diff(src, dst, src_name, dst_name))
f.detach()
def format_file_contents(
@ -579,6 +602,19 @@ def format_str(
return dst_contents
def prepare_input(src: bytes) -> Tuple[str, str, str]:
"""Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
universal newlines (i.e. only LF).
"""
srcbuf = io.BytesIO(src)
encoding, lines = tokenize.detect_encoding(srcbuf.readline)
newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
srcbuf.seek(0)
return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
GRAMMARS = [
pygram.python_grammar_no_print_statement_no_exec_statement,
pygram.python_grammar_no_print_statement,
@ -590,8 +626,7 @@ def lib2to3_parse(src_txt: str) -> Node:
"""Given a string with source, return the lib2to3 Node."""
grammar = pygram.python_grammar_no_print_statement
if src_txt[-1] != "\n":
nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
src_txt += nl
src_txt += "\n"
for grammar in GRAMMARS:
drv = driver.Driver(grammar, pytree.convert)
try:

View File

@ -61,6 +61,8 @@ Parsing
.. autofunction:: black.lib2to3_unparse
.. autofunction:: black.prepare_input
Split functions
---------------

View File

@ -3,7 +3,7 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from functools import partial
from io import StringIO
from io import BytesIO, TextIOWrapper
import os
from pathlib import Path
import sys
@ -121,8 +121,9 @@ def test_piping(self) -> None:
source, expected = read_data("../black")
hold_stdin, hold_stdout = sys.stdin, sys.stdout
try:
sys.stdin, sys.stdout = StringIO(source), StringIO()
sys.stdin.name = "<stdin>"
sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
sys.stdin.buffer.name = "<stdin>" # type: ignore
black.format_stdin_to_stdout(
line_length=ll, fast=True, write_back=black.WriteBack.YES
)
@ -139,8 +140,9 @@ def test_piping_diff(self) -> None:
expected, _ = read_data("expression.diff")
hold_stdin, hold_stdout = sys.stdin, sys.stdout
try:
sys.stdin, sys.stdout = StringIO(source), StringIO()
sys.stdin.name = "<stdin>"
sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
sys.stdin.buffer.name = "<stdin>" # type: ignore
black.format_stdin_to_stdout(
line_length=ll, fast=True, write_back=black.WriteBack.DIFF
)
@ -204,7 +206,7 @@ def test_expression_diff(self) -> None:
tmp_file = Path(black.dump_to_file(source))
hold_stdout = sys.stdout
try:
sys.stdout = StringIO()
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
sys.stdout.seek(0)
actual = sys.stdout.read()
@ -1108,6 +1110,18 @@ def test_invalid_include_exclude(self) -> None:
result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
self.assertEqual(result.exit_code, 2)
def test_preserves_line_endings(self) -> None:
with TemporaryDirectory() as workspace:
test_file = Path(workspace) / "test.py"
for nl in ["\n", "\r\n"]:
contents = nl.join(["def f( ):", " pass"])
test_file.write_bytes(contents.encode())
ff(test_file, write_back=black.WriteBack.YES)
updated_contents: bytes = test_file.read_bytes()
self.assertIn(nl.encode(), updated_contents) # type: ignore
if nl == "\n":
self.assertNotIn(b"\r\n", updated_contents) # type: ignore
if __name__ == "__main__":
unittest.main()