Significantly speedup ESP on large expressions that contain many strings (#3467)
This commit is contained in:
parent
3246df89d6
commit
3feff21eca
@ -16,6 +16,7 @@
|
||||
|
||||
<!-- Changes that affect Black's preview style -->
|
||||
|
||||
- Improve the performance on large expressions that contain many strings (#3467)
|
||||
- Fix a crash in preview style with assert + parenthesized string (#3415)
|
||||
- Fix crashes in preview style with walrus operators used in function return annotations
|
||||
and except clauses (#3423)
|
||||
|
@ -69,7 +69,7 @@ class CannotTransform(Exception):
|
||||
ParserState = int
|
||||
StringID = int
|
||||
TResult = Result[T, CannotTransform] # (T)ransform Result
|
||||
TMatchResult = TResult[Index]
|
||||
TMatchResult = TResult[List[Index]]
|
||||
|
||||
|
||||
def TErr(err_msg: str) -> Err[CannotTransform]:
|
||||
@ -198,14 +198,19 @@ def __init__(self, line_length: int, normalize_strings: bool) -> None:
|
||||
def do_match(self, line: Line) -> TMatchResult:
|
||||
"""
|
||||
Returns:
|
||||
* Ok(string_idx) such that `line.leaves[string_idx]` is our target
|
||||
string, if a match was able to be made.
|
||||
* Ok(string_indices) such that for each index, `line.leaves[index]`
|
||||
is our target string if a match was able to be made. For
|
||||
transformers that don't result in more lines (e.g. StringMerger,
|
||||
StringParenStripper), multiple matches and transforms are done at
|
||||
once to reduce the complexity.
|
||||
OR
|
||||
* Err(CannotTransform), if a match was not able to be made.
|
||||
* Err(CannotTransform), if no match could be made.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
|
||||
def do_transform(
|
||||
self, line: Line, string_indices: List[int]
|
||||
) -> Iterator[TResult[Line]]:
|
||||
"""
|
||||
Yields:
|
||||
* Ok(new_line) where new_line is the new transformed line.
|
||||
@ -246,9 +251,9 @@ def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]
|
||||
" this line as one that it can transform."
|
||||
) from cant_transform
|
||||
|
||||
string_idx = match_result.ok()
|
||||
string_indices = match_result.ok()
|
||||
|
||||
for line_result in self.do_transform(line, string_idx):
|
||||
for line_result in self.do_transform(line, string_indices):
|
||||
if isinstance(line_result, Err):
|
||||
cant_transform = line_result.err()
|
||||
raise CannotTransform(
|
||||
@ -371,30 +376,50 @@ def do_match(self, line: Line) -> TMatchResult:
|
||||
|
||||
is_valid_index = is_valid_index_factory(LL)
|
||||
|
||||
for i, leaf in enumerate(LL):
|
||||
string_indices = []
|
||||
idx = 0
|
||||
while is_valid_index(idx):
|
||||
leaf = LL[idx]
|
||||
if (
|
||||
leaf.type == token.STRING
|
||||
and is_valid_index(i + 1)
|
||||
and LL[i + 1].type == token.STRING
|
||||
and is_valid_index(idx + 1)
|
||||
and LL[idx + 1].type == token.STRING
|
||||
):
|
||||
if is_part_of_annotation(leaf):
|
||||
return TErr("String is part of type annotation.")
|
||||
return Ok(i)
|
||||
if not is_part_of_annotation(leaf):
|
||||
string_indices.append(idx)
|
||||
|
||||
if leaf.type == token.STRING and "\\\n" in leaf.value:
|
||||
return Ok(i)
|
||||
# Advance to the next non-STRING leaf.
|
||||
idx += 2
|
||||
while is_valid_index(idx) and LL[idx].type == token.STRING:
|
||||
idx += 1
|
||||
|
||||
return TErr("This line has no strings that need merging.")
|
||||
elif leaf.type == token.STRING and "\\\n" in leaf.value:
|
||||
string_indices.append(idx)
|
||||
# Advance to the next non-STRING leaf.
|
||||
idx += 1
|
||||
while is_valid_index(idx) and LL[idx].type == token.STRING:
|
||||
idx += 1
|
||||
|
||||
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
|
||||
else:
|
||||
idx += 1
|
||||
|
||||
if string_indices:
|
||||
return Ok(string_indices)
|
||||
else:
|
||||
return TErr("This line has no strings that need merging.")
|
||||
|
||||
def do_transform(
|
||||
self, line: Line, string_indices: List[int]
|
||||
) -> Iterator[TResult[Line]]:
|
||||
new_line = line
|
||||
|
||||
rblc_result = self._remove_backslash_line_continuation_chars(
|
||||
new_line, string_idx
|
||||
new_line, string_indices
|
||||
)
|
||||
if isinstance(rblc_result, Ok):
|
||||
new_line = rblc_result.ok()
|
||||
|
||||
msg_result = self._merge_string_group(new_line, string_idx)
|
||||
msg_result = self._merge_string_group(new_line, string_indices)
|
||||
if isinstance(msg_result, Ok):
|
||||
new_line = msg_result.ok()
|
||||
|
||||
@ -415,7 +440,7 @@ def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
|
||||
|
||||
@staticmethod
|
||||
def _remove_backslash_line_continuation_chars(
|
||||
line: Line, string_idx: int
|
||||
line: Line, string_indices: List[int]
|
||||
) -> TResult[Line]:
|
||||
"""
|
||||
Merge strings that were split across multiple lines using
|
||||
@ -429,30 +454,40 @@ def _remove_backslash_line_continuation_chars(
|
||||
"""
|
||||
LL = line.leaves
|
||||
|
||||
string_leaf = LL[string_idx]
|
||||
if not (
|
||||
string_leaf.type == token.STRING
|
||||
and "\\\n" in string_leaf.value
|
||||
and not has_triple_quotes(string_leaf.value)
|
||||
):
|
||||
indices_to_transform = []
|
||||
for string_idx in string_indices:
|
||||
string_leaf = LL[string_idx]
|
||||
if (
|
||||
string_leaf.type == token.STRING
|
||||
and "\\\n" in string_leaf.value
|
||||
and not has_triple_quotes(string_leaf.value)
|
||||
):
|
||||
indices_to_transform.append(string_idx)
|
||||
|
||||
if not indices_to_transform:
|
||||
return TErr(
|
||||
f"String leaf {string_leaf} does not contain any backslash line"
|
||||
" continuation characters."
|
||||
"Found no string leaves that contain backslash line continuation"
|
||||
" characters."
|
||||
)
|
||||
|
||||
new_line = line.clone()
|
||||
new_line.comments = line.comments.copy()
|
||||
append_leaves(new_line, line, LL)
|
||||
|
||||
new_string_leaf = new_line.leaves[string_idx]
|
||||
new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
|
||||
for string_idx in indices_to_transform:
|
||||
new_string_leaf = new_line.leaves[string_idx]
|
||||
new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
|
||||
|
||||
return Ok(new_line)
|
||||
|
||||
def _merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
|
||||
def _merge_string_group(
|
||||
self, line: Line, string_indices: List[int]
|
||||
) -> TResult[Line]:
|
||||
"""
|
||||
Merges string group (i.e. set of adjacent strings) where the first
|
||||
string in the group is `line.leaves[string_idx]`.
|
||||
Merges string groups (i.e. set of adjacent strings).
|
||||
|
||||
Each index from `string_indices` designates one string group's first
|
||||
leaf in `line.leaves`.
|
||||
|
||||
Returns:
|
||||
Ok(new_line), if ALL of the validation checks found in
|
||||
@ -464,10 +499,54 @@ def _merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
|
||||
|
||||
is_valid_index = is_valid_index_factory(LL)
|
||||
|
||||
vresult = self._validate_msg(line, string_idx)
|
||||
if isinstance(vresult, Err):
|
||||
return vresult
|
||||
# A dict of {string_idx: tuple[num_of_strings, string_leaf]}.
|
||||
merged_string_idx_dict: Dict[int, Tuple[int, Leaf]] = {}
|
||||
for string_idx in string_indices:
|
||||
vresult = self._validate_msg(line, string_idx)
|
||||
if isinstance(vresult, Err):
|
||||
continue
|
||||
merged_string_idx_dict[string_idx] = self._merge_one_string_group(
|
||||
LL, string_idx, is_valid_index
|
||||
)
|
||||
|
||||
if not merged_string_idx_dict:
|
||||
return TErr("No string group is merged")
|
||||
|
||||
# Build the final line ('new_line') that this method will later return.
|
||||
new_line = line.clone()
|
||||
previous_merged_string_idx = -1
|
||||
previous_merged_num_of_strings = -1
|
||||
for i, leaf in enumerate(LL):
|
||||
if i in merged_string_idx_dict:
|
||||
previous_merged_string_idx = i
|
||||
previous_merged_num_of_strings, string_leaf = merged_string_idx_dict[i]
|
||||
new_line.append(string_leaf)
|
||||
|
||||
if (
|
||||
previous_merged_string_idx
|
||||
<= i
|
||||
< previous_merged_string_idx + previous_merged_num_of_strings
|
||||
):
|
||||
for comment_leaf in line.comments_after(LL[i]):
|
||||
new_line.append(comment_leaf, preformatted=True)
|
||||
continue
|
||||
|
||||
append_leaves(new_line, line, [leaf])
|
||||
|
||||
return Ok(new_line)
|
||||
|
||||
def _merge_one_string_group(
|
||||
self, LL: List[Leaf], string_idx: int, is_valid_index: Callable[[int], bool]
|
||||
) -> Tuple[int, Leaf]:
|
||||
"""
|
||||
Merges one string group where the first string in the group is
|
||||
`LL[string_idx]`.
|
||||
|
||||
Returns:
|
||||
A tuple of `(num_of_strings, leaf)` where `num_of_strings` is the
|
||||
number of strings merged and `leaf` is the newly merged string
|
||||
to be replaced in the new line.
|
||||
"""
|
||||
# If the string group is wrapped inside an Atom node, we must make sure
|
||||
# to later replace that Atom with our new (merged) string leaf.
|
||||
atom_node = LL[string_idx].parent
|
||||
@ -590,21 +669,8 @@ def make_naked(string: str, string_prefix: str) -> str:
|
||||
# Else replace the atom node with the new string leaf.
|
||||
replace_child(atom_node, string_leaf)
|
||||
|
||||
# Build the final line ('new_line') that this method will later return.
|
||||
new_line = line.clone()
|
||||
for i, leaf in enumerate(LL):
|
||||
if i == string_idx:
|
||||
new_line.append(string_leaf)
|
||||
|
||||
if string_idx <= i < string_idx + num_of_strings:
|
||||
for comment_leaf in line.comments_after(LL[i]):
|
||||
new_line.append(comment_leaf, preformatted=True)
|
||||
continue
|
||||
|
||||
append_leaves(new_line, line, [leaf])
|
||||
|
||||
self.add_custom_splits(string_leaf.value, custom_splits)
|
||||
return Ok(new_line)
|
||||
return num_of_strings, string_leaf
|
||||
|
||||
@staticmethod
|
||||
def _validate_msg(line: Line, string_idx: int) -> TResult[None]:
|
||||
@ -718,7 +784,15 @@ def do_match(self, line: Line) -> TMatchResult:
|
||||
|
||||
is_valid_index = is_valid_index_factory(LL)
|
||||
|
||||
for idx, leaf in enumerate(LL):
|
||||
string_indices = []
|
||||
|
||||
idx = -1
|
||||
while True:
|
||||
idx += 1
|
||||
if idx >= len(LL):
|
||||
break
|
||||
leaf = LL[idx]
|
||||
|
||||
# Should be a string...
|
||||
if leaf.type != token.STRING:
|
||||
continue
|
||||
@ -800,39 +874,73 @@ def do_match(self, line: Line) -> TMatchResult:
|
||||
}:
|
||||
continue
|
||||
|
||||
return Ok(string_idx)
|
||||
string_indices.append(string_idx)
|
||||
idx = string_idx
|
||||
while idx < len(LL) - 1 and LL[idx + 1].type == token.STRING:
|
||||
idx += 1
|
||||
|
||||
if string_indices:
|
||||
return Ok(string_indices)
|
||||
return TErr("This line has no strings wrapped in parens.")
|
||||
|
||||
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
|
||||
def do_transform(
|
||||
self, line: Line, string_indices: List[int]
|
||||
) -> Iterator[TResult[Line]]:
|
||||
LL = line.leaves
|
||||
|
||||
string_parser = StringParser()
|
||||
rpar_idx = string_parser.parse(LL, string_idx)
|
||||
string_and_rpar_indices: List[int] = []
|
||||
for string_idx in string_indices:
|
||||
string_parser = StringParser()
|
||||
rpar_idx = string_parser.parse(LL, string_idx)
|
||||
|
||||
for leaf in (LL[string_idx - 1], LL[rpar_idx]):
|
||||
if line.comments_after(leaf):
|
||||
yield TErr(
|
||||
"Will not strip parentheses which have comments attached to them."
|
||||
)
|
||||
return
|
||||
should_transform = True
|
||||
for leaf in (LL[string_idx - 1], LL[rpar_idx]):
|
||||
if line.comments_after(leaf):
|
||||
# Should not strip parentheses which have comments attached
|
||||
# to them.
|
||||
should_transform = False
|
||||
break
|
||||
if should_transform:
|
||||
string_and_rpar_indices.extend((string_idx, rpar_idx))
|
||||
|
||||
if string_and_rpar_indices:
|
||||
yield Ok(self._transform_to_new_line(line, string_and_rpar_indices))
|
||||
else:
|
||||
yield Err(
|
||||
CannotTransform("All string groups have comments attached to them.")
|
||||
)
|
||||
|
||||
def _transform_to_new_line(
|
||||
self, line: Line, string_and_rpar_indices: List[int]
|
||||
) -> Line:
|
||||
LL = line.leaves
|
||||
|
||||
new_line = line.clone()
|
||||
new_line.comments = line.comments.copy()
|
||||
append_leaves(new_line, line, LL[: string_idx - 1])
|
||||
|
||||
string_leaf = Leaf(token.STRING, LL[string_idx].value)
|
||||
LL[string_idx - 1].remove()
|
||||
replace_child(LL[string_idx], string_leaf)
|
||||
new_line.append(string_leaf)
|
||||
previous_idx = -1
|
||||
# We need to sort the indices, since string_idx and its matching
|
||||
# rpar_idx may not come in order, e.g. in
|
||||
# `("outer" % ("inner".join(items)))`, the "inner" string's
|
||||
# string_idx is smaller than "outer" string's rpar_idx.
|
||||
for idx in sorted(string_and_rpar_indices):
|
||||
leaf = LL[idx]
|
||||
lpar_or_rpar_idx = idx - 1 if leaf.type == token.STRING else idx
|
||||
append_leaves(new_line, line, LL[previous_idx + 1 : lpar_or_rpar_idx])
|
||||
if leaf.type == token.STRING:
|
||||
string_leaf = Leaf(token.STRING, LL[idx].value)
|
||||
LL[lpar_or_rpar_idx].remove() # Remove lpar.
|
||||
replace_child(LL[idx], string_leaf)
|
||||
new_line.append(string_leaf)
|
||||
else:
|
||||
LL[lpar_or_rpar_idx].remove() # This is a rpar.
|
||||
|
||||
append_leaves(
|
||||
new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
|
||||
)
|
||||
previous_idx = idx
|
||||
|
||||
LL[rpar_idx].remove()
|
||||
# Append the leaves after the last idx:
|
||||
append_leaves(new_line, line, LL[idx + 1 :])
|
||||
|
||||
yield Ok(new_line)
|
||||
return new_line
|
||||
|
||||
|
||||
class BaseStringSplitter(StringTransformer):
|
||||
@ -885,7 +993,12 @@ def do_match(self, line: Line) -> TMatchResult:
|
||||
if isinstance(match_result, Err):
|
||||
return match_result
|
||||
|
||||
string_idx = match_result.ok()
|
||||
string_indices = match_result.ok()
|
||||
assert len(string_indices) == 1, (
|
||||
f"{self.__class__.__name__} should only find one match at a time, found"
|
||||
f" {len(string_indices)}"
|
||||
)
|
||||
string_idx = string_indices[0]
|
||||
vresult = self._validate(line, string_idx)
|
||||
if isinstance(vresult, Err):
|
||||
return vresult
|
||||
@ -1219,10 +1332,17 @@ def do_splitter_match(self, line: Line) -> TMatchResult:
|
||||
if is_valid_index(idx):
|
||||
return TErr("This line does not end with a string.")
|
||||
|
||||
return Ok(string_idx)
|
||||
return Ok([string_idx])
|
||||
|
||||
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
|
||||
def do_transform(
|
||||
self, line: Line, string_indices: List[int]
|
||||
) -> Iterator[TResult[Line]]:
|
||||
LL = line.leaves
|
||||
assert len(string_indices) == 1, (
|
||||
f"{self.__class__.__name__} should only find one match at a time, found"
|
||||
f" {len(string_indices)}"
|
||||
)
|
||||
string_idx = string_indices[0]
|
||||
|
||||
QUOTE = LL[string_idx].value[-1]
|
||||
|
||||
@ -1710,7 +1830,7 @@ def do_splitter_match(self, line: Line) -> TMatchResult:
|
||||
" resultant line would still be over the specified line"
|
||||
" length and can't be split further by StringSplitter."
|
||||
)
|
||||
return Ok(string_idx)
|
||||
return Ok([string_idx])
|
||||
|
||||
return TErr("This line does not contain any non-atomic strings.")
|
||||
|
||||
@ -1887,8 +2007,15 @@ def _dict_or_lambda_match(LL: List[Leaf]) -> Optional[int]:
|
||||
|
||||
return None
|
||||
|
||||
def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
|
||||
def do_transform(
|
||||
self, line: Line, string_indices: List[int]
|
||||
) -> Iterator[TResult[Line]]:
|
||||
LL = line.leaves
|
||||
assert len(string_indices) == 1, (
|
||||
f"{self.__class__.__name__} should only find one match at a time, found"
|
||||
f" {len(string_indices)}"
|
||||
)
|
||||
string_idx = string_indices[0]
|
||||
|
||||
is_valid_index = is_valid_index_factory(LL)
|
||||
insert_str_child = insert_str_child_factory(LL[string_idx])
|
||||
|
@ -287,6 +287,23 @@ def foo():
|
||||
),
|
||||
}
|
||||
|
||||
# Complex string concatenations with a method call in the middle.
|
||||
code = (
|
||||
(" return [\n")
|
||||
+ (
|
||||
", \n".join(
|
||||
" (%r, self.%s, visitor.%s)"
|
||||
% (attrname, attrname, visit_name)
|
||||
for attrname, visit_name in names
|
||||
)
|
||||
)
|
||||
+ ("\n ]\n")
|
||||
)
|
||||
|
||||
|
||||
# Test case of an outer string' parens enclose an inner string's parens.
|
||||
call(body=("%s %s" % ((",".join(items)), suffix)))
|
||||
|
||||
|
||||
# output
|
||||
|
||||
@ -828,3 +845,17 @@ def foo():
|
||||
f"{some_function_call(j.right)})"
|
||||
),
|
||||
}
|
||||
|
||||
# Complex string concatenations with a method call in the middle.
|
||||
code = (
|
||||
" return [\n"
|
||||
+ ", \n".join(
|
||||
" (%r, self.%s, visitor.%s)" % (attrname, attrname, visit_name)
|
||||
for attrname, visit_name in names
|
||||
)
|
||||
+ "\n ]\n"
|
||||
)
|
||||
|
||||
|
||||
# Test case of an outer string' parens enclose an inner string's parens.
|
||||
call(body="%s %s" % (",".join(items), suffix))
|
||||
|
Loading…
Reference in New Issue
Block a user