Preserve line endings when formatting a file in place (#288)
This commit is contained in:
parent
dbe26161fa
commit
00a302560b
@ -720,6 +720,8 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md).
|
||||
* fixed stdin handling not working correctly if an old version of Click was
|
||||
used (#276)
|
||||
|
||||
* *Black* now preserves line endings when formatting a file in place (#258)
|
||||
|
||||
|
||||
### 18.5b1
|
||||
|
||||
|
53
black.py
53
black.py
@ -4,6 +4,7 @@
|
||||
from concurrent.futures import Executor, ProcessPoolExecutor
|
||||
from enum import Enum, Flag
|
||||
from functools import partial, wraps
|
||||
import io
|
||||
import keyword
|
||||
import logging
|
||||
from multiprocessing import Manager
|
||||
@ -465,8 +466,9 @@ def format_file_in_place(
|
||||
"""
|
||||
if src.suffix == ".pyi":
|
||||
mode |= FileMode.PYI
|
||||
with tokenize.open(src) as src_buffer:
|
||||
src_contents = src_buffer.read()
|
||||
|
||||
with open(src, "rb") as buf:
|
||||
newline, encoding, src_contents = prepare_input(buf.read())
|
||||
try:
|
||||
dst_contents = format_file_contents(
|
||||
src_contents, line_length=line_length, fast=fast, mode=mode
|
||||
@ -475,7 +477,7 @@ def format_file_in_place(
|
||||
return False
|
||||
|
||||
if write_back == write_back.YES:
|
||||
with open(src, "w", encoding=src_buffer.encoding) as f:
|
||||
with open(src, "w", encoding=encoding, newline=newline) as f:
|
||||
f.write(dst_contents)
|
||||
elif write_back == write_back.DIFF:
|
||||
src_name = f"{src} (original)"
|
||||
@ -484,7 +486,14 @@ def format_file_in_place(
|
||||
if lock:
|
||||
lock.acquire()
|
||||
try:
|
||||
sys.stdout.write(diff_contents)
|
||||
f = io.TextIOWrapper(
|
||||
sys.stdout.buffer,
|
||||
encoding=encoding,
|
||||
newline=newline,
|
||||
write_through=True,
|
||||
)
|
||||
f.write(diff_contents)
|
||||
f.detach()
|
||||
finally:
|
||||
if lock:
|
||||
lock.release()
|
||||
@ -503,7 +512,7 @@ def format_stdin_to_stdout(
|
||||
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
|
||||
:func:`format_file_contents`.
|
||||
"""
|
||||
src = sys.stdin.read()
|
||||
newline, encoding, src = prepare_input(sys.stdin.buffer.read())
|
||||
dst = src
|
||||
try:
|
||||
dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
|
||||
@ -514,11 +523,25 @@ def format_stdin_to_stdout(
|
||||
|
||||
finally:
|
||||
if write_back == WriteBack.YES:
|
||||
sys.stdout.write(dst)
|
||||
f = io.TextIOWrapper(
|
||||
sys.stdout.buffer,
|
||||
encoding=encoding,
|
||||
newline=newline,
|
||||
write_through=True,
|
||||
)
|
||||
f.write(dst)
|
||||
f.detach()
|
||||
elif write_back == WriteBack.DIFF:
|
||||
src_name = "<stdin> (original)"
|
||||
dst_name = "<stdin> (formatted)"
|
||||
sys.stdout.write(diff(src, dst, src_name, dst_name))
|
||||
f = io.TextIOWrapper(
|
||||
sys.stdout.buffer,
|
||||
encoding=encoding,
|
||||
newline=newline,
|
||||
write_through=True,
|
||||
)
|
||||
f.write(diff(src, dst, src_name, dst_name))
|
||||
f.detach()
|
||||
|
||||
|
||||
def format_file_contents(
|
||||
@ -579,6 +602,19 @@ def format_str(
|
||||
return dst_contents
|
||||
|
||||
|
||||
def prepare_input(src: bytes) -> Tuple[str, str, str]:
|
||||
"""Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
|
||||
|
||||
Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
|
||||
universal newlines (i.e. only LF).
|
||||
"""
|
||||
srcbuf = io.BytesIO(src)
|
||||
encoding, lines = tokenize.detect_encoding(srcbuf.readline)
|
||||
newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
|
||||
srcbuf.seek(0)
|
||||
return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
|
||||
|
||||
|
||||
GRAMMARS = [
|
||||
pygram.python_grammar_no_print_statement_no_exec_statement,
|
||||
pygram.python_grammar_no_print_statement,
|
||||
@ -590,8 +626,7 @@ def lib2to3_parse(src_txt: str) -> Node:
|
||||
"""Given a string with source, return the lib2to3 Node."""
|
||||
grammar = pygram.python_grammar_no_print_statement
|
||||
if src_txt[-1] != "\n":
|
||||
nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
|
||||
src_txt += nl
|
||||
src_txt += "\n"
|
||||
for grammar in GRAMMARS:
|
||||
drv = driver.Driver(grammar, pytree.convert)
|
||||
try:
|
||||
|
@ -61,6 +61,8 @@ Parsing
|
||||
|
||||
.. autofunction:: black.lib2to3_unparse
|
||||
|
||||
.. autofunction:: black.prepare_input
|
||||
|
||||
Split functions
|
||||
---------------
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from io import StringIO
|
||||
from io import BytesIO, TextIOWrapper
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
@ -121,8 +121,9 @@ def test_piping(self) -> None:
|
||||
source, expected = read_data("../black")
|
||||
hold_stdin, hold_stdout = sys.stdin, sys.stdout
|
||||
try:
|
||||
sys.stdin, sys.stdout = StringIO(source), StringIO()
|
||||
sys.stdin.name = "<stdin>"
|
||||
sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
|
||||
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
|
||||
sys.stdin.buffer.name = "<stdin>" # type: ignore
|
||||
black.format_stdin_to_stdout(
|
||||
line_length=ll, fast=True, write_back=black.WriteBack.YES
|
||||
)
|
||||
@ -139,8 +140,9 @@ def test_piping_diff(self) -> None:
|
||||
expected, _ = read_data("expression.diff")
|
||||
hold_stdin, hold_stdout = sys.stdin, sys.stdout
|
||||
try:
|
||||
sys.stdin, sys.stdout = StringIO(source), StringIO()
|
||||
sys.stdin.name = "<stdin>"
|
||||
sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
|
||||
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
|
||||
sys.stdin.buffer.name = "<stdin>" # type: ignore
|
||||
black.format_stdin_to_stdout(
|
||||
line_length=ll, fast=True, write_back=black.WriteBack.DIFF
|
||||
)
|
||||
@ -204,7 +206,7 @@ def test_expression_diff(self) -> None:
|
||||
tmp_file = Path(black.dump_to_file(source))
|
||||
hold_stdout = sys.stdout
|
||||
try:
|
||||
sys.stdout = StringIO()
|
||||
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
|
||||
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
|
||||
sys.stdout.seek(0)
|
||||
actual = sys.stdout.read()
|
||||
@ -1108,6 +1110,18 @@ def test_invalid_include_exclude(self) -> None:
|
||||
result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
|
||||
self.assertEqual(result.exit_code, 2)
|
||||
|
||||
def test_preserves_line_endings(self) -> None:
|
||||
with TemporaryDirectory() as workspace:
|
||||
test_file = Path(workspace) / "test.py"
|
||||
for nl in ["\n", "\r\n"]:
|
||||
contents = nl.join(["def f( ):", " pass"])
|
||||
test_file.write_bytes(contents.encode())
|
||||
ff(test_file, write_back=black.WriteBack.YES)
|
||||
updated_contents: bytes = test_file.read_bytes()
|
||||
self.assertIn(nl.encode(), updated_contents) # type: ignore
|
||||
if nl == "\n":
|
||||
self.assertNotIn(b"\r\n", updated_contents) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user