Implementing mypyc support pt. 2 (#2431)

This commit is contained in:
Richard Si 2021-11-15 23:24:16 -05:00 committed by GitHub
parent 1d7260050d
commit 117891878e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 311 additions and 169 deletions

View File

@ -3,7 +3,6 @@
# free to run mypy on Windows, Linux, or macOS and get consistent
# results.
python_version=3.6
platform=linux
mypy_path=src
@ -24,6 +23,10 @@ warn_redundant_casts=True
warn_unused_ignores=True
disallow_any_generics=True
# Unreachable blocks have been an issue when compiling mypyc, let's try
# to avoid 'em in the first place.
warn_unreachable=True
# The following are off by default. Flip them on if you feel
# adventurous.
disallow_untyped_defs=True
@ -32,6 +35,11 @@ check_untyped_defs=True
# No incremental mode
cache_dir=/dev/null
[mypy-black]
# The following is because of `patch_click()`. Remove when
# we drop Python 3.6 support.
warn_unused_ignores=False
[mypy-black_primer.*]
# Until we're not supporting 3.6 primer needs this
disallow_any_generics=False

View File

@ -33,3 +33,6 @@ optional-tests = [
"no_blackd: run when `d` extra NOT installed",
"no_jupyter: run when `jupyter` extra NOT installed",
]
markers = [
"incompatible_with_mypyc: run when testing mypyc compiled black"
]

View File

@ -5,6 +5,7 @@
assert sys.version_info >= (3, 6, 2), "black requires Python 3.6.2+"
from pathlib import Path # noqa E402
from typing import List # noqa: E402
CURRENT_DIR = Path(__file__).parent
sys.path.insert(0, str(CURRENT_DIR)) # for setuptools.build_meta
@ -18,6 +19,17 @@ def get_long_description() -> str:
)
def find_python_files(base: Path) -> List[Path]:
files = []
for entry in base.iterdir():
if entry.is_file() and entry.suffix == ".py":
files.append(entry)
elif entry.is_dir():
files.extend(find_python_files(entry))
return files
USE_MYPYC = False
# To compile with mypyc, a mypyc checkout must be present on the PYTHONPATH
if len(sys.argv) > 1 and sys.argv[1] == "--use-mypyc":
@ -27,21 +39,34 @@ def get_long_description() -> str:
USE_MYPYC = True
if USE_MYPYC:
mypyc_targets = [
"src/black/__init__.py",
"src/blib2to3/pytree.py",
"src/blib2to3/pygram.py",
"src/blib2to3/pgen2/parse.py",
"src/blib2to3/pgen2/grammar.py",
"src/blib2to3/pgen2/token.py",
"src/blib2to3/pgen2/driver.py",
"src/blib2to3/pgen2/pgen.py",
]
from mypyc.build import mypycify
src = CURRENT_DIR / "src"
# TIP: filepaths are normalized to use forward slashes and are relative to ./src/
# before being checked against.
blocklist = [
# Not performance sensitive, so save bytes + compilation time:
"blib2to3/__init__.py",
"blib2to3/pgen2/__init__.py",
"black/output.py",
"black/concurrency.py",
"black/files.py",
"black/report.py",
# Breaks the test suite when compiled (and is also useless):
"black/debug.py",
# Compiled modules can't be run directly and that's a problem here:
"black/__main__.py",
]
discovered = []
# black-primer and blackd have no good reason to be compiled.
discovered.extend(find_python_files(src / "black"))
discovered.extend(find_python_files(src / "blib2to3"))
mypyc_targets = [
str(p) for p in discovered if p.relative_to(src).as_posix() not in blocklist
]
opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
ext_modules = mypycify(mypyc_targets, opt_level=opt_level)
ext_modules = mypycify(mypyc_targets, opt_level=opt_level, verbose=True)
else:
ext_modules = []

View File

@ -30,8 +30,9 @@
Union,
)
from dataclasses import replace
import click
from dataclasses import replace
from mypy_extensions import mypyc_attr
from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
from black.const import STDIN_PLACEHOLDER
@ -66,6 +67,8 @@
from _black_version import version as __version__
COMPILED = Path(__file__).suffix in (".pyd", ".so")
# types
FileContent = str
Encoding = str
@ -177,7 +180,12 @@ def validate_regex(
raise click.BadParameter("Not a valid regular expression") from None
@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
@click.command(
context_settings=dict(help_option_names=["-h", "--help"]),
# While Click does set this field automatically using the docstring, mypyc
# (annoyingly) strips 'em so we need to set it here too.
help="The uncompromising code formatter.",
)
@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@click.option(
"-l",
@ -346,7 +354,10 @@ def validate_regex(
" due to exclusion patterns."
),
)
@click.version_option(version=__version__)
@click.version_option(
version=__version__,
message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
)
@click.argument(
"src",
nargs=-1,
@ -387,7 +398,7 @@ def main(
experimental_string_processing: bool,
quiet: bool,
verbose: bool,
required_version: str,
required_version: Optional[str],
include: Pattern[str],
exclude: Optional[Pattern[str]],
extend_exclude: Optional[Pattern[str]],
@ -655,6 +666,9 @@ def reformat_one(
report.failed(src, str(exc))
# diff-shades depends on being to monkeypatch this function to operate. I know it's
# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
@mypyc_attr(patchable=True)
def reformat_many(
sources: Set[Path],
fast: bool,
@ -669,6 +683,7 @@ def reformat_many(
worker_count = workers if workers is not None else DEFAULT_WORKERS
if sys.platform == "win32":
# Work around https://bugs.python.org/issue26903
assert worker_count is not None
worker_count = min(worker_count, 60)
try:
executor = ProcessPoolExecutor(max_workers=worker_count)

View File

@ -49,7 +49,7 @@
DOT_PRIORITY: Final = 1
class BracketMatchError(KeyError):
class BracketMatchError(Exception):
"""Raised when an opening bracket is unable to be matched to a closing bracket."""

View File

@ -1,8 +1,14 @@
import sys
from dataclasses import dataclass
from functools import lru_cache
import regex as re
from typing import Iterator, List, Optional, Union
if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final
from blib2to3.pytree import Node, Leaf
from blib2to3.pgen2 import token
@ -12,11 +18,10 @@
# types
LN = Union[Leaf, Node]
FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
FMT_PASS = {*FMT_OFF, *FMT_SKIP}
FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
FMT_OFF: Final = {"# fmt: off", "# fmt:off", "# yapf: disable"}
FMT_SKIP: Final = {"# fmt: skip", "# fmt:skip"}
FMT_PASS: Final = {*FMT_OFF, *FMT_SKIP}
FMT_ON: Final = {"# fmt: on", "# fmt:on", "# yapf: enable"}
@dataclass

View File

@ -17,6 +17,7 @@
TYPE_CHECKING,
)
from mypy_extensions import mypyc_attr
from pathspec import PathSpec
from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
import tomli
@ -88,13 +89,14 @@ def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
return None
@mypyc_attr(patchable=True)
def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
"""Parse a pyproject toml file, pulling out relevant parts for Black
If parsing fails, will raise a tomli.TOMLDecodeError
"""
with open(path_config, encoding="utf8") as f:
pyproject_toml = tomli.load(f) # type: ignore # due to deprecated API usage
pyproject_toml = tomli.loads(f.read())
config = pyproject_toml.get("tool", {}).get("black", {})
return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}

View File

@ -333,7 +333,7 @@ def header(self) -> str:
return f"%%{self.name}"
@dataclasses.dataclass
# ast.NodeVisitor + dataclass = breakage under mypyc.
class CellMagicFinder(ast.NodeVisitor):
"""Find cell magics.
@ -352,7 +352,8 @@ class CellMagicFinder(ast.NodeVisitor):
and we look for instances of the latter.
"""
cell_magic: Optional[CellMagic] = None
def __init__(self, cell_magic: Optional[CellMagic] = None) -> None:
self.cell_magic = cell_magic
def visit_Expr(self, node: ast.Expr) -> None:
"""Find cell magic, extract header and body."""
@ -372,7 +373,8 @@ class OffsetAndMagic:
magic: str
@dataclasses.dataclass
# Unsurprisingly, subclassing ast.NodeVisitor means we can't use dataclasses here
# as mypyc will generate broken code.
class MagicFinder(ast.NodeVisitor):
"""Visit cell to look for get_ipython calls.
@ -392,9 +394,8 @@ class MagicFinder(ast.NodeVisitor):
types of magics).
"""
magics: Dict[int, List[OffsetAndMagic]] = dataclasses.field(
default_factory=lambda: collections.defaultdict(list)
)
def __init__(self) -> None:
self.magics: Dict[int, List[OffsetAndMagic]] = collections.defaultdict(list)
def visit_Assign(self, node: ast.Assign) -> None:
"""Look for system assign magics.

View File

@ -5,8 +5,6 @@
import sys
from typing import Collection, Iterator, List, Optional, Set, Union
from dataclasses import dataclass, field
from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT
from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS
from black.nodes import Visitor, syms, first_child_is_arith, ensure_visible
@ -40,7 +38,8 @@ class CannotSplit(CannotTransform):
"""A readable split that fits the allotted line length is impossible."""
@dataclass
# This isn't a dataclass because @dataclass + Generic breaks mypyc.
# See also https://github.com/mypyc/mypyc/issues/827.
class LineGenerator(Visitor[Line]):
"""Generates reformatted Line objects. Empty lines are not emitted.
@ -48,9 +47,11 @@ class LineGenerator(Visitor[Line]):
in ways that will no longer stringify to valid Python code on the tree.
"""
mode: Mode
remove_u_prefix: bool = False
current_line: Line = field(init=False)
def __init__(self, mode: Mode, remove_u_prefix: bool = False) -> None:
self.mode = mode
self.remove_u_prefix = remove_u_prefix
self.current_line: Line
self.__post_init__()
def line(self, indent: int = 0) -> Iterator[Line]:
"""Generate a line.
@ -339,7 +340,9 @@ def transform_line(
transformers = [left_hand_split]
else:
def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
def _rhs(
self: object, line: Line, features: Collection[Feature]
) -> Iterator[Line]:
"""Wraps calls to `right_hand_split`.
The calls increasingly `omit` right-hand trailers (bracket pairs with
@ -366,6 +369,12 @@ def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
line, line_length=mode.line_length, features=features
)
# HACK: nested functions (like _rhs) compiled by mypyc don't retain their
# __name__ attribute which is needed in `run_transformer` further down.
# Unfortunately a nested class breaks mypyc too. So a class must be created
# via type ... https://github.com/mypyc/mypyc/issues/884
rhs = type("rhs", (), {"__call__": _rhs})()
if mode.experimental_string_processing:
if line.inside_brackets:
transformers = [
@ -980,7 +989,7 @@ def run_transformer(
result.extend(transform_line(transformed_line, mode=mode, features=features))
if (
transform.__name__ != "rhs"
transform.__class__.__name__ != "rhs"
or not line.bracket_tracker.invisible
or any(bracket.value for bracket in line.bracket_tracker.invisible)
or line.contains_multiline_strings()

View File

@ -6,6 +6,7 @@
from dataclasses import dataclass, field
from enum import Enum
from operator import attrgetter
from typing import Dict, Set
from black.const import DEFAULT_LINE_LENGTH
@ -134,7 +135,7 @@ def get_cache_key(self) -> str:
if self.target_versions:
version_str = ",".join(
str(version.value)
for version in sorted(self.target_versions, key=lambda v: v.value)
for version in sorted(self.target_versions, key=attrgetter("value"))
)
else:
version_str = "-"

View File

@ -15,10 +15,12 @@
Union,
)
if sys.version_info < (3, 8):
from typing_extensions import Final
else:
if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final
from mypy_extensions import mypyc_attr
# lib2to3 fork
from blib2to3.pytree import Node, Leaf, type_repr
@ -30,7 +32,7 @@
pygram.initialize(CACHE_DIR)
syms = pygram.python_symbols
syms: Final = pygram.python_symbols
# types
@ -128,16 +130,21 @@
"//=",
}
IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
OPENING_BRACKETS = set(BRACKET.keys())
CLOSING_BRACKETS = set(BRACKET.values())
BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
IMPLICIT_TUPLE: Final = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
BRACKET: Final = {
token.LPAR: token.RPAR,
token.LSQB: token.RSQB,
token.LBRACE: token.RBRACE,
}
OPENING_BRACKETS: Final = set(BRACKET.keys())
CLOSING_BRACKETS: Final = set(BRACKET.values())
BRACKETS: Final = OPENING_BRACKETS | CLOSING_BRACKETS
ALWAYS_NO_SPACE: Final = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
RARROW = 55
@mypyc_attr(allow_interpreted_subclasses=True)
class Visitor(Generic[T]):
"""Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
@ -178,9 +185,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa: C901
`complex_subscript` signals whether the given leaf is part of a subscription
which has non-trivial arguments, like arithmetic expressions or function calls.
"""
NO = ""
SPACE = " "
DOUBLESPACE = " "
NO: Final = ""
SPACE: Final = " "
DOUBLESPACE: Final = " "
t = leaf.type
p = leaf.parent
v = leaf.value
@ -441,8 +448,8 @@ def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> b
def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
"""Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
stop_after = None
last = None
stop_after: Optional[Leaf] = None
last: Optional[Leaf] = None
for leaf in reversed(leaves):
if stop_after:
if leaf is stop_after:

View File

@ -11,6 +11,7 @@
from click import echo, style
@mypyc_attr(patchable=True)
def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
if message is not None:
if "bold" not in styles:
@ -19,6 +20,7 @@ def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
echo(message, nl=nl, err=True)
@mypyc_attr(patchable=True)
def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
if message is not None:
if "fg" not in styles:
@ -27,6 +29,7 @@ def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
echo(message, nl=nl, err=True)
@mypyc_attr(patchable=True)
def out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
_out(message, nl=nl, **styles)

View File

@ -4,11 +4,16 @@
import ast
import platform
import sys
from typing import Iterable, Iterator, List, Set, Union, Tuple
from typing import Any, Iterable, Iterator, List, Set, Tuple, Type, Union
if sys.version_info < (3, 8):
from typing_extensions import Final
else:
from typing import Final
# lib2to3 fork
from blib2to3.pytree import Node, Leaf
from blib2to3 import pygram, pytree
from blib2to3 import pygram
from blib2to3.pgen2 import driver
from blib2to3.pgen2.grammar import Grammar
from blib2to3.pgen2.parse import ParseError
@ -16,6 +21,9 @@
from black.mode import TargetVersion, Feature, supports_feature
from black.nodes import syms
ast3: Any
ast27: Any
_IS_PYPY = platform.python_implementation() == "PyPy"
try:
@ -86,7 +94,7 @@ def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -
src_txt += "\n"
for grammar in get_grammars(set(target_versions)):
drv = driver.Driver(grammar, pytree.convert)
drv = driver.Driver(grammar)
try:
result = drv.parse_string(src_txt, True)
break
@ -148,6 +156,10 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
raise SyntaxError(first_error)
ast3_AST: Final[Type[ast3.AST]] = ast3.AST
ast27_AST: Final[Type[ast27.AST]] = ast27.AST
def stringify_ast(
node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
) -> Iterator[str]:
@ -189,7 +201,13 @@ def stringify_ast(
elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
yield from stringify_ast(item, depth + 2)
elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
# Note that we are referencing the typed-ast ASTs via global variables and not
# direct module attribute accesses because that breaks mypyc. It's probably
# something to do with the ast3 / ast27 variables being marked as Any leading
# mypy to think this branch is always taken, leaving the rest of the code
# unanalyzed. Tighting up the types for the typed-ast AST types avoids the
# mypyc crash.
elif isinstance(value, (ast.AST, ast3_AST, ast27_AST)):
yield from stringify_ast(value, depth + 2)
else:

View File

@ -4,10 +4,20 @@
import regex as re
import sys
from functools import lru_cache
from typing import List, Pattern
if sys.version_info < (3, 8):
from typing_extensions import Final
else:
from typing import Final
STRING_PREFIX_CHARS = "furbFURB" # All possible string prefix characters.
STRING_PREFIX_CHARS: Final = "furbFURB" # All possible string prefix characters.
STRING_PREFIX_RE: Final = re.compile(
r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", re.DOTALL
)
FIRST_NON_WHITESPACE_RE: Final = re.compile(r"\s*\t+\s*(\S)")
def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
@ -37,7 +47,7 @@ def lines_with_leading_tabs_expanded(s: str) -> List[str]:
for line in s.splitlines():
# Find the index of the first non-whitespace character after a string of
# whitespace that includes at least one tab
match = re.match(r"\s*\t+\s*(\S)", line)
match = FIRST_NON_WHITESPACE_RE.match(line)
if match:
first_non_whitespace_idx = match.start(1)
@ -133,7 +143,7 @@ def normalize_string_prefix(s: str, remove_u_prefix: bool = False) -> str:
If remove_u_prefix is given, also removes any u prefix from the string.
"""
match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", s, re.DOTALL)
match = STRING_PREFIX_RE.match(s)
assert match is not None, f"failed to match string {s!r}"
orig_prefix = match.group(1)
new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
@ -142,6 +152,14 @@ def normalize_string_prefix(s: str, remove_u_prefix: bool = False) -> str:
return f"{new_prefix}{match.group(2)}"
# Re(gex) does actually cache patterns internally but this still improves
# performance on a long list literal of strings by 5-9% since lru_cache's
# caching overhead is much lower.
@lru_cache(maxsize=64)
def _cached_compile(pattern: str) -> re.Pattern:
return re.compile(pattern)
def normalize_string_quotes(s: str) -> str:
"""Prefer double quotes but only if it doesn't cause more escaping.
@ -166,9 +184,9 @@ def normalize_string_quotes(s: str) -> str:
return s # There's an internal error
prefix = s[:first_quote_pos]
unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
unescaped_new_quote = _cached_compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
escaped_new_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
escaped_orig_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
body = s[first_quote_pos + len(orig_quote) : -len(orig_quote)]
if "r" in prefix.casefold():
if unescaped_new_quote.search(body):

View File

@ -8,6 +8,7 @@
from typing import (
Any,
Callable,
ClassVar,
Collection,
Dict,
Iterable,
@ -20,6 +21,14 @@
TypeVar,
Union,
)
import sys
if sys.version_info < (3, 8):
from typing_extensions import Final
else:
from typing import Final
from mypy_extensions import trait
from black.rusty import Result, Ok, Err
@ -62,7 +71,6 @@ def TErr(err_msg: str) -> Err[CannotTransform]:
return Err(cant_transform)
@dataclass # type: ignore
class StringTransformer(ABC):
"""
An implementation of the Transformer protocol that relies on its
@ -90,9 +98,13 @@ class StringTransformer(ABC):
as much as possible.
"""
line_length: int
normalize_strings: bool
__name__ = "StringTransformer"
__name__: Final = "StringTransformer"
# Ideally this would be a dataclass, but unfortunately mypyc breaks when used with
# `abc.ABC`.
def __init__(self, line_length: int, normalize_strings: bool) -> None:
self.line_length = line_length
self.normalize_strings = normalize_strings
@abstractmethod
def do_match(self, line: Line) -> TMatchResult:
@ -184,6 +196,7 @@ class CustomSplit:
break_idx: int
@trait
class CustomSplitMapMixin:
"""
This mixin class is used to map merged strings to a sequence of
@ -191,8 +204,10 @@ class CustomSplitMapMixin:
the resultant substrings go over the configured max line length.
"""
_Key = Tuple[StringID, str]
_CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
_Key: ClassVar = Tuple[StringID, str]
_CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict(
tuple
)
@staticmethod
def _get_key(string: str) -> "CustomSplitMapMixin._Key":
@ -243,7 +258,7 @@ def has_custom_splits(self, string: str) -> bool:
return key in self._CUSTOM_SPLIT_MAP
class StringMerger(CustomSplitMapMixin, StringTransformer):
class StringMerger(StringTransformer, CustomSplitMapMixin):
"""StringTransformer that merges strings together.
Requirements:
@ -739,7 +754,7 @@ class BaseStringSplitter(StringTransformer):
* The target string is not a multiline (i.e. triple-quote) string.
"""
STRING_OPERATORS = [
STRING_OPERATORS: Final = [
token.EQEQUAL,
token.GREATER,
token.GREATEREQUAL,
@ -927,7 +942,7 @@ def _get_max_string_length(self, line: Line, string_idx: int) -> int:
return max_string_length
class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
class StringSplitter(BaseStringSplitter, CustomSplitMapMixin):
"""
StringTransformer that splits "atom" strings (i.e. strings which exist on
lines by themselves).
@ -965,9 +980,9 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
CustomSplit objects and add them to the custom split map.
"""
MIN_SUBSTR_SIZE = 6
MIN_SUBSTR_SIZE: Final = 6
# Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
RE_FEXPR = r"""
RE_FEXPR: Final = r"""
(?<!\{) (?:\{\{)* \{ (?!\{)
(?:
[^\{\}]
@ -1426,7 +1441,7 @@ def _get_string_operator_leaves(self, leaves: Iterable[Leaf]) -> List[Leaf]:
return string_op_leaves
class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin):
"""
StringTransformer that splits non-"atom" strings (i.e. strings that do not
exist on lines by themselves).
@ -1811,20 +1826,20 @@ class StringParser:
```
"""
DEFAULT_TOKEN = -1
DEFAULT_TOKEN: Final = 20210605
# String Parser States
START = 1
DOT = 2
NAME = 3
PERCENT = 4
SINGLE_FMT_ARG = 5
LPAR = 6
RPAR = 7
DONE = 8
START: Final = 1
DOT: Final = 2
NAME: Final = 3
PERCENT: Final = 4
SINGLE_FMT_ARG: Final = 5
LPAR: Final = 6
RPAR: Final = 7
DONE: Final = 8
# Lookup Table for Next State
_goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
_goto: Final[Dict[Tuple[ParserState, NodeType], ParserState]] = {
# A string trailer may start with '.' OR '%'.
(START, token.DOT): DOT,
(START, token.PERCENT): PERCENT,

View File

@ -104,13 +104,12 @@ async def async_main(
no_diff,
)
return int(ret_val)
finally:
if not keep and work_path.exists():
LOG.debug(f"Removing {work_path}")
rmtree(work_path, onerror=lib.handle_PermissionError)
return -2
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option(

View File

@ -19,3 +19,5 @@ Change Log:
https://github.com/python/cpython/commit/cae60187cf7a7b26281d012e1952fafe4e2e97e9
- "bpo-42316: Allow unparenthesized walrus operator in indexes (GH-23317)"
https://github.com/python/cpython/commit/b0aba1fcdc3da952698d99aec2334faa79a8b68c
- Tweaks to help mypyc compile faster code (including inlining type information,
"Final-ing", etc.)

View File

@ -23,6 +23,7 @@
import sys
from typing import (
Any,
cast,
IO,
Iterable,
List,
@ -34,14 +35,15 @@
Generic,
Union,
)
from contextlib import contextmanager
from dataclasses import dataclass, field
# Pgen imports
from . import grammar, parse, token, tokenize, pgen
from logging import Logger
from blib2to3.pytree import _Convert, NL
from blib2to3.pytree import NL
from blib2to3.pgen2.grammar import Grammar
from contextlib import contextmanager
from blib2to3.pgen2.tokenize import GoodTokenInfo
Path = Union[str, "os.PathLike[str]"]
@ -115,29 +117,23 @@ def can_advance(self, to: int) -> bool:
class Driver(object):
def __init__(
self,
grammar: Grammar,
convert: Optional[_Convert] = None,
logger: Optional[Logger] = None,
) -> None:
def __init__(self, grammar: Grammar, logger: Optional[Logger] = None) -> None:
self.grammar = grammar
if logger is None:
logger = logging.getLogger(__name__)
self.logger = logger
self.convert = convert
def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL:
def parse_tokens(self, tokens: Iterable[GoodTokenInfo], debug: bool = False) -> NL:
"""Parse a series of tokens and return the syntax tree."""
# XXX Move the prefix computation into a wrapper around tokenize.
proxy = TokenProxy(tokens)
p = parse.Parser(self.grammar, self.convert)
p = parse.Parser(self.grammar)
p.setup(proxy=proxy)
lineno = 1
column = 0
indent_columns = []
indent_columns: List[int] = []
type = value = start = end = line_text = None
prefix = ""
@ -163,6 +159,7 @@ def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL:
if type == token.OP:
type = grammar.opmap[value]
if debug:
assert type is not None
self.logger.debug(
"%s %r (prefix=%r)", token.tok_name[type], value, prefix
)
@ -174,7 +171,7 @@ def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL:
elif type == token.DEDENT:
_indent_col = indent_columns.pop()
prefix, _prefix = self._partially_consume_prefix(prefix, _indent_col)
if p.addtoken(type, value, (prefix, start)):
if p.addtoken(cast(int, type), value, (prefix, start)):
if debug:
self.logger.debug("Stop.")
break

View File

@ -29,7 +29,7 @@
TYPE_CHECKING,
)
from blib2to3.pgen2.grammar import Grammar
from blib2to3.pytree import NL, Context, RawNode, Leaf, Node
from blib2to3.pytree import convert, NL, Context, RawNode, Leaf, Node
if TYPE_CHECKING:
from blib2to3.driver import TokenProxy
@ -70,9 +70,7 @@ def switch_to(self, ilabel: int) -> Iterator[None]:
finally:
self.parser.stack = self._start_point
def add_token(
self, tok_type: int, tok_val: Optional[Text], raw: bool = False
) -> None:
def add_token(self, tok_type: int, tok_val: Text, raw: bool = False) -> None:
func: Callable[..., Any]
if raw:
func = self.parser._addtoken
@ -86,9 +84,7 @@ def add_token(
args.insert(0, ilabel)
func(*args)
def determine_route(
self, value: Optional[Text] = None, force: bool = False
) -> Optional[int]:
def determine_route(self, value: Text = None, force: bool = False) -> Optional[int]:
alive_ilabels = self.ilabels
if len(alive_ilabels) == 0:
*_, most_successful_ilabel = self._dead_ilabels
@ -164,6 +160,11 @@ def __init__(self, grammar: Grammar, convert: Optional[Convert] = None) -> None:
to be converted. The syntax tree is converted from the bottom
up.
**post-note: the convert argument is ignored since for Black's
usage, convert will always be blib2to3.pytree.convert. Allowing
this to be dynamic hurts mypyc's ability to use early binding.
These docs are left for historical and informational value.
A concrete syntax tree node is a (type, value, context, nodes)
tuple, where type is the node type (a token or symbol number),
value is None for symbols and a string for tokens, context is
@ -176,6 +177,7 @@ def __init__(self, grammar: Grammar, convert: Optional[Convert] = None) -> None:
"""
self.grammar = grammar
# See note in docstring above. TL;DR this is ignored.
self.convert = convert or lam_sub
def setup(self, proxy: "TokenProxy", start: Optional[int] = None) -> None:
@ -203,7 +205,7 @@ def setup(self, proxy: "TokenProxy", start: Optional[int] = None) -> None:
self.used_names: Set[str] = set()
self.proxy = proxy
def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool:
def addtoken(self, type: int, value: Text, context: Context) -> bool:
"""Add a token; return True iff this is the end of the program."""
# Map from token to label
ilabels = self.classify(type, value, context)
@ -237,7 +239,7 @@ def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool:
next_token_type, next_token_value, *_ = proxy.eat(counter)
if next_token_type == tokenize.OP:
next_token_type = grammar.opmap[cast(str, next_token_value)]
next_token_type = grammar.opmap[next_token_value]
recorder.add_token(next_token_type, next_token_value)
counter += 1
@ -247,9 +249,7 @@ def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool:
return self._addtoken(ilabel, type, value, context)
def _addtoken(
self, ilabel: int, type: int, value: Optional[Text], context: Context
) -> bool:
def _addtoken(self, ilabel: int, type: int, value: Text, context: Context) -> bool:
# Loop until the token is shifted; may raise exceptions
while True:
dfa, state, node = self.stack[-1]
@ -257,10 +257,18 @@ def _addtoken(
arcs = states[state]
# Look for a state with this label
for i, newstate in arcs:
t, v = self.grammar.labels[i]
if ilabel == i:
t = self.grammar.labels[i][0]
if t >= 256:
# See if it's a symbol and if we're in its first set
itsdfa = self.grammar.dfas[t]
itsstates, itsfirst = itsdfa
if ilabel in itsfirst:
# Push a symbol
self.push(t, itsdfa, newstate, context)
break # To continue the outer while loop
elif ilabel == i:
# Look it up in the list of labels
assert t < 256
# Shift a token; we're done with it
self.shift(type, value, newstate, context)
# Pop while we are in an accept-only state
@ -274,14 +282,7 @@ def _addtoken(
states, first = dfa
# Done with this token
return False
elif t >= 256:
# See if it's a symbol and if we're in its first set
itsdfa = self.grammar.dfas[t]
itsstates, itsfirst = itsdfa
if ilabel in itsfirst:
# Push a symbol
self.push(t, self.grammar.dfas[t], newstate, context)
break # To continue the outer while loop
else:
if (0, state) in arcs:
# An accepting state, pop it and try something else
@ -293,14 +294,13 @@ def _addtoken(
# No success finding a transition
raise ParseError("bad input", type, value, context)
def classify(self, type: int, value: Optional[Text], context: Context) -> List[int]:
def classify(self, type: int, value: Text, context: Context) -> List[int]:
"""Turn a token into a label. (Internal)
Depending on whether the value is a soft-keyword or not,
this function may return multiple labels to choose from."""
if type == token.NAME:
# Keep a listing of all used names
assert value is not None
self.used_names.add(value)
# Check for reserved words
if value in self.grammar.keywords:
@ -317,18 +317,13 @@ def classify(self, type: int, value: Optional[Text], context: Context) -> List[i
raise ParseError("bad token", type, value, context)
return [ilabel]
def shift(
self, type: int, value: Optional[Text], newstate: int, context: Context
) -> None:
def shift(self, type: int, value: Text, newstate: int, context: Context) -> None:
"""Shift a token. (Internal)"""
dfa, state, node = self.stack[-1]
assert value is not None
assert context is not None
rawnode: RawNode = (type, value, context, None)
newnode = self.convert(self.grammar, rawnode)
if newnode is not None:
assert node[-1] is not None
node[-1].append(newnode)
newnode = convert(self.grammar, rawnode)
assert node[-1] is not None
node[-1].append(newnode)
self.stack[-1] = (dfa, newstate, node)
def push(self, type: int, newdfa: DFAS, newstate: int, context: Context) -> None:
@ -341,12 +336,11 @@ def push(self, type: int, newdfa: DFAS, newstate: int, context: Context) -> None
def pop(self) -> None:
"""Pop a nonterminal. (Internal)"""
popdfa, popstate, popnode = self.stack.pop()
newnode = self.convert(self.grammar, popnode)
if newnode is not None:
if self.stack:
dfa, state, node = self.stack[-1]
assert node[-1] is not None
node[-1].append(newnode)
else:
self.rootnode = newnode
self.rootnode.used_names = self.used_names
newnode = convert(self.grammar, popnode)
if self.stack:
dfa, state, node = self.stack[-1]
assert node[-1] is not None
node[-1].append(newnode)
else:
self.rootnode = newnode
self.rootnode.used_names = self.used_names

View File

@ -27,6 +27,7 @@
function to which the 5 fields described above are passed as 5 arguments,
each time a new token is found."""
import sys
from typing import (
Callable,
Iterable,
@ -39,6 +40,12 @@
Union,
cast,
)
if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final
from blib2to3.pgen2.token import *
from blib2to3.pgen2.grammar import Grammar
@ -139,7 +146,7 @@ def _combinations(*l):
PseudoExtras = group(r"\\\r?\n", Comment, Triple)
PseudoToken = Whitespace + group(PseudoExtras, Number, Funny, ContStr, Name)
pseudoprog = re.compile(PseudoToken, re.UNICODE)
pseudoprog: Final = re.compile(PseudoToken, re.UNICODE)
single3prog = re.compile(Single3)
double3prog = re.compile(Double3)
@ -149,7 +156,7 @@ def _combinations(*l):
| {"u", "U", "ur", "uR", "Ur", "UR"}
)
endprogs = {
endprogs: Final = {
"'": re.compile(Single),
'"': re.compile(Double),
"'''": single3prog,
@ -159,12 +166,12 @@ def _combinations(*l):
**{prefix: None for prefix in _strprefixes},
}
triple_quoted = (
triple_quoted: Final = (
{"'''", '"""'}
| {f"{prefix}'''" for prefix in _strprefixes}
| {f'{prefix}"""' for prefix in _strprefixes}
)
single_quoted = (
single_quoted: Final = (
{"'", '"'}
| {f"{prefix}'" for prefix in _strprefixes}
| {f'{prefix}"' for prefix in _strprefixes}
@ -418,7 +425,7 @@ def generate_tokens(
logical line; continuation lines are included.
"""
lnum = parenlev = continued = 0
numchars = "0123456789"
numchars: Final = "0123456789"
contstr, needcont = "", 0
contline: Optional[str] = None
indents = [0]
@ -427,7 +434,7 @@ def generate_tokens(
# `await` as keywords.
async_keywords = False if grammar is None else grammar.async_keywords
# 'stashed' and 'async_*' are used for async/await parsing
stashed = None
stashed: Optional[GoodTokenInfo] = None
async_def = False
async_def_indent = 0
async_def_nl = False
@ -440,7 +447,7 @@ def generate_tokens(
line = readline()
except StopIteration:
line = ""
lnum = lnum + 1
lnum += 1
pos, max = 0, len(line)
if contstr: # continued string
@ -481,14 +488,14 @@ def generate_tokens(
column = 0
while pos < max: # measure leading whitespace
if line[pos] == " ":
column = column + 1
column += 1
elif line[pos] == "\t":
column = (column // tabsize + 1) * tabsize
elif line[pos] == "\f":
column = 0
else:
break
pos = pos + 1
pos += 1
if pos == max:
break
@ -507,7 +514,7 @@ def generate_tokens(
COMMENT,
comment_token,
(lnum, pos),
(lnum, pos + len(comment_token)),
(lnum, nl_pos),
line,
)
yield (NL, line[nl_pos:], (lnum, nl_pos), (lnum, len(line)), line)
@ -652,16 +659,16 @@ def generate_tokens(
continued = 1
else:
if initial in "([{":
parenlev = parenlev + 1
parenlev += 1
elif initial in ")]}":
parenlev = parenlev - 1
parenlev -= 1
if stashed:
yield stashed
stashed = None
yield (OP, token, spos, epos, line)
else:
yield (ERRORTOKEN, line[pos], (lnum, pos), (lnum, pos + 1), line)
pos = pos + 1
pos += 1
if stashed:
yield stashed

View File

@ -14,7 +14,6 @@
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
@ -92,8 +91,6 @@ def __eq__(self, other: Any) -> bool:
return NotImplemented
return self._eq(other)
__hash__ = None # type: Any # For Py3 compatibility.
@property
def prefix(self) -> Text:
raise NotImplementedError
@ -437,7 +434,7 @@ def __str__(self) -> Text:
This reproduces the input source exactly.
"""
return self.prefix + str(self.value)
return self._prefix + str(self.value)
def _eq(self, other) -> bool:
"""Compare two nodes for equality."""
@ -672,8 +669,11 @@ def __init__(
newcontent = list(content)
for i, item in enumerate(newcontent):
assert isinstance(item, BasePattern), (i, item)
if isinstance(item, WildcardPattern):
self.wildcards = True
# I don't even think this code is used anywhere, but it does cause
# unreachable errors from mypy. This function's signature does look
# odd though *shrug*.
if isinstance(item, WildcardPattern): # type: ignore[unreachable]
self.wildcards = True # type: ignore[unreachable]
self.type = type
self.content = newcontent
self.name = name
@ -978,6 +978,3 @@ def generate_matches(
r.update(r0)
r.update(r1)
yield c0 + c1, r
_Convert = Callable[[Grammar, RawNode], Any]

View File

@ -122,7 +122,7 @@ def invokeBlack(
runner = BlackRunner()
if ignore_config:
args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
result = runner.invoke(black.main, args)
result = runner.invoke(black.main, args, catch_exceptions=False)
assert result.stdout_bytes is not None
assert result.stderr_bytes is not None
msg = (
@ -841,6 +841,7 @@ def test_get_future_imports(self) -> None:
)
self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
@pytest.mark.incompatible_with_mypyc
def test_debug_visitor(self) -> None:
source, _ = read_data("debug_visitor.py")
expected, _ = read_data("debug_visitor.out")
@ -891,6 +892,7 @@ def test_endmarker(self) -> None:
self.assertEqual(len(n.children), 1)
self.assertEqual(n.children[0].type, black.token.ENDMARKER)
@pytest.mark.incompatible_with_mypyc
@unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
def test_assertFormatEqual(self) -> None:
out_lines = []
@ -1055,6 +1057,7 @@ def test_pipe_force_py36(self) -> None:
actual = result.output
self.assertFormatEqual(actual, expected)
@pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin(self) -> None:
with patch(
"black.format_stdin_to_stdout",
@ -1072,6 +1075,7 @@ def test_reformat_one_with_stdin(self) -> None:
fsts.assert_called_once()
report.done.assert_called_with(path, black.Changed.YES)
@pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_filename(self) -> None:
with patch(
"black.format_stdin_to_stdout",
@ -1094,6 +1098,7 @@ def test_reformat_one_with_stdin_filename(self) -> None:
# __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES)
@pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_filename_pyi(self) -> None:
with patch(
"black.format_stdin_to_stdout",
@ -1118,6 +1123,7 @@ def test_reformat_one_with_stdin_filename_pyi(self) -> None:
# __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES)
@pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
with patch(
"black.format_stdin_to_stdout",
@ -1142,6 +1148,7 @@ def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
# __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES)
@pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_and_existing_path(self) -> None:
with patch(
"black.format_stdin_to_stdout",
@ -1296,6 +1303,7 @@ def test_read_pyproject_toml(self) -> None:
self.assertEqual(config["exclude"], r"\.pyi?$")
self.assertEqual(config["include"], r"\.py?$")
@pytest.mark.incompatible_with_mypyc
def test_find_project_root(self) -> None:
with TemporaryDirectory() as workspace:
root = Path(workspace)
@ -1483,6 +1491,7 @@ def test_code_option_color_diff(self) -> None:
assert output == result_diff, "The output did not match the expected value."
assert result.exit_code == 0, "The exit code is incorrect."
@pytest.mark.incompatible_with_mypyc
def test_code_option_safe(self) -> None:
"""Test that the code option throws an error when the sanity checks fail."""
# Patch black.assert_equivalent to ensure the sanity checks fail
@ -1507,6 +1516,7 @@ def test_code_option_fast(self) -> None:
self.compare_results(result, formatted, 0)
@pytest.mark.incompatible_with_mypyc
def test_code_option_config(self) -> None:
"""
Test that the code option finds the pyproject.toml in the current directory.
@ -1527,6 +1537,7 @@ def test_code_option_config(self) -> None:
call_args[0].lower() == str(pyproject_path).lower()
), "Incorrect config loaded."
@pytest.mark.incompatible_with_mypyc
def test_code_option_parent_config(self) -> None:
"""
Test that the code option finds the pyproject.toml in the parent directory.
@ -1894,6 +1905,7 @@ def test_extend_exclude(self) -> None:
src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
)
@pytest.mark.incompatible_with_mypyc
def test_symlink_out_of_root_directory(self) -> None:
path = MagicMock()
root = THIS_DIR.resolve()
@ -2047,8 +2059,12 @@ def test_python_2_deprecation_autodetection_extended() -> None:
}, non_python2_case
with open(black.__file__, "r", encoding="utf-8") as _bf:
black_source_lines = _bf.readlines()
try:
with open(black.__file__, "r", encoding="utf-8") as _bf:
black_source_lines = _bf.readlines()
except UnicodeDecodeError:
if not black.COMPILED:
raise
def tracefunc(