Correctly handle trailing commas that are inside a line's leading non-nested parens (#3370)

- Fixes #1671
- Fixes #3229
This commit is contained in:
Yilei "Dolee" Yang 2022-11-09 15:08:51 -08:00 committed by GitHub
parent ffaaf48382
commit 8091b2503d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 240 additions and 16 deletions

View File

@ -17,6 +17,8 @@
- Enforce empty lines before classes and functions with sticky leading comments (#3302)
- Implicitly concatenated strings used as function args are now wrapped inside
parentheses (#3307)
- Correctly handle trailing commas that are inside a line's leading non-nested parens
(#3370)
### Configuration

View File

@ -2,7 +2,7 @@
import sys
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
if sys.version_info < (3, 8):
from typing_extensions import Final
@ -340,3 +340,35 @@ def max_delimiter_priority_in_atom(node: LN) -> Priority:
except ValueError:
return 0
def get_leaves_inside_matching_brackets(leaves: Sequence[Leaf]) -> Set[LeafID]:
"""Return leaves that are inside matching brackets.
The input `leaves` can have non-matching brackets at the head or tail parts.
Matching brackets are included.
"""
try:
# Only track brackets from the first opening bracket to the last closing
# bracket.
start_index = next(
i for i, l in enumerate(leaves) if l.type in OPENING_BRACKETS
)
end_index = next(
len(leaves) - i
for i, l in enumerate(reversed(leaves))
if l.type in CLOSING_BRACKETS
)
except StopIteration:
return set()
ids = set()
depth = 0
for i in range(end_index, start_index - 1, -1):
leaf = leaves[i]
if leaf.type in CLOSING_BRACKETS:
depth += 1
if depth > 0:
ids.add(id(leaf))
if leaf.type in OPENING_BRACKETS:
depth -= 1
return ids

View File

@ -2,10 +2,16 @@
Generating lines of code.
"""
import sys
from enum import Enum, auto
from functools import partial, wraps
from typing import Collection, Iterator, List, Optional, Set, Union, cast
from black.brackets import COMMA_PRIORITY, DOT_PRIORITY, max_delimiter_priority_in_atom
from black.brackets import (
COMMA_PRIORITY,
DOT_PRIORITY,
get_leaves_inside_matching_brackets,
max_delimiter_priority_in_atom,
)
from black.comments import FMT_OFF, generate_comments, list_comments
from black.lines import (
Line,
@ -561,6 +567,12 @@ def _rhs(
yield line
class _BracketSplitComponent(Enum):
head = auto()
body = auto()
tail = auto()
def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
"""Split line into many lines, starting with the first matching bracket pair.
@ -591,9 +603,15 @@ def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator
if not matching_bracket:
raise CannotSplit("No brackets found")
head = bracket_split_build_line(head_leaves, line, matching_bracket)
body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
head = bracket_split_build_line(
head_leaves, line, matching_bracket, component=_BracketSplitComponent.head
)
body = bracket_split_build_line(
body_leaves, line, matching_bracket, component=_BracketSplitComponent.body
)
tail = bracket_split_build_line(
tail_leaves, line, matching_bracket, component=_BracketSplitComponent.tail
)
bracket_split_succeeded_or_raise(head, body, tail)
for result in (head, body, tail):
if result:
@ -639,9 +657,15 @@ def right_hand_split(
tail_leaves.reverse()
body_leaves.reverse()
head_leaves.reverse()
head = bracket_split_build_line(head_leaves, line, opening_bracket)
body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
head = bracket_split_build_line(
head_leaves, line, opening_bracket, component=_BracketSplitComponent.head
)
body = bracket_split_build_line(
body_leaves, line, opening_bracket, component=_BracketSplitComponent.body
)
tail = bracket_split_build_line(
tail_leaves, line, opening_bracket, component=_BracketSplitComponent.tail
)
bracket_split_succeeded_or_raise(head, body, tail)
if (
Feature.FORCE_OPTIONAL_PARENTHESES not in features
@ -715,15 +739,23 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None
def bracket_split_build_line(
leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
leaves: List[Leaf],
original: Line,
opening_bracket: Leaf,
*,
component: _BracketSplitComponent,
) -> Line:
"""Return a new line with given `leaves` and respective comments from `original`.
If `is_body` is True, the result line is one-indented inside brackets and as such
has its first leaf's prefix normalized and a trailing comma added when expected.
If it's the head component, brackets will be tracked so trailing commas are
respected.
If it's the body component, the result line is one-indented inside brackets and as
such has its first leaf's prefix normalized and a trailing comma added when
expected.
"""
result = Line(mode=original.mode, depth=original.depth)
if is_body:
if component is _BracketSplitComponent.body:
result.inside_brackets = True
result.depth += 1
if leaves:
@ -761,12 +793,24 @@ def bracket_split_build_line(
leaves.insert(i + 1, new_comma)
break
leaves_to_track: Set[LeafID] = set()
if (
Preview.handle_trailing_commas_in_head in original.mode
and component is _BracketSplitComponent.head
):
leaves_to_track = get_leaves_inside_matching_brackets(leaves)
# Populate the line
for leaf in leaves:
result.append(leaf, preformatted=True)
result.append(
leaf,
preformatted=True,
track_bracket=id(leaf) in leaves_to_track,
)
for comment_after in original.comments_after(leaf):
result.append(comment_after, preformatted=True)
if is_body and should_split_line(result, opening_bracket):
if component is _BracketSplitComponent.body and should_split_line(
result, opening_bracket
):
result.should_split_rhs = True
return result

View File

@ -53,7 +53,9 @@ class Line:
should_split_rhs: bool = False
magic_trailing_comma: Optional[Leaf] = None
def append(self, leaf: Leaf, preformatted: bool = False) -> None:
def append(
self, leaf: Leaf, preformatted: bool = False, track_bracket: bool = False
) -> None:
"""Add a new `leaf` to the end of the line.
Unless `preformatted` is True, the `leaf` will receive a new consistent
@ -75,7 +77,7 @@ def append(self, leaf: Leaf, preformatted: bool = False) -> None:
leaf.prefix += whitespace(
leaf, complex_subscript=self.is_complex_subscript(leaf)
)
if self.inside_brackets or not preformatted:
if self.inside_brackets or not preformatted or track_bracket:
self.bracket_tracker.mark(leaf)
if self.mode.magic_trailing_comma:
if self.has_magic_trailing_comma(leaf):

View File

@ -151,6 +151,7 @@ class Preview(Enum):
annotation_parens = auto()
empty_lines_before_class_or_def_with_leading_comments = auto()
handle_trailing_commas_in_head = auto()
long_docstring_quotes_on_newline = auto()
normalize_docstring_quotes_and_prefixes_properly = auto()
one_element_subscript = auto()

View File

@ -15,6 +15,37 @@
# Except single element tuples
small_tuple = (1,)
# Trailing commas in multiple chained non-nested parens.
zero(
one,
).two(
three,
).four(
five,
)
func1(arg1).func2(arg2,).func3(arg3).func4(arg4,).func5(arg5)
(
a,
b,
c,
d,
) = func1(
arg1
) and func2(arg2)
func(
argument1,
(
one,
two,
),
argument4,
argument5,
argument6,
)
# output
# We should not remove the trailing comma in a single-element subscript.
a: tuple[int,]
@ -32,3 +63,12 @@
# Except single element tuples
small_tuple = (1,)
# Trailing commas in multiple chained non-nested parens.
zero(one).two(three).four(five)
func1(arg1).func2(arg2).func3(arg3).func4(arg4).func5(arg5)
(a, b, c, d) = func1(arg1) and func2(arg2)
func(argument1, (one, two), argument4, argument5, argument6)

View File

@ -0,0 +1,74 @@
zero(one,).two(three,).four(five,)
func1(arg1).func2(arg2,).func3(arg3).func4(arg4,).func5(arg5)
# Inner one-element tuple shouldn't explode
func1(arg1).func2(arg1, (one_tuple,)).func3(arg3)
(a, b, c, d,) = func1(arg1) and func2(arg2)
# Example from https://github.com/psf/black/issues/3229
def refresh_token(self, device_family, refresh_token, api_key):
return self.orchestration.refresh_token(
data={
"refreshToken": refresh_token,
},
api_key=api_key,
)["extensions"]["sdk"]["token"]
# Edge case where a bug in a working-in-progress version of
# https://github.com/psf/black/pull/3370 causes an infinite recursion.
assert (
long_module.long_class.long_func().another_func()
== long_module.long_class.long_func()["some_key"].another_func(arg1)
)
# output
zero(
one,
).two(
three,
).four(
five,
)
func1(arg1).func2(
arg2,
).func3(arg3).func4(
arg4,
).func5(arg5)
# Inner one-element tuple shouldn't explode
func1(arg1).func2(arg1, (one_tuple,)).func3(arg3)
(
a,
b,
c,
d,
) = func1(
arg1
) and func2(arg2)
# Example from https://github.com/psf/black/issues/3229
def refresh_token(self, device_family, refresh_token, api_key):
return self.orchestration.refresh_token(
data={
"refreshToken": refresh_token,
},
api_key=api_key,
)["extensions"]["sdk"]["token"]
# Edge case where a bug in a working-in-progress version of
# https://github.com/psf/black/pull/3370 causes an infinite recursion.
assert (
long_module.long_class.long_func().another_func()
== long_module.long_class.long_func()["some_key"].another_func(arg1)
)

View File

@ -49,6 +49,17 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr
):
pass
# Make sure inner one-element tuple won't explode
some_module.some_function(
argument1, (one_element_tuple,), argument4, argument5, argument6
)
# Inner trailing comma causes outer to explode
some_module.some_function(
argument1, (one, two,), argument4, argument5, argument6
)
# output
def f(
@ -151,3 +162,21 @@ def func() -> (
)
):
pass
# Make sure inner one-element tuple won't explode
some_module.some_function(
argument1, (one_element_tuple,), argument4, argument5, argument6
)
# Inner trailing comma causes outer to explode
some_module.some_function(
argument1,
(
one,
two,
),
argument4,
argument5,
argument6,
)