Preserve line endings when formatting a file in place (#288)
This commit is contained in:
parent
dbe26161fa
commit
00a302560b
@ -720,6 +720,8 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md).
|
|||||||
* fixed stdin handling not working correctly if an old version of Click was
|
* fixed stdin handling not working correctly if an old version of Click was
|
||||||
used (#276)
|
used (#276)
|
||||||
|
|
||||||
|
* *Black* now preserves line endings when formatting a file in place (#258)
|
||||||
|
|
||||||
|
|
||||||
### 18.5b1
|
### 18.5b1
|
||||||
|
|
||||||
|
53
black.py
53
black.py
@ -4,6 +4,7 @@
|
|||||||
from concurrent.futures import Executor, ProcessPoolExecutor
|
from concurrent.futures import Executor, ProcessPoolExecutor
|
||||||
from enum import Enum, Flag
|
from enum import Enum, Flag
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
|
import io
|
||||||
import keyword
|
import keyword
|
||||||
import logging
|
import logging
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
@ -465,8 +466,9 @@ def format_file_in_place(
|
|||||||
"""
|
"""
|
||||||
if src.suffix == ".pyi":
|
if src.suffix == ".pyi":
|
||||||
mode |= FileMode.PYI
|
mode |= FileMode.PYI
|
||||||
with tokenize.open(src) as src_buffer:
|
|
||||||
src_contents = src_buffer.read()
|
with open(src, "rb") as buf:
|
||||||
|
newline, encoding, src_contents = prepare_input(buf.read())
|
||||||
try:
|
try:
|
||||||
dst_contents = format_file_contents(
|
dst_contents = format_file_contents(
|
||||||
src_contents, line_length=line_length, fast=fast, mode=mode
|
src_contents, line_length=line_length, fast=fast, mode=mode
|
||||||
@ -475,7 +477,7 @@ def format_file_in_place(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if write_back == write_back.YES:
|
if write_back == write_back.YES:
|
||||||
with open(src, "w", encoding=src_buffer.encoding) as f:
|
with open(src, "w", encoding=encoding, newline=newline) as f:
|
||||||
f.write(dst_contents)
|
f.write(dst_contents)
|
||||||
elif write_back == write_back.DIFF:
|
elif write_back == write_back.DIFF:
|
||||||
src_name = f"{src} (original)"
|
src_name = f"{src} (original)"
|
||||||
@ -484,7 +486,14 @@ def format_file_in_place(
|
|||||||
if lock:
|
if lock:
|
||||||
lock.acquire()
|
lock.acquire()
|
||||||
try:
|
try:
|
||||||
sys.stdout.write(diff_contents)
|
f = io.TextIOWrapper(
|
||||||
|
sys.stdout.buffer,
|
||||||
|
encoding=encoding,
|
||||||
|
newline=newline,
|
||||||
|
write_through=True,
|
||||||
|
)
|
||||||
|
f.write(diff_contents)
|
||||||
|
f.detach()
|
||||||
finally:
|
finally:
|
||||||
if lock:
|
if lock:
|
||||||
lock.release()
|
lock.release()
|
||||||
@ -503,7 +512,7 @@ def format_stdin_to_stdout(
|
|||||||
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
|
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
|
||||||
:func:`format_file_contents`.
|
:func:`format_file_contents`.
|
||||||
"""
|
"""
|
||||||
src = sys.stdin.read()
|
newline, encoding, src = prepare_input(sys.stdin.buffer.read())
|
||||||
dst = src
|
dst = src
|
||||||
try:
|
try:
|
||||||
dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
|
dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
|
||||||
@ -514,11 +523,25 @@ def format_stdin_to_stdout(
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
if write_back == WriteBack.YES:
|
if write_back == WriteBack.YES:
|
||||||
sys.stdout.write(dst)
|
f = io.TextIOWrapper(
|
||||||
|
sys.stdout.buffer,
|
||||||
|
encoding=encoding,
|
||||||
|
newline=newline,
|
||||||
|
write_through=True,
|
||||||
|
)
|
||||||
|
f.write(dst)
|
||||||
|
f.detach()
|
||||||
elif write_back == WriteBack.DIFF:
|
elif write_back == WriteBack.DIFF:
|
||||||
src_name = "<stdin> (original)"
|
src_name = "<stdin> (original)"
|
||||||
dst_name = "<stdin> (formatted)"
|
dst_name = "<stdin> (formatted)"
|
||||||
sys.stdout.write(diff(src, dst, src_name, dst_name))
|
f = io.TextIOWrapper(
|
||||||
|
sys.stdout.buffer,
|
||||||
|
encoding=encoding,
|
||||||
|
newline=newline,
|
||||||
|
write_through=True,
|
||||||
|
)
|
||||||
|
f.write(diff(src, dst, src_name, dst_name))
|
||||||
|
f.detach()
|
||||||
|
|
||||||
|
|
||||||
def format_file_contents(
|
def format_file_contents(
|
||||||
@ -579,6 +602,19 @@ def format_str(
|
|||||||
return dst_contents
|
return dst_contents
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_input(src: bytes) -> Tuple[str, str, str]:
|
||||||
|
"""Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
|
||||||
|
|
||||||
|
Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
|
||||||
|
universal newlines (i.e. only LF).
|
||||||
|
"""
|
||||||
|
srcbuf = io.BytesIO(src)
|
||||||
|
encoding, lines = tokenize.detect_encoding(srcbuf.readline)
|
||||||
|
newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
|
||||||
|
srcbuf.seek(0)
|
||||||
|
return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
|
||||||
|
|
||||||
|
|
||||||
GRAMMARS = [
|
GRAMMARS = [
|
||||||
pygram.python_grammar_no_print_statement_no_exec_statement,
|
pygram.python_grammar_no_print_statement_no_exec_statement,
|
||||||
pygram.python_grammar_no_print_statement,
|
pygram.python_grammar_no_print_statement,
|
||||||
@ -590,8 +626,7 @@ def lib2to3_parse(src_txt: str) -> Node:
|
|||||||
"""Given a string with source, return the lib2to3 Node."""
|
"""Given a string with source, return the lib2to3 Node."""
|
||||||
grammar = pygram.python_grammar_no_print_statement
|
grammar = pygram.python_grammar_no_print_statement
|
||||||
if src_txt[-1] != "\n":
|
if src_txt[-1] != "\n":
|
||||||
nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
|
src_txt += "\n"
|
||||||
src_txt += nl
|
|
||||||
for grammar in GRAMMARS:
|
for grammar in GRAMMARS:
|
||||||
drv = driver.Driver(grammar, pytree.convert)
|
drv = driver.Driver(grammar, pytree.convert)
|
||||||
try:
|
try:
|
||||||
|
@ -61,6 +61,8 @@ Parsing
|
|||||||
|
|
||||||
.. autofunction:: black.lib2to3_unparse
|
.. autofunction:: black.lib2to3_unparse
|
||||||
|
|
||||||
|
.. autofunction:: black.prepare_input
|
||||||
|
|
||||||
Split functions
|
Split functions
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from io import StringIO
|
from io import BytesIO, TextIOWrapper
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
@ -121,8 +121,9 @@ def test_piping(self) -> None:
|
|||||||
source, expected = read_data("../black")
|
source, expected = read_data("../black")
|
||||||
hold_stdin, hold_stdout = sys.stdin, sys.stdout
|
hold_stdin, hold_stdout = sys.stdin, sys.stdout
|
||||||
try:
|
try:
|
||||||
sys.stdin, sys.stdout = StringIO(source), StringIO()
|
sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
|
||||||
sys.stdin.name = "<stdin>"
|
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
|
||||||
|
sys.stdin.buffer.name = "<stdin>" # type: ignore
|
||||||
black.format_stdin_to_stdout(
|
black.format_stdin_to_stdout(
|
||||||
line_length=ll, fast=True, write_back=black.WriteBack.YES
|
line_length=ll, fast=True, write_back=black.WriteBack.YES
|
||||||
)
|
)
|
||||||
@ -139,8 +140,9 @@ def test_piping_diff(self) -> None:
|
|||||||
expected, _ = read_data("expression.diff")
|
expected, _ = read_data("expression.diff")
|
||||||
hold_stdin, hold_stdout = sys.stdin, sys.stdout
|
hold_stdin, hold_stdout = sys.stdin, sys.stdout
|
||||||
try:
|
try:
|
||||||
sys.stdin, sys.stdout = StringIO(source), StringIO()
|
sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
|
||||||
sys.stdin.name = "<stdin>"
|
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
|
||||||
|
sys.stdin.buffer.name = "<stdin>" # type: ignore
|
||||||
black.format_stdin_to_stdout(
|
black.format_stdin_to_stdout(
|
||||||
line_length=ll, fast=True, write_back=black.WriteBack.DIFF
|
line_length=ll, fast=True, write_back=black.WriteBack.DIFF
|
||||||
)
|
)
|
||||||
@ -204,7 +206,7 @@ def test_expression_diff(self) -> None:
|
|||||||
tmp_file = Path(black.dump_to_file(source))
|
tmp_file = Path(black.dump_to_file(source))
|
||||||
hold_stdout = sys.stdout
|
hold_stdout = sys.stdout
|
||||||
try:
|
try:
|
||||||
sys.stdout = StringIO()
|
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
|
||||||
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
|
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
|
||||||
sys.stdout.seek(0)
|
sys.stdout.seek(0)
|
||||||
actual = sys.stdout.read()
|
actual = sys.stdout.read()
|
||||||
@ -1108,6 +1110,18 @@ def test_invalid_include_exclude(self) -> None:
|
|||||||
result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
|
result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
|
||||||
self.assertEqual(result.exit_code, 2)
|
self.assertEqual(result.exit_code, 2)
|
||||||
|
|
||||||
|
def test_preserves_line_endings(self) -> None:
|
||||||
|
with TemporaryDirectory() as workspace:
|
||||||
|
test_file = Path(workspace) / "test.py"
|
||||||
|
for nl in ["\n", "\r\n"]:
|
||||||
|
contents = nl.join(["def f( ):", " pass"])
|
||||||
|
test_file.write_bytes(contents.encode())
|
||||||
|
ff(test_file, write_back=black.WriteBack.YES)
|
||||||
|
updated_contents: bytes = test_file.read_bytes()
|
||||||
|
self.assertIn(nl.encode(), updated_contents) # type: ignore
|
||||||
|
if nl == "\n":
|
||||||
|
self.assertNotIn(b"\r\n", updated_contents) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user