From dc2d1046eecefa34bccad6ec116d021064ba4324 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 11 Dec 2023 13:39:36 -0800 Subject: [PATCH] Fix another dummy impl case --- src/black/linegen.py | 6 ++-- src/black/nodes.py | 13 ++++---- tests/data/cases/expression.py | 18 +++++++---- tests/data/cases/function.py | 3 +- tests/data/cases/pattern_matching_extras.py | 6 ++-- .../expression_skip_magic_trailing_comma.diff | 30 ++++++++++++------- 6 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/black/linegen.py b/src/black/linegen.py index 1745b95..4760c4d 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -42,6 +42,7 @@ is_atom_with_invisible_parens, is_docstring, is_empty_tuple, + is_function_or_class, is_lpar_token, is_multiline_string, is_name_token, @@ -293,9 +294,8 @@ def visit_simple_stmt(self, node: Node) -> Iterator[Line]: wrap_in_parentheses(node, child, visible=False) prev_type = child.type - is_suite_like = node.parent and node.parent.type in STATEMENT - if is_suite_like: - if is_stub_body(node): + if node.parent and node.parent.type in STATEMENT: + if is_stub_body(node) and is_function_or_class(node.parent): yield from self.visit_default(node) else: yield from self.line(+1) diff --git a/src/black/nodes.py b/src/black/nodes.py index da53fa2..36b2c5b 100644 --- a/src/black/nodes.py +++ b/src/black/nodes.py @@ -735,15 +735,14 @@ def is_funcdef(node: Node) -> bool: return node.type == syms.funcdef +def is_function_or_class(node: Node) -> bool: + return node.type in {syms.funcdef, syms.classdef, syms.async_funcdef} + + def is_stub_suite(node: Node) -> bool: """Return True if `node` is a suite with a stub body.""" - if node.parent is not None: - if node.parent.type not in ( - syms.funcdef, - syms.async_funcdef, - syms.classdef, - ): - return False + if node.parent is not None and not is_function_or_class(node.parent): + return False # If there is a comment, we want to keep it. if node.prefix.strip(): diff --git a/tests/data/cases/expression.py b/tests/data/cases/expression.py index 613d2d3..761b33c 100644 --- a/tests/data/cases/expression.py +++ b/tests/data/cases/expression.py @@ -514,12 +514,18 @@ async def f(): force=False ), "Short message" assert parens is TooMany -for (x,) in (1,), (2,), (3,): ... -for y in (): ... -for z in (i for i in (1, 2, 3)): ... -for i in call(): ... -for j in 1 + (2 + 3): ... -while this and that: ... +for (x,) in (1,), (2,), (3,): + ... +for y in (): + ... +for z in (i for i in (1, 2, 3)): + ... +for i in call(): + ... +for j in 1 + (2 + 3): + ... +while this and that: + ... for ( addr_family, addr_type, diff --git a/tests/data/cases/function.py b/tests/data/cases/function.py index 8aba756..4e3f91f 100644 --- a/tests/data/cases/function.py +++ b/tests/data/cases/function.py @@ -114,7 +114,8 @@ def func_no_args(): c if True: raise RuntimeError - if False: ... + if False: + ... for i in range(10): print(i) continue diff --git a/tests/data/cases/pattern_matching_extras.py b/tests/data/cases/pattern_matching_extras.py index df6ef4b..1aef8f1 100644 --- a/tests/data/cases/pattern_matching_extras.py +++ b/tests/data/cases/pattern_matching_extras.py @@ -24,8 +24,10 @@ def func(match: case, case: match) -> case: match Something(): - case func(match, case): ... - case another: ... + case func(match, case): + ... + case another: + ... match a, *b, c: diff --git a/tests/data/miscellaneous/expression_skip_magic_trailing_comma.diff b/tests/data/miscellaneous/expression_skip_magic_trailing_comma.diff index d20ad0d..8d0f1ce 100644 --- a/tests/data/miscellaneous/expression_skip_magic_trailing_comma.diff +++ b/tests/data/miscellaneous/expression_skip_magic_trailing_comma.diff @@ -167,7 +167,7 @@ slice[0:1:2] slice[:] slice[:-1] -@@ -137,118 +156,191 @@ +@@ -137,118 +156,197 @@ numpy[-(c + 1) :, d] numpy[:, l[-2]] numpy[:, ::-1] @@ -265,22 +265,30 @@ -assert this is ComplexTest and not requirements.fit_in_a_single_line(force=False), "Short message" -assert(((parens is TooMany))) -for x, in (1,), (2,), (3,): ... +-for y in (): ... +-for z in (i for i in (1, 2, 3)): ... +-for i in (call()): ... +-for j in (1 + (2 + 3)): ... +-while(this and that): ... +-for addr_family, addr_type, addr_proto, addr_canonname, addr_sockaddr in socket.getaddrinfo('google.com', 'http'): +print(*lambda x: x) +assert not Test, "Short message" +assert this is ComplexTest and not requirements.fit_in_a_single_line( + force=False +), "Short message" +assert parens is TooMany -+for (x,) in (1,), (2,), (3,): ... - for y in (): ... - for z in (i for i in (1, 2, 3)): ... --for i in (call()): ... --for j in (1 + (2 + 3)): ... --while(this and that): ... --for addr_family, addr_type, addr_proto, addr_canonname, addr_sockaddr in socket.getaddrinfo('google.com', 'http'): -+for i in call(): ... -+for j in 1 + (2 + 3): ... -+while this and that: ... ++for (x,) in (1,), (2,), (3,): ++ ... ++for y in (): ++ ... ++for z in (i for i in (1, 2, 3)): ++ ... ++for i in call(): ++ ... ++for j in 1 + (2 + 3): ++ ... ++while this and that: ++ ... +for ( + addr_family, + addr_type,