Fix AST safety check false negative (#4270)

Fixes #4268

Previously we would allow whitespace changes in all strings, now
only in docstrings.

Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
This commit is contained in:
Jelle Zijlstra 2024-03-09 17:42:29 -08:00 committed by GitHub
parent f03ee113c9
commit 6af7d11096
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 156 additions and 27 deletions

View File

@ -11,6 +11,10 @@
<!-- Changes that affect Black's stable style --> <!-- Changes that affect Black's stable style -->
- Don't move comments along with delimiters, which could cause crashes (#4248) - Don't move comments along with delimiters, which could cause crashes (#4248)
- Strengthen AST safety check to catch more unsafe changes to strings. Previous versions
of Black would incorrectly format the contents of certain unusual f-strings containing
nested strings with the same quote type. Now, Black will crash on such strings until
support for the new f-string syntax is implemented. (#4270)
### Preview style ### Preview style

View File

@ -77,8 +77,13 @@
syms, syms,
) )
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
from black.parsing import InvalidInput # noqa F401 from black.parsing import ( # noqa F401
from black.parsing import lib2to3_parse, parse_ast, stringify_ast ASTSafetyError,
InvalidInput,
lib2to3_parse,
parse_ast,
stringify_ast,
)
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
from black.report import Changed, NothingChanged, Report from black.report import Changed, NothingChanged, Report
from black.trans import iter_fexpr_spans from black.trans import iter_fexpr_spans
@ -1511,7 +1516,7 @@ def assert_equivalent(src: str, dst: str) -> None:
try: try:
src_ast = parse_ast(src) src_ast = parse_ast(src)
except Exception as exc: except Exception as exc:
raise AssertionError( raise ASTSafetyError(
"cannot use --safe with this file; failed to parse source file AST: " "cannot use --safe with this file; failed to parse source file AST: "
f"{exc}\n" f"{exc}\n"
"This could be caused by running Black with an older Python version " "This could be caused by running Black with an older Python version "
@ -1522,7 +1527,7 @@ def assert_equivalent(src: str, dst: str) -> None:
dst_ast = parse_ast(dst) dst_ast = parse_ast(dst)
except Exception as exc: except Exception as exc:
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError( raise ASTSafetyError(
f"INTERNAL ERROR: Black produced invalid code: {exc}. " f"INTERNAL ERROR: Black produced invalid code: {exc}. "
"Please report a bug on https://github.com/psf/black/issues. " "Please report a bug on https://github.com/psf/black/issues. "
f"This invalid output might be helpful: {log}" f"This invalid output might be helpful: {log}"
@ -1532,7 +1537,7 @@ def assert_equivalent(src: str, dst: str) -> None:
dst_ast_str = "\n".join(stringify_ast(dst_ast)) dst_ast_str = "\n".join(stringify_ast(dst_ast))
if src_ast_str != dst_ast_str: if src_ast_str != dst_ast_str:
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError( raise ASTSafetyError(
"INTERNAL ERROR: Black produced code that is not equivalent to the" "INTERNAL ERROR: Black produced code that is not equivalent to the"
" source. Please report a bug on " " source. Please report a bug on "
f"https://github.com/psf/black/issues. This diff might be helpful: {log}" f"https://github.com/psf/black/issues. This diff might be helpful: {log}"

View File

@ -110,6 +110,10 @@ def lib2to3_unparse(node: Node) -> str:
return code return code
class ASTSafetyError(Exception):
"""Raised when Black's generated code is not equivalent to the old AST."""
def _parse_single_version( def _parse_single_version(
src: str, version: Tuple[int, int], *, type_comments: bool src: str, version: Tuple[int, int], *, type_comments: bool
) -> ast.AST: ) -> ast.AST:
@ -154,9 +158,20 @@ def _normalize(lineend: str, value: str) -> str:
return normalized.strip() return normalized.strip()
def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]: def stringify_ast(node: ast.AST) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content.""" """Simple visitor generating strings to compare ASTs by content."""
return _stringify_ast(node, [])
def _stringify_ast_with_new_parent(
node: ast.AST, parent_stack: List[ast.AST], new_parent: ast.AST
) -> Iterator[str]:
parent_stack.append(new_parent)
yield from _stringify_ast(node, parent_stack)
parent_stack.pop()
def _stringify_ast(node: ast.AST, parent_stack: List[ast.AST]) -> Iterator[str]:
if ( if (
isinstance(node, ast.Constant) isinstance(node, ast.Constant)
and isinstance(node.value, str) and isinstance(node.value, str)
@ -167,7 +182,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
# over the kind # over the kind
node.kind = None node.kind = None
yield f"{' ' * depth}{node.__class__.__name__}(" yield f"{' ' * len(parent_stack)}{node.__class__.__name__}("
for field in sorted(node._fields): # noqa: F402 for field in sorted(node._fields): # noqa: F402
# TypeIgnore has only one field 'lineno' which breaks this comparison # TypeIgnore has only one field 'lineno' which breaks this comparison
@ -179,7 +194,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
except AttributeError: except AttributeError:
continue continue
yield f"{' ' * (depth + 1)}{field}=" yield f"{' ' * (len(parent_stack) + 1)}{field}="
if isinstance(value, list): if isinstance(value, list):
for item in value: for item in value:
@ -191,13 +206,15 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
and isinstance(item, ast.Tuple) and isinstance(item, ast.Tuple)
): ):
for elt in item.elts: for elt in item.elts:
yield from stringify_ast(elt, depth + 2) yield from _stringify_ast_with_new_parent(
elt, parent_stack, node
)
elif isinstance(item, ast.AST): elif isinstance(item, ast.AST):
yield from stringify_ast(item, depth + 2) yield from _stringify_ast_with_new_parent(item, parent_stack, node)
elif isinstance(value, ast.AST): elif isinstance(value, ast.AST):
yield from stringify_ast(value, depth + 2) yield from _stringify_ast_with_new_parent(value, parent_stack, node)
else: else:
normalized: object normalized: object
@ -205,6 +222,12 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
isinstance(node, ast.Constant) isinstance(node, ast.Constant)
and field == "value" and field == "value"
and isinstance(value, str) and isinstance(value, str)
and len(parent_stack) >= 2
and isinstance(parent_stack[-1], ast.Expr)
and isinstance(
parent_stack[-2],
(ast.FunctionDef, ast.AsyncFunctionDef, ast.Module, ast.ClassDef),
)
): ):
# Constant strings may be indented across newlines, if they are # Constant strings may be indented across newlines, if they are
# docstrings; fold spaces after newlines when comparing. Similarly, # docstrings; fold spaces after newlines when comparing. Similarly,
@ -215,6 +238,9 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
normalized = value.rstrip() normalized = value.rstrip()
else: else:
normalized = value normalized = value
yield f"{' ' * (depth + 2)}{normalized!r}, # {value.__class__.__name__}" yield (
f"{' ' * (len(parent_stack) + 1)}{normalized!r}, #"
f" {value.__class__.__name__}"
)
yield f"{' ' * depth}) # /{node.__class__.__name__}" yield f"{' ' * len(parent_stack)}) # /{node.__class__.__name__}"

View File

@ -46,6 +46,7 @@
from black.debug import DebugVisitor from black.debug import DebugVisitor
from black.mode import Mode, Preview from black.mode import Mode, Preview
from black.output import color_diff, diff from black.output import color_diff, diff
from black.parsing import ASTSafetyError
from black.report import Report from black.report import Report
# Import other test classes # Import other test classes
@ -1473,10 +1474,6 @@ def test_normalize_line_endings(self) -> None:
ff(test_file, write_back=black.WriteBack.YES) ff(test_file, write_back=black.WriteBack.YES)
self.assertEqual(test_file.read_bytes(), expected) self.assertEqual(test_file.read_bytes(), expected)
def test_assert_equivalent_different_asts(self) -> None:
with self.assertRaises(AssertionError):
black.assert_equivalent("{}", "None")
def test_root_logger_not_used_directly(self) -> None: def test_root_logger_not_used_directly(self) -> None:
def fail(*args: Any, **kwargs: Any) -> None: def fail(*args: Any, **kwargs: Any) -> None:
self.fail("Record created with root logger") self.fail("Record created with root logger")
@ -1962,16 +1959,6 @@ def test_for_handled_unexpected_eof_error(self) -> None:
exc_info.match("Cannot parse: 2:0: EOF in multi-line statement") exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
def test_equivalency_ast_parse_failure_includes_error(self) -> None:
with pytest.raises(AssertionError) as err:
black.assert_equivalent("a«»a = 1", "a«»a = 1")
err.match("--safe")
# Unfortunately the SyntaxError message has changed in newer versions so we
# can't match it directly.
err.match("invalid character")
err.match(r"\(<unknown>, line 1\)")
def test_line_ranges_with_code_option(self) -> None: def test_line_ranges_with_code_option(self) -> None:
code = textwrap.dedent("""\ code = textwrap.dedent("""\
if a == b: if a == b:
@ -2822,6 +2809,113 @@ def test_format_file_contents(self) -> None:
black.format_file_contents("x = 1\n", fast=True, mode=black.Mode()) black.format_file_contents("x = 1\n", fast=True, mode=black.Mode())
class TestASTSafety(BlackBaseTestCase):
def check_ast_equivalence(
self, source: str, dest: str, *, should_fail: bool = False
) -> None:
# If we get a failure, make sure it's not because the code itself
# is invalid, since that will also cause assert_equivalent() to throw
# ASTSafetyError.
source = textwrap.dedent(source)
dest = textwrap.dedent(dest)
black.parse_ast(source)
black.parse_ast(dest)
if should_fail:
with self.assertRaises(ASTSafetyError):
black.assert_equivalent(source, dest)
else:
black.assert_equivalent(source, dest)
def test_assert_equivalent_basic(self) -> None:
self.check_ast_equivalence("{}", "None", should_fail=True)
self.check_ast_equivalence("1+2", "1 + 2")
self.check_ast_equivalence("hi # comment", "hi")
def test_assert_equivalent_del(self) -> None:
self.check_ast_equivalence("del (a, b)", "del a, b")
def test_assert_equivalent_strings(self) -> None:
self.check_ast_equivalence('x = "x"', 'x = " x "', should_fail=True)
self.check_ast_equivalence(
'''
"""docstring """
''',
'''
"""docstring"""
''',
)
self.check_ast_equivalence(
'''
"""docstring """
''',
'''
"""ddocstring"""
''',
should_fail=True,
)
self.check_ast_equivalence(
'''
class A:
"""
docstring
"""
''',
'''
class A:
"""docstring"""
''',
)
self.check_ast_equivalence(
"""
def f():
" docstring "
""",
'''
def f():
"""docstring"""
''',
)
self.check_ast_equivalence(
"""
async def f():
" docstring "
""",
'''
async def f():
"""docstring"""
''',
)
def test_assert_equivalent_fstring(self) -> None:
major, minor = sys.version_info[:2]
if major < 3 or (major == 3 and minor < 12):
pytest.skip("relies on 3.12+ syntax")
# https://github.com/psf/black/issues/4268
self.check_ast_equivalence(
"""print(f"{"|".join([a,b,c])}")""",
"""print(f"{" | ".join([a,b,c])}")""",
should_fail=True,
)
self.check_ast_equivalence(
"""print(f"{"|".join(['a','b','c'])}")""",
"""print(f"{" | ".join(['a','b','c'])}")""",
should_fail=True,
)
def test_equivalency_ast_parse_failure_includes_error(self) -> None:
with pytest.raises(ASTSafetyError) as err:
black.assert_equivalent("a«»a = 1", "a«»a = 1")
err.match("--safe")
# Unfortunately the SyntaxError message has changed in newer versions so we
# can't match it directly.
err.match("invalid character")
err.match(r"\(<unknown>, line 1\)")
try: try:
with open(black.__file__, "r", encoding="utf-8") as _bf: with open(black.__file__, "r", encoding="utf-8") as _bf:
black_source_lines = _bf.readlines() black_source_lines = _bf.readlines()