parent
3eab6d3131
commit
14ba1bf8b6
120
black.py
120
black.py
@ -329,12 +329,13 @@ def format_file_in_place(
|
||||
If `write_back` is True, write reformatted code back to stdout.
|
||||
`line_length` and `fast` options are passed to :func:`format_file_contents`.
|
||||
"""
|
||||
is_pyi = src.suffix == ".pyi"
|
||||
|
||||
with tokenize.open(src) as src_buffer:
|
||||
src_contents = src_buffer.read()
|
||||
try:
|
||||
dst_contents = format_file_contents(
|
||||
src_contents, line_length=line_length, fast=fast
|
||||
src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
|
||||
)
|
||||
except NothingChanged:
|
||||
return False
|
||||
@ -383,7 +384,7 @@ def format_stdin_to_stdout(
|
||||
|
||||
|
||||
def format_file_contents(
|
||||
src_contents: str, line_length: int, fast: bool
|
||||
src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
|
||||
) -> FileContent:
|
||||
"""Reformat contents a file and return new contents.
|
||||
|
||||
@ -394,17 +395,21 @@ def format_file_contents(
|
||||
if src_contents.strip() == "":
|
||||
raise NothingChanged
|
||||
|
||||
dst_contents = format_str(src_contents, line_length=line_length)
|
||||
dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
|
||||
if src_contents == dst_contents:
|
||||
raise NothingChanged
|
||||
|
||||
if not fast:
|
||||
assert_equivalent(src_contents, dst_contents)
|
||||
assert_stable(src_contents, dst_contents, line_length=line_length)
|
||||
assert_stable(
|
||||
src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
|
||||
)
|
||||
return dst_contents
|
||||
|
||||
|
||||
def format_str(src_contents: str, line_length: int) -> FileContent:
|
||||
def format_str(
|
||||
src_contents: str, line_length: int, *, is_pyi: bool = False
|
||||
) -> FileContent:
|
||||
"""Reformat a string and return new contents.
|
||||
|
||||
`line_length` determines how many characters per line are allowed.
|
||||
@ -412,9 +417,11 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
|
||||
src_node = lib2to3_parse(src_contents)
|
||||
dst_contents = ""
|
||||
future_imports = get_future_imports(src_node)
|
||||
elt = EmptyLineTracker(is_pyi=is_pyi)
|
||||
py36 = is_python36(src_node)
|
||||
lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports)
|
||||
elt = EmptyLineTracker()
|
||||
lines = LineGenerator(
|
||||
remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
|
||||
)
|
||||
empty_line = Line()
|
||||
after = 0
|
||||
for current_line in lines.visit(src_node):
|
||||
@ -833,6 +840,14 @@ def is_class(self) -> bool:
|
||||
and self.leaves[0].value == "class"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_trivial_class(self) -> bool:
|
||||
"""Is this line a class definition with a body consisting only of "..."?"""
|
||||
return (
|
||||
self.is_class
|
||||
and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)]
|
||||
)
|
||||
|
||||
@property
|
||||
def is_def(self) -> bool:
|
||||
"""Is this a function definition? (Also returns True for async defs.)"""
|
||||
@ -1100,6 +1115,7 @@ class EmptyLineTracker:
|
||||
the prefix of the first leaf consists of optional newlines. Those newlines
|
||||
are consumed by `maybe_empty_lines()` and included in the computation.
|
||||
"""
|
||||
is_pyi: bool = False
|
||||
previous_line: Optional[Line] = None
|
||||
previous_after: int = 0
|
||||
previous_defs: List[int] = Factory(list)
|
||||
@ -1123,7 +1139,7 @@ def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
|
||||
def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
|
||||
max_allowed = 1
|
||||
if current_line.depth == 0:
|
||||
max_allowed = 2
|
||||
max_allowed = 1 if self.is_pyi else 2
|
||||
if current_line.leaves:
|
||||
# Consume the first leaf's extra newlines.
|
||||
first_leaf = current_line.leaves[0]
|
||||
@ -1135,7 +1151,10 @@ def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
|
||||
depth = current_line.depth
|
||||
while self.previous_defs and self.previous_defs[-1] >= depth:
|
||||
self.previous_defs.pop()
|
||||
before = 1 if depth else 2
|
||||
if self.is_pyi:
|
||||
before = 0 if depth else 1
|
||||
else:
|
||||
before = 1 if depth else 2
|
||||
is_decorator = current_line.is_decorator
|
||||
if is_decorator or current_line.is_def or current_line.is_class:
|
||||
if not is_decorator:
|
||||
@ -1154,8 +1173,22 @@ def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
|
||||
):
|
||||
return 0, 0
|
||||
|
||||
newlines = 2
|
||||
if current_line.depth:
|
||||
if self.is_pyi:
|
||||
if self.previous_line.depth > current_line.depth:
|
||||
newlines = 1
|
||||
elif current_line.is_class or self.previous_line.is_class:
|
||||
if (
|
||||
current_line.is_trivial_class
|
||||
and self.previous_line.is_trivial_class
|
||||
):
|
||||
newlines = 0
|
||||
else:
|
||||
newlines = 1
|
||||
else:
|
||||
newlines = 0
|
||||
else:
|
||||
newlines = 2
|
||||
if current_line.depth and newlines:
|
||||
newlines -= 1
|
||||
return newlines, 0
|
||||
|
||||
@ -1177,6 +1210,7 @@ class LineGenerator(Visitor[Line]):
|
||||
Note: destroys the tree it's visiting by mutating prefixes of its leaves
|
||||
in ways that will no longer stringify to valid Python code on the tree.
|
||||
"""
|
||||
is_pyi: bool = False
|
||||
current_line: Line = Factory(Line)
|
||||
remove_u_prefix: bool = False
|
||||
|
||||
@ -1293,16 +1327,66 @@ def visit_stmt(
|
||||
|
||||
yield from self.visit(child)
|
||||
|
||||
def visit_suite(self, node: Node) -> Iterator[Line]:
|
||||
"""Visit a suite."""
|
||||
if self.is_pyi and self.is_trivial_suite(node):
|
||||
yield from self.visit(node.children[2])
|
||||
else:
|
||||
yield from self.visit_default(node)
|
||||
|
||||
def is_trivial_suite(self, node: Node) -> bool:
|
||||
if len(node.children) != 4:
|
||||
return False
|
||||
if (
|
||||
not isinstance(node.children[0], Leaf)
|
||||
or node.children[0].type != token.NEWLINE
|
||||
):
|
||||
return False
|
||||
if (
|
||||
not isinstance(node.children[1], Leaf)
|
||||
or node.children[1].type != token.INDENT
|
||||
):
|
||||
return False
|
||||
if (
|
||||
not isinstance(node.children[3], Leaf)
|
||||
or node.children[3].type != token.DEDENT
|
||||
):
|
||||
return False
|
||||
stmt = node.children[2]
|
||||
if not isinstance(stmt, Node):
|
||||
return False
|
||||
return self.is_trivial_body(stmt)
|
||||
|
||||
def is_trivial_body(self, stmt: Node) -> bool:
|
||||
if not isinstance(stmt, Node) or stmt.type != syms.simple_stmt:
|
||||
return False
|
||||
if len(stmt.children) != 2:
|
||||
return False
|
||||
child = stmt.children[0]
|
||||
return (
|
||||
child.type == syms.atom
|
||||
and len(child.children) == 3
|
||||
and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
|
||||
)
|
||||
|
||||
def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
|
||||
"""Visit a statement without nested statements."""
|
||||
is_suite_like = node.parent and node.parent.type in STATEMENT
|
||||
if is_suite_like:
|
||||
yield from self.line(+1)
|
||||
yield from self.visit_default(node)
|
||||
yield from self.line(-1)
|
||||
if self.is_pyi and self.is_trivial_body(node):
|
||||
yield from self.visit_default(node)
|
||||
else:
|
||||
yield from self.line(+1)
|
||||
yield from self.visit_default(node)
|
||||
yield from self.line(-1)
|
||||
|
||||
else:
|
||||
yield from self.line()
|
||||
if (
|
||||
not self.is_pyi
|
||||
or not node.parent
|
||||
or not self.is_trivial_suite(node.parent)
|
||||
):
|
||||
yield from self.line()
|
||||
yield from self.visit_default(node)
|
||||
|
||||
def visit_async_stmt(self, node: Node) -> Iterator[Line]:
|
||||
@ -2554,7 +2638,7 @@ def get_future_imports(node: Node) -> Set[str]:
|
||||
return imports
|
||||
|
||||
|
||||
PYTHON_EXTENSIONS = {".py"}
|
||||
PYTHON_EXTENSIONS = {".py", ".pyi"}
|
||||
BLACKLISTED_DIRECTORIES = {
|
||||
"build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
|
||||
}
|
||||
@ -2717,9 +2801,9 @@ def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
||||
) from None
|
||||
|
||||
|
||||
def assert_stable(src: str, dst: str, line_length: int) -> None:
|
||||
def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
|
||||
"""Raise AssertionError if `dst` reformats differently the second time."""
|
||||
newdst = format_str(dst, line_length=line_length)
|
||||
newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
|
||||
if dst != newdst:
|
||||
log = dump_to_file(
|
||||
diff(src, dst, "source", "first pass"),
|
||||
|
27
tests/stub.pyi
Normal file
27
tests/stub.pyi
Normal file
@ -0,0 +1,27 @@
|
||||
class C:
|
||||
...
|
||||
|
||||
class B:
|
||||
...
|
||||
|
||||
class A:
|
||||
def f(self) -> int:
|
||||
...
|
||||
|
||||
def g(self) -> str: ...
|
||||
|
||||
def g():
|
||||
...
|
||||
|
||||
def h(): ...
|
||||
|
||||
# output
|
||||
class C: ...
|
||||
class B: ...
|
||||
|
||||
class A:
|
||||
def f(self) -> int: ...
|
||||
def g(self) -> str: ...
|
||||
|
||||
def g(): ...
|
||||
def h(): ...
|
@ -31,7 +31,7 @@ def dump_to_stderr(*output: str) -> str:
|
||||
|
||||
def read_data(name: str) -> Tuple[str, str]:
|
||||
"""read_data('test_name') -> 'input', 'output'"""
|
||||
if not name.endswith((".py", ".out", ".diff")):
|
||||
if not name.endswith((".py", ".pyi", ".out", ".diff")):
|
||||
name += ".py"
|
||||
_input: List[str] = []
|
||||
_output: List[str] = []
|
||||
@ -340,6 +340,13 @@ def test_python2_unicode_literals(self) -> None:
|
||||
self.assertFormatEqual(expected, actual)
|
||||
black.assert_stable(source, actual, line_length=ll)
|
||||
|
||||
@patch("black.dump_to_file", dump_to_stderr)
|
||||
def test_stub(self) -> None:
|
||||
source, expected = read_data("stub.pyi")
|
||||
actual = fs(source, is_pyi=True)
|
||||
self.assertFormatEqual(expected, actual)
|
||||
black.assert_stable(source, actual, line_length=ll, is_pyi=True)
|
||||
|
||||
@patch("black.dump_to_file", dump_to_stderr)
|
||||
def test_fmtonoff(self) -> None:
|
||||
source, expected = read_data("fmtonoff")
|
||||
|
Loading…
Reference in New Issue
Block a user