Significantly speedup ESP on large expressions that contain many strings (#3467)

This commit is contained in:
Yilei "Dolee" Yang 2022-12-23 12:13:45 -08:00 committed by GitHub
parent 3246df89d6
commit 3feff21eca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 235 additions and 76 deletions

View File

@ -16,6 +16,7 @@
<!-- Changes that affect Black's preview style -->
- Improve the performance on large expressions that contain many strings (#3467)
- Fix a crash in preview style with assert + parenthesized string (#3415)
- Fix crashes in preview style with walrus operators used in function return annotations
and except clauses (#3423)

View File

@ -69,7 +69,7 @@ class CannotTransform(Exception):
ParserState = int
StringID = int
TResult = Result[T, CannotTransform] # (T)ransform Result
TMatchResult = TResult[Index]
TMatchResult = TResult[List[Index]]
def TErr(err_msg: str) -> Err[CannotTransform]:
@ -198,14 +198,19 @@ def __init__(self, line_length: int, normalize_strings: bool) -> None:
def do_match(self, line: Line) -> TMatchResult:
"""
Returns:
* Ok(string_idx) such that `line.leaves[string_idx]` is our target
string, if a match was able to be made.
* Ok(string_indices) such that for each index, `line.leaves[index]`
is our target string if a match was able to be made. For
transformers that don't result in more lines (e.g. StringMerger,
StringParenStripper), multiple matches and transforms are done at
once to reduce the complexity.
OR
* Err(CannotTransform), if a match was not able to be made.
* Err(CannotTransform), if no match could be made.
"""
@abstractmethod
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
def do_transform(
self, line: Line, string_indices: List[int]
) -> Iterator[TResult[Line]]:
"""
Yields:
* Ok(new_line) where new_line is the new transformed line.
@ -246,9 +251,9 @@ def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]
" this line as one that it can transform."
) from cant_transform
string_idx = match_result.ok()
string_indices = match_result.ok()
for line_result in self.do_transform(line, string_idx):
for line_result in self.do_transform(line, string_indices):
if isinstance(line_result, Err):
cant_transform = line_result.err()
raise CannotTransform(
@ -371,30 +376,50 @@ def do_match(self, line: Line) -> TMatchResult:
is_valid_index = is_valid_index_factory(LL)
for i, leaf in enumerate(LL):
string_indices = []
idx = 0
while is_valid_index(idx):
leaf = LL[idx]
if (
leaf.type == token.STRING
and is_valid_index(i + 1)
and LL[i + 1].type == token.STRING
and is_valid_index(idx + 1)
and LL[idx + 1].type == token.STRING
):
if is_part_of_annotation(leaf):
return TErr("String is part of type annotation.")
return Ok(i)
if not is_part_of_annotation(leaf):
string_indices.append(idx)
if leaf.type == token.STRING and "\\\n" in leaf.value:
return Ok(i)
# Advance to the next non-STRING leaf.
idx += 2
while is_valid_index(idx) and LL[idx].type == token.STRING:
idx += 1
return TErr("This line has no strings that need merging.")
elif leaf.type == token.STRING and "\\\n" in leaf.value:
string_indices.append(idx)
# Advance to the next non-STRING leaf.
idx += 1
while is_valid_index(idx) and LL[idx].type == token.STRING:
idx += 1
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
else:
idx += 1
if string_indices:
return Ok(string_indices)
else:
return TErr("This line has no strings that need merging.")
def do_transform(
self, line: Line, string_indices: List[int]
) -> Iterator[TResult[Line]]:
new_line = line
rblc_result = self._remove_backslash_line_continuation_chars(
new_line, string_idx
new_line, string_indices
)
if isinstance(rblc_result, Ok):
new_line = rblc_result.ok()
msg_result = self._merge_string_group(new_line, string_idx)
msg_result = self._merge_string_group(new_line, string_indices)
if isinstance(msg_result, Ok):
new_line = msg_result.ok()
@ -415,7 +440,7 @@ def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
@staticmethod
def _remove_backslash_line_continuation_chars(
line: Line, string_idx: int
line: Line, string_indices: List[int]
) -> TResult[Line]:
"""
Merge strings that were split across multiple lines using
@ -429,30 +454,40 @@ def _remove_backslash_line_continuation_chars(
"""
LL = line.leaves
string_leaf = LL[string_idx]
if not (
string_leaf.type == token.STRING
and "\\\n" in string_leaf.value
and not has_triple_quotes(string_leaf.value)
):
indices_to_transform = []
for string_idx in string_indices:
string_leaf = LL[string_idx]
if (
string_leaf.type == token.STRING
and "\\\n" in string_leaf.value
and not has_triple_quotes(string_leaf.value)
):
indices_to_transform.append(string_idx)
if not indices_to_transform:
return TErr(
f"String leaf {string_leaf} does not contain any backslash line"
" continuation characters."
"Found no string leaves that contain backslash line continuation"
" characters."
)
new_line = line.clone()
new_line.comments = line.comments.copy()
append_leaves(new_line, line, LL)
new_string_leaf = new_line.leaves[string_idx]
new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
for string_idx in indices_to_transform:
new_string_leaf = new_line.leaves[string_idx]
new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
return Ok(new_line)
def _merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
def _merge_string_group(
self, line: Line, string_indices: List[int]
) -> TResult[Line]:
"""
Merges string group (i.e. set of adjacent strings) where the first
string in the group is `line.leaves[string_idx]`.
Merges string groups (i.e. set of adjacent strings).
Each index from `string_indices` designates one string group's first
leaf in `line.leaves`.
Returns:
Ok(new_line), if ALL of the validation checks found in
@ -464,10 +499,54 @@ def _merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
is_valid_index = is_valid_index_factory(LL)
vresult = self._validate_msg(line, string_idx)
if isinstance(vresult, Err):
return vresult
# A dict of {string_idx: tuple[num_of_strings, string_leaf]}.
merged_string_idx_dict: Dict[int, Tuple[int, Leaf]] = {}
for string_idx in string_indices:
vresult = self._validate_msg(line, string_idx)
if isinstance(vresult, Err):
continue
merged_string_idx_dict[string_idx] = self._merge_one_string_group(
LL, string_idx, is_valid_index
)
if not merged_string_idx_dict:
return TErr("No string group is merged")
# Build the final line ('new_line') that this method will later return.
new_line = line.clone()
previous_merged_string_idx = -1
previous_merged_num_of_strings = -1
for i, leaf in enumerate(LL):
if i in merged_string_idx_dict:
previous_merged_string_idx = i
previous_merged_num_of_strings, string_leaf = merged_string_idx_dict[i]
new_line.append(string_leaf)
if (
previous_merged_string_idx
<= i
< previous_merged_string_idx + previous_merged_num_of_strings
):
for comment_leaf in line.comments_after(LL[i]):
new_line.append(comment_leaf, preformatted=True)
continue
append_leaves(new_line, line, [leaf])
return Ok(new_line)
def _merge_one_string_group(
self, LL: List[Leaf], string_idx: int, is_valid_index: Callable[[int], bool]
) -> Tuple[int, Leaf]:
"""
Merges one string group where the first string in the group is
`LL[string_idx]`.
Returns:
A tuple of `(num_of_strings, leaf)` where `num_of_strings` is the
number of strings merged and `leaf` is the newly merged string
to be replaced in the new line.
"""
# If the string group is wrapped inside an Atom node, we must make sure
# to later replace that Atom with our new (merged) string leaf.
atom_node = LL[string_idx].parent
@ -590,21 +669,8 @@ def make_naked(string: str, string_prefix: str) -> str:
# Else replace the atom node with the new string leaf.
replace_child(atom_node, string_leaf)
# Build the final line ('new_line') that this method will later return.
new_line = line.clone()
for i, leaf in enumerate(LL):
if i == string_idx:
new_line.append(string_leaf)
if string_idx <= i < string_idx + num_of_strings:
for comment_leaf in line.comments_after(LL[i]):
new_line.append(comment_leaf, preformatted=True)
continue
append_leaves(new_line, line, [leaf])
self.add_custom_splits(string_leaf.value, custom_splits)
return Ok(new_line)
return num_of_strings, string_leaf
@staticmethod
def _validate_msg(line: Line, string_idx: int) -> TResult[None]:
@ -718,7 +784,15 @@ def do_match(self, line: Line) -> TMatchResult:
is_valid_index = is_valid_index_factory(LL)
for idx, leaf in enumerate(LL):
string_indices = []
idx = -1
while True:
idx += 1
if idx >= len(LL):
break
leaf = LL[idx]
# Should be a string...
if leaf.type != token.STRING:
continue
@ -800,39 +874,73 @@ def do_match(self, line: Line) -> TMatchResult:
}:
continue
return Ok(string_idx)
string_indices.append(string_idx)
idx = string_idx
while idx < len(LL) - 1 and LL[idx + 1].type == token.STRING:
idx += 1
if string_indices:
return Ok(string_indices)
return TErr("This line has no strings wrapped in parens.")
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
def do_transform(
self, line: Line, string_indices: List[int]
) -> Iterator[TResult[Line]]:
LL = line.leaves
string_parser = StringParser()
rpar_idx = string_parser.parse(LL, string_idx)
string_and_rpar_indices: List[int] = []
for string_idx in string_indices:
string_parser = StringParser()
rpar_idx = string_parser.parse(LL, string_idx)
for leaf in (LL[string_idx - 1], LL[rpar_idx]):
if line.comments_after(leaf):
yield TErr(
"Will not strip parentheses which have comments attached to them."
)
return
should_transform = True
for leaf in (LL[string_idx - 1], LL[rpar_idx]):
if line.comments_after(leaf):
# Should not strip parentheses which have comments attached
# to them.
should_transform = False
break
if should_transform:
string_and_rpar_indices.extend((string_idx, rpar_idx))
if string_and_rpar_indices:
yield Ok(self._transform_to_new_line(line, string_and_rpar_indices))
else:
yield Err(
CannotTransform("All string groups have comments attached to them.")
)
def _transform_to_new_line(
self, line: Line, string_and_rpar_indices: List[int]
) -> Line:
LL = line.leaves
new_line = line.clone()
new_line.comments = line.comments.copy()
append_leaves(new_line, line, LL[: string_idx - 1])
string_leaf = Leaf(token.STRING, LL[string_idx].value)
LL[string_idx - 1].remove()
replace_child(LL[string_idx], string_leaf)
new_line.append(string_leaf)
previous_idx = -1
# We need to sort the indices, since string_idx and its matching
# rpar_idx may not come in order, e.g. in
# `("outer" % ("inner".join(items)))`, the "inner" string's
# string_idx is smaller than "outer" string's rpar_idx.
for idx in sorted(string_and_rpar_indices):
leaf = LL[idx]
lpar_or_rpar_idx = idx - 1 if leaf.type == token.STRING else idx
append_leaves(new_line, line, LL[previous_idx + 1 : lpar_or_rpar_idx])
if leaf.type == token.STRING:
string_leaf = Leaf(token.STRING, LL[idx].value)
LL[lpar_or_rpar_idx].remove() # Remove lpar.
replace_child(LL[idx], string_leaf)
new_line.append(string_leaf)
else:
LL[lpar_or_rpar_idx].remove() # This is a rpar.
append_leaves(
new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
)
previous_idx = idx
LL[rpar_idx].remove()
# Append the leaves after the last idx:
append_leaves(new_line, line, LL[idx + 1 :])
yield Ok(new_line)
return new_line
class BaseStringSplitter(StringTransformer):
@ -885,7 +993,12 @@ def do_match(self, line: Line) -> TMatchResult:
if isinstance(match_result, Err):
return match_result
string_idx = match_result.ok()
string_indices = match_result.ok()
assert len(string_indices) == 1, (
f"{self.__class__.__name__} should only find one match at a time, found"
f" {len(string_indices)}"
)
string_idx = string_indices[0]
vresult = self._validate(line, string_idx)
if isinstance(vresult, Err):
return vresult
@ -1219,10 +1332,17 @@ def do_splitter_match(self, line: Line) -> TMatchResult:
if is_valid_index(idx):
return TErr("This line does not end with a string.")
return Ok(string_idx)
return Ok([string_idx])
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
def do_transform(
self, line: Line, string_indices: List[int]
) -> Iterator[TResult[Line]]:
LL = line.leaves
assert len(string_indices) == 1, (
f"{self.__class__.__name__} should only find one match at a time, found"
f" {len(string_indices)}"
)
string_idx = string_indices[0]
QUOTE = LL[string_idx].value[-1]
@ -1710,7 +1830,7 @@ def do_splitter_match(self, line: Line) -> TMatchResult:
" resultant line would still be over the specified line"
" length and can't be split further by StringSplitter."
)
return Ok(string_idx)
return Ok([string_idx])
return TErr("This line does not contain any non-atomic strings.")
@ -1887,8 +2007,15 @@ def _dict_or_lambda_match(LL: List[Leaf]) -> Optional[int]:
return None
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
def do_transform(
self, line: Line, string_indices: List[int]
) -> Iterator[TResult[Line]]:
LL = line.leaves
assert len(string_indices) == 1, (
f"{self.__class__.__name__} should only find one match at a time, found"
f" {len(string_indices)}"
)
string_idx = string_indices[0]
is_valid_index = is_valid_index_factory(LL)
insert_str_child = insert_str_child_factory(LL[string_idx])

View File

@ -287,6 +287,23 @@ def foo():
),
}
# Complex string concatenations with a method call in the middle.
code = (
(" return [\n")
+ (
", \n".join(
" (%r, self.%s, visitor.%s)"
% (attrname, attrname, visit_name)
for attrname, visit_name in names
)
)
+ ("\n ]\n")
)
# Test case of an outer string' parens enclose an inner string's parens.
call(body=("%s %s" % ((",".join(items)), suffix)))
# output
@ -828,3 +845,17 @@ def foo():
f"{some_function_call(j.right)})"
),
}
# Complex string concatenations with a method call in the middle.
code = (
" return [\n"
+ ", \n".join(
" (%r, self.%s, visitor.%s)" % (attrname, attrname, visit_name)
for attrname, visit_name in names
)
+ "\n ]\n"
)
# Test case of an outer string' parens enclose an inner string's parens.
call(body="%s %s" % (",".join(items), suffix))