Add support for walrus operator (#935)

* Parse `:=` properly
* never unwrap parenthesis around `:=`
* When checking for AST-equivalence, use `ast` instead of `typed-ast` when running on python >=3.8
* Assume code that uses `:=` is at least 3.8
This commit is contained in:
Zsolt Dollenstein 2019-07-28 16:03:23 +01:00 committed by GitHub
parent cad4138050
commit d8fa8df052
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 138 additions and 30 deletions

View File

@ -1,3 +1,4 @@
import ast
import asyncio
from concurrent.futures import Executor, ProcessPoolExecutor
from contextlib import contextmanager
@ -141,6 +142,7 @@ class Feature(Enum):
# set for every version of python.
ASYNC_IDENTIFIERS = 6
ASYNC_KEYWORDS = 7
ASSIGNMENT_EXPRESSIONS = 8
VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
@ -175,6 +177,7 @@ class Feature(Enum):
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
Feature.ASYNC_KEYWORDS,
Feature.ASSIGNMENT_EXPRESSIONS,
},
}
@ -2863,6 +2866,8 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
check_lpar = True
if check_lpar:
if is_walrus_assignment(child):
continue
if child.type == syms.atom:
if maybe_make_parens_invisible_in_atom(child, parent=node):
lpar = Leaf(token.LPAR, "")
@ -3017,18 +3022,24 @@ def is_empty_tuple(node: LN) -> bool:
)
def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
"""Returns `wrapped` if `node` is of the shape ( wrapped ).
Parenthesis can be optional. Returns None otherwise"""
if len(node.children) != 3:
return None
lpar, wrapped, rpar = node.children
if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
return None
return wrapped
def is_one_tuple(node: LN) -> bool:
"""Return True if `node` holds a tuple with one element, with or without parens."""
if node.type == syms.atom:
if len(node.children) != 3:
return False
lpar, gexp, rpar = node.children
if not (
lpar.type == token.LPAR
and gexp.type == syms.testlist_gexp
and rpar.type == token.RPAR
):
gexp = unwrap_singleton_parenthesis(node)
if gexp is None or gexp.type != syms.testlist_gexp:
return False
return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
@ -3040,6 +3051,12 @@ def is_one_tuple(node: LN) -> bool:
)
def is_walrus_assignment(node: LN) -> bool:
"""Return True iff `node` is of the shape ( test := test )"""
inner = unwrap_singleton_parenthesis(node)
return inner is not None and inner.type == syms.namedexpr_test
def is_yield(node: LN) -> bool:
"""Return True if `node` holds a `yield` or `yield from` expression."""
if node.type == syms.yield_expr:
@ -3198,6 +3215,9 @@ def get_features_used(node: Node) -> Set[Feature]:
if "_" in n.value: # type: ignore
features.add(Feature.NUMERIC_UNDERSCORES)
elif n.type == token.COLONEQUAL:
features.add(Feature.ASSIGNMENT_EXPRESSIONS)
elif (
n.type in {syms.typedargslist, syms.arglist}
and n.children
@ -3479,32 +3499,58 @@ def __str__(self) -> str:
return ", ".join(report) + "."
def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
for feature_version in (7, 6):
try:
return ast3.parse(src, feature_version=feature_version)
except SyntaxError:
continue
def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
filename = "<unknown>"
if sys.version_info >= (3, 8):
# TODO: support Python 4+ ;)
for minor_version in range(sys.version_info[1], 4, -1):
try:
return ast.parse(src, filename, feature_version=(3, minor_version))
except SyntaxError:
continue
else:
for feature_version in (7, 6):
try:
return ast3.parse(src, filename, feature_version=feature_version)
except SyntaxError:
continue
return ast27.parse(src)
def _fixup_ast_constants(
node: Union[ast.AST, ast3.AST, ast27.AST]
) -> Union[ast.AST, ast3.AST, ast27.AST]:
"""Map ast nodes deprecated in 3.8 to Constant."""
# casts are required until this is released:
# https://github.com/python/typeshed/pull/3142
if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
return cast(ast.AST, ast.Constant(value=node.s))
elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
return cast(ast.AST, ast.Constant(value=node.n))
elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
return cast(ast.AST, ast.Constant(value=node.value))
return node
def assert_equivalent(src: str, dst: str) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content."""
node = _fixup_ast_constants(node)
yield f"{' ' * depth}{node.__class__.__name__}("
for field in sorted(node._fields):
# TypeIgnore has only one field 'lineno' which breaks this comparison
if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
if sys.version_info >= (3, 8):
type_ignore_classes += (ast.TypeIgnore,)
if isinstance(node, type_ignore_classes):
break
# Ignore str kind which is case sensitive / and ignores unicode_literals
if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
continue
try:
value = getattr(node, field)
except AttributeError:
@ -3518,15 +3564,15 @@ def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
# parentheses and they change the AST.
if (
field == "targets"
and isinstance(node, (ast3.Delete, ast27.Delete))
and isinstance(item, (ast3.Tuple, ast27.Tuple))
and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
):
for item in item.elts:
yield from _v(item, depth + 2)
elif isinstance(item, (ast3.AST, ast27.AST)):
elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
yield from _v(item, depth + 2)
elif isinstance(value, (ast3.AST, ast27.AST)):
elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
yield from _v(value, depth + 2)
else:

View File

@ -67,7 +67,7 @@ assert_stmt: 'assert' test [',' test]
compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt
async_stmt: ASYNC (funcdef | with_stmt | for_stmt)
if_stmt: 'if' test ':' suite ('elif' test ':' suite)* ['else' ':' suite]
if_stmt: 'if' namedexpr_test ':' suite ('elif' namedexpr_test ':' suite)* ['else' ':' suite]
while_stmt: 'while' test ':' suite ['else' ':' suite]
for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite]
try_stmt: ('try' ':' suite
@ -91,6 +91,7 @@ testlist_safe: old_test [(',' old_test)+ [',']]
old_test: or_test | old_lambdef
old_lambdef: 'lambda' [varargslist] ':' old_test
namedexpr_test: test [':=' test]
test: or_test ['if' or_test 'else' test] | lambdef
or_test: and_test ('or' and_test)*
and_test: not_test ('and' not_test)*
@ -111,8 +112,8 @@ atom: ('(' [yield_expr|testlist_gexp] ')' |
'{' [dictsetmaker] '}' |
'`' testlist1 '`' |
NAME | NUMBER | STRING+ | '.' '.' '.')
listmaker: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] )
testlist_gexp: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] )
listmaker: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] )
testlist_gexp: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] )
lambdef: 'lambda' [varargslist] ':' test
trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
subscriptlist: subscript (',' subscript)* [',']
@ -137,6 +138,7 @@ arglist: argument (',' argument)* [',']
# multiple (test comp_for) arguments are blocked; keyword unpackings
# that precede iterable unpackings are blocked; etc.
argument: ( test [comp_for] |
test ':=' test |
test '=' test |
'**' test |
'*' test )

View File

@ -184,6 +184,7 @@ def report(self):
// DOUBLESLASH
//= DOUBLESLASHEQUAL
-> RARROW
:= COLONEQUAL
"""
opmap = {}

View File

@ -63,7 +63,8 @@
AWAIT = 56
ASYNC = 57
ERRORTOKEN = 58
N_TOKENS = 59
COLONEQUAL = 59
N_TOKENS = 60
NT_OFFSET = 256
#--end constants--

View File

@ -89,7 +89,7 @@ def _combinations(*l):
# recognized as two instances of =).
Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"<>", r"!=",
r"//=?", r"->",
r"[+\-*/%&@|^=<>]=?",
r"[+\-*/%&@|^=<>:]=?",
r"~")
Bracket = '[][(){}]'

View File

@ -57,6 +57,7 @@ class python_symbols(Symbols):
import_stmt: int
lambdef: int
listmaker: int
namedexpr_test: int
not_test: int
old_comp_for: int
old_comp_if: int

40
tests/data/pep_572.py Normal file
View File

@ -0,0 +1,40 @@
(a := 1)
(a := a)
if (match := pattern.search(data)) is None:
pass
[y := f(x), y ** 2, y ** 3]
filtered_data = [y for x in data if (y := f(x)) is None]
(y := f(x))
y0 = (y1 := f(x))
foo(x=(y := f(x)))
def foo(answer=(p := 42)):
pass
def foo(answer: (p := 42) = 5):
pass
lambda: (x := 1)
(x := lambda: 1)
(x := lambda: (y := 1))
lambda line: (m := re.match(pattern, line)) and m.group(1)
x = (y := 0)
(z := (y := (x := 0)))
(info := (name, phone, *rest))
(x := 1, 2)
(total := total + tax)
len(lines := f.readlines())
foo(x := 3, cat="vector")
foo(cat=(category := "vector"))
if any(len(longline := l) >= 100 for l in lines):
print(longline)
if env_base := os.environ.get("PYTHONUSERBASE", None):
return env_base
if self._is_special and (ans := self._check_nans(context=context)):
return ans
foo(b := 2, a=1)
foo((b := 2), a=1)
foo(c=(b := 2), a=1)

View File

@ -280,6 +280,23 @@ def test_expression(self) -> None:
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_pep_572(self) -> None:
source, expected = read_data("pep_572")
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, black.FileMode())
if sys.version_info >= (3, 8):
black.assert_equivalent(source, actual)
def test_pep_572_version_detection(self) -> None:
source, _ = read_data("pep_572")
root = black.lib2to3_parse(source)
features = black.get_features_used(root)
self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
versions = black.detect_target_versions(root)
self.assertIn(black.TargetVersion.PY38, versions)
def test_expression_ff(self) -> None:
source, expected = read_data("expression")
tmp_file = Path(black.dump_to_file(source))