Fix trailing comma for function with one arg (#880) (#891)

Modified maybe_remove_trailing_comma to remove trailing commas for
typedarglists (in addition to arglists), and updated line split logic
to ensure that all lines in a function definition that contain only one
arg have a trailing comma.
This commit is contained in:
dylanjblack 2019-06-15 14:49:49 +10:00 committed by Jelle Zijlstra
parent 1bbb01b854
commit 9394de150e
3 changed files with 33 additions and 4 deletions

View File

@ -1352,7 +1352,10 @@ def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
bracket_depth = leaf.bracket_depth
if bracket_depth == depth and leaf.type == token.COMMA:
commas += 1
if leaf.parent and leaf.parent.type == syms.arglist:
if leaf.parent and leaf.parent.type in {
syms.arglist,
syms.typedargslist,
}:
commas += 1
break
@ -2488,9 +2491,13 @@ def bracket_split_build_line(
if leaves:
# Since body is a new indent level, remove spurious leading whitespace.
normalize_prefix(leaves[0], inside_brackets=True)
# Ensure a trailing comma for imports, but be careful not to add one after
# any comments.
if original.is_import:
# Ensure a trailing comma for imports and standalone function arguments, but
# be careful not to add one after any comments.
no_commas = original.is_def and not any(
l.type == token.COMMA for l in leaves
)
if original.is_import or no_commas:
for i in range(len(leaves) - 1, -1, -1):
if leaves[i].type == STANDALONE_COMMENT:
continue

View File

@ -0,0 +1,14 @@
def f(a,):
...
def f(a:int=1,):
...
# output
def f(a):
...
def f(a: int = 1):
...

View File

@ -264,6 +264,14 @@ def test_function2(self) -> None:
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_function_trailing_comma(self) -> None:
source, expected = read_data("function_trailing_comma")
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_expression(self) -> None:
source, expected = read_data("expression")