Add support for walrus operator (#935)
* Parse `:=` properly * never unwrap parenthesis around `:=` * When checking for AST-equivalence, use `ast` instead of `typed-ast` when running on python >=3.8 * Assume code that uses `:=` is at least 3.8
This commit is contained in:
parent
cad4138050
commit
d8fa8df052
88
black.py
88
black.py
@ -1,3 +1,4 @@
|
||||
import ast
|
||||
import asyncio
|
||||
from concurrent.futures import Executor, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
@ -141,6 +142,7 @@ class Feature(Enum):
|
||||
# set for every version of python.
|
||||
ASYNC_IDENTIFIERS = 6
|
||||
ASYNC_KEYWORDS = 7
|
||||
ASSIGNMENT_EXPRESSIONS = 8
|
||||
|
||||
|
||||
VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
|
||||
@ -175,6 +177,7 @@ class Feature(Enum):
|
||||
Feature.TRAILING_COMMA_IN_CALL,
|
||||
Feature.TRAILING_COMMA_IN_DEF,
|
||||
Feature.ASYNC_KEYWORDS,
|
||||
Feature.ASSIGNMENT_EXPRESSIONS,
|
||||
},
|
||||
}
|
||||
|
||||
@ -2863,6 +2866,8 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
|
||||
check_lpar = True
|
||||
|
||||
if check_lpar:
|
||||
if is_walrus_assignment(child):
|
||||
continue
|
||||
if child.type == syms.atom:
|
||||
if maybe_make_parens_invisible_in_atom(child, parent=node):
|
||||
lpar = Leaf(token.LPAR, "")
|
||||
@ -3017,18 +3022,24 @@ def is_empty_tuple(node: LN) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
|
||||
"""Returns `wrapped` if `node` is of the shape ( wrapped ).
|
||||
|
||||
Parenthesis can be optional. Returns None otherwise"""
|
||||
if len(node.children) != 3:
|
||||
return None
|
||||
lpar, wrapped, rpar = node.children
|
||||
if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
|
||||
return None
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def is_one_tuple(node: LN) -> bool:
|
||||
"""Return True if `node` holds a tuple with one element, with or without parens."""
|
||||
if node.type == syms.atom:
|
||||
if len(node.children) != 3:
|
||||
return False
|
||||
|
||||
lpar, gexp, rpar = node.children
|
||||
if not (
|
||||
lpar.type == token.LPAR
|
||||
and gexp.type == syms.testlist_gexp
|
||||
and rpar.type == token.RPAR
|
||||
):
|
||||
gexp = unwrap_singleton_parenthesis(node)
|
||||
if gexp is None or gexp.type != syms.testlist_gexp:
|
||||
return False
|
||||
|
||||
return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
|
||||
@ -3040,6 +3051,12 @@ def is_one_tuple(node: LN) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_walrus_assignment(node: LN) -> bool:
|
||||
"""Return True iff `node` is of the shape ( test := test )"""
|
||||
inner = unwrap_singleton_parenthesis(node)
|
||||
return inner is not None and inner.type == syms.namedexpr_test
|
||||
|
||||
|
||||
def is_yield(node: LN) -> bool:
|
||||
"""Return True if `node` holds a `yield` or `yield from` expression."""
|
||||
if node.type == syms.yield_expr:
|
||||
@ -3198,6 +3215,9 @@ def get_features_used(node: Node) -> Set[Feature]:
|
||||
if "_" in n.value: # type: ignore
|
||||
features.add(Feature.NUMERIC_UNDERSCORES)
|
||||
|
||||
elif n.type == token.COLONEQUAL:
|
||||
features.add(Feature.ASSIGNMENT_EXPRESSIONS)
|
||||
|
||||
elif (
|
||||
n.type in {syms.typedargslist, syms.arglist}
|
||||
and n.children
|
||||
@ -3479,32 +3499,58 @@ def __str__(self) -> str:
|
||||
return ", ".join(report) + "."
|
||||
|
||||
|
||||
def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
|
||||
def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
|
||||
filename = "<unknown>"
|
||||
if sys.version_info >= (3, 8):
|
||||
# TODO: support Python 4+ ;)
|
||||
for minor_version in range(sys.version_info[1], 4, -1):
|
||||
try:
|
||||
return ast.parse(src, filename, feature_version=(3, minor_version))
|
||||
except SyntaxError:
|
||||
continue
|
||||
else:
|
||||
for feature_version in (7, 6):
|
||||
try:
|
||||
return ast3.parse(src, feature_version=feature_version)
|
||||
return ast3.parse(src, filename, feature_version=feature_version)
|
||||
except SyntaxError:
|
||||
continue
|
||||
|
||||
return ast27.parse(src)
|
||||
|
||||
|
||||
def _fixup_ast_constants(
|
||||
node: Union[ast.AST, ast3.AST, ast27.AST]
|
||||
) -> Union[ast.AST, ast3.AST, ast27.AST]:
|
||||
"""Map ast nodes deprecated in 3.8 to Constant."""
|
||||
# casts are required until this is released:
|
||||
# https://github.com/python/typeshed/pull/3142
|
||||
if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
|
||||
return cast(ast.AST, ast.Constant(value=node.s))
|
||||
elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
|
||||
return cast(ast.AST, ast.Constant(value=node.n))
|
||||
elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
|
||||
return cast(ast.AST, ast.Constant(value=node.value))
|
||||
return node
|
||||
|
||||
|
||||
def assert_equivalent(src: str, dst: str) -> None:
|
||||
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
|
||||
|
||||
def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
|
||||
def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
|
||||
"""Simple visitor generating strings to compare ASTs by content."""
|
||||
|
||||
node = _fixup_ast_constants(node)
|
||||
|
||||
yield f"{' ' * depth}{node.__class__.__name__}("
|
||||
|
||||
for field in sorted(node._fields):
|
||||
# TypeIgnore has only one field 'lineno' which breaks this comparison
|
||||
if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
|
||||
type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
|
||||
if sys.version_info >= (3, 8):
|
||||
type_ignore_classes += (ast.TypeIgnore,)
|
||||
if isinstance(node, type_ignore_classes):
|
||||
break
|
||||
|
||||
# Ignore str kind which is case sensitive / and ignores unicode_literals
|
||||
if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
|
||||
continue
|
||||
|
||||
try:
|
||||
value = getattr(node, field)
|
||||
except AttributeError:
|
||||
@ -3518,15 +3564,15 @@ def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
|
||||
# parentheses and they change the AST.
|
||||
if (
|
||||
field == "targets"
|
||||
and isinstance(node, (ast3.Delete, ast27.Delete))
|
||||
and isinstance(item, (ast3.Tuple, ast27.Tuple))
|
||||
and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
|
||||
and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
|
||||
):
|
||||
for item in item.elts:
|
||||
yield from _v(item, depth + 2)
|
||||
elif isinstance(item, (ast3.AST, ast27.AST)):
|
||||
elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
|
||||
yield from _v(item, depth + 2)
|
||||
|
||||
elif isinstance(value, (ast3.AST, ast27.AST)):
|
||||
elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
|
||||
yield from _v(value, depth + 2)
|
||||
|
||||
else:
|
||||
|
@ -67,7 +67,7 @@ assert_stmt: 'assert' test [',' test]
|
||||
|
||||
compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt
|
||||
async_stmt: ASYNC (funcdef | with_stmt | for_stmt)
|
||||
if_stmt: 'if' test ':' suite ('elif' test ':' suite)* ['else' ':' suite]
|
||||
if_stmt: 'if' namedexpr_test ':' suite ('elif' namedexpr_test ':' suite)* ['else' ':' suite]
|
||||
while_stmt: 'while' test ':' suite ['else' ':' suite]
|
||||
for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite]
|
||||
try_stmt: ('try' ':' suite
|
||||
@ -91,6 +91,7 @@ testlist_safe: old_test [(',' old_test)+ [',']]
|
||||
old_test: or_test | old_lambdef
|
||||
old_lambdef: 'lambda' [varargslist] ':' old_test
|
||||
|
||||
namedexpr_test: test [':=' test]
|
||||
test: or_test ['if' or_test 'else' test] | lambdef
|
||||
or_test: and_test ('or' and_test)*
|
||||
and_test: not_test ('and' not_test)*
|
||||
@ -111,8 +112,8 @@ atom: ('(' [yield_expr|testlist_gexp] ')' |
|
||||
'{' [dictsetmaker] '}' |
|
||||
'`' testlist1 '`' |
|
||||
NAME | NUMBER | STRING+ | '.' '.' '.')
|
||||
listmaker: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] )
|
||||
testlist_gexp: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] )
|
||||
listmaker: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] )
|
||||
testlist_gexp: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] )
|
||||
lambdef: 'lambda' [varargslist] ':' test
|
||||
trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
|
||||
subscriptlist: subscript (',' subscript)* [',']
|
||||
@ -137,6 +138,7 @@ arglist: argument (',' argument)* [',']
|
||||
# multiple (test comp_for) arguments are blocked; keyword unpackings
|
||||
# that precede iterable unpackings are blocked; etc.
|
||||
argument: ( test [comp_for] |
|
||||
test ':=' test |
|
||||
test '=' test |
|
||||
'**' test |
|
||||
'*' test )
|
||||
|
@ -184,6 +184,7 @@ def report(self):
|
||||
// DOUBLESLASH
|
||||
//= DOUBLESLASHEQUAL
|
||||
-> RARROW
|
||||
:= COLONEQUAL
|
||||
"""
|
||||
|
||||
opmap = {}
|
||||
|
@ -63,7 +63,8 @@
|
||||
AWAIT = 56
|
||||
ASYNC = 57
|
||||
ERRORTOKEN = 58
|
||||
N_TOKENS = 59
|
||||
COLONEQUAL = 59
|
||||
N_TOKENS = 60
|
||||
NT_OFFSET = 256
|
||||
#--end constants--
|
||||
|
||||
|
@ -89,7 +89,7 @@ def _combinations(*l):
|
||||
# recognized as two instances of =).
|
||||
Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"<>", r"!=",
|
||||
r"//=?", r"->",
|
||||
r"[+\-*/%&@|^=<>]=?",
|
||||
r"[+\-*/%&@|^=<>:]=?",
|
||||
r"~")
|
||||
|
||||
Bracket = '[][(){}]'
|
||||
|
@ -57,6 +57,7 @@ class python_symbols(Symbols):
|
||||
import_stmt: int
|
||||
lambdef: int
|
||||
listmaker: int
|
||||
namedexpr_test: int
|
||||
not_test: int
|
||||
old_comp_for: int
|
||||
old_comp_if: int
|
||||
|
40
tests/data/pep_572.py
Normal file
40
tests/data/pep_572.py
Normal file
@ -0,0 +1,40 @@
|
||||
(a := 1)
|
||||
(a := a)
|
||||
if (match := pattern.search(data)) is None:
|
||||
pass
|
||||
[y := f(x), y ** 2, y ** 3]
|
||||
filtered_data = [y for x in data if (y := f(x)) is None]
|
||||
(y := f(x))
|
||||
y0 = (y1 := f(x))
|
||||
foo(x=(y := f(x)))
|
||||
|
||||
|
||||
def foo(answer=(p := 42)):
|
||||
pass
|
||||
|
||||
|
||||
def foo(answer: (p := 42) = 5):
|
||||
pass
|
||||
|
||||
|
||||
lambda: (x := 1)
|
||||
(x := lambda: 1)
|
||||
(x := lambda: (y := 1))
|
||||
lambda line: (m := re.match(pattern, line)) and m.group(1)
|
||||
x = (y := 0)
|
||||
(z := (y := (x := 0)))
|
||||
(info := (name, phone, *rest))
|
||||
(x := 1, 2)
|
||||
(total := total + tax)
|
||||
len(lines := f.readlines())
|
||||
foo(x := 3, cat="vector")
|
||||
foo(cat=(category := "vector"))
|
||||
if any(len(longline := l) >= 100 for l in lines):
|
||||
print(longline)
|
||||
if env_base := os.environ.get("PYTHONUSERBASE", None):
|
||||
return env_base
|
||||
if self._is_special and (ans := self._check_nans(context=context)):
|
||||
return ans
|
||||
foo(b := 2, a=1)
|
||||
foo((b := 2), a=1)
|
||||
foo(c=(b := 2), a=1)
|
@ -280,6 +280,23 @@ def test_expression(self) -> None:
|
||||
black.assert_equivalent(source, actual)
|
||||
black.assert_stable(source, actual, black.FileMode())
|
||||
|
||||
@patch("black.dump_to_file", dump_to_stderr)
|
||||
def test_pep_572(self) -> None:
|
||||
source, expected = read_data("pep_572")
|
||||
actual = fs(source)
|
||||
self.assertFormatEqual(expected, actual)
|
||||
black.assert_stable(source, actual, black.FileMode())
|
||||
if sys.version_info >= (3, 8):
|
||||
black.assert_equivalent(source, actual)
|
||||
|
||||
def test_pep_572_version_detection(self) -> None:
|
||||
source, _ = read_data("pep_572")
|
||||
root = black.lib2to3_parse(source)
|
||||
features = black.get_features_used(root)
|
||||
self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
|
||||
versions = black.detect_target_versions(root)
|
||||
self.assertIn(black.TargetVersion.PY38, versions)
|
||||
|
||||
def test_expression_ff(self) -> None:
|
||||
source, expected = read_data("expression")
|
||||
tmp_file = Path(black.dump_to_file(source))
|
||||
|
Loading…
Reference in New Issue
Block a user