Significantly speedup ESP on large expressions that contain many strings (#3467)

This commit is contained in:
Yilei "Dolee" Yang 2022-12-23 12:13:45 -08:00 committed by GitHub
parent 3246df89d6
commit 3feff21eca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 235 additions and 76 deletions

View File

@ -16,6 +16,7 @@
<!-- Changes that affect Black's preview style --> <!-- Changes that affect Black's preview style -->
- Improve the performance on large expressions that contain many strings (#3467)
- Fix a crash in preview style with assert + parenthesized string (#3415) - Fix a crash in preview style with assert + parenthesized string (#3415)
- Fix crashes in preview style with walrus operators used in function return annotations - Fix crashes in preview style with walrus operators used in function return annotations
and except clauses (#3423) and except clauses (#3423)

View File

@ -69,7 +69,7 @@ class CannotTransform(Exception):
ParserState = int ParserState = int
StringID = int StringID = int
TResult = Result[T, CannotTransform] # (T)ransform Result TResult = Result[T, CannotTransform] # (T)ransform Result
TMatchResult = TResult[Index] TMatchResult = TResult[List[Index]]
def TErr(err_msg: str) -> Err[CannotTransform]: def TErr(err_msg: str) -> Err[CannotTransform]:
@ -198,14 +198,19 @@ def __init__(self, line_length: int, normalize_strings: bool) -> None:
def do_match(self, line: Line) -> TMatchResult: def do_match(self, line: Line) -> TMatchResult:
""" """
Returns: Returns:
* Ok(string_idx) such that `line.leaves[string_idx]` is our target * Ok(string_indices) such that for each index, `line.leaves[index]`
string, if a match was able to be made. is our target string if a match was able to be made. For
transformers that don't result in more lines (e.g. StringMerger,
StringParenStripper), multiple matches and transforms are done at
once to reduce the complexity.
OR OR
* Err(CannotTransform), if a match was not able to be made. * Err(CannotTransform), if no match could be made.
""" """
@abstractmethod @abstractmethod
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: def do_transform(
self, line: Line, string_indices: List[int]
) -> Iterator[TResult[Line]]:
""" """
Yields: Yields:
* Ok(new_line) where new_line is the new transformed line. * Ok(new_line) where new_line is the new transformed line.
@ -246,9 +251,9 @@ def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]
" this line as one that it can transform." " this line as one that it can transform."
) from cant_transform ) from cant_transform
string_idx = match_result.ok() string_indices = match_result.ok()
for line_result in self.do_transform(line, string_idx): for line_result in self.do_transform(line, string_indices):
if isinstance(line_result, Err): if isinstance(line_result, Err):
cant_transform = line_result.err() cant_transform = line_result.err()
raise CannotTransform( raise CannotTransform(
@ -371,30 +376,50 @@ def do_match(self, line: Line) -> TMatchResult:
is_valid_index = is_valid_index_factory(LL) is_valid_index = is_valid_index_factory(LL)
for i, leaf in enumerate(LL): string_indices = []
idx = 0
while is_valid_index(idx):
leaf = LL[idx]
if ( if (
leaf.type == token.STRING leaf.type == token.STRING
and is_valid_index(i + 1) and is_valid_index(idx + 1)
and LL[i + 1].type == token.STRING and LL[idx + 1].type == token.STRING
): ):
if is_part_of_annotation(leaf): if not is_part_of_annotation(leaf):
return TErr("String is part of type annotation.") string_indices.append(idx)
return Ok(i)
if leaf.type == token.STRING and "\\\n" in leaf.value: # Advance to the next non-STRING leaf.
return Ok(i) idx += 2
while is_valid_index(idx) and LL[idx].type == token.STRING:
idx += 1
elif leaf.type == token.STRING and "\\\n" in leaf.value:
string_indices.append(idx)
# Advance to the next non-STRING leaf.
idx += 1
while is_valid_index(idx) and LL[idx].type == token.STRING:
idx += 1
else:
idx += 1
if string_indices:
return Ok(string_indices)
else:
return TErr("This line has no strings that need merging.") return TErr("This line has no strings that need merging.")
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: def do_transform(
self, line: Line, string_indices: List[int]
) -> Iterator[TResult[Line]]:
new_line = line new_line = line
rblc_result = self._remove_backslash_line_continuation_chars( rblc_result = self._remove_backslash_line_continuation_chars(
new_line, string_idx new_line, string_indices
) )
if isinstance(rblc_result, Ok): if isinstance(rblc_result, Ok):
new_line = rblc_result.ok() new_line = rblc_result.ok()
msg_result = self._merge_string_group(new_line, string_idx) msg_result = self._merge_string_group(new_line, string_indices)
if isinstance(msg_result, Ok): if isinstance(msg_result, Ok):
new_line = msg_result.ok() new_line = msg_result.ok()
@ -415,7 +440,7 @@ def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
@staticmethod @staticmethod
def _remove_backslash_line_continuation_chars( def _remove_backslash_line_continuation_chars(
line: Line, string_idx: int line: Line, string_indices: List[int]
) -> TResult[Line]: ) -> TResult[Line]:
""" """
Merge strings that were split across multiple lines using Merge strings that were split across multiple lines using
@ -429,30 +454,40 @@ def _remove_backslash_line_continuation_chars(
""" """
LL = line.leaves LL = line.leaves
indices_to_transform = []
for string_idx in string_indices:
string_leaf = LL[string_idx] string_leaf = LL[string_idx]
if not ( if (
string_leaf.type == token.STRING string_leaf.type == token.STRING
and "\\\n" in string_leaf.value and "\\\n" in string_leaf.value
and not has_triple_quotes(string_leaf.value) and not has_triple_quotes(string_leaf.value)
): ):
indices_to_transform.append(string_idx)
if not indices_to_transform:
return TErr( return TErr(
f"String leaf {string_leaf} does not contain any backslash line" "Found no string leaves that contain backslash line continuation"
" continuation characters." " characters."
) )
new_line = line.clone() new_line = line.clone()
new_line.comments = line.comments.copy() new_line.comments = line.comments.copy()
append_leaves(new_line, line, LL) append_leaves(new_line, line, LL)
for string_idx in indices_to_transform:
new_string_leaf = new_line.leaves[string_idx] new_string_leaf = new_line.leaves[string_idx]
new_string_leaf.value = new_string_leaf.value.replace("\\\n", "") new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
return Ok(new_line) return Ok(new_line)
def _merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]: def _merge_string_group(
self, line: Line, string_indices: List[int]
) -> TResult[Line]:
""" """
Merges string group (i.e. set of adjacent strings) where the first Merges string groups (i.e. set of adjacent strings).
string in the group is `line.leaves[string_idx]`.
Each index from `string_indices` designates one string group's first
leaf in `line.leaves`.
Returns: Returns:
Ok(new_line), if ALL of the validation checks found in Ok(new_line), if ALL of the validation checks found in
@ -464,10 +499,54 @@ def _merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
is_valid_index = is_valid_index_factory(LL) is_valid_index = is_valid_index_factory(LL)
# A dict of {string_idx: tuple[num_of_strings, string_leaf]}.
merged_string_idx_dict: Dict[int, Tuple[int, Leaf]] = {}
for string_idx in string_indices:
vresult = self._validate_msg(line, string_idx) vresult = self._validate_msg(line, string_idx)
if isinstance(vresult, Err): if isinstance(vresult, Err):
return vresult continue
merged_string_idx_dict[string_idx] = self._merge_one_string_group(
LL, string_idx, is_valid_index
)
if not merged_string_idx_dict:
return TErr("No string group is merged")
# Build the final line ('new_line') that this method will later return.
new_line = line.clone()
previous_merged_string_idx = -1
previous_merged_num_of_strings = -1
for i, leaf in enumerate(LL):
if i in merged_string_idx_dict:
previous_merged_string_idx = i
previous_merged_num_of_strings, string_leaf = merged_string_idx_dict[i]
new_line.append(string_leaf)
if (
previous_merged_string_idx
<= i
< previous_merged_string_idx + previous_merged_num_of_strings
):
for comment_leaf in line.comments_after(LL[i]):
new_line.append(comment_leaf, preformatted=True)
continue
append_leaves(new_line, line, [leaf])
return Ok(new_line)
def _merge_one_string_group(
self, LL: List[Leaf], string_idx: int, is_valid_index: Callable[[int], bool]
) -> Tuple[int, Leaf]:
"""
Merges one string group where the first string in the group is
`LL[string_idx]`.
Returns:
A tuple of `(num_of_strings, leaf)` where `num_of_strings` is the
number of strings merged and `leaf` is the newly merged string
to be replaced in the new line.
"""
# If the string group is wrapped inside an Atom node, we must make sure # If the string group is wrapped inside an Atom node, we must make sure
# to later replace that Atom with our new (merged) string leaf. # to later replace that Atom with our new (merged) string leaf.
atom_node = LL[string_idx].parent atom_node = LL[string_idx].parent
@ -590,21 +669,8 @@ def make_naked(string: str, string_prefix: str) -> str:
# Else replace the atom node with the new string leaf. # Else replace the atom node with the new string leaf.
replace_child(atom_node, string_leaf) replace_child(atom_node, string_leaf)
# Build the final line ('new_line') that this method will later return.
new_line = line.clone()
for i, leaf in enumerate(LL):
if i == string_idx:
new_line.append(string_leaf)
if string_idx <= i < string_idx + num_of_strings:
for comment_leaf in line.comments_after(LL[i]):
new_line.append(comment_leaf, preformatted=True)
continue
append_leaves(new_line, line, [leaf])
self.add_custom_splits(string_leaf.value, custom_splits) self.add_custom_splits(string_leaf.value, custom_splits)
return Ok(new_line) return num_of_strings, string_leaf
@staticmethod @staticmethod
def _validate_msg(line: Line, string_idx: int) -> TResult[None]: def _validate_msg(line: Line, string_idx: int) -> TResult[None]:
@ -718,7 +784,15 @@ def do_match(self, line: Line) -> TMatchResult:
is_valid_index = is_valid_index_factory(LL) is_valid_index = is_valid_index_factory(LL)
for idx, leaf in enumerate(LL): string_indices = []
idx = -1
while True:
idx += 1
if idx >= len(LL):
break
leaf = LL[idx]
# Should be a string... # Should be a string...
if leaf.type != token.STRING: if leaf.type != token.STRING:
continue continue
@ -800,39 +874,73 @@ def do_match(self, line: Line) -> TMatchResult:
}: }:
continue continue
return Ok(string_idx) string_indices.append(string_idx)
idx = string_idx
while idx < len(LL) - 1 and LL[idx + 1].type == token.STRING:
idx += 1
if string_indices:
return Ok(string_indices)
return TErr("This line has no strings wrapped in parens.") return TErr("This line has no strings wrapped in parens.")
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: def do_transform(
self, line: Line, string_indices: List[int]
) -> Iterator[TResult[Line]]:
LL = line.leaves LL = line.leaves
string_and_rpar_indices: List[int] = []
for string_idx in string_indices:
string_parser = StringParser() string_parser = StringParser()
rpar_idx = string_parser.parse(LL, string_idx) rpar_idx = string_parser.parse(LL, string_idx)
should_transform = True
for leaf in (LL[string_idx - 1], LL[rpar_idx]): for leaf in (LL[string_idx - 1], LL[rpar_idx]):
if line.comments_after(leaf): if line.comments_after(leaf):
yield TErr( # Should not strip parentheses which have comments attached
"Will not strip parentheses which have comments attached to them." # to them.
should_transform = False
break
if should_transform:
string_and_rpar_indices.extend((string_idx, rpar_idx))
if string_and_rpar_indices:
yield Ok(self._transform_to_new_line(line, string_and_rpar_indices))
else:
yield Err(
CannotTransform("All string groups have comments attached to them.")
) )
return
def _transform_to_new_line(
self, line: Line, string_and_rpar_indices: List[int]
) -> Line:
LL = line.leaves
new_line = line.clone() new_line = line.clone()
new_line.comments = line.comments.copy() new_line.comments = line.comments.copy()
append_leaves(new_line, line, LL[: string_idx - 1])
string_leaf = Leaf(token.STRING, LL[string_idx].value) previous_idx = -1
LL[string_idx - 1].remove() # We need to sort the indices, since string_idx and its matching
replace_child(LL[string_idx], string_leaf) # rpar_idx may not come in order, e.g. in
# `("outer" % ("inner".join(items)))`, the "inner" string's
# string_idx is smaller than "outer" string's rpar_idx.
for idx in sorted(string_and_rpar_indices):
leaf = LL[idx]
lpar_or_rpar_idx = idx - 1 if leaf.type == token.STRING else idx
append_leaves(new_line, line, LL[previous_idx + 1 : lpar_or_rpar_idx])
if leaf.type == token.STRING:
string_leaf = Leaf(token.STRING, LL[idx].value)
LL[lpar_or_rpar_idx].remove() # Remove lpar.
replace_child(LL[idx], string_leaf)
new_line.append(string_leaf) new_line.append(string_leaf)
else:
LL[lpar_or_rpar_idx].remove() # This is a rpar.
append_leaves( previous_idx = idx
new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
)
LL[rpar_idx].remove() # Append the leaves after the last idx:
append_leaves(new_line, line, LL[idx + 1 :])
yield Ok(new_line) return new_line
class BaseStringSplitter(StringTransformer): class BaseStringSplitter(StringTransformer):
@ -885,7 +993,12 @@ def do_match(self, line: Line) -> TMatchResult:
if isinstance(match_result, Err): if isinstance(match_result, Err):
return match_result return match_result
string_idx = match_result.ok() string_indices = match_result.ok()
assert len(string_indices) == 1, (
f"{self.__class__.__name__} should only find one match at a time, found"
f" {len(string_indices)}"
)
string_idx = string_indices[0]
vresult = self._validate(line, string_idx) vresult = self._validate(line, string_idx)
if isinstance(vresult, Err): if isinstance(vresult, Err):
return vresult return vresult
@ -1219,10 +1332,17 @@ def do_splitter_match(self, line: Line) -> TMatchResult:
if is_valid_index(idx): if is_valid_index(idx):
return TErr("This line does not end with a string.") return TErr("This line does not end with a string.")
return Ok(string_idx) return Ok([string_idx])
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: def do_transform(
self, line: Line, string_indices: List[int]
) -> Iterator[TResult[Line]]:
LL = line.leaves LL = line.leaves
assert len(string_indices) == 1, (
f"{self.__class__.__name__} should only find one match at a time, found"
f" {len(string_indices)}"
)
string_idx = string_indices[0]
QUOTE = LL[string_idx].value[-1] QUOTE = LL[string_idx].value[-1]
@ -1710,7 +1830,7 @@ def do_splitter_match(self, line: Line) -> TMatchResult:
" resultant line would still be over the specified line" " resultant line would still be over the specified line"
" length and can't be split further by StringSplitter." " length and can't be split further by StringSplitter."
) )
return Ok(string_idx) return Ok([string_idx])
return TErr("This line does not contain any non-atomic strings.") return TErr("This line does not contain any non-atomic strings.")
@ -1887,8 +2007,15 @@ def _dict_or_lambda_match(LL: List[Leaf]) -> Optional[int]:
return None return None
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: def do_transform(
self, line: Line, string_indices: List[int]
) -> Iterator[TResult[Line]]:
LL = line.leaves LL = line.leaves
assert len(string_indices) == 1, (
f"{self.__class__.__name__} should only find one match at a time, found"
f" {len(string_indices)}"
)
string_idx = string_indices[0]
is_valid_index = is_valid_index_factory(LL) is_valid_index = is_valid_index_factory(LL)
insert_str_child = insert_str_child_factory(LL[string_idx]) insert_str_child = insert_str_child_factory(LL[string_idx])

View File

@ -287,6 +287,23 @@ def foo():
), ),
} }
# Complex string concatenations with a method call in the middle.
code = (
(" return [\n")
+ (
", \n".join(
" (%r, self.%s, visitor.%s)"
% (attrname, attrname, visit_name)
for attrname, visit_name in names
)
)
+ ("\n ]\n")
)
# Test case of an outer string' parens enclose an inner string's parens.
call(body=("%s %s" % ((",".join(items)), suffix)))
# output # output
@ -828,3 +845,17 @@ def foo():
f"{some_function_call(j.right)})" f"{some_function_call(j.right)})"
), ),
} }
# Complex string concatenations with a method call in the middle.
code = (
" return [\n"
+ ", \n".join(
" (%r, self.%s, visitor.%s)" % (attrname, attrname, visit_name)
for attrname, visit_name in names
)
+ "\n ]\n"
)
# Test case of an outer string' parens enclose an inner string's parens.
call(body="%s %s" % (",".join(items), suffix))