diff --git a/README.md b/README.md index b3e6985..5a6825f 100644 --- a/README.md +++ b/README.md @@ -260,6 +260,11 @@ You can still try but prepare to be disappointed. * added `--check` +* only put trailing commas in function signatures and calls if it's + safe to do so. If the file is Python 3.6+ it's always safe, otherwise + only safe if there are no `*args` or `**kwargs` used in the signature + or call. (#8) + * fixed invalid spacing of dots in relative imports (#6, #13) * fixed invalid splitting after comma on unpacked variables in for-loops diff --git a/black.py b/black.py index 774d91d..0a9d3ea 100644 --- a/black.py +++ b/black.py @@ -7,6 +7,7 @@ import os from pathlib import Path import tokenize +import sys from typing import ( Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union ) @@ -192,6 +193,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent: comments: List[Line] = [] lines = LineGenerator() elt = EmptyLineTracker() + py36 = is_python36(src_node) empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -204,7 +206,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent: for comment in comments: dst_contents += str(comment) comments = [] - for line in split_line(current_line, line_length=line_length): + for line in split_line(current_line, line_length=line_length, py36=py36): dst_contents += str(line) else: comments.append(current_line) @@ -1108,13 +1110,18 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: yield Leaf(STANDALONE_COMMENT, line) -def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]: +def split_line( + line: Line, line_length: int, inner: bool = False, py36: bool = False +) -> Iterator[Line]: """Splits a `line` into potentially many lines. They should fit in the allotted `line_length` but might not be able to. `inner` signifies that there were a pair of brackets somewhere around the current `line`, possibly transitively. This means we can fallback to splitting by delimiters if the LHS/RHS don't yield any results. + + If `py36` is True, splitting may generate syntax that is only compatible + with Python 3.6 and later. """ line_str = str(line).strip('\n') if len(line_str) <= line_length and '\n' not in line_str: @@ -1137,11 +1144,13 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li # split altogether. result: List[Line] = [] try: - for l in split_func(line): + for l in split_func(line, py36=py36): if str(l).strip('\n') == line_str: raise CannotSplit("Split function returned an unchanged result") - result.extend(split_line(l, line_length=line_length, inner=True)) + result.extend( + split_line(l, line_length=line_length, inner=True, py36=py36) + ) except CannotSplit as cs: continue @@ -1153,7 +1162,7 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li yield line -def left_hand_split(line: Line) -> Iterator[Line]: +def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. @@ -1208,7 +1217,7 @@ def left_hand_split(line: Line) -> Iterator[Line]: yield result -def right_hand_split(line: Line) -> Iterator[Line]: +def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the last matching bracket pair.""" head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1259,10 +1268,12 @@ def right_hand_split(line: Line) -> Iterator[Line]: yield result -def delimiter_split(line: Line) -> Iterator[Line]: +def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split according to delimiters of the highest priority. This kind of split doesn't increase indentation. + If `py36` is True, the split will add trailing commas also in function + signatures that contain * and **. """ try: last_leaf = line.leaves[-1] @@ -1276,11 +1287,20 @@ def delimiter_split(line: Line) -> Iterator[Line]: raise CannotSplit("No delimiters found") current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + lowest_depth = sys.maxsize + trailing_comma_safe = True for leaf in line.leaves: current_line.append(leaf, preformatted=True) comment_after = line.comments.get(id(leaf)) if comment_after: current_line.append(comment_after, preformatted=True) + lowest_depth = min(lowest_depth, leaf.bracket_depth) + if ( + leaf.bracket_depth == lowest_depth and # type: ignore + leaf.type == token.STAR or + leaf.type == token.DOUBLESTAR + ): + trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: normalize_prefix(current_line.leaves[0]) @@ -1290,7 +1310,8 @@ def delimiter_split(line: Line) -> Iterator[Line]: if current_line: if ( delimiter_priority == COMMA_PRIORITY and - current_line.leaves[-1].type != token.COMMA + current_line.leaves[-1].type != token.COMMA and + trailing_comma_safe ): current_line.append(Leaf(token.COMMA, ',')) normalize_prefix(current_line.leaves[0]) @@ -1325,6 +1346,31 @@ def normalize_prefix(leaf: Leaf) -> None: leaf.prefix = '' +def is_python36(node: Node) -> bool: + """Returns True if the current file is using Python 3.6+ features. + + Currently looking for: + - f-strings; and + - trailing commas after * or ** in function signatures. + """ + for n in node.pre_order(): + if n.type == token.STRING: + assert isinstance(n, Leaf) + if n.value[:2] in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}: + return True + + elif ( + n.type == syms.typedargslist and + n.children and + n.children[-1].type == token.COMMA + ): + for ch in n.children: + if ch.type == token.STAR or ch.type == token.DOUBLESTAR: + return True + + return False + + PYTHON_EXTENSIONS = {'.py'} BLACKLISTED_DIRECTORIES = { 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv' diff --git a/tests/expression.py b/tests/expression.py index 59e4211..a3c810e 100644 --- a/tests/expression.py +++ b/tests/expression.py @@ -71,6 +71,7 @@ call(kwarg='hey') call(arg, kwarg='hey') call(arg, another, kwarg='hey', **kwargs) +call(this_is_a_very_long_variable_which_will_force_a_delimiter_split, arg, another, kwarg='hey', **kwargs) # note: no trailing comma pre-3.6 lukasz.langa.pl call.me(maybe) 1 .real @@ -88,11 +89,6 @@ slice[1:] slice[::-1] (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None) -f'f-string without formatted values is just a string' -f'{{NOT a formatted value}}' -f'some f-string with {a} {few():.2f} {formatted.values!r}' -f"{f'{nested} inner'} outer" -f'space between opening braces: { {a for a in (1, 2, 3)}}' {'2.7': dead, '3.7': long_live or die_hard} {'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'} [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C] @@ -200,6 +196,13 @@ async def f(): call(kwarg='hey') call(arg, kwarg='hey') call(arg, another, kwarg='hey', **kwargs) +call( + this_is_a_very_long_variable_which_will_force_a_delimiter_split, + arg, + another, + kwarg='hey', + **kwargs +) # note: no trailing comma pre-3.6 lukasz.langa.pl call.me(maybe) 1 .real @@ -217,11 +220,6 @@ async def f(): slice[1:] slice[::-1] (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None) -f'f-string without formatted values is just a string' -f'{{NOT a formatted value}}' -f'some f-string with {a} {few():.2f} {formatted.values!r}' -f"{f'{nested} inner'} outer" -f'space between opening braces: { {a for a in (1, 2, 3)}}' {'2.7': dead, '3.7': long_live or die_hard} {'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'} [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C] diff --git a/tests/fstring.py b/tests/fstring.py new file mode 100644 index 0000000..6b821be --- /dev/null +++ b/tests/fstring.py @@ -0,0 +1,5 @@ +f'f-string without formatted values is just a string' +f'{{NOT a formatted value}}' +f'some f-string with {a} {few():.2f} {formatted.values!r}' +f"{f'{nested} inner'} outer" +f'space between opening braces: { {a for a in (1, 2, 3)}}' diff --git a/tests/function.py b/tests/function.py index 858b042..abe2200 100644 --- a/tests/function.py +++ b/tests/function.py @@ -6,7 +6,7 @@ from library import some_connection, \ some_decorator - +f'trigger 3.6 mode' def func_no_args(): a; b; c if True: raise RuntimeError @@ -71,6 +71,8 @@ def long_lines(): from library import some_connection, some_decorator +f'trigger 3.6 mode' + def func_no_args(): a diff --git a/tests/test_black.py b/tests/test_black.py index 223c907..1dda5fc 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -108,6 +108,14 @@ def test_expression(self) -> None: black.assert_equivalent(source, actual) black.assert_stable(source, actual, line_length=ll) + @patch("black.dump_to_file", dump_to_stderr) + def test_fstring(self) -> None: + source, expected = read_data('fstring') + actual = fs(source) + self.assertFormatEqual(expected, actual) + black.assert_equivalent(source, actual) + black.assert_stable(source, actual, line_length=ll) + @patch("black.dump_to_file", dump_to_stderr) def test_comments(self) -> None: source, expected = read_data('comments') @@ -215,6 +223,24 @@ def err(msg: str, **kwargs): ) self.assertEqual(report.return_code, 123) + def test_is_python36(self): + node = black.lib2to3_parse("def f(*, arg): ...\n") + self.assertFalse(black.is_python36(node)) + node = black.lib2to3_parse("def f(*, arg,): ...\n") + self.assertTrue(black.is_python36(node)) + node = black.lib2to3_parse("def f(*, arg): f'string'\n") + self.assertTrue(black.is_python36(node)) + source, expected = read_data('function') + node = black.lib2to3_parse(source) + self.assertTrue(black.is_python36(node)) + node = black.lib2to3_parse(expected) + self.assertTrue(black.is_python36(node)) + source, expected = read_data('expression') + node = black.lib2to3_parse(source) + self.assertFalse(black.is_python36(node)) + node = black.lib2to3_parse(expected) + self.assertFalse(black.is_python36(node)) + if __name__ == '__main__': unittest.main()