Fix AST safety check false negative (#4270)
Fixes #4268 Previously we would allow whitespace changes in all strings, now only in docstrings. Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
This commit is contained in:
parent
f03ee113c9
commit
6af7d11096
@ -11,6 +11,10 @@
|
||||
<!-- Changes that affect Black's stable style -->
|
||||
|
||||
- Don't move comments along with delimiters, which could cause crashes (#4248)
|
||||
- Strengthen AST safety check to catch more unsafe changes to strings. Previous versions
|
||||
of Black would incorrectly format the contents of certain unusual f-strings containing
|
||||
nested strings with the same quote type. Now, Black will crash on such strings until
|
||||
support for the new f-string syntax is implemented. (#4270)
|
||||
|
||||
### Preview style
|
||||
|
||||
|
@ -77,8 +77,13 @@
|
||||
syms,
|
||||
)
|
||||
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
|
||||
from black.parsing import InvalidInput # noqa F401
|
||||
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
|
||||
from black.parsing import ( # noqa F401
|
||||
ASTSafetyError,
|
||||
InvalidInput,
|
||||
lib2to3_parse,
|
||||
parse_ast,
|
||||
stringify_ast,
|
||||
)
|
||||
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
|
||||
from black.report import Changed, NothingChanged, Report
|
||||
from black.trans import iter_fexpr_spans
|
||||
@ -1511,7 +1516,7 @@ def assert_equivalent(src: str, dst: str) -> None:
|
||||
try:
|
||||
src_ast = parse_ast(src)
|
||||
except Exception as exc:
|
||||
raise AssertionError(
|
||||
raise ASTSafetyError(
|
||||
"cannot use --safe with this file; failed to parse source file AST: "
|
||||
f"{exc}\n"
|
||||
"This could be caused by running Black with an older Python version "
|
||||
@ -1522,7 +1527,7 @@ def assert_equivalent(src: str, dst: str) -> None:
|
||||
dst_ast = parse_ast(dst)
|
||||
except Exception as exc:
|
||||
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
|
||||
raise AssertionError(
|
||||
raise ASTSafetyError(
|
||||
f"INTERNAL ERROR: Black produced invalid code: {exc}. "
|
||||
"Please report a bug on https://github.com/psf/black/issues. "
|
||||
f"This invalid output might be helpful: {log}"
|
||||
@ -1532,7 +1537,7 @@ def assert_equivalent(src: str, dst: str) -> None:
|
||||
dst_ast_str = "\n".join(stringify_ast(dst_ast))
|
||||
if src_ast_str != dst_ast_str:
|
||||
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
|
||||
raise AssertionError(
|
||||
raise ASTSafetyError(
|
||||
"INTERNAL ERROR: Black produced code that is not equivalent to the"
|
||||
" source. Please report a bug on "
|
||||
f"https://github.com/psf/black/issues. This diff might be helpful: {log}"
|
||||
|
@ -110,6 +110,10 @@ def lib2to3_unparse(node: Node) -> str:
|
||||
return code
|
||||
|
||||
|
||||
class ASTSafetyError(Exception):
|
||||
"""Raised when Black's generated code is not equivalent to the old AST."""
|
||||
|
||||
|
||||
def _parse_single_version(
|
||||
src: str, version: Tuple[int, int], *, type_comments: bool
|
||||
) -> ast.AST:
|
||||
@ -154,9 +158,20 @@ def _normalize(lineend: str, value: str) -> str:
|
||||
return normalized.strip()
|
||||
|
||||
|
||||
def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
||||
def stringify_ast(node: ast.AST) -> Iterator[str]:
|
||||
"""Simple visitor generating strings to compare ASTs by content."""
|
||||
return _stringify_ast(node, [])
|
||||
|
||||
|
||||
def _stringify_ast_with_new_parent(
|
||||
node: ast.AST, parent_stack: List[ast.AST], new_parent: ast.AST
|
||||
) -> Iterator[str]:
|
||||
parent_stack.append(new_parent)
|
||||
yield from _stringify_ast(node, parent_stack)
|
||||
parent_stack.pop()
|
||||
|
||||
|
||||
def _stringify_ast(node: ast.AST, parent_stack: List[ast.AST]) -> Iterator[str]:
|
||||
if (
|
||||
isinstance(node, ast.Constant)
|
||||
and isinstance(node.value, str)
|
||||
@ -167,7 +182,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
||||
# over the kind
|
||||
node.kind = None
|
||||
|
||||
yield f"{' ' * depth}{node.__class__.__name__}("
|
||||
yield f"{' ' * len(parent_stack)}{node.__class__.__name__}("
|
||||
|
||||
for field in sorted(node._fields): # noqa: F402
|
||||
# TypeIgnore has only one field 'lineno' which breaks this comparison
|
||||
@ -179,7 +194,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
yield f"{' ' * (depth + 1)}{field}="
|
||||
yield f"{' ' * (len(parent_stack) + 1)}{field}="
|
||||
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
@ -191,13 +206,15 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
||||
and isinstance(item, ast.Tuple)
|
||||
):
|
||||
for elt in item.elts:
|
||||
yield from stringify_ast(elt, depth + 2)
|
||||
yield from _stringify_ast_with_new_parent(
|
||||
elt, parent_stack, node
|
||||
)
|
||||
|
||||
elif isinstance(item, ast.AST):
|
||||
yield from stringify_ast(item, depth + 2)
|
||||
yield from _stringify_ast_with_new_parent(item, parent_stack, node)
|
||||
|
||||
elif isinstance(value, ast.AST):
|
||||
yield from stringify_ast(value, depth + 2)
|
||||
yield from _stringify_ast_with_new_parent(value, parent_stack, node)
|
||||
|
||||
else:
|
||||
normalized: object
|
||||
@ -205,6 +222,12 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
||||
isinstance(node, ast.Constant)
|
||||
and field == "value"
|
||||
and isinstance(value, str)
|
||||
and len(parent_stack) >= 2
|
||||
and isinstance(parent_stack[-1], ast.Expr)
|
||||
and isinstance(
|
||||
parent_stack[-2],
|
||||
(ast.FunctionDef, ast.AsyncFunctionDef, ast.Module, ast.ClassDef),
|
||||
)
|
||||
):
|
||||
# Constant strings may be indented across newlines, if they are
|
||||
# docstrings; fold spaces after newlines when comparing. Similarly,
|
||||
@ -215,6 +238,9 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
||||
normalized = value.rstrip()
|
||||
else:
|
||||
normalized = value
|
||||
yield f"{' ' * (depth + 2)}{normalized!r}, # {value.__class__.__name__}"
|
||||
yield (
|
||||
f"{' ' * (len(parent_stack) + 1)}{normalized!r}, #"
|
||||
f" {value.__class__.__name__}"
|
||||
)
|
||||
|
||||
yield f"{' ' * depth}) # /{node.__class__.__name__}"
|
||||
yield f"{' ' * len(parent_stack)}) # /{node.__class__.__name__}"
|
||||
|
@ -46,6 +46,7 @@
|
||||
from black.debug import DebugVisitor
|
||||
from black.mode import Mode, Preview
|
||||
from black.output import color_diff, diff
|
||||
from black.parsing import ASTSafetyError
|
||||
from black.report import Report
|
||||
|
||||
# Import other test classes
|
||||
@ -1473,10 +1474,6 @@ def test_normalize_line_endings(self) -> None:
|
||||
ff(test_file, write_back=black.WriteBack.YES)
|
||||
self.assertEqual(test_file.read_bytes(), expected)
|
||||
|
||||
def test_assert_equivalent_different_asts(self) -> None:
|
||||
with self.assertRaises(AssertionError):
|
||||
black.assert_equivalent("{}", "None")
|
||||
|
||||
def test_root_logger_not_used_directly(self) -> None:
|
||||
def fail(*args: Any, **kwargs: Any) -> None:
|
||||
self.fail("Record created with root logger")
|
||||
@ -1962,16 +1959,6 @@ def test_for_handled_unexpected_eof_error(self) -> None:
|
||||
|
||||
exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
|
||||
|
||||
def test_equivalency_ast_parse_failure_includes_error(self) -> None:
|
||||
with pytest.raises(AssertionError) as err:
|
||||
black.assert_equivalent("a«»a = 1", "a«»a = 1")
|
||||
|
||||
err.match("--safe")
|
||||
# Unfortunately the SyntaxError message has changed in newer versions so we
|
||||
# can't match it directly.
|
||||
err.match("invalid character")
|
||||
err.match(r"\(<unknown>, line 1\)")
|
||||
|
||||
def test_line_ranges_with_code_option(self) -> None:
|
||||
code = textwrap.dedent("""\
|
||||
if a == b:
|
||||
@ -2822,6 +2809,113 @@ def test_format_file_contents(self) -> None:
|
||||
black.format_file_contents("x = 1\n", fast=True, mode=black.Mode())
|
||||
|
||||
|
||||
class TestASTSafety(BlackBaseTestCase):
|
||||
def check_ast_equivalence(
|
||||
self, source: str, dest: str, *, should_fail: bool = False
|
||||
) -> None:
|
||||
# If we get a failure, make sure it's not because the code itself
|
||||
# is invalid, since that will also cause assert_equivalent() to throw
|
||||
# ASTSafetyError.
|
||||
source = textwrap.dedent(source)
|
||||
dest = textwrap.dedent(dest)
|
||||
black.parse_ast(source)
|
||||
black.parse_ast(dest)
|
||||
if should_fail:
|
||||
with self.assertRaises(ASTSafetyError):
|
||||
black.assert_equivalent(source, dest)
|
||||
else:
|
||||
black.assert_equivalent(source, dest)
|
||||
|
||||
def test_assert_equivalent_basic(self) -> None:
|
||||
self.check_ast_equivalence("{}", "None", should_fail=True)
|
||||
self.check_ast_equivalence("1+2", "1 + 2")
|
||||
self.check_ast_equivalence("hi # comment", "hi")
|
||||
|
||||
def test_assert_equivalent_del(self) -> None:
|
||||
self.check_ast_equivalence("del (a, b)", "del a, b")
|
||||
|
||||
def test_assert_equivalent_strings(self) -> None:
|
||||
self.check_ast_equivalence('x = "x"', 'x = " x "', should_fail=True)
|
||||
self.check_ast_equivalence(
|
||||
'''
|
||||
"""docstring """
|
||||
''',
|
||||
'''
|
||||
"""docstring"""
|
||||
''',
|
||||
)
|
||||
self.check_ast_equivalence(
|
||||
'''
|
||||
"""docstring """
|
||||
''',
|
||||
'''
|
||||
"""ddocstring"""
|
||||
''',
|
||||
should_fail=True,
|
||||
)
|
||||
self.check_ast_equivalence(
|
||||
'''
|
||||
class A:
|
||||
"""
|
||||
|
||||
docstring
|
||||
|
||||
|
||||
"""
|
||||
''',
|
||||
'''
|
||||
class A:
|
||||
"""docstring"""
|
||||
''',
|
||||
)
|
||||
self.check_ast_equivalence(
|
||||
"""
|
||||
def f():
|
||||
" docstring "
|
||||
""",
|
||||
'''
|
||||
def f():
|
||||
"""docstring"""
|
||||
''',
|
||||
)
|
||||
self.check_ast_equivalence(
|
||||
"""
|
||||
async def f():
|
||||
" docstring "
|
||||
""",
|
||||
'''
|
||||
async def f():
|
||||
"""docstring"""
|
||||
''',
|
||||
)
|
||||
|
||||
def test_assert_equivalent_fstring(self) -> None:
|
||||
major, minor = sys.version_info[:2]
|
||||
if major < 3 or (major == 3 and minor < 12):
|
||||
pytest.skip("relies on 3.12+ syntax")
|
||||
# https://github.com/psf/black/issues/4268
|
||||
self.check_ast_equivalence(
|
||||
"""print(f"{"|".join([a,b,c])}")""",
|
||||
"""print(f"{" | ".join([a,b,c])}")""",
|
||||
should_fail=True,
|
||||
)
|
||||
self.check_ast_equivalence(
|
||||
"""print(f"{"|".join(['a','b','c'])}")""",
|
||||
"""print(f"{" | ".join(['a','b','c'])}")""",
|
||||
should_fail=True,
|
||||
)
|
||||
|
||||
def test_equivalency_ast_parse_failure_includes_error(self) -> None:
|
||||
with pytest.raises(ASTSafetyError) as err:
|
||||
black.assert_equivalent("a«»a = 1", "a«»a = 1")
|
||||
|
||||
err.match("--safe")
|
||||
# Unfortunately the SyntaxError message has changed in newer versions so we
|
||||
# can't match it directly.
|
||||
err.match("invalid character")
|
||||
err.match(r"\(<unknown>, line 1\)")
|
||||
|
||||
|
||||
try:
|
||||
with open(black.__file__, "r", encoding="utf-8") as _bf:
|
||||
black_source_lines = _bf.readlines()
|
||||
|
Loading…
Reference in New Issue
Block a user