Fix crashes with comments in parentheses (#4453)

Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
This commit is contained in:
Shantanu 2024-09-15 19:34:02 -07:00 committed by GitHub
parent b4d6d8632d
commit 2a45cecf29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 185 additions and 34 deletions

View File

@ -19,6 +19,9 @@
<!-- Changes that affect Black's stable style -->
- Fix crashes involving comments in parenthesised return types or `X | Y` style unions.
(#4453)
### Preview style
<!-- Changes that affect Black's preview style -->

View File

@ -1079,6 +1079,47 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None
)
def _ensure_trailing_comma(
leaves: List[Leaf], original: Line, opening_bracket: Leaf
) -> bool:
if not leaves:
return False
# Ensure a trailing comma for imports
if original.is_import:
return True
# ...and standalone function arguments
if not original.is_def:
return False
if opening_bracket.value != "(":
return False
# Don't add commas if we already have any commas
if any(
leaf.type == token.COMMA
and (
Preview.typed_params_trailing_comma not in original.mode
or not is_part_of_annotation(leaf)
)
for leaf in leaves
):
return False
# Find a leaf with a parent (comments don't have parents)
leaf_with_parent = next((leaf for leaf in leaves if leaf.parent), None)
if leaf_with_parent is None:
return True
# Don't add commas inside parenthesized return annotations
if get_annotation_type(leaf_with_parent) == "return":
return False
# Don't add commas inside PEP 604 unions
if (
leaf_with_parent.parent
and leaf_with_parent.parent.next_sibling
and leaf_with_parent.parent.next_sibling.type == token.VBAR
):
return False
return True
def bracket_split_build_line(
leaves: List[Leaf],
original: Line,
@ -1099,32 +1140,7 @@ def bracket_split_build_line(
if component is _BracketSplitComponent.body:
result.inside_brackets = True
result.depth += 1
if leaves:
no_commas = (
# Ensure a trailing comma for imports and standalone function arguments
original.is_def
# Don't add one after any comments or within type annotations
and opening_bracket.value == "("
# Don't add one if there's already one there
and not any(
leaf.type == token.COMMA
and (
Preview.typed_params_trailing_comma not in original.mode
or not is_part_of_annotation(leaf)
)
for leaf in leaves
)
# Don't add one inside parenthesized return annotations
and get_annotation_type(leaves[0]) != "return"
# Don't add one inside PEP 604 unions
and not (
leaves[0].parent
and leaves[0].parent.next_sibling
and leaves[0].parent.next_sibling.type == token.VBAR
)
)
if original.is_import or no_commas:
if _ensure_trailing_comma(leaves, original, opening_bracket):
for i in range(len(leaves) - 1, -1, -1):
if leaves[i].type == STANDALONE_COMMENT:
continue

View File

@ -1012,6 +1012,7 @@ def get_annotation_type(leaf: Leaf) -> Literal["return", "param", None]:
def is_part_of_annotation(leaf: Leaf) -> bool:
"""Returns whether this leaf is part of a type annotation."""
assert leaf.parent is not None
return get_annotation_type(leaf) is not None

View File

@ -488,7 +488,7 @@ def do_match(self, line: Line) -> TMatchResult:
break
i += 1
if not is_part_of_annotation(leaf) and not contains_comment:
if not contains_comment and not is_part_of_annotation(leaf):
string_indices.append(idx)
# Advance to the next non-STRING leaf.

View File

@ -142,6 +142,7 @@ def SimplePyFn(
Buffer[UInt8, 2],
Buffer[UInt8, 2],
]: ...
# output
# normal, short, function definition
def foo(a, b) -> tuple[int, float]: ...

View File

@ -60,6 +60,64 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr
argument1, (one, two,), argument4, argument5, argument6
)
def foo() -> (
# comment inside parenthesised return type
int
):
...
def foo() -> (
# comment inside parenthesised return type
# more
int
# another
):
...
def foo() -> (
# comment inside parenthesised new union return type
int | str | bytes
):
...
def foo() -> (
# comment inside plain tuple
):
pass
def foo(arg: (# comment with non-return annotation
int
# comment with non-return annotation
)):
pass
def foo(arg: (# comment with non-return annotation
int | range | memoryview
# comment with non-return annotation
)):
pass
def foo(arg: (# only before
int
)):
pass
def foo(arg: (
int
# only after
)):
pass
variable: ( # annotation
because
# why not
)
variable: (
because
# why not
)
# output
def f(
@ -176,3 +234,75 @@ def func() -> (
argument5,
argument6,
)
def foo() -> (
# comment inside parenthesised return type
int
): ...
def foo() -> (
# comment inside parenthesised return type
# more
int
# another
): ...
def foo() -> (
# comment inside parenthesised new union return type
int
| str
| bytes
): ...
def foo() -> (
# comment inside plain tuple
):
pass
def foo(
arg: ( # comment with non-return annotation
int
# comment with non-return annotation
),
):
pass
def foo(
arg: ( # comment with non-return annotation
int
| range
| memoryview
# comment with non-return annotation
),
):
pass
def foo(arg: int): # only before
pass
def foo(
arg: (
int
# only after
),
):
pass
variable: ( # annotation
because
# why not
)
variable: (
because
# why not
)