Fix crashes with comments in parentheses (#4453)
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
This commit is contained in:
parent
b4d6d8632d
commit
2a45cecf29
@ -19,6 +19,9 @@
|
|||||||
|
|
||||||
<!-- Changes that affect Black's stable style -->
|
<!-- Changes that affect Black's stable style -->
|
||||||
|
|
||||||
|
- Fix crashes involving comments in parenthesised return types or `X | Y` style unions.
|
||||||
|
(#4453)
|
||||||
|
|
||||||
### Preview style
|
### Preview style
|
||||||
|
|
||||||
<!-- Changes that affect Black's preview style -->
|
<!-- Changes that affect Black's preview style -->
|
||||||
|
@ -1079,6 +1079,47 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_trailing_comma(
|
||||||
|
leaves: List[Leaf], original: Line, opening_bracket: Leaf
|
||||||
|
) -> bool:
|
||||||
|
if not leaves:
|
||||||
|
return False
|
||||||
|
# Ensure a trailing comma for imports
|
||||||
|
if original.is_import:
|
||||||
|
return True
|
||||||
|
# ...and standalone function arguments
|
||||||
|
if not original.is_def:
|
||||||
|
return False
|
||||||
|
if opening_bracket.value != "(":
|
||||||
|
return False
|
||||||
|
# Don't add commas if we already have any commas
|
||||||
|
if any(
|
||||||
|
leaf.type == token.COMMA
|
||||||
|
and (
|
||||||
|
Preview.typed_params_trailing_comma not in original.mode
|
||||||
|
or not is_part_of_annotation(leaf)
|
||||||
|
)
|
||||||
|
for leaf in leaves
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Find a leaf with a parent (comments don't have parents)
|
||||||
|
leaf_with_parent = next((leaf for leaf in leaves if leaf.parent), None)
|
||||||
|
if leaf_with_parent is None:
|
||||||
|
return True
|
||||||
|
# Don't add commas inside parenthesized return annotations
|
||||||
|
if get_annotation_type(leaf_with_parent) == "return":
|
||||||
|
return False
|
||||||
|
# Don't add commas inside PEP 604 unions
|
||||||
|
if (
|
||||||
|
leaf_with_parent.parent
|
||||||
|
and leaf_with_parent.parent.next_sibling
|
||||||
|
and leaf_with_parent.parent.next_sibling.type == token.VBAR
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def bracket_split_build_line(
|
def bracket_split_build_line(
|
||||||
leaves: List[Leaf],
|
leaves: List[Leaf],
|
||||||
original: Line,
|
original: Line,
|
||||||
@ -1099,40 +1140,15 @@ def bracket_split_build_line(
|
|||||||
if component is _BracketSplitComponent.body:
|
if component is _BracketSplitComponent.body:
|
||||||
result.inside_brackets = True
|
result.inside_brackets = True
|
||||||
result.depth += 1
|
result.depth += 1
|
||||||
if leaves:
|
if _ensure_trailing_comma(leaves, original, opening_bracket):
|
||||||
no_commas = (
|
for i in range(len(leaves) - 1, -1, -1):
|
||||||
# Ensure a trailing comma for imports and standalone function arguments
|
if leaves[i].type == STANDALONE_COMMENT:
|
||||||
original.is_def
|
continue
|
||||||
# Don't add one after any comments or within type annotations
|
|
||||||
and opening_bracket.value == "("
|
|
||||||
# Don't add one if there's already one there
|
|
||||||
and not any(
|
|
||||||
leaf.type == token.COMMA
|
|
||||||
and (
|
|
||||||
Preview.typed_params_trailing_comma not in original.mode
|
|
||||||
or not is_part_of_annotation(leaf)
|
|
||||||
)
|
|
||||||
for leaf in leaves
|
|
||||||
)
|
|
||||||
# Don't add one inside parenthesized return annotations
|
|
||||||
and get_annotation_type(leaves[0]) != "return"
|
|
||||||
# Don't add one inside PEP 604 unions
|
|
||||||
and not (
|
|
||||||
leaves[0].parent
|
|
||||||
and leaves[0].parent.next_sibling
|
|
||||||
and leaves[0].parent.next_sibling.type == token.VBAR
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if original.is_import or no_commas:
|
if leaves[i].type != token.COMMA:
|
||||||
for i in range(len(leaves) - 1, -1, -1):
|
new_comma = Leaf(token.COMMA, ",")
|
||||||
if leaves[i].type == STANDALONE_COMMENT:
|
leaves.insert(i + 1, new_comma)
|
||||||
continue
|
break
|
||||||
|
|
||||||
if leaves[i].type != token.COMMA:
|
|
||||||
new_comma = Leaf(token.COMMA, ",")
|
|
||||||
leaves.insert(i + 1, new_comma)
|
|
||||||
break
|
|
||||||
|
|
||||||
leaves_to_track: Set[LeafID] = set()
|
leaves_to_track: Set[LeafID] = set()
|
||||||
if component is _BracketSplitComponent.head:
|
if component is _BracketSplitComponent.head:
|
||||||
|
@ -1012,6 +1012,7 @@ def get_annotation_type(leaf: Leaf) -> Literal["return", "param", None]:
|
|||||||
|
|
||||||
def is_part_of_annotation(leaf: Leaf) -> bool:
|
def is_part_of_annotation(leaf: Leaf) -> bool:
|
||||||
"""Returns whether this leaf is part of a type annotation."""
|
"""Returns whether this leaf is part of a type annotation."""
|
||||||
|
assert leaf.parent is not None
|
||||||
return get_annotation_type(leaf) is not None
|
return get_annotation_type(leaf) is not None
|
||||||
|
|
||||||
|
|
||||||
|
@ -488,7 +488,7 @@ def do_match(self, line: Line) -> TMatchResult:
|
|||||||
break
|
break
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
if not is_part_of_annotation(leaf) and not contains_comment:
|
if not contains_comment and not is_part_of_annotation(leaf):
|
||||||
string_indices.append(idx)
|
string_indices.append(idx)
|
||||||
|
|
||||||
# Advance to the next non-STRING leaf.
|
# Advance to the next non-STRING leaf.
|
||||||
|
@ -142,6 +142,7 @@ def SimplePyFn(
|
|||||||
Buffer[UInt8, 2],
|
Buffer[UInt8, 2],
|
||||||
Buffer[UInt8, 2],
|
Buffer[UInt8, 2],
|
||||||
]: ...
|
]: ...
|
||||||
|
|
||||||
# output
|
# output
|
||||||
# normal, short, function definition
|
# normal, short, function definition
|
||||||
def foo(a, b) -> tuple[int, float]: ...
|
def foo(a, b) -> tuple[int, float]: ...
|
||||||
|
@ -60,6 +60,64 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr
|
|||||||
argument1, (one, two,), argument4, argument5, argument6
|
argument1, (one, two,), argument4, argument5, argument6
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def foo() -> (
|
||||||
|
# comment inside parenthesised return type
|
||||||
|
int
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
def foo() -> (
|
||||||
|
# comment inside parenthesised return type
|
||||||
|
# more
|
||||||
|
int
|
||||||
|
# another
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
def foo() -> (
|
||||||
|
# comment inside parenthesised new union return type
|
||||||
|
int | str | bytes
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
def foo() -> (
|
||||||
|
# comment inside plain tuple
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def foo(arg: (# comment with non-return annotation
|
||||||
|
int
|
||||||
|
# comment with non-return annotation
|
||||||
|
)):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def foo(arg: (# comment with non-return annotation
|
||||||
|
int | range | memoryview
|
||||||
|
# comment with non-return annotation
|
||||||
|
)):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def foo(arg: (# only before
|
||||||
|
int
|
||||||
|
)):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def foo(arg: (
|
||||||
|
int
|
||||||
|
# only after
|
||||||
|
)):
|
||||||
|
pass
|
||||||
|
|
||||||
|
variable: ( # annotation
|
||||||
|
because
|
||||||
|
# why not
|
||||||
|
)
|
||||||
|
|
||||||
|
variable: (
|
||||||
|
because
|
||||||
|
# why not
|
||||||
|
)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
|
|
||||||
def f(
|
def f(
|
||||||
@ -176,3 +234,75 @@ def func() -> (
|
|||||||
argument5,
|
argument5,
|
||||||
argument6,
|
argument6,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def foo() -> (
|
||||||
|
# comment inside parenthesised return type
|
||||||
|
int
|
||||||
|
): ...
|
||||||
|
|
||||||
|
|
||||||
|
def foo() -> (
|
||||||
|
# comment inside parenthesised return type
|
||||||
|
# more
|
||||||
|
int
|
||||||
|
# another
|
||||||
|
): ...
|
||||||
|
|
||||||
|
|
||||||
|
def foo() -> (
|
||||||
|
# comment inside parenthesised new union return type
|
||||||
|
int
|
||||||
|
| str
|
||||||
|
| bytes
|
||||||
|
): ...
|
||||||
|
|
||||||
|
|
||||||
|
def foo() -> (
|
||||||
|
# comment inside plain tuple
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def foo(
|
||||||
|
arg: ( # comment with non-return annotation
|
||||||
|
int
|
||||||
|
# comment with non-return annotation
|
||||||
|
),
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def foo(
|
||||||
|
arg: ( # comment with non-return annotation
|
||||||
|
int
|
||||||
|
| range
|
||||||
|
| memoryview
|
||||||
|
# comment with non-return annotation
|
||||||
|
),
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def foo(arg: int): # only before
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def foo(
|
||||||
|
arg: (
|
||||||
|
int
|
||||||
|
# only after
|
||||||
|
),
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
variable: ( # annotation
|
||||||
|
because
|
||||||
|
# why not
|
||||||
|
)
|
||||||
|
|
||||||
|
variable: (
|
||||||
|
because
|
||||||
|
# why not
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user