respect magic trailing commas in return types (#3916)
This commit is contained in:
parent
947bd3825e
commit
36078bc83f
@ -18,6 +18,7 @@
|
||||
|
||||
- Long type hints are now wrapped in parentheses and properly indented when split across
|
||||
multiple lines (#3899)
|
||||
- Magic trailing commas are now respected in return types. (#3916)
|
||||
|
||||
### Configuration
|
||||
|
||||
|
@ -573,7 +573,7 @@ def transform_line(
|
||||
transformers = [string_merge, string_paren_strip]
|
||||
else:
|
||||
transformers = []
|
||||
elif line.is_def:
|
||||
elif line.is_def and not should_split_funcdef_with_rhs(line, mode):
|
||||
transformers = [left_hand_split]
|
||||
else:
|
||||
|
||||
@ -652,6 +652,40 @@ def _rhs(
|
||||
yield line
|
||||
|
||||
|
||||
def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:
|
||||
"""If a funcdef has a magic trailing comma in the return type, then we should first
|
||||
split the line with rhs to respect the comma.
|
||||
"""
|
||||
if Preview.respect_magic_trailing_comma_in_return_type not in mode:
|
||||
return False
|
||||
|
||||
return_type_leaves: List[Leaf] = []
|
||||
in_return_type = False
|
||||
|
||||
for leaf in line.leaves:
|
||||
if leaf.type == token.COLON:
|
||||
in_return_type = False
|
||||
if in_return_type:
|
||||
return_type_leaves.append(leaf)
|
||||
if leaf.type == token.RARROW:
|
||||
in_return_type = True
|
||||
|
||||
# using `bracket_split_build_line` will mess with whitespace, so we duplicate a
|
||||
# couple lines from it.
|
||||
result = Line(mode=line.mode, depth=line.depth)
|
||||
leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)
|
||||
for leaf in return_type_leaves:
|
||||
result.append(
|
||||
leaf,
|
||||
preformatted=True,
|
||||
track_bracket=id(leaf) in leaves_to_track,
|
||||
)
|
||||
|
||||
# we could also return true if the line is too long, and the return type is longer
|
||||
# than the param list. Or if `should_split_rhs` returns True.
|
||||
return result.magic_trailing_comma is not None
|
||||
|
||||
|
||||
class _BracketSplitComponent(Enum):
|
||||
head = auto()
|
||||
body = auto()
|
||||
|
@ -181,6 +181,7 @@ class Preview(Enum):
|
||||
string_processing = auto()
|
||||
parenthesize_conditional_expressions = auto()
|
||||
parenthesize_long_type_hints = auto()
|
||||
respect_magic_trailing_comma_in_return_type = auto()
|
||||
skip_magic_trailing_comma_in_subscript = auto()
|
||||
wrap_long_dict_values_in_parens = auto()
|
||||
wrap_multiple_context_managers_in_parens = auto()
|
||||
|
@ -2,6 +2,10 @@
|
||||
def frobnicate() -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
|
||||
pass
|
||||
|
||||
# splitting the string breaks if there's any parameters
|
||||
def frobnicate(a) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
|
||||
pass
|
||||
|
||||
# output
|
||||
|
||||
# Long string example
|
||||
@ -10,3 +14,10 @@ def frobnicate() -> (
|
||||
" list[ThisIsTrulyUnreasonablyExtremelyLongClassName]"
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# splitting the string breaks if there's any parameters
|
||||
def frobnicate(
|
||||
a,
|
||||
) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
|
||||
pass
|
||||
|
300
tests/data/preview_py_310/funcdef_return_type_trailing_comma.py
Normal file
300
tests/data/preview_py_310/funcdef_return_type_trailing_comma.py
Normal file
@ -0,0 +1,300 @@
|
||||
# normal, short, function definition
|
||||
def foo(a, b) -> tuple[int, float]: ...
|
||||
|
||||
|
||||
# normal, short, function definition w/o return type
|
||||
def foo(a, b): ...
|
||||
|
||||
|
||||
# no splitting
|
||||
def foo(a: A, b: B) -> list[p, q]:
|
||||
pass
|
||||
|
||||
|
||||
# magic trailing comma in param list
|
||||
def foo(a, b,): ...
|
||||
|
||||
|
||||
# magic trailing comma in nested params in param list
|
||||
def foo(a, b: tuple[int, float,]): ...
|
||||
|
||||
|
||||
# magic trailing comma in return type, no params
|
||||
def a() -> tuple[
|
||||
a,
|
||||
b,
|
||||
]: ...
|
||||
|
||||
|
||||
# magic trailing comma in return type, params
|
||||
def foo(a: A, b: B) -> list[
|
||||
p,
|
||||
q,
|
||||
]:
|
||||
pass
|
||||
|
||||
|
||||
# magic trailing comma in param list and in return type
|
||||
def foo(
|
||||
a: a,
|
||||
b: b,
|
||||
) -> list[
|
||||
a,
|
||||
a,
|
||||
]:
|
||||
pass
|
||||
|
||||
|
||||
# long function definition, param list is longer
|
||||
def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
|
||||
bbbbbbbbbbbbbbbbbb,
|
||||
) -> cccccccccccccccccccccccccccccc: ...
|
||||
|
||||
|
||||
# long function definition, return type is longer
|
||||
# this should maybe split on rhs?
|
||||
def aaaaaaaaaaaaaaaaa(bbbbbbbbbbbbbbbbbb) -> list[
|
||||
Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd
|
||||
]: ...
|
||||
|
||||
|
||||
# long return type, no param list
|
||||
def foo() -> list[
|
||||
Loooooooooooooooooooooooooooooooooooong,
|
||||
Loooooooooooooooooooong,
|
||||
Looooooooooooong,
|
||||
]: ...
|
||||
|
||||
|
||||
# long function name, no param list, no return value
|
||||
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
|
||||
pass
|
||||
|
||||
|
||||
# long function name, no param list
|
||||
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
|
||||
list[int, float]
|
||||
): ...
|
||||
|
||||
|
||||
# long function name, no return value
|
||||
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
|
||||
a, b
|
||||
): ...
|
||||
|
||||
|
||||
# unskippable type hint (??)
|
||||
def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
def foo(a) -> list[
|
||||
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
|
||||
]: # abpedeifnore
|
||||
pass
|
||||
|
||||
def foo(a, b: list[Bad],): ... # type: ignore
|
||||
|
||||
# don't lose any comments (no magic)
|
||||
def foo( # 1
|
||||
a, # 2
|
||||
b) -> list[ # 3
|
||||
a, # 4
|
||||
b]: # 5
|
||||
... # 6
|
||||
|
||||
|
||||
# don't lose any comments (param list magic)
|
||||
def foo( # 1
|
||||
a, # 2
|
||||
b,) -> list[ # 3
|
||||
a, # 4
|
||||
b]: # 5
|
||||
... # 6
|
||||
|
||||
|
||||
# don't lose any comments (return type magic)
|
||||
def foo( # 1
|
||||
a, # 2
|
||||
b) -> list[ # 3
|
||||
a, # 4
|
||||
b,]: # 5
|
||||
... # 6
|
||||
|
||||
|
||||
# don't lose any comments (both magic)
|
||||
def foo( # 1
|
||||
a, # 2
|
||||
b,) -> list[ # 3
|
||||
a, # 4
|
||||
b,]: # 5
|
||||
... # 6
|
||||
|
||||
# real life example
|
||||
def SimplePyFn(
|
||||
context: hl.GeneratorContext,
|
||||
buffer_input: Buffer[UInt8, 2],
|
||||
func_input: Buffer[Int32, 2],
|
||||
float_arg: Scalar[Float32],
|
||||
offset: int = 0,
|
||||
) -> tuple[
|
||||
Buffer[UInt8, 2],
|
||||
Buffer[UInt8, 2],
|
||||
]: ...
|
||||
# output
|
||||
# normal, short, function definition
|
||||
def foo(a, b) -> tuple[int, float]: ...
|
||||
|
||||
|
||||
# normal, short, function definition w/o return type
|
||||
def foo(a, b): ...
|
||||
|
||||
|
||||
# no splitting
|
||||
def foo(a: A, b: B) -> list[p, q]:
|
||||
pass
|
||||
|
||||
|
||||
# magic trailing comma in param list
|
||||
def foo(
|
||||
a,
|
||||
b,
|
||||
): ...
|
||||
|
||||
|
||||
# magic trailing comma in nested params in param list
|
||||
def foo(
|
||||
a,
|
||||
b: tuple[
|
||||
int,
|
||||
float,
|
||||
],
|
||||
): ...
|
||||
|
||||
|
||||
# magic trailing comma in return type, no params
|
||||
def a() -> tuple[
|
||||
a,
|
||||
b,
|
||||
]: ...
|
||||
|
||||
|
||||
# magic trailing comma in return type, params
|
||||
def foo(a: A, b: B) -> list[
|
||||
p,
|
||||
q,
|
||||
]:
|
||||
pass
|
||||
|
||||
|
||||
# magic trailing comma in param list and in return type
|
||||
def foo(
|
||||
a: a,
|
||||
b: b,
|
||||
) -> list[
|
||||
a,
|
||||
a,
|
||||
]:
|
||||
pass
|
||||
|
||||
|
||||
# long function definition, param list is longer
|
||||
def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
|
||||
bbbbbbbbbbbbbbbbbb,
|
||||
) -> cccccccccccccccccccccccccccccc: ...
|
||||
|
||||
|
||||
# long function definition, return type is longer
|
||||
# this should maybe split on rhs?
|
||||
def aaaaaaaaaaaaaaaaa(
|
||||
bbbbbbbbbbbbbbbbbb,
|
||||
) -> list[Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd]: ...
|
||||
|
||||
|
||||
# long return type, no param list
|
||||
def foo() -> list[
|
||||
Loooooooooooooooooooooooooooooooooooong,
|
||||
Loooooooooooooooooooong,
|
||||
Looooooooooooong,
|
||||
]: ...
|
||||
|
||||
|
||||
# long function name, no param list, no return value
|
||||
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
|
||||
pass
|
||||
|
||||
|
||||
# long function name, no param list
|
||||
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
|
||||
list[int, float]
|
||||
): ...
|
||||
|
||||
|
||||
# long function name, no return value
|
||||
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
|
||||
a, b
|
||||
): ...
|
||||
|
||||
|
||||
# unskippable type hint (??)
|
||||
def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
def foo(
|
||||
a,
|
||||
) -> list[
|
||||
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
|
||||
]: # abpedeifnore
|
||||
pass
|
||||
|
||||
|
||||
def foo(
|
||||
a,
|
||||
b: list[Bad],
|
||||
): ... # type: ignore
|
||||
|
||||
|
||||
# don't lose any comments (no magic)
|
||||
def foo(a, b) -> list[a, b]: # 1 # 2 # 3 # 4 # 5
|
||||
... # 6
|
||||
|
||||
|
||||
# don't lose any comments (param list magic)
|
||||
def foo( # 1
|
||||
a, # 2
|
||||
b,
|
||||
) -> list[a, b]: # 3 # 4 # 5
|
||||
... # 6
|
||||
|
||||
|
||||
# don't lose any comments (return type magic)
|
||||
def foo(a, b) -> list[ # 1 # 2 # 3
|
||||
a, # 4
|
||||
b,
|
||||
]: # 5
|
||||
... # 6
|
||||
|
||||
|
||||
# don't lose any comments (both magic)
|
||||
def foo( # 1
|
||||
a, # 2
|
||||
b,
|
||||
) -> list[ # 3
|
||||
a, # 4
|
||||
b,
|
||||
]: # 5
|
||||
... # 6
|
||||
|
||||
|
||||
# real life example
|
||||
def SimplePyFn(
|
||||
context: hl.GeneratorContext,
|
||||
buffer_input: Buffer[UInt8, 2],
|
||||
func_input: Buffer[Int32, 2],
|
||||
float_arg: Scalar[Float32],
|
||||
offset: int = 0,
|
||||
) -> tuple[
|
||||
Buffer[UInt8, 2],
|
||||
Buffer[UInt8, 2],
|
||||
]: ...
|
@ -87,6 +87,11 @@ def foo() -> tuple[loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo
|
||||
def foo() -> tuple[int, int, int,]:
|
||||
return 2
|
||||
|
||||
# Magic trailing comma example, with params
|
||||
# this is broken - the trailing comma is transferred to the param list. Fixed in preview
|
||||
def foo(a,b) -> tuple[int, int, int,]:
|
||||
return 2
|
||||
|
||||
# output
|
||||
# Control
|
||||
def double(a: int) -> int:
|
||||
@ -208,3 +213,11 @@ def foo() -> (
|
||||
]
|
||||
):
|
||||
return 2
|
||||
|
||||
|
||||
# Magic trailing comma example, with params
|
||||
# this is broken - the trailing comma is transferred to the param list. Fixed in preview
|
||||
def foo(
|
||||
a, b
|
||||
) -> tuple[int, int, int,]:
|
||||
return 2
|
||||
|
Loading…
Reference in New Issue
Block a user