from __future__ import annotations now implies 3.7+ (#2690)

This commit is contained in:
Batuhan Taskaya 2021-12-15 02:22:56 +03:00 committed by GitHub
parent 1c6b3a3a6f
commit ab86513710
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 5 deletions

View File

@ -15,6 +15,7 @@
- Fix determination of f-string expression spans (#2654) - Fix determination of f-string expression spans (#2654)
- Fix bad formatting of error messages about EOF in multi-line statements (#2343) - Fix bad formatting of error messages about EOF in multi-line statements (#2343)
- Functions and classes in blocks now have more consistent surrounding spacing (#2472) - Functions and classes in blocks now have more consistent surrounding spacing (#2472)
- `from __future__ import annotations` statement now implies Python 3.7+ (#2690)
#### Jupyter Notebook support #### Jupyter Notebook support

View File

@ -40,7 +40,7 @@
from black.lines import Line, EmptyLineTracker from black.lines import Line, EmptyLineTracker
from black.linegen import transform_line, LineGenerator, LN from black.linegen import transform_line, LineGenerator, LN
from black.comments import normalize_fmt_off from black.comments import normalize_fmt_off
from black.mode import Mode, TargetVersion from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
from black.mode import Feature, supports_feature, VERSION_TO_FEATURES from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
from black.concurrency import cancel, shutdown, maybe_install_uvloop from black.concurrency import cancel, shutdown, maybe_install_uvloop
@ -1080,7 +1080,7 @@ def f(
if mode.target_versions: if mode.target_versions:
versions = mode.target_versions versions = mode.target_versions
else: else:
versions = detect_target_versions(src_node) versions = detect_target_versions(src_node, future_imports=future_imports)
# TODO: fully drop support and this code hopefully in January 2022 :D # TODO: fully drop support and this code hopefully in January 2022 :D
if TargetVersion.PY27 in mode.target_versions or versions == {TargetVersion.PY27}: if TargetVersion.PY27 in mode.target_versions or versions == {TargetVersion.PY27}:
@ -1132,7 +1132,9 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
return tiow.read(), encoding, newline return tiow.read(), encoding, newline
def get_features_used(node: Node) -> Set[Feature]: # noqa: C901 def get_features_used( # noqa: C901
node: Node, *, future_imports: Optional[Set[str]] = None
) -> Set[Feature]:
"""Return a set of (relatively) new Python features used in this file. """Return a set of (relatively) new Python features used in this file.
Currently looking for: Currently looking for:
@ -1142,9 +1144,17 @@ def get_features_used(node: Node) -> Set[Feature]: # noqa: C901
- positional only arguments in function signatures and lambdas; - positional only arguments in function signatures and lambdas;
- assignment expression; - assignment expression;
- relaxed decorator syntax; - relaxed decorator syntax;
- usage of __future__ flags (annotations);
- print / exec statements; - print / exec statements;
""" """
features: Set[Feature] = set() features: Set[Feature] = set()
if future_imports:
features |= {
FUTURE_FLAG_TO_FEATURE[future_import]
for future_import in future_imports
if future_import in FUTURE_FLAG_TO_FEATURE
}
for n in node.pre_order(): for n in node.pre_order():
if n.type == token.STRING: if n.type == token.STRING:
value_head = n.value[:2] # type: ignore value_head = n.value[:2] # type: ignore
@ -1229,9 +1239,11 @@ def get_features_used(node: Node) -> Set[Feature]: # noqa: C901
return features return features
def detect_target_versions(node: Node) -> Set[TargetVersion]: def detect_target_versions(
node: Node, *, future_imports: Optional[Set[str]] = None
) -> Set[TargetVersion]:
"""Detect the version to target based on the nodes used.""" """Detect the version to target based on the nodes used."""
features = get_features_used(node) features = get_features_used(node, future_imports=future_imports)
return { return {
version for version in TargetVersion if features <= VERSION_TO_FEATURES[version] version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
} }

View File

@ -4,11 +4,18 @@
chosen by the user. chosen by the user.
""" """
import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from operator import attrgetter from operator import attrgetter
from typing import Dict, Set from typing import Dict, Set
if sys.version_info < (3, 8):
from typing_extensions import Final
else:
from typing import Final
from black.const import DEFAULT_LINE_LENGTH from black.const import DEFAULT_LINE_LENGTH
@ -44,6 +51,9 @@ class Feature(Enum):
PATTERN_MATCHING = 11 PATTERN_MATCHING = 11
FORCE_OPTIONAL_PARENTHESES = 50 FORCE_OPTIONAL_PARENTHESES = 50
# __future__ flags
FUTURE_ANNOTATIONS = 51
# temporary for Python 2 deprecation # temporary for Python 2 deprecation
PRINT_STMT = 200 PRINT_STMT = 200
EXEC_STMT = 201 EXEC_STMT = 201
@ -55,6 +65,11 @@ class Feature(Enum):
BACKQUOTE_REPR = 207 BACKQUOTE_REPR = 207
FUTURE_FLAG_TO_FEATURE: Final = {
"annotations": Feature.FUTURE_ANNOTATIONS,
}
VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
TargetVersion.PY27: { TargetVersion.PY27: {
Feature.ASYNC_IDENTIFIERS, Feature.ASYNC_IDENTIFIERS,
@ -89,6 +104,7 @@ class Feature(Enum):
Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF, Feature.TRAILING_COMMA_IN_DEF,
Feature.ASYNC_KEYWORDS, Feature.ASYNC_KEYWORDS,
Feature.FUTURE_ANNOTATIONS,
}, },
TargetVersion.PY38: { TargetVersion.PY38: {
Feature.UNICODE_LITERALS, Feature.UNICODE_LITERALS,
@ -97,6 +113,7 @@ class Feature(Enum):
Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF, Feature.TRAILING_COMMA_IN_DEF,
Feature.ASYNC_KEYWORDS, Feature.ASYNC_KEYWORDS,
Feature.FUTURE_ANNOTATIONS,
Feature.ASSIGNMENT_EXPRESSIONS, Feature.ASSIGNMENT_EXPRESSIONS,
Feature.POS_ONLY_ARGUMENTS, Feature.POS_ONLY_ARGUMENTS,
}, },
@ -107,6 +124,7 @@ class Feature(Enum):
Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF, Feature.TRAILING_COMMA_IN_DEF,
Feature.ASYNC_KEYWORDS, Feature.ASYNC_KEYWORDS,
Feature.FUTURE_ANNOTATIONS,
Feature.ASSIGNMENT_EXPRESSIONS, Feature.ASSIGNMENT_EXPRESSIONS,
Feature.RELAXED_DECORATORS, Feature.RELAXED_DECORATORS,
Feature.POS_ONLY_ARGUMENTS, Feature.POS_ONLY_ARGUMENTS,
@ -118,6 +136,7 @@ class Feature(Enum):
Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF, Feature.TRAILING_COMMA_IN_DEF,
Feature.ASYNC_KEYWORDS, Feature.ASYNC_KEYWORDS,
Feature.FUTURE_ANNOTATIONS,
Feature.ASSIGNMENT_EXPRESSIONS, Feature.ASSIGNMENT_EXPRESSIONS,
Feature.RELAXED_DECORATORS, Feature.RELAXED_DECORATORS,
Feature.POS_ONLY_ARGUMENTS, Feature.POS_ONLY_ARGUMENTS,

View File

@ -811,6 +811,24 @@ def test_get_features_used(self) -> None:
node = black.lib2to3_parse("def fn(a, /, b): ...") node = black.lib2to3_parse("def fn(a, /, b): ...")
self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS}) self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
def test_get_features_used_for_future_flags(self) -> None:
for src, features in [
("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
(
"from __future__ import (other, annotations)",
{Feature.FUTURE_ANNOTATIONS},
),
("a = 1 + 2\nfrom something import annotations", set()),
("from __future__ import x, y", set()),
]:
with self.subTest(src=src, features=features):
node = black.lib2to3_parse(src)
future_imports = black.get_future_imports(node)
self.assertEqual(
black.get_features_used(node, future_imports=future_imports),
features,
)
def test_get_future_imports(self) -> None: def test_get_future_imports(self) -> None:
node = black.lib2to3_parse("\n") node = black.lib2to3_parse("\n")
self.assertEqual(set(), black.get_future_imports(node)) self.assertEqual(set(), black.get_future_imports(node))