Add support for pyi files (#210)

Fixes #207
This commit is contained in:
Jelle Zijlstra 2018-05-15 15:09:35 -04:00 committed by Łukasz Langa
parent 3eab6d3131
commit 14ba1bf8b6
3 changed files with 137 additions and 19 deletions

108
black.py
View File

@ -329,12 +329,13 @@ def format_file_in_place(
If `write_back` is True, write reformatted code back to stdout.
`line_length` and `fast` options are passed to :func:`format_file_contents`.
"""
is_pyi = src.suffix == ".pyi"
with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read()
try:
dst_contents = format_file_contents(
src_contents, line_length=line_length, fast=fast
src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
)
except NothingChanged:
return False
@ -383,7 +384,7 @@ def format_stdin_to_stdout(
def format_file_contents(
src_contents: str, line_length: int, fast: bool
src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
) -> FileContent:
"""Reformat contents a file and return new contents.
@ -394,17 +395,21 @@ def format_file_contents(
if src_contents.strip() == "":
raise NothingChanged
dst_contents = format_str(src_contents, line_length=line_length)
dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
if src_contents == dst_contents:
raise NothingChanged
if not fast:
assert_equivalent(src_contents, dst_contents)
assert_stable(src_contents, dst_contents, line_length=line_length)
assert_stable(
src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
)
return dst_contents
def format_str(src_contents: str, line_length: int) -> FileContent:
def format_str(
src_contents: str, line_length: int, *, is_pyi: bool = False
) -> FileContent:
"""Reformat a string and return new contents.
`line_length` determines how many characters per line are allowed.
@ -412,9 +417,11 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
src_node = lib2to3_parse(src_contents)
dst_contents = ""
future_imports = get_future_imports(src_node)
elt = EmptyLineTracker(is_pyi=is_pyi)
py36 = is_python36(src_node)
lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports)
elt = EmptyLineTracker()
lines = LineGenerator(
remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
)
empty_line = Line()
after = 0
for current_line in lines.visit(src_node):
@ -833,6 +840,14 @@ def is_class(self) -> bool:
and self.leaves[0].value == "class"
)
@property
def is_trivial_class(self) -> bool:
"""Is this line a class definition with a body consisting only of "..."?"""
return (
self.is_class
and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)]
)
@property
def is_def(self) -> bool:
"""Is this a function definition? (Also returns True for async defs.)"""
@ -1100,6 +1115,7 @@ class EmptyLineTracker:
the prefix of the first leaf consists of optional newlines. Those newlines
are consumed by `maybe_empty_lines()` and included in the computation.
"""
is_pyi: bool = False
previous_line: Optional[Line] = None
previous_after: int = 0
previous_defs: List[int] = Factory(list)
@ -1123,7 +1139,7 @@ def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
max_allowed = 1
if current_line.depth == 0:
max_allowed = 2
max_allowed = 1 if self.is_pyi else 2
if current_line.leaves:
# Consume the first leaf's extra newlines.
first_leaf = current_line.leaves[0]
@ -1135,6 +1151,9 @@ def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
depth = current_line.depth
while self.previous_defs and self.previous_defs[-1] >= depth:
self.previous_defs.pop()
if self.is_pyi:
before = 0 if depth else 1
else:
before = 1 if depth else 2
is_decorator = current_line.is_decorator
if is_decorator or current_line.is_def or current_line.is_class:
@ -1154,8 +1173,22 @@ def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
):
return 0, 0
if self.is_pyi:
if self.previous_line.depth > current_line.depth:
newlines = 1
elif current_line.is_class or self.previous_line.is_class:
if (
current_line.is_trivial_class
and self.previous_line.is_trivial_class
):
newlines = 0
else:
newlines = 1
else:
newlines = 0
else:
newlines = 2
if current_line.depth:
if current_line.depth and newlines:
newlines -= 1
return newlines, 0
@ -1177,6 +1210,7 @@ class LineGenerator(Visitor[Line]):
Note: destroys the tree it's visiting by mutating prefixes of its leaves
in ways that will no longer stringify to valid Python code on the tree.
"""
is_pyi: bool = False
current_line: Line = Factory(Line)
remove_u_prefix: bool = False
@ -1293,15 +1327,65 @@ def visit_stmt(
yield from self.visit(child)
def visit_suite(self, node: Node) -> Iterator[Line]:
"""Visit a suite."""
if self.is_pyi and self.is_trivial_suite(node):
yield from self.visit(node.children[2])
else:
yield from self.visit_default(node)
def is_trivial_suite(self, node: Node) -> bool:
if len(node.children) != 4:
return False
if (
not isinstance(node.children[0], Leaf)
or node.children[0].type != token.NEWLINE
):
return False
if (
not isinstance(node.children[1], Leaf)
or node.children[1].type != token.INDENT
):
return False
if (
not isinstance(node.children[3], Leaf)
or node.children[3].type != token.DEDENT
):
return False
stmt = node.children[2]
if not isinstance(stmt, Node):
return False
return self.is_trivial_body(stmt)
def is_trivial_body(self, stmt: Node) -> bool:
if not isinstance(stmt, Node) or stmt.type != syms.simple_stmt:
return False
if len(stmt.children) != 2:
return False
child = stmt.children[0]
return (
child.type == syms.atom
and len(child.children) == 3
and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
)
def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
"""Visit a statement without nested statements."""
is_suite_like = node.parent and node.parent.type in STATEMENT
if is_suite_like:
if self.is_pyi and self.is_trivial_body(node):
yield from self.visit_default(node)
else:
yield from self.line(+1)
yield from self.visit_default(node)
yield from self.line(-1)
else:
if (
not self.is_pyi
or not node.parent
or not self.is_trivial_suite(node.parent)
):
yield from self.line()
yield from self.visit_default(node)
@ -2554,7 +2638,7 @@ def get_future_imports(node: Node) -> Set[str]:
return imports
PYTHON_EXTENSIONS = {".py"}
PYTHON_EXTENSIONS = {".py", ".pyi"}
BLACKLISTED_DIRECTORIES = {
"build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
}
@ -2717,9 +2801,9 @@ def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
) from None
def assert_stable(src: str, dst: str, line_length: int) -> None:
def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
newdst = format_str(dst, line_length=line_length)
newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
if dst != newdst:
log = dump_to_file(
diff(src, dst, "source", "first pass"),

27
tests/stub.pyi Normal file
View File

@ -0,0 +1,27 @@
class C:
...
class B:
...
class A:
def f(self) -> int:
...
def g(self) -> str: ...
def g():
...
def h(): ...
# output
class C: ...
class B: ...
class A:
def f(self) -> int: ...
def g(self) -> str: ...
def g(): ...
def h(): ...

View File

@ -31,7 +31,7 @@ def dump_to_stderr(*output: str) -> str:
def read_data(name: str) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'"""
if not name.endswith((".py", ".out", ".diff")):
if not name.endswith((".py", ".pyi", ".out", ".diff")):
name += ".py"
_input: List[str] = []
_output: List[str] = []
@ -340,6 +340,13 @@ def test_python2_unicode_literals(self) -> None:
self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll)
@patch("black.dump_to_file", dump_to_stderr)
def test_stub(self) -> None:
source, expected = read_data("stub.pyi")
actual = fs(source, is_pyi=True)
self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll, is_pyi=True)
@patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff(self) -> None:
source, expected = read_data("fmtonoff")