Add support for pyi files (#210)

Fixes #207
This commit is contained in:
Jelle Zijlstra 2018-05-15 15:09:35 -04:00 committed by Łukasz Langa
parent 3eab6d3131
commit 14ba1bf8b6
3 changed files with 137 additions and 19 deletions

120
black.py
View File

@ -329,12 +329,13 @@ def format_file_in_place(
If `write_back` is True, write reformatted code back to stdout. If `write_back` is True, write reformatted code back to stdout.
`line_length` and `fast` options are passed to :func:`format_file_contents`. `line_length` and `fast` options are passed to :func:`format_file_contents`.
""" """
is_pyi = src.suffix == ".pyi"
with tokenize.open(src) as src_buffer: with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read() src_contents = src_buffer.read()
try: try:
dst_contents = format_file_contents( dst_contents = format_file_contents(
src_contents, line_length=line_length, fast=fast src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
) )
except NothingChanged: except NothingChanged:
return False return False
@ -383,7 +384,7 @@ def format_stdin_to_stdout(
def format_file_contents( def format_file_contents(
src_contents: str, line_length: int, fast: bool src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
) -> FileContent: ) -> FileContent:
"""Reformat contents a file and return new contents. """Reformat contents a file and return new contents.
@ -394,17 +395,21 @@ def format_file_contents(
if src_contents.strip() == "": if src_contents.strip() == "":
raise NothingChanged raise NothingChanged
dst_contents = format_str(src_contents, line_length=line_length) dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
if src_contents == dst_contents: if src_contents == dst_contents:
raise NothingChanged raise NothingChanged
if not fast: if not fast:
assert_equivalent(src_contents, dst_contents) assert_equivalent(src_contents, dst_contents)
assert_stable(src_contents, dst_contents, line_length=line_length) assert_stable(
src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
)
return dst_contents return dst_contents
def format_str(src_contents: str, line_length: int) -> FileContent: def format_str(
src_contents: str, line_length: int, *, is_pyi: bool = False
) -> FileContent:
"""Reformat a string and return new contents. """Reformat a string and return new contents.
`line_length` determines how many characters per line are allowed. `line_length` determines how many characters per line are allowed.
@ -412,9 +417,11 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
src_node = lib2to3_parse(src_contents) src_node = lib2to3_parse(src_contents)
dst_contents = "" dst_contents = ""
future_imports = get_future_imports(src_node) future_imports = get_future_imports(src_node)
elt = EmptyLineTracker(is_pyi=is_pyi)
py36 = is_python36(src_node) py36 = is_python36(src_node)
lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports) lines = LineGenerator(
elt = EmptyLineTracker() remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
)
empty_line = Line() empty_line = Line()
after = 0 after = 0
for current_line in lines.visit(src_node): for current_line in lines.visit(src_node):
@ -833,6 +840,14 @@ def is_class(self) -> bool:
and self.leaves[0].value == "class" and self.leaves[0].value == "class"
) )
@property
def is_trivial_class(self) -> bool:
"""Is this line a class definition with a body consisting only of "..."?"""
return (
self.is_class
and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)]
)
@property @property
def is_def(self) -> bool: def is_def(self) -> bool:
"""Is this a function definition? (Also returns True for async defs.)""" """Is this a function definition? (Also returns True for async defs.)"""
@ -1100,6 +1115,7 @@ class EmptyLineTracker:
the prefix of the first leaf consists of optional newlines. Those newlines the prefix of the first leaf consists of optional newlines. Those newlines
are consumed by `maybe_empty_lines()` and included in the computation. are consumed by `maybe_empty_lines()` and included in the computation.
""" """
is_pyi: bool = False
previous_line: Optional[Line] = None previous_line: Optional[Line] = None
previous_after: int = 0 previous_after: int = 0
previous_defs: List[int] = Factory(list) previous_defs: List[int] = Factory(list)
@ -1123,7 +1139,7 @@ def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
max_allowed = 1 max_allowed = 1
if current_line.depth == 0: if current_line.depth == 0:
max_allowed = 2 max_allowed = 1 if self.is_pyi else 2
if current_line.leaves: if current_line.leaves:
# Consume the first leaf's extra newlines. # Consume the first leaf's extra newlines.
first_leaf = current_line.leaves[0] first_leaf = current_line.leaves[0]
@ -1135,7 +1151,10 @@ def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
depth = current_line.depth depth = current_line.depth
while self.previous_defs and self.previous_defs[-1] >= depth: while self.previous_defs and self.previous_defs[-1] >= depth:
self.previous_defs.pop() self.previous_defs.pop()
before = 1 if depth else 2 if self.is_pyi:
before = 0 if depth else 1
else:
before = 1 if depth else 2
is_decorator = current_line.is_decorator is_decorator = current_line.is_decorator
if is_decorator or current_line.is_def or current_line.is_class: if is_decorator or current_line.is_def or current_line.is_class:
if not is_decorator: if not is_decorator:
@ -1154,8 +1173,22 @@ def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
): ):
return 0, 0 return 0, 0
newlines = 2 if self.is_pyi:
if current_line.depth: if self.previous_line.depth > current_line.depth:
newlines = 1
elif current_line.is_class or self.previous_line.is_class:
if (
current_line.is_trivial_class
and self.previous_line.is_trivial_class
):
newlines = 0
else:
newlines = 1
else:
newlines = 0
else:
newlines = 2
if current_line.depth and newlines:
newlines -= 1 newlines -= 1
return newlines, 0 return newlines, 0
@ -1177,6 +1210,7 @@ class LineGenerator(Visitor[Line]):
Note: destroys the tree it's visiting by mutating prefixes of its leaves Note: destroys the tree it's visiting by mutating prefixes of its leaves
in ways that will no longer stringify to valid Python code on the tree. in ways that will no longer stringify to valid Python code on the tree.
""" """
is_pyi: bool = False
current_line: Line = Factory(Line) current_line: Line = Factory(Line)
remove_u_prefix: bool = False remove_u_prefix: bool = False
@ -1293,16 +1327,66 @@ def visit_stmt(
yield from self.visit(child) yield from self.visit(child)
def visit_suite(self, node: Node) -> Iterator[Line]:
"""Visit a suite."""
if self.is_pyi and self.is_trivial_suite(node):
yield from self.visit(node.children[2])
else:
yield from self.visit_default(node)
def is_trivial_suite(self, node: Node) -> bool:
if len(node.children) != 4:
return False
if (
not isinstance(node.children[0], Leaf)
or node.children[0].type != token.NEWLINE
):
return False
if (
not isinstance(node.children[1], Leaf)
or node.children[1].type != token.INDENT
):
return False
if (
not isinstance(node.children[3], Leaf)
or node.children[3].type != token.DEDENT
):
return False
stmt = node.children[2]
if not isinstance(stmt, Node):
return False
return self.is_trivial_body(stmt)
def is_trivial_body(self, stmt: Node) -> bool:
if not isinstance(stmt, Node) or stmt.type != syms.simple_stmt:
return False
if len(stmt.children) != 2:
return False
child = stmt.children[0]
return (
child.type == syms.atom
and len(child.children) == 3
and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
)
def visit_simple_stmt(self, node: Node) -> Iterator[Line]: def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
"""Visit a statement without nested statements.""" """Visit a statement without nested statements."""
is_suite_like = node.parent and node.parent.type in STATEMENT is_suite_like = node.parent and node.parent.type in STATEMENT
if is_suite_like: if is_suite_like:
yield from self.line(+1) if self.is_pyi and self.is_trivial_body(node):
yield from self.visit_default(node) yield from self.visit_default(node)
yield from self.line(-1) else:
yield from self.line(+1)
yield from self.visit_default(node)
yield from self.line(-1)
else: else:
yield from self.line() if (
not self.is_pyi
or not node.parent
or not self.is_trivial_suite(node.parent)
):
yield from self.line()
yield from self.visit_default(node) yield from self.visit_default(node)
def visit_async_stmt(self, node: Node) -> Iterator[Line]: def visit_async_stmt(self, node: Node) -> Iterator[Line]:
@ -2554,7 +2638,7 @@ def get_future_imports(node: Node) -> Set[str]:
return imports return imports
PYTHON_EXTENSIONS = {".py"} PYTHON_EXTENSIONS = {".py", ".pyi"}
BLACKLISTED_DIRECTORIES = { BLACKLISTED_DIRECTORIES = {
"build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv" "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
} }
@ -2717,9 +2801,9 @@ def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
) from None ) from None
def assert_stable(src: str, dst: str, line_length: int) -> None: def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
"""Raise AssertionError if `dst` reformats differently the second time.""" """Raise AssertionError if `dst` reformats differently the second time."""
newdst = format_str(dst, line_length=line_length) newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
if dst != newdst: if dst != newdst:
log = dump_to_file( log = dump_to_file(
diff(src, dst, "source", "first pass"), diff(src, dst, "source", "first pass"),

27
tests/stub.pyi Normal file
View File

@ -0,0 +1,27 @@
class C:
...
class B:
...
class A:
def f(self) -> int:
...
def g(self) -> str: ...
def g():
...
def h(): ...
# output
class C: ...
class B: ...
class A:
def f(self) -> int: ...
def g(self) -> str: ...
def g(): ...
def h(): ...

View File

@ -31,7 +31,7 @@ def dump_to_stderr(*output: str) -> str:
def read_data(name: str) -> Tuple[str, str]: def read_data(name: str) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'""" """read_data('test_name') -> 'input', 'output'"""
if not name.endswith((".py", ".out", ".diff")): if not name.endswith((".py", ".pyi", ".out", ".diff")):
name += ".py" name += ".py"
_input: List[str] = [] _input: List[str] = []
_output: List[str] = [] _output: List[str] = []
@ -340,6 +340,13 @@ def test_python2_unicode_literals(self) -> None:
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, line_length=ll)
@patch("black.dump_to_file", dump_to_stderr)
def test_stub(self) -> None:
source, expected = read_data("stub.pyi")
actual = fs(source, is_pyi=True)
self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll, is_pyi=True)
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff(self) -> None: def test_fmtonoff(self) -> None:
source, expected = read_data("fmtonoff") source, expected = read_data("fmtonoff")