Split the TRAILING_COMMA feature (#763)

This commit is contained in:
Jelle Zijlstra 2019-03-25 08:22:02 -07:00 committed by GitHub
parent 0b7913f904
commit cea13f4984
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 58 deletions

View File

@ -68,7 +68,7 @@
Priority = int
Index = int
LN = Union[Leaf, Node]
SplitFunc = Callable[["Line", bool], Iterator["Line"]]
SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
Timestamp = float
FileSize = int
CacheInfo = Tuple[Timestamp, FileSize]
@ -133,31 +133,35 @@ class Feature(Enum):
UNICODE_LITERALS = 1
F_STRINGS = 2
NUMERIC_UNDERSCORES = 3
TRAILING_COMMA = 4
TRAILING_COMMA_IN_CALL = 4
TRAILING_COMMA_IN_DEF = 5
VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
TargetVersion.PY27: set(),
TargetVersion.PY33: {Feature.UNICODE_LITERALS},
TargetVersion.PY34: {Feature.UNICODE_LITERALS},
TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
TargetVersion.PY36: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
},
TargetVersion.PY37: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
},
TargetVersion.PY38: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
},
}
@ -683,6 +687,11 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
elt = EmptyLineTracker(is_pyi=mode.is_pyi)
empty_line = Line()
after = 0
split_line_features = {
feature
for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
if supports_feature(versions, feature)
}
for current_line in lines.visit(src_node):
for _ in range(after):
dst_contents += str(empty_line)
@ -690,9 +699,7 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
for _ in range(before):
dst_contents += str(empty_line)
for line in split_line(
current_line,
line_length=mode.line_length,
supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
current_line, line_length=mode.line_length, features=split_line_features
):
dst_contents += str(line)
return dst_contents
@ -2158,7 +2165,7 @@ def split_line(
line: Line,
line_length: int,
inner: bool = False,
supports_trailing_commas: bool = False,
features: Collection[Feature] = (),
) -> Iterator[Line]:
"""Split a `line` into potentially many lines.
@ -2167,7 +2174,7 @@ def split_line(
current `line`, possibly transitively. This means we can fallback to splitting
by delimiters if the LHS/RHS don't yield any results.
If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
`features` are syntactical features that may be used in the output.
"""
if line.is_comment:
yield line
@ -2188,13 +2195,9 @@ def split_line(
split_funcs = [left_hand_split]
else:
def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
for omit in generate_trailers_to_omit(line, line_length):
lines = list(
right_hand_split(
line, line_length, supports_trailing_commas, omit=omit
)
)
lines = list(right_hand_split(line, line_length, features, omit=omit))
if is_line_short_enough(lines[0], line_length=line_length):
yield from lines
return
@ -2202,7 +2205,7 @@ def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
# All splits failed, best effort split with no omits.
# This mostly happens to multiline strings that are by definition
# reported as not fitting a single line.
yield from right_hand_split(line, line_length, supports_trailing_commas)
yield from right_hand_split(line, line_length, features=features)
if line.inside_brackets:
split_funcs = [delimiter_split, standalone_comment_split, rhs]
@ -2214,16 +2217,13 @@ def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
# split altogether.
result: List[Line] = []
try:
for l in split_func(line, supports_trailing_commas):
for l in split_func(line, features):
if str(l).strip("\n") == line_str:
raise CannotSplit("Split function returned an unchanged result")
result.extend(
split_line(
l,
line_length=line_length,
inner=True,
supports_trailing_commas=supports_trailing_commas,
l, line_length=line_length, inner=True, features=features
)
)
except CannotSplit:
@ -2237,9 +2237,7 @@ def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
yield line
def left_hand_split(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
"""Split line into many lines, starting with the first matching bracket pair.
Note: this usually looks weird, only use this for function definitions.
@ -2278,7 +2276,7 @@ def left_hand_split(
def right_hand_split(
line: Line,
line_length: int,
supports_trailing_commas: bool = False,
features: Collection[Feature] = (),
omit: Collection[LeafID] = (),
) -> Iterator[Line]:
"""Split line into many lines, starting with the last matching bracket pair.
@ -2337,12 +2335,7 @@ def right_hand_split(
):
omit = {id(closing_bracket), *omit}
try:
yield from right_hand_split(
line,
line_length,
supports_trailing_commas=supports_trailing_commas,
omit=omit,
)
yield from right_hand_split(line, line_length, features=features, omit=omit)
return
except CannotSplit:
@ -2431,10 +2424,8 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
"""
@wraps(split_func)
def split_wrapper(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
for l in split_func(line, supports_trailing_commas):
def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
for l in split_func(line, features):
normalize_prefix(l.leaves[0], inside_brackets=True)
yield l
@ -2442,13 +2433,11 @@ def split_wrapper(
@dont_increase_indentation
def delimiter_split(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
"""Split according to delimiters of the highest priority.
If `supports_trailing_commas` is True, the split will add trailing commas
also in function signatures that contain `*` and `**`.
If the appropriate Features are given, the split will add trailing commas
also in function signatures and calls that contain `*` and `**`.
"""
try:
last_leaf = line.leaves[-1]
@ -2487,10 +2476,16 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:
yield from append_to_line(comment_after)
lowest_depth = min(lowest_depth, leaf.bracket_depth)
if leaf.bracket_depth == lowest_depth and is_vararg(
leaf, within=VARARGS_PARENTS
):
trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
if leaf.bracket_depth == lowest_depth:
if is_vararg(leaf, within={syms.typedargslist}):
trailing_comma_safe = (
trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
)
elif is_vararg(leaf, within={syms.arglist, syms.argument}):
trailing_comma_safe = (
trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
)
leaf_priority = bt.delimiters.get(id(leaf))
if leaf_priority == delimiter_priority:
yield current_line
@ -2509,7 +2504,7 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:
@dont_increase_indentation
def standalone_comment_split(
line: Line, supports_trailing_commas: bool = False
line: Line, features: Collection[Feature] = ()
) -> Iterator[Line]:
"""Split standalone comments from the rest of the line."""
if not line.contains_standalone_comments(0):
@ -3059,14 +3054,19 @@ def get_features_used(node: Node) -> Set[Feature]:
and n.children
and n.children[-1].type == token.COMMA
):
if n.type == syms.typedargslist:
feature = Feature.TRAILING_COMMA_IN_DEF
else:
feature = Feature.TRAILING_COMMA_IN_CALL
for ch in n.children:
if ch.type in STARS:
features.add(Feature.TRAILING_COMMA)
features.add(feature)
if ch.type == syms.argument:
for argch in ch.children:
if argch.type in STARS:
features.add(Feature.TRAILING_COMMA)
features.add(feature)
return features

View File

@ -857,7 +857,11 @@ def test_get_features_used(self) -> None:
node = black.lib2to3_parse("def f(*, arg): ...\n")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("def f(*, arg,): ...\n")
self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
node = black.lib2to3_parse("f(*arg,)\n")
self.assertEqual(
black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
)
node = black.lib2to3_parse("def f(*, arg): f'string'\n")
self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
node = black.lib2to3_parse("123_456\n")
@ -866,13 +870,14 @@ def test_get_features_used(self) -> None:
self.assertEqual(black.get_features_used(node), set())
source, expected = read_data("function")
node = black.lib2to3_parse(source)
self.assertEqual(
black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
)
expected_features = {
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
Feature.F_STRINGS,
}
self.assertEqual(black.get_features_used(node), expected_features)
node = black.lib2to3_parse(expected)
self.assertEqual(
black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
)
self.assertEqual(black.get_features_used(node), expected_features)
source, expected = read_data("expression")
node = black.lib2to3_parse(source)
self.assertEqual(black.get_features_used(node), set())
@ -1524,8 +1529,8 @@ async def check(header_value: str, expected_status: int) -> None:
await check("3.6", 200)
await check("py3.6", 200)
await check("3.5,3.7", 200)
await check("3.5,py3.7", 200)
await check("3.6,3.7", 200)
await check("3.6,py3.7", 200)
await check("2", 204)
await check("2.7", 204)