Fix AST safety check false negative (#4270)
Fixes #4268 Previously we would allow whitespace changes in all strings, now only in docstrings. Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
This commit is contained in:
parent
f03ee113c9
commit
6af7d11096
@ -11,6 +11,10 @@
|
|||||||
<!-- Changes that affect Black's stable style -->
|
<!-- Changes that affect Black's stable style -->
|
||||||
|
|
||||||
- Don't move comments along with delimiters, which could cause crashes (#4248)
|
- Don't move comments along with delimiters, which could cause crashes (#4248)
|
||||||
|
- Strengthen AST safety check to catch more unsafe changes to strings. Previous versions
|
||||||
|
of Black would incorrectly format the contents of certain unusual f-strings containing
|
||||||
|
nested strings with the same quote type. Now, Black will crash on such strings until
|
||||||
|
support for the new f-string syntax is implemented. (#4270)
|
||||||
|
|
||||||
### Preview style
|
### Preview style
|
||||||
|
|
||||||
|
@ -77,8 +77,13 @@
|
|||||||
syms,
|
syms,
|
||||||
)
|
)
|
||||||
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
|
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
|
||||||
from black.parsing import InvalidInput # noqa F401
|
from black.parsing import ( # noqa F401
|
||||||
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
|
ASTSafetyError,
|
||||||
|
InvalidInput,
|
||||||
|
lib2to3_parse,
|
||||||
|
parse_ast,
|
||||||
|
stringify_ast,
|
||||||
|
)
|
||||||
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
|
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
|
||||||
from black.report import Changed, NothingChanged, Report
|
from black.report import Changed, NothingChanged, Report
|
||||||
from black.trans import iter_fexpr_spans
|
from black.trans import iter_fexpr_spans
|
||||||
@ -1511,7 +1516,7 @@ def assert_equivalent(src: str, dst: str) -> None:
|
|||||||
try:
|
try:
|
||||||
src_ast = parse_ast(src)
|
src_ast = parse_ast(src)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise AssertionError(
|
raise ASTSafetyError(
|
||||||
"cannot use --safe with this file; failed to parse source file AST: "
|
"cannot use --safe with this file; failed to parse source file AST: "
|
||||||
f"{exc}\n"
|
f"{exc}\n"
|
||||||
"This could be caused by running Black with an older Python version "
|
"This could be caused by running Black with an older Python version "
|
||||||
@ -1522,7 +1527,7 @@ def assert_equivalent(src: str, dst: str) -> None:
|
|||||||
dst_ast = parse_ast(dst)
|
dst_ast = parse_ast(dst)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
|
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
|
||||||
raise AssertionError(
|
raise ASTSafetyError(
|
||||||
f"INTERNAL ERROR: Black produced invalid code: {exc}. "
|
f"INTERNAL ERROR: Black produced invalid code: {exc}. "
|
||||||
"Please report a bug on https://github.com/psf/black/issues. "
|
"Please report a bug on https://github.com/psf/black/issues. "
|
||||||
f"This invalid output might be helpful: {log}"
|
f"This invalid output might be helpful: {log}"
|
||||||
@ -1532,7 +1537,7 @@ def assert_equivalent(src: str, dst: str) -> None:
|
|||||||
dst_ast_str = "\n".join(stringify_ast(dst_ast))
|
dst_ast_str = "\n".join(stringify_ast(dst_ast))
|
||||||
if src_ast_str != dst_ast_str:
|
if src_ast_str != dst_ast_str:
|
||||||
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
|
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
|
||||||
raise AssertionError(
|
raise ASTSafetyError(
|
||||||
"INTERNAL ERROR: Black produced code that is not equivalent to the"
|
"INTERNAL ERROR: Black produced code that is not equivalent to the"
|
||||||
" source. Please report a bug on "
|
" source. Please report a bug on "
|
||||||
f"https://github.com/psf/black/issues. This diff might be helpful: {log}"
|
f"https://github.com/psf/black/issues. This diff might be helpful: {log}"
|
||||||
|
@ -110,6 +110,10 @@ def lib2to3_unparse(node: Node) -> str:
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
class ASTSafetyError(Exception):
|
||||||
|
"""Raised when Black's generated code is not equivalent to the old AST."""
|
||||||
|
|
||||||
|
|
||||||
def _parse_single_version(
|
def _parse_single_version(
|
||||||
src: str, version: Tuple[int, int], *, type_comments: bool
|
src: str, version: Tuple[int, int], *, type_comments: bool
|
||||||
) -> ast.AST:
|
) -> ast.AST:
|
||||||
@ -154,9 +158,20 @@ def _normalize(lineend: str, value: str) -> str:
|
|||||||
return normalized.strip()
|
return normalized.strip()
|
||||||
|
|
||||||
|
|
||||||
def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
def stringify_ast(node: ast.AST) -> Iterator[str]:
|
||||||
"""Simple visitor generating strings to compare ASTs by content."""
|
"""Simple visitor generating strings to compare ASTs by content."""
|
||||||
|
return _stringify_ast(node, [])
|
||||||
|
|
||||||
|
|
||||||
|
def _stringify_ast_with_new_parent(
|
||||||
|
node: ast.AST, parent_stack: List[ast.AST], new_parent: ast.AST
|
||||||
|
) -> Iterator[str]:
|
||||||
|
parent_stack.append(new_parent)
|
||||||
|
yield from _stringify_ast(node, parent_stack)
|
||||||
|
parent_stack.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def _stringify_ast(node: ast.AST, parent_stack: List[ast.AST]) -> Iterator[str]:
|
||||||
if (
|
if (
|
||||||
isinstance(node, ast.Constant)
|
isinstance(node, ast.Constant)
|
||||||
and isinstance(node.value, str)
|
and isinstance(node.value, str)
|
||||||
@ -167,7 +182,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
|||||||
# over the kind
|
# over the kind
|
||||||
node.kind = None
|
node.kind = None
|
||||||
|
|
||||||
yield f"{' ' * depth}{node.__class__.__name__}("
|
yield f"{' ' * len(parent_stack)}{node.__class__.__name__}("
|
||||||
|
|
||||||
for field in sorted(node._fields): # noqa: F402
|
for field in sorted(node._fields): # noqa: F402
|
||||||
# TypeIgnore has only one field 'lineno' which breaks this comparison
|
# TypeIgnore has only one field 'lineno' which breaks this comparison
|
||||||
@ -179,7 +194,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield f"{' ' * (depth + 1)}{field}="
|
yield f"{' ' * (len(parent_stack) + 1)}{field}="
|
||||||
|
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
for item in value:
|
for item in value:
|
||||||
@ -191,13 +206,15 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
|||||||
and isinstance(item, ast.Tuple)
|
and isinstance(item, ast.Tuple)
|
||||||
):
|
):
|
||||||
for elt in item.elts:
|
for elt in item.elts:
|
||||||
yield from stringify_ast(elt, depth + 2)
|
yield from _stringify_ast_with_new_parent(
|
||||||
|
elt, parent_stack, node
|
||||||
|
)
|
||||||
|
|
||||||
elif isinstance(item, ast.AST):
|
elif isinstance(item, ast.AST):
|
||||||
yield from stringify_ast(item, depth + 2)
|
yield from _stringify_ast_with_new_parent(item, parent_stack, node)
|
||||||
|
|
||||||
elif isinstance(value, ast.AST):
|
elif isinstance(value, ast.AST):
|
||||||
yield from stringify_ast(value, depth + 2)
|
yield from _stringify_ast_with_new_parent(value, parent_stack, node)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
normalized: object
|
normalized: object
|
||||||
@ -205,6 +222,12 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
|||||||
isinstance(node, ast.Constant)
|
isinstance(node, ast.Constant)
|
||||||
and field == "value"
|
and field == "value"
|
||||||
and isinstance(value, str)
|
and isinstance(value, str)
|
||||||
|
and len(parent_stack) >= 2
|
||||||
|
and isinstance(parent_stack[-1], ast.Expr)
|
||||||
|
and isinstance(
|
||||||
|
parent_stack[-2],
|
||||||
|
(ast.FunctionDef, ast.AsyncFunctionDef, ast.Module, ast.ClassDef),
|
||||||
|
)
|
||||||
):
|
):
|
||||||
# Constant strings may be indented across newlines, if they are
|
# Constant strings may be indented across newlines, if they are
|
||||||
# docstrings; fold spaces after newlines when comparing. Similarly,
|
# docstrings; fold spaces after newlines when comparing. Similarly,
|
||||||
@ -215,6 +238,9 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
|
|||||||
normalized = value.rstrip()
|
normalized = value.rstrip()
|
||||||
else:
|
else:
|
||||||
normalized = value
|
normalized = value
|
||||||
yield f"{' ' * (depth + 2)}{normalized!r}, # {value.__class__.__name__}"
|
yield (
|
||||||
|
f"{' ' * (len(parent_stack) + 1)}{normalized!r}, #"
|
||||||
|
f" {value.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
yield f"{' ' * depth}) # /{node.__class__.__name__}"
|
yield f"{' ' * len(parent_stack)}) # /{node.__class__.__name__}"
|
||||||
|
@ -46,6 +46,7 @@
|
|||||||
from black.debug import DebugVisitor
|
from black.debug import DebugVisitor
|
||||||
from black.mode import Mode, Preview
|
from black.mode import Mode, Preview
|
||||||
from black.output import color_diff, diff
|
from black.output import color_diff, diff
|
||||||
|
from black.parsing import ASTSafetyError
|
||||||
from black.report import Report
|
from black.report import Report
|
||||||
|
|
||||||
# Import other test classes
|
# Import other test classes
|
||||||
@ -1473,10 +1474,6 @@ def test_normalize_line_endings(self) -> None:
|
|||||||
ff(test_file, write_back=black.WriteBack.YES)
|
ff(test_file, write_back=black.WriteBack.YES)
|
||||||
self.assertEqual(test_file.read_bytes(), expected)
|
self.assertEqual(test_file.read_bytes(), expected)
|
||||||
|
|
||||||
def test_assert_equivalent_different_asts(self) -> None:
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
black.assert_equivalent("{}", "None")
|
|
||||||
|
|
||||||
def test_root_logger_not_used_directly(self) -> None:
|
def test_root_logger_not_used_directly(self) -> None:
|
||||||
def fail(*args: Any, **kwargs: Any) -> None:
|
def fail(*args: Any, **kwargs: Any) -> None:
|
||||||
self.fail("Record created with root logger")
|
self.fail("Record created with root logger")
|
||||||
@ -1962,16 +1959,6 @@ def test_for_handled_unexpected_eof_error(self) -> None:
|
|||||||
|
|
||||||
exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
|
exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
|
||||||
|
|
||||||
def test_equivalency_ast_parse_failure_includes_error(self) -> None:
|
|
||||||
with pytest.raises(AssertionError) as err:
|
|
||||||
black.assert_equivalent("a«»a = 1", "a«»a = 1")
|
|
||||||
|
|
||||||
err.match("--safe")
|
|
||||||
# Unfortunately the SyntaxError message has changed in newer versions so we
|
|
||||||
# can't match it directly.
|
|
||||||
err.match("invalid character")
|
|
||||||
err.match(r"\(<unknown>, line 1\)")
|
|
||||||
|
|
||||||
def test_line_ranges_with_code_option(self) -> None:
|
def test_line_ranges_with_code_option(self) -> None:
|
||||||
code = textwrap.dedent("""\
|
code = textwrap.dedent("""\
|
||||||
if a == b:
|
if a == b:
|
||||||
@ -2822,6 +2809,113 @@ def test_format_file_contents(self) -> None:
|
|||||||
black.format_file_contents("x = 1\n", fast=True, mode=black.Mode())
|
black.format_file_contents("x = 1\n", fast=True, mode=black.Mode())
|
||||||
|
|
||||||
|
|
||||||
|
class TestASTSafety(BlackBaseTestCase):
|
||||||
|
def check_ast_equivalence(
|
||||||
|
self, source: str, dest: str, *, should_fail: bool = False
|
||||||
|
) -> None:
|
||||||
|
# If we get a failure, make sure it's not because the code itself
|
||||||
|
# is invalid, since that will also cause assert_equivalent() to throw
|
||||||
|
# ASTSafetyError.
|
||||||
|
source = textwrap.dedent(source)
|
||||||
|
dest = textwrap.dedent(dest)
|
||||||
|
black.parse_ast(source)
|
||||||
|
black.parse_ast(dest)
|
||||||
|
if should_fail:
|
||||||
|
with self.assertRaises(ASTSafetyError):
|
||||||
|
black.assert_equivalent(source, dest)
|
||||||
|
else:
|
||||||
|
black.assert_equivalent(source, dest)
|
||||||
|
|
||||||
|
def test_assert_equivalent_basic(self) -> None:
|
||||||
|
self.check_ast_equivalence("{}", "None", should_fail=True)
|
||||||
|
self.check_ast_equivalence("1+2", "1 + 2")
|
||||||
|
self.check_ast_equivalence("hi # comment", "hi")
|
||||||
|
|
||||||
|
def test_assert_equivalent_del(self) -> None:
|
||||||
|
self.check_ast_equivalence("del (a, b)", "del a, b")
|
||||||
|
|
||||||
|
def test_assert_equivalent_strings(self) -> None:
|
||||||
|
self.check_ast_equivalence('x = "x"', 'x = " x "', should_fail=True)
|
||||||
|
self.check_ast_equivalence(
|
||||||
|
'''
|
||||||
|
"""docstring """
|
||||||
|
''',
|
||||||
|
'''
|
||||||
|
"""docstring"""
|
||||||
|
''',
|
||||||
|
)
|
||||||
|
self.check_ast_equivalence(
|
||||||
|
'''
|
||||||
|
"""docstring """
|
||||||
|
''',
|
||||||
|
'''
|
||||||
|
"""ddocstring"""
|
||||||
|
''',
|
||||||
|
should_fail=True,
|
||||||
|
)
|
||||||
|
self.check_ast_equivalence(
|
||||||
|
'''
|
||||||
|
class A:
|
||||||
|
"""
|
||||||
|
|
||||||
|
docstring
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
''',
|
||||||
|
'''
|
||||||
|
class A:
|
||||||
|
"""docstring"""
|
||||||
|
''',
|
||||||
|
)
|
||||||
|
self.check_ast_equivalence(
|
||||||
|
"""
|
||||||
|
def f():
|
||||||
|
" docstring "
|
||||||
|
""",
|
||||||
|
'''
|
||||||
|
def f():
|
||||||
|
"""docstring"""
|
||||||
|
''',
|
||||||
|
)
|
||||||
|
self.check_ast_equivalence(
|
||||||
|
"""
|
||||||
|
async def f():
|
||||||
|
" docstring "
|
||||||
|
""",
|
||||||
|
'''
|
||||||
|
async def f():
|
||||||
|
"""docstring"""
|
||||||
|
''',
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_assert_equivalent_fstring(self) -> None:
|
||||||
|
major, minor = sys.version_info[:2]
|
||||||
|
if major < 3 or (major == 3 and minor < 12):
|
||||||
|
pytest.skip("relies on 3.12+ syntax")
|
||||||
|
# https://github.com/psf/black/issues/4268
|
||||||
|
self.check_ast_equivalence(
|
||||||
|
"""print(f"{"|".join([a,b,c])}")""",
|
||||||
|
"""print(f"{" | ".join([a,b,c])}")""",
|
||||||
|
should_fail=True,
|
||||||
|
)
|
||||||
|
self.check_ast_equivalence(
|
||||||
|
"""print(f"{"|".join(['a','b','c'])}")""",
|
||||||
|
"""print(f"{" | ".join(['a','b','c'])}")""",
|
||||||
|
should_fail=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_equivalency_ast_parse_failure_includes_error(self) -> None:
|
||||||
|
with pytest.raises(ASTSafetyError) as err:
|
||||||
|
black.assert_equivalent("a«»a = 1", "a«»a = 1")
|
||||||
|
|
||||||
|
err.match("--safe")
|
||||||
|
# Unfortunately the SyntaxError message has changed in newer versions so we
|
||||||
|
# can't match it directly.
|
||||||
|
err.match("invalid character")
|
||||||
|
err.match(r"\(<unknown>, line 1\)")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(black.__file__, "r", encoding="utf-8") as _bf:
|
with open(black.__file__, "r", encoding="utf-8") as _bf:
|
||||||
black_source_lines = _bf.readlines()
|
black_source_lines = _bf.readlines()
|
||||||
|
Loading…
Reference in New Issue
Block a user