Split the TRAILING_COMMA feature (#763)

This commit is contained in:
Jelle Zijlstra 2019-03-25 08:22:02 -07:00 committed by GitHub
parent 0b7913f904
commit cea13f4984
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 58 deletions

View File

@ -68,7 +68,7 @@
Priority = int Priority = int
Index = int Index = int
LN = Union[Leaf, Node] LN = Union[Leaf, Node]
SplitFunc = Callable[["Line", bool], Iterator["Line"]] SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
Timestamp = float Timestamp = float
FileSize = int FileSize = int
CacheInfo = Tuple[Timestamp, FileSize] CacheInfo = Tuple[Timestamp, FileSize]
@ -133,31 +133,35 @@ class Feature(Enum):
UNICODE_LITERALS = 1 UNICODE_LITERALS = 1
F_STRINGS = 2 F_STRINGS = 2
NUMERIC_UNDERSCORES = 3 NUMERIC_UNDERSCORES = 3
TRAILING_COMMA = 4 TRAILING_COMMA_IN_CALL = 4
TRAILING_COMMA_IN_DEF = 5
VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
TargetVersion.PY27: set(), TargetVersion.PY27: set(),
TargetVersion.PY33: {Feature.UNICODE_LITERALS}, TargetVersion.PY33: {Feature.UNICODE_LITERALS},
TargetVersion.PY34: {Feature.UNICODE_LITERALS}, TargetVersion.PY34: {Feature.UNICODE_LITERALS},
TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA}, TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
TargetVersion.PY36: { TargetVersion.PY36: {
Feature.UNICODE_LITERALS, Feature.UNICODE_LITERALS,
Feature.F_STRINGS, Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES, Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA, Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
}, },
TargetVersion.PY37: { TargetVersion.PY37: {
Feature.UNICODE_LITERALS, Feature.UNICODE_LITERALS,
Feature.F_STRINGS, Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES, Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA, Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
}, },
TargetVersion.PY38: { TargetVersion.PY38: {
Feature.UNICODE_LITERALS, Feature.UNICODE_LITERALS,
Feature.F_STRINGS, Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES, Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA, Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
}, },
} }
@ -683,6 +687,11 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
elt = EmptyLineTracker(is_pyi=mode.is_pyi) elt = EmptyLineTracker(is_pyi=mode.is_pyi)
empty_line = Line() empty_line = Line()
after = 0 after = 0
split_line_features = {
feature
for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
if supports_feature(versions, feature)
}
for current_line in lines.visit(src_node): for current_line in lines.visit(src_node):
for _ in range(after): for _ in range(after):
dst_contents += str(empty_line) dst_contents += str(empty_line)
@ -690,9 +699,7 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
for _ in range(before): for _ in range(before):
dst_contents += str(empty_line) dst_contents += str(empty_line)
for line in split_line( for line in split_line(
current_line, current_line, line_length=mode.line_length, features=split_line_features
line_length=mode.line_length,
supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
): ):
dst_contents += str(line) dst_contents += str(line)
return dst_contents return dst_contents
@ -2158,7 +2165,7 @@ def split_line(
line: Line, line: Line,
line_length: int, line_length: int,
inner: bool = False, inner: bool = False,
supports_trailing_commas: bool = False, features: Collection[Feature] = (),
) -> Iterator[Line]: ) -> Iterator[Line]:
"""Split a `line` into potentially many lines. """Split a `line` into potentially many lines.
@ -2167,7 +2174,7 @@ def split_line(
current `line`, possibly transitively. This means we can fallback to splitting current `line`, possibly transitively. This means we can fallback to splitting
by delimiters if the LHS/RHS don't yield any results. by delimiters if the LHS/RHS don't yield any results.
If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature. `features` are syntactical features that may be used in the output.
""" """
if line.is_comment: if line.is_comment:
yield line yield line
@ -2188,13 +2195,9 @@ def split_line(
split_funcs = [left_hand_split] split_funcs = [left_hand_split]
else: else:
def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]: def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
for omit in generate_trailers_to_omit(line, line_length): for omit in generate_trailers_to_omit(line, line_length):
lines = list( lines = list(right_hand_split(line, line_length, features, omit=omit))
right_hand_split(
line, line_length, supports_trailing_commas, omit=omit
)
)
if is_line_short_enough(lines[0], line_length=line_length): if is_line_short_enough(lines[0], line_length=line_length):
yield from lines yield from lines
return return
@ -2202,7 +2205,7 @@ def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
# All splits failed, best effort split with no omits. # All splits failed, best effort split with no omits.
# This mostly happens to multiline strings that are by definition # This mostly happens to multiline strings that are by definition
# reported as not fitting a single line. # reported as not fitting a single line.
yield from right_hand_split(line, line_length, supports_trailing_commas) yield from right_hand_split(line, line_length, features=features)
if line.inside_brackets: if line.inside_brackets:
split_funcs = [delimiter_split, standalone_comment_split, rhs] split_funcs = [delimiter_split, standalone_comment_split, rhs]
@ -2214,16 +2217,13 @@ def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
# split altogether. # split altogether.
result: List[Line] = [] result: List[Line] = []
try: try:
for l in split_func(line, supports_trailing_commas): for l in split_func(line, features):
if str(l).strip("\n") == line_str: if str(l).strip("\n") == line_str:
raise CannotSplit("Split function returned an unchanged result") raise CannotSplit("Split function returned an unchanged result")
result.extend( result.extend(
split_line( split_line(
l, l, line_length=line_length, inner=True, features=features
line_length=line_length,
inner=True,
supports_trailing_commas=supports_trailing_commas,
) )
) )
except CannotSplit: except CannotSplit:
@ -2237,9 +2237,7 @@ def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
yield line yield line
def left_hand_split( def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
"""Split line into many lines, starting with the first matching bracket pair. """Split line into many lines, starting with the first matching bracket pair.
Note: this usually looks weird, only use this for function definitions. Note: this usually looks weird, only use this for function definitions.
@ -2278,7 +2276,7 @@ def left_hand_split(
def right_hand_split( def right_hand_split(
line: Line, line: Line,
line_length: int, line_length: int,
supports_trailing_commas: bool = False, features: Collection[Feature] = (),
omit: Collection[LeafID] = (), omit: Collection[LeafID] = (),
) -> Iterator[Line]: ) -> Iterator[Line]:
"""Split line into many lines, starting with the last matching bracket pair. """Split line into many lines, starting with the last matching bracket pair.
@ -2337,12 +2335,7 @@ def right_hand_split(
): ):
omit = {id(closing_bracket), *omit} omit = {id(closing_bracket), *omit}
try: try:
yield from right_hand_split( yield from right_hand_split(line, line_length, features=features, omit=omit)
line,
line_length,
supports_trailing_commas=supports_trailing_commas,
omit=omit,
)
return return
except CannotSplit: except CannotSplit:
@ -2431,10 +2424,8 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
""" """
@wraps(split_func) @wraps(split_func)
def split_wrapper( def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
line: Line, supports_trailing_commas: bool = False for l in split_func(line, features):
) -> Iterator[Line]:
for l in split_func(line, supports_trailing_commas):
normalize_prefix(l.leaves[0], inside_brackets=True) normalize_prefix(l.leaves[0], inside_brackets=True)
yield l yield l
@ -2442,13 +2433,11 @@ def split_wrapper(
@dont_increase_indentation @dont_increase_indentation
def delimiter_split( def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
"""Split according to delimiters of the highest priority. """Split according to delimiters of the highest priority.
If `supports_trailing_commas` is True, the split will add trailing commas If the appropriate Features are given, the split will add trailing commas
also in function signatures that contain `*` and `**`. also in function signatures and calls that contain `*` and `**`.
""" """
try: try:
last_leaf = line.leaves[-1] last_leaf = line.leaves[-1]
@ -2487,10 +2476,16 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:
yield from append_to_line(comment_after) yield from append_to_line(comment_after)
lowest_depth = min(lowest_depth, leaf.bracket_depth) lowest_depth = min(lowest_depth, leaf.bracket_depth)
if leaf.bracket_depth == lowest_depth and is_vararg( if leaf.bracket_depth == lowest_depth:
leaf, within=VARARGS_PARENTS if is_vararg(leaf, within={syms.typedargslist}):
): trailing_comma_safe = (
trailing_comma_safe = trailing_comma_safe and supports_trailing_commas trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
)
elif is_vararg(leaf, within={syms.arglist, syms.argument}):
trailing_comma_safe = (
trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
)
leaf_priority = bt.delimiters.get(id(leaf)) leaf_priority = bt.delimiters.get(id(leaf))
if leaf_priority == delimiter_priority: if leaf_priority == delimiter_priority:
yield current_line yield current_line
@ -2509,7 +2504,7 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:
@dont_increase_indentation @dont_increase_indentation
def standalone_comment_split( def standalone_comment_split(
line: Line, supports_trailing_commas: bool = False line: Line, features: Collection[Feature] = ()
) -> Iterator[Line]: ) -> Iterator[Line]:
"""Split standalone comments from the rest of the line.""" """Split standalone comments from the rest of the line."""
if not line.contains_standalone_comments(0): if not line.contains_standalone_comments(0):
@ -3059,14 +3054,19 @@ def get_features_used(node: Node) -> Set[Feature]:
and n.children and n.children
and n.children[-1].type == token.COMMA and n.children[-1].type == token.COMMA
): ):
if n.type == syms.typedargslist:
feature = Feature.TRAILING_COMMA_IN_DEF
else:
feature = Feature.TRAILING_COMMA_IN_CALL
for ch in n.children: for ch in n.children:
if ch.type in STARS: if ch.type in STARS:
features.add(Feature.TRAILING_COMMA) features.add(feature)
if ch.type == syms.argument: if ch.type == syms.argument:
for argch in ch.children: for argch in ch.children:
if argch.type in STARS: if argch.type in STARS:
features.add(Feature.TRAILING_COMMA) features.add(feature)
return features return features

View File

@ -857,7 +857,11 @@ def test_get_features_used(self) -> None:
node = black.lib2to3_parse("def f(*, arg): ...\n") node = black.lib2to3_parse("def f(*, arg): ...\n")
self.assertEqual(black.get_features_used(node), set()) self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("def f(*, arg,): ...\n") node = black.lib2to3_parse("def f(*, arg,): ...\n")
self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA}) self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
node = black.lib2to3_parse("f(*arg,)\n")
self.assertEqual(
black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
)
node = black.lib2to3_parse("def f(*, arg): f'string'\n") node = black.lib2to3_parse("def f(*, arg): f'string'\n")
self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS}) self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
node = black.lib2to3_parse("123_456\n") node = black.lib2to3_parse("123_456\n")
@ -866,13 +870,14 @@ def test_get_features_used(self) -> None:
self.assertEqual(black.get_features_used(node), set()) self.assertEqual(black.get_features_used(node), set())
source, expected = read_data("function") source, expected = read_data("function")
node = black.lib2to3_parse(source) node = black.lib2to3_parse(source)
self.assertEqual( expected_features = {
black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS} Feature.TRAILING_COMMA_IN_CALL,
) Feature.TRAILING_COMMA_IN_DEF,
Feature.F_STRINGS,
}
self.assertEqual(black.get_features_used(node), expected_features)
node = black.lib2to3_parse(expected) node = black.lib2to3_parse(expected)
self.assertEqual( self.assertEqual(black.get_features_used(node), expected_features)
black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
)
source, expected = read_data("expression") source, expected = read_data("expression")
node = black.lib2to3_parse(source) node = black.lib2to3_parse(source)
self.assertEqual(black.get_features_used(node), set()) self.assertEqual(black.get_features_used(node), set())
@ -1524,8 +1529,8 @@ async def check(header_value: str, expected_status: int) -> None:
await check("3.6", 200) await check("3.6", 200)
await check("py3.6", 200) await check("py3.6", 200)
await check("3.5,3.7", 200) await check("3.6,3.7", 200)
await check("3.5,py3.7", 200) await check("3.6,py3.7", 200)
await check("2", 204) await check("2", 204)
await check("2.7", 204) await check("2.7", 204)