Add support for walrus operator (#935)
* Parse `:=` properly * never unwrap parenthesis around `:=` * When checking for AST-equivalence, use `ast` instead of `typed-ast` when running on python >=3.8 * Assume code that uses `:=` is at least 3.8
This commit is contained in:
parent
cad4138050
commit
d8fa8df052
88
black.py
88
black.py
@ -1,3 +1,4 @@
|
|||||||
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import Executor, ProcessPoolExecutor
|
from concurrent.futures import Executor, ProcessPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -141,6 +142,7 @@ class Feature(Enum):
|
|||||||
# set for every version of python.
|
# set for every version of python.
|
||||||
ASYNC_IDENTIFIERS = 6
|
ASYNC_IDENTIFIERS = 6
|
||||||
ASYNC_KEYWORDS = 7
|
ASYNC_KEYWORDS = 7
|
||||||
|
ASSIGNMENT_EXPRESSIONS = 8
|
||||||
|
|
||||||
|
|
||||||
VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
|
VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
|
||||||
@ -175,6 +177,7 @@ class Feature(Enum):
|
|||||||
Feature.TRAILING_COMMA_IN_CALL,
|
Feature.TRAILING_COMMA_IN_CALL,
|
||||||
Feature.TRAILING_COMMA_IN_DEF,
|
Feature.TRAILING_COMMA_IN_DEF,
|
||||||
Feature.ASYNC_KEYWORDS,
|
Feature.ASYNC_KEYWORDS,
|
||||||
|
Feature.ASSIGNMENT_EXPRESSIONS,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2863,6 +2866,8 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
|
|||||||
check_lpar = True
|
check_lpar = True
|
||||||
|
|
||||||
if check_lpar:
|
if check_lpar:
|
||||||
|
if is_walrus_assignment(child):
|
||||||
|
continue
|
||||||
if child.type == syms.atom:
|
if child.type == syms.atom:
|
||||||
if maybe_make_parens_invisible_in_atom(child, parent=node):
|
if maybe_make_parens_invisible_in_atom(child, parent=node):
|
||||||
lpar = Leaf(token.LPAR, "")
|
lpar = Leaf(token.LPAR, "")
|
||||||
@ -3017,18 +3022,24 @@ def is_empty_tuple(node: LN) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
|
||||||
|
"""Returns `wrapped` if `node` is of the shape ( wrapped ).
|
||||||
|
|
||||||
|
Parenthesis can be optional. Returns None otherwise"""
|
||||||
|
if len(node.children) != 3:
|
||||||
|
return None
|
||||||
|
lpar, wrapped, rpar = node.children
|
||||||
|
if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def is_one_tuple(node: LN) -> bool:
|
def is_one_tuple(node: LN) -> bool:
|
||||||
"""Return True if `node` holds a tuple with one element, with or without parens."""
|
"""Return True if `node` holds a tuple with one element, with or without parens."""
|
||||||
if node.type == syms.atom:
|
if node.type == syms.atom:
|
||||||
if len(node.children) != 3:
|
gexp = unwrap_singleton_parenthesis(node)
|
||||||
return False
|
if gexp is None or gexp.type != syms.testlist_gexp:
|
||||||
|
|
||||||
lpar, gexp, rpar = node.children
|
|
||||||
if not (
|
|
||||||
lpar.type == token.LPAR
|
|
||||||
and gexp.type == syms.testlist_gexp
|
|
||||||
and rpar.type == token.RPAR
|
|
||||||
):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
|
return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
|
||||||
@ -3040,6 +3051,12 @@ def is_one_tuple(node: LN) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_walrus_assignment(node: LN) -> bool:
|
||||||
|
"""Return True iff `node` is of the shape ( test := test )"""
|
||||||
|
inner = unwrap_singleton_parenthesis(node)
|
||||||
|
return inner is not None and inner.type == syms.namedexpr_test
|
||||||
|
|
||||||
|
|
||||||
def is_yield(node: LN) -> bool:
|
def is_yield(node: LN) -> bool:
|
||||||
"""Return True if `node` holds a `yield` or `yield from` expression."""
|
"""Return True if `node` holds a `yield` or `yield from` expression."""
|
||||||
if node.type == syms.yield_expr:
|
if node.type == syms.yield_expr:
|
||||||
@ -3198,6 +3215,9 @@ def get_features_used(node: Node) -> Set[Feature]:
|
|||||||
if "_" in n.value: # type: ignore
|
if "_" in n.value: # type: ignore
|
||||||
features.add(Feature.NUMERIC_UNDERSCORES)
|
features.add(Feature.NUMERIC_UNDERSCORES)
|
||||||
|
|
||||||
|
elif n.type == token.COLONEQUAL:
|
||||||
|
features.add(Feature.ASSIGNMENT_EXPRESSIONS)
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
n.type in {syms.typedargslist, syms.arglist}
|
n.type in {syms.typedargslist, syms.arglist}
|
||||||
and n.children
|
and n.children
|
||||||
@ -3479,32 +3499,58 @@ def __str__(self) -> str:
|
|||||||
return ", ".join(report) + "."
|
return ", ".join(report) + "."
|
||||||
|
|
||||||
|
|
||||||
def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
|
def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
|
||||||
|
filename = "<unknown>"
|
||||||
|
if sys.version_info >= (3, 8):
|
||||||
|
# TODO: support Python 4+ ;)
|
||||||
|
for minor_version in range(sys.version_info[1], 4, -1):
|
||||||
|
try:
|
||||||
|
return ast.parse(src, filename, feature_version=(3, minor_version))
|
||||||
|
except SyntaxError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
for feature_version in (7, 6):
|
for feature_version in (7, 6):
|
||||||
try:
|
try:
|
||||||
return ast3.parse(src, feature_version=feature_version)
|
return ast3.parse(src, filename, feature_version=feature_version)
|
||||||
except SyntaxError:
|
except SyntaxError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return ast27.parse(src)
|
return ast27.parse(src)
|
||||||
|
|
||||||
|
|
||||||
|
def _fixup_ast_constants(
|
||||||
|
node: Union[ast.AST, ast3.AST, ast27.AST]
|
||||||
|
) -> Union[ast.AST, ast3.AST, ast27.AST]:
|
||||||
|
"""Map ast nodes deprecated in 3.8 to Constant."""
|
||||||
|
# casts are required until this is released:
|
||||||
|
# https://github.com/python/typeshed/pull/3142
|
||||||
|
if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
|
||||||
|
return cast(ast.AST, ast.Constant(value=node.s))
|
||||||
|
elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
|
||||||
|
return cast(ast.AST, ast.Constant(value=node.n))
|
||||||
|
elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
|
||||||
|
return cast(ast.AST, ast.Constant(value=node.value))
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
def assert_equivalent(src: str, dst: str) -> None:
|
def assert_equivalent(src: str, dst: str) -> None:
|
||||||
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
|
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
|
||||||
|
|
||||||
def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
|
def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
|
||||||
"""Simple visitor generating strings to compare ASTs by content."""
|
"""Simple visitor generating strings to compare ASTs by content."""
|
||||||
|
|
||||||
|
node = _fixup_ast_constants(node)
|
||||||
|
|
||||||
yield f"{' ' * depth}{node.__class__.__name__}("
|
yield f"{' ' * depth}{node.__class__.__name__}("
|
||||||
|
|
||||||
for field in sorted(node._fields):
|
for field in sorted(node._fields):
|
||||||
# TypeIgnore has only one field 'lineno' which breaks this comparison
|
# TypeIgnore has only one field 'lineno' which breaks this comparison
|
||||||
if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
|
type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
|
||||||
|
if sys.version_info >= (3, 8):
|
||||||
|
type_ignore_classes += (ast.TypeIgnore,)
|
||||||
|
if isinstance(node, type_ignore_classes):
|
||||||
break
|
break
|
||||||
|
|
||||||
# Ignore str kind which is case sensitive / and ignores unicode_literals
|
|
||||||
if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
value = getattr(node, field)
|
value = getattr(node, field)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -3518,15 +3564,15 @@ def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
|
|||||||
# parentheses and they change the AST.
|
# parentheses and they change the AST.
|
||||||
if (
|
if (
|
||||||
field == "targets"
|
field == "targets"
|
||||||
and isinstance(node, (ast3.Delete, ast27.Delete))
|
and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
|
||||||
and isinstance(item, (ast3.Tuple, ast27.Tuple))
|
and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
|
||||||
):
|
):
|
||||||
for item in item.elts:
|
for item in item.elts:
|
||||||
yield from _v(item, depth + 2)
|
yield from _v(item, depth + 2)
|
||||||
elif isinstance(item, (ast3.AST, ast27.AST)):
|
elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
|
||||||
yield from _v(item, depth + 2)
|
yield from _v(item, depth + 2)
|
||||||
|
|
||||||
elif isinstance(value, (ast3.AST, ast27.AST)):
|
elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
|
||||||
yield from _v(value, depth + 2)
|
yield from _v(value, depth + 2)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -67,7 +67,7 @@ assert_stmt: 'assert' test [',' test]
|
|||||||
|
|
||||||
compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt
|
compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt
|
||||||
async_stmt: ASYNC (funcdef | with_stmt | for_stmt)
|
async_stmt: ASYNC (funcdef | with_stmt | for_stmt)
|
||||||
if_stmt: 'if' test ':' suite ('elif' test ':' suite)* ['else' ':' suite]
|
if_stmt: 'if' namedexpr_test ':' suite ('elif' namedexpr_test ':' suite)* ['else' ':' suite]
|
||||||
while_stmt: 'while' test ':' suite ['else' ':' suite]
|
while_stmt: 'while' test ':' suite ['else' ':' suite]
|
||||||
for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite]
|
for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite]
|
||||||
try_stmt: ('try' ':' suite
|
try_stmt: ('try' ':' suite
|
||||||
@ -91,6 +91,7 @@ testlist_safe: old_test [(',' old_test)+ [',']]
|
|||||||
old_test: or_test | old_lambdef
|
old_test: or_test | old_lambdef
|
||||||
old_lambdef: 'lambda' [varargslist] ':' old_test
|
old_lambdef: 'lambda' [varargslist] ':' old_test
|
||||||
|
|
||||||
|
namedexpr_test: test [':=' test]
|
||||||
test: or_test ['if' or_test 'else' test] | lambdef
|
test: or_test ['if' or_test 'else' test] | lambdef
|
||||||
or_test: and_test ('or' and_test)*
|
or_test: and_test ('or' and_test)*
|
||||||
and_test: not_test ('and' not_test)*
|
and_test: not_test ('and' not_test)*
|
||||||
@ -111,8 +112,8 @@ atom: ('(' [yield_expr|testlist_gexp] ')' |
|
|||||||
'{' [dictsetmaker] '}' |
|
'{' [dictsetmaker] '}' |
|
||||||
'`' testlist1 '`' |
|
'`' testlist1 '`' |
|
||||||
NAME | NUMBER | STRING+ | '.' '.' '.')
|
NAME | NUMBER | STRING+ | '.' '.' '.')
|
||||||
listmaker: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] )
|
listmaker: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] )
|
||||||
testlist_gexp: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] )
|
testlist_gexp: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] )
|
||||||
lambdef: 'lambda' [varargslist] ':' test
|
lambdef: 'lambda' [varargslist] ':' test
|
||||||
trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
|
trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
|
||||||
subscriptlist: subscript (',' subscript)* [',']
|
subscriptlist: subscript (',' subscript)* [',']
|
||||||
@ -137,6 +138,7 @@ arglist: argument (',' argument)* [',']
|
|||||||
# multiple (test comp_for) arguments are blocked; keyword unpackings
|
# multiple (test comp_for) arguments are blocked; keyword unpackings
|
||||||
# that precede iterable unpackings are blocked; etc.
|
# that precede iterable unpackings are blocked; etc.
|
||||||
argument: ( test [comp_for] |
|
argument: ( test [comp_for] |
|
||||||
|
test ':=' test |
|
||||||
test '=' test |
|
test '=' test |
|
||||||
'**' test |
|
'**' test |
|
||||||
'*' test )
|
'*' test )
|
||||||
|
@ -184,6 +184,7 @@ def report(self):
|
|||||||
// DOUBLESLASH
|
// DOUBLESLASH
|
||||||
//= DOUBLESLASHEQUAL
|
//= DOUBLESLASHEQUAL
|
||||||
-> RARROW
|
-> RARROW
|
||||||
|
:= COLONEQUAL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
opmap = {}
|
opmap = {}
|
||||||
|
@ -63,7 +63,8 @@
|
|||||||
AWAIT = 56
|
AWAIT = 56
|
||||||
ASYNC = 57
|
ASYNC = 57
|
||||||
ERRORTOKEN = 58
|
ERRORTOKEN = 58
|
||||||
N_TOKENS = 59
|
COLONEQUAL = 59
|
||||||
|
N_TOKENS = 60
|
||||||
NT_OFFSET = 256
|
NT_OFFSET = 256
|
||||||
#--end constants--
|
#--end constants--
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def _combinations(*l):
|
|||||||
# recognized as two instances of =).
|
# recognized as two instances of =).
|
||||||
Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"<>", r"!=",
|
Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"<>", r"!=",
|
||||||
r"//=?", r"->",
|
r"//=?", r"->",
|
||||||
r"[+\-*/%&@|^=<>]=?",
|
r"[+\-*/%&@|^=<>:]=?",
|
||||||
r"~")
|
r"~")
|
||||||
|
|
||||||
Bracket = '[][(){}]'
|
Bracket = '[][(){}]'
|
||||||
|
@ -57,6 +57,7 @@ class python_symbols(Symbols):
|
|||||||
import_stmt: int
|
import_stmt: int
|
||||||
lambdef: int
|
lambdef: int
|
||||||
listmaker: int
|
listmaker: int
|
||||||
|
namedexpr_test: int
|
||||||
not_test: int
|
not_test: int
|
||||||
old_comp_for: int
|
old_comp_for: int
|
||||||
old_comp_if: int
|
old_comp_if: int
|
||||||
|
40
tests/data/pep_572.py
Normal file
40
tests/data/pep_572.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
(a := 1)
|
||||||
|
(a := a)
|
||||||
|
if (match := pattern.search(data)) is None:
|
||||||
|
pass
|
||||||
|
[y := f(x), y ** 2, y ** 3]
|
||||||
|
filtered_data = [y for x in data if (y := f(x)) is None]
|
||||||
|
(y := f(x))
|
||||||
|
y0 = (y1 := f(x))
|
||||||
|
foo(x=(y := f(x)))
|
||||||
|
|
||||||
|
|
||||||
|
def foo(answer=(p := 42)):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def foo(answer: (p := 42) = 5):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
lambda: (x := 1)
|
||||||
|
(x := lambda: 1)
|
||||||
|
(x := lambda: (y := 1))
|
||||||
|
lambda line: (m := re.match(pattern, line)) and m.group(1)
|
||||||
|
x = (y := 0)
|
||||||
|
(z := (y := (x := 0)))
|
||||||
|
(info := (name, phone, *rest))
|
||||||
|
(x := 1, 2)
|
||||||
|
(total := total + tax)
|
||||||
|
len(lines := f.readlines())
|
||||||
|
foo(x := 3, cat="vector")
|
||||||
|
foo(cat=(category := "vector"))
|
||||||
|
if any(len(longline := l) >= 100 for l in lines):
|
||||||
|
print(longline)
|
||||||
|
if env_base := os.environ.get("PYTHONUSERBASE", None):
|
||||||
|
return env_base
|
||||||
|
if self._is_special and (ans := self._check_nans(context=context)):
|
||||||
|
return ans
|
||||||
|
foo(b := 2, a=1)
|
||||||
|
foo((b := 2), a=1)
|
||||||
|
foo(c=(b := 2), a=1)
|
@ -280,6 +280,23 @@ def test_expression(self) -> None:
|
|||||||
black.assert_equivalent(source, actual)
|
black.assert_equivalent(source, actual)
|
||||||
black.assert_stable(source, actual, black.FileMode())
|
black.assert_stable(source, actual, black.FileMode())
|
||||||
|
|
||||||
|
@patch("black.dump_to_file", dump_to_stderr)
|
||||||
|
def test_pep_572(self) -> None:
|
||||||
|
source, expected = read_data("pep_572")
|
||||||
|
actual = fs(source)
|
||||||
|
self.assertFormatEqual(expected, actual)
|
||||||
|
black.assert_stable(source, actual, black.FileMode())
|
||||||
|
if sys.version_info >= (3, 8):
|
||||||
|
black.assert_equivalent(source, actual)
|
||||||
|
|
||||||
|
def test_pep_572_version_detection(self) -> None:
|
||||||
|
source, _ = read_data("pep_572")
|
||||||
|
root = black.lib2to3_parse(source)
|
||||||
|
features = black.get_features_used(root)
|
||||||
|
self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
|
||||||
|
versions = black.detect_target_versions(root)
|
||||||
|
self.assertIn(black.TargetVersion.PY38, versions)
|
||||||
|
|
||||||
def test_expression_ff(self) -> None:
|
def test_expression_ff(self) -> None:
|
||||||
source, expected = read_data("expression")
|
source, expected = read_data("expression")
|
||||||
tmp_file = Path(black.dump_to_file(source))
|
tmp_file = Path(black.dump_to_file(source))
|
||||||
|
Loading…
Reference in New Issue
Block a user