Implementing mypyc support pt. 2 (#2431)

This commit is contained in:
Richard Si 2021-11-15 23:24:16 -05:00 committed by GitHub
parent 1d7260050d
commit 117891878e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 311 additions and 169 deletions

View File

@ -3,7 +3,6 @@
# free to run mypy on Windows, Linux, or macOS and get consistent # free to run mypy on Windows, Linux, or macOS and get consistent
# results. # results.
python_version=3.6 python_version=3.6
platform=linux
mypy_path=src mypy_path=src
@ -24,6 +23,10 @@ warn_redundant_casts=True
warn_unused_ignores=True warn_unused_ignores=True
disallow_any_generics=True disallow_any_generics=True
# Unreachable blocks have been an issue when compiling mypyc, let's try
# to avoid 'em in the first place.
warn_unreachable=True
# The following are off by default. Flip them on if you feel # The following are off by default. Flip them on if you feel
# adventurous. # adventurous.
disallow_untyped_defs=True disallow_untyped_defs=True
@ -32,6 +35,11 @@ check_untyped_defs=True
# No incremental mode # No incremental mode
cache_dir=/dev/null cache_dir=/dev/null
[mypy-black]
# The following is because of `patch_click()`. Remove when
# we drop Python 3.6 support.
warn_unused_ignores=False
[mypy-black_primer.*] [mypy-black_primer.*]
# Until we're not supporting 3.6 primer needs this # Until we're not supporting 3.6 primer needs this
disallow_any_generics=False disallow_any_generics=False

View File

@ -33,3 +33,6 @@ optional-tests = [
"no_blackd: run when `d` extra NOT installed", "no_blackd: run when `d` extra NOT installed",
"no_jupyter: run when `jupyter` extra NOT installed", "no_jupyter: run when `jupyter` extra NOT installed",
] ]
markers = [
"incompatible_with_mypyc: run when testing mypyc compiled black"
]

View File

@ -5,6 +5,7 @@
assert sys.version_info >= (3, 6, 2), "black requires Python 3.6.2+" assert sys.version_info >= (3, 6, 2), "black requires Python 3.6.2+"
from pathlib import Path # noqa E402 from pathlib import Path # noqa E402
from typing import List # noqa: E402
CURRENT_DIR = Path(__file__).parent CURRENT_DIR = Path(__file__).parent
sys.path.insert(0, str(CURRENT_DIR)) # for setuptools.build_meta sys.path.insert(0, str(CURRENT_DIR)) # for setuptools.build_meta
@ -18,6 +19,17 @@ def get_long_description() -> str:
) )
def find_python_files(base: Path) -> List[Path]:
files = []
for entry in base.iterdir():
if entry.is_file() and entry.suffix == ".py":
files.append(entry)
elif entry.is_dir():
files.extend(find_python_files(entry))
return files
USE_MYPYC = False USE_MYPYC = False
# To compile with mypyc, a mypyc checkout must be present on the PYTHONPATH # To compile with mypyc, a mypyc checkout must be present on the PYTHONPATH
if len(sys.argv) > 1 and sys.argv[1] == "--use-mypyc": if len(sys.argv) > 1 and sys.argv[1] == "--use-mypyc":
@ -27,21 +39,34 @@ def get_long_description() -> str:
USE_MYPYC = True USE_MYPYC = True
if USE_MYPYC: if USE_MYPYC:
mypyc_targets = [
"src/black/__init__.py",
"src/blib2to3/pytree.py",
"src/blib2to3/pygram.py",
"src/blib2to3/pgen2/parse.py",
"src/blib2to3/pgen2/grammar.py",
"src/blib2to3/pgen2/token.py",
"src/blib2to3/pgen2/driver.py",
"src/blib2to3/pgen2/pgen.py",
]
from mypyc.build import mypycify from mypyc.build import mypycify
src = CURRENT_DIR / "src"
# TIP: filepaths are normalized to use forward slashes and are relative to ./src/
# before being checked against.
blocklist = [
# Not performance sensitive, so save bytes + compilation time:
"blib2to3/__init__.py",
"blib2to3/pgen2/__init__.py",
"black/output.py",
"black/concurrency.py",
"black/files.py",
"black/report.py",
# Breaks the test suite when compiled (and is also useless):
"black/debug.py",
# Compiled modules can't be run directly and that's a problem here:
"black/__main__.py",
]
discovered = []
# black-primer and blackd have no good reason to be compiled.
discovered.extend(find_python_files(src / "black"))
discovered.extend(find_python_files(src / "blib2to3"))
mypyc_targets = [
str(p) for p in discovered if p.relative_to(src).as_posix() not in blocklist
]
opt_level = os.getenv("MYPYC_OPT_LEVEL", "3") opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
ext_modules = mypycify(mypyc_targets, opt_level=opt_level) ext_modules = mypycify(mypyc_targets, opt_level=opt_level, verbose=True)
else: else:
ext_modules = [] ext_modules = []

View File

@ -30,8 +30,9 @@
Union, Union,
) )
from dataclasses import replace
import click import click
from dataclasses import replace
from mypy_extensions import mypyc_attr
from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
from black.const import STDIN_PLACEHOLDER from black.const import STDIN_PLACEHOLDER
@ -66,6 +67,8 @@
from _black_version import version as __version__ from _black_version import version as __version__
COMPILED = Path(__file__).suffix in (".pyd", ".so")
# types # types
FileContent = str FileContent = str
Encoding = str Encoding = str
@ -177,7 +180,12 @@ def validate_regex(
raise click.BadParameter("Not a valid regular expression") from None raise click.BadParameter("Not a valid regular expression") from None
@click.command(context_settings=dict(help_option_names=["-h", "--help"])) @click.command(
context_settings=dict(help_option_names=["-h", "--help"]),
# While Click does set this field automatically using the docstring, mypyc
# (annoyingly) strips 'em so we need to set it here too.
help="The uncompromising code formatter.",
)
@click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@click.option( @click.option(
"-l", "-l",
@ -346,7 +354,10 @@ def validate_regex(
" due to exclusion patterns." " due to exclusion patterns."
), ),
) )
@click.version_option(version=__version__) @click.version_option(
version=__version__,
message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
)
@click.argument( @click.argument(
"src", "src",
nargs=-1, nargs=-1,
@ -387,7 +398,7 @@ def main(
experimental_string_processing: bool, experimental_string_processing: bool,
quiet: bool, quiet: bool,
verbose: bool, verbose: bool,
required_version: str, required_version: Optional[str],
include: Pattern[str], include: Pattern[str],
exclude: Optional[Pattern[str]], exclude: Optional[Pattern[str]],
extend_exclude: Optional[Pattern[str]], extend_exclude: Optional[Pattern[str]],
@ -655,6 +666,9 @@ def reformat_one(
report.failed(src, str(exc)) report.failed(src, str(exc))
# diff-shades depends on being to monkeypatch this function to operate. I know it's
# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
@mypyc_attr(patchable=True)
def reformat_many( def reformat_many(
sources: Set[Path], sources: Set[Path],
fast: bool, fast: bool,
@ -669,6 +683,7 @@ def reformat_many(
worker_count = workers if workers is not None else DEFAULT_WORKERS worker_count = workers if workers is not None else DEFAULT_WORKERS
if sys.platform == "win32": if sys.platform == "win32":
# Work around https://bugs.python.org/issue26903 # Work around https://bugs.python.org/issue26903
assert worker_count is not None
worker_count = min(worker_count, 60) worker_count = min(worker_count, 60)
try: try:
executor = ProcessPoolExecutor(max_workers=worker_count) executor = ProcessPoolExecutor(max_workers=worker_count)

View File

@ -49,7 +49,7 @@
DOT_PRIORITY: Final = 1 DOT_PRIORITY: Final = 1
class BracketMatchError(KeyError): class BracketMatchError(Exception):
"""Raised when an opening bracket is unable to be matched to a closing bracket.""" """Raised when an opening bracket is unable to be matched to a closing bracket."""

View File

@ -1,8 +1,14 @@
import sys
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
import regex as re import regex as re
from typing import Iterator, List, Optional, Union from typing import Iterator, List, Optional, Union
if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final
from blib2to3.pytree import Node, Leaf from blib2to3.pytree import Node, Leaf
from blib2to3.pgen2 import token from blib2to3.pgen2 import token
@ -12,11 +18,10 @@
# types # types
LN = Union[Leaf, Node] LN = Union[Leaf, Node]
FMT_OFF: Final = {"# fmt: off", "# fmt:off", "# yapf: disable"}
FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"} FMT_SKIP: Final = {"# fmt: skip", "# fmt:skip"}
FMT_SKIP = {"# fmt: skip", "# fmt:skip"} FMT_PASS: Final = {*FMT_OFF, *FMT_SKIP}
FMT_PASS = {*FMT_OFF, *FMT_SKIP} FMT_ON: Final = {"# fmt: on", "# fmt:on", "# yapf: enable"}
FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
@dataclass @dataclass

View File

@ -17,6 +17,7 @@
TYPE_CHECKING, TYPE_CHECKING,
) )
from mypy_extensions import mypyc_attr
from pathspec import PathSpec from pathspec import PathSpec
from pathspec.patterns.gitwildmatch import GitWildMatchPatternError from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
import tomli import tomli
@ -88,13 +89,14 @@ def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
return None return None
@mypyc_attr(patchable=True)
def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
"""Parse a pyproject toml file, pulling out relevant parts for Black """Parse a pyproject toml file, pulling out relevant parts for Black
If parsing fails, will raise a tomli.TOMLDecodeError If parsing fails, will raise a tomli.TOMLDecodeError
""" """
with open(path_config, encoding="utf8") as f: with open(path_config, encoding="utf8") as f:
pyproject_toml = tomli.load(f) # type: ignore # due to deprecated API usage pyproject_toml = tomli.loads(f.read())
config = pyproject_toml.get("tool", {}).get("black", {}) config = pyproject_toml.get("tool", {}).get("black", {})
return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}

View File

@ -333,7 +333,7 @@ def header(self) -> str:
return f"%%{self.name}" return f"%%{self.name}"
@dataclasses.dataclass # ast.NodeVisitor + dataclass = breakage under mypyc.
class CellMagicFinder(ast.NodeVisitor): class CellMagicFinder(ast.NodeVisitor):
"""Find cell magics. """Find cell magics.
@ -352,7 +352,8 @@ class CellMagicFinder(ast.NodeVisitor):
and we look for instances of the latter. and we look for instances of the latter.
""" """
cell_magic: Optional[CellMagic] = None def __init__(self, cell_magic: Optional[CellMagic] = None) -> None:
self.cell_magic = cell_magic
def visit_Expr(self, node: ast.Expr) -> None: def visit_Expr(self, node: ast.Expr) -> None:
"""Find cell magic, extract header and body.""" """Find cell magic, extract header and body."""
@ -372,7 +373,8 @@ class OffsetAndMagic:
magic: str magic: str
@dataclasses.dataclass # Unsurprisingly, subclassing ast.NodeVisitor means we can't use dataclasses here
# as mypyc will generate broken code.
class MagicFinder(ast.NodeVisitor): class MagicFinder(ast.NodeVisitor):
"""Visit cell to look for get_ipython calls. """Visit cell to look for get_ipython calls.
@ -392,9 +394,8 @@ class MagicFinder(ast.NodeVisitor):
types of magics). types of magics).
""" """
magics: Dict[int, List[OffsetAndMagic]] = dataclasses.field( def __init__(self) -> None:
default_factory=lambda: collections.defaultdict(list) self.magics: Dict[int, List[OffsetAndMagic]] = collections.defaultdict(list)
)
def visit_Assign(self, node: ast.Assign) -> None: def visit_Assign(self, node: ast.Assign) -> None:
"""Look for system assign magics. """Look for system assign magics.

View File

@ -5,8 +5,6 @@
import sys import sys
from typing import Collection, Iterator, List, Optional, Set, Union from typing import Collection, Iterator, List, Optional, Set, Union
from dataclasses import dataclass, field
from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT
from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS
from black.nodes import Visitor, syms, first_child_is_arith, ensure_visible from black.nodes import Visitor, syms, first_child_is_arith, ensure_visible
@ -40,7 +38,8 @@ class CannotSplit(CannotTransform):
"""A readable split that fits the allotted line length is impossible.""" """A readable split that fits the allotted line length is impossible."""
@dataclass # This isn't a dataclass because @dataclass + Generic breaks mypyc.
# See also https://github.com/mypyc/mypyc/issues/827.
class LineGenerator(Visitor[Line]): class LineGenerator(Visitor[Line]):
"""Generates reformatted Line objects. Empty lines are not emitted. """Generates reformatted Line objects. Empty lines are not emitted.
@ -48,9 +47,11 @@ class LineGenerator(Visitor[Line]):
in ways that will no longer stringify to valid Python code on the tree. in ways that will no longer stringify to valid Python code on the tree.
""" """
mode: Mode def __init__(self, mode: Mode, remove_u_prefix: bool = False) -> None:
remove_u_prefix: bool = False self.mode = mode
current_line: Line = field(init=False) self.remove_u_prefix = remove_u_prefix
self.current_line: Line
self.__post_init__()
def line(self, indent: int = 0) -> Iterator[Line]: def line(self, indent: int = 0) -> Iterator[Line]:
"""Generate a line. """Generate a line.
@ -339,7 +340,9 @@ def transform_line(
transformers = [left_hand_split] transformers = [left_hand_split]
else: else:
def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]: def _rhs(
self: object, line: Line, features: Collection[Feature]
) -> Iterator[Line]:
"""Wraps calls to `right_hand_split`. """Wraps calls to `right_hand_split`.
The calls increasingly `omit` right-hand trailers (bracket pairs with The calls increasingly `omit` right-hand trailers (bracket pairs with
@ -366,6 +369,12 @@ def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
line, line_length=mode.line_length, features=features line, line_length=mode.line_length, features=features
) )
# HACK: nested functions (like _rhs) compiled by mypyc don't retain their
# __name__ attribute which is needed in `run_transformer` further down.
# Unfortunately a nested class breaks mypyc too. So a class must be created
# via type ... https://github.com/mypyc/mypyc/issues/884
rhs = type("rhs", (), {"__call__": _rhs})()
if mode.experimental_string_processing: if mode.experimental_string_processing:
if line.inside_brackets: if line.inside_brackets:
transformers = [ transformers = [
@ -980,7 +989,7 @@ def run_transformer(
result.extend(transform_line(transformed_line, mode=mode, features=features)) result.extend(transform_line(transformed_line, mode=mode, features=features))
if ( if (
transform.__name__ != "rhs" transform.__class__.__name__ != "rhs"
or not line.bracket_tracker.invisible or not line.bracket_tracker.invisible
or any(bracket.value for bracket in line.bracket_tracker.invisible) or any(bracket.value for bracket in line.bracket_tracker.invisible)
or line.contains_multiline_strings() or line.contains_multiline_strings()

View File

@ -6,6 +6,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from operator import attrgetter
from typing import Dict, Set from typing import Dict, Set
from black.const import DEFAULT_LINE_LENGTH from black.const import DEFAULT_LINE_LENGTH
@ -134,7 +135,7 @@ def get_cache_key(self) -> str:
if self.target_versions: if self.target_versions:
version_str = ",".join( version_str = ",".join(
str(version.value) str(version.value)
for version in sorted(self.target_versions, key=lambda v: v.value) for version in sorted(self.target_versions, key=attrgetter("value"))
) )
else: else:
version_str = "-" version_str = "-"

View File

@ -15,10 +15,12 @@
Union, Union,
) )
if sys.version_info < (3, 8): if sys.version_info >= (3, 8):
from typing_extensions import Final
else:
from typing import Final from typing import Final
else:
from typing_extensions import Final
from mypy_extensions import mypyc_attr
# lib2to3 fork # lib2to3 fork
from blib2to3.pytree import Node, Leaf, type_repr from blib2to3.pytree import Node, Leaf, type_repr
@ -30,7 +32,7 @@
pygram.initialize(CACHE_DIR) pygram.initialize(CACHE_DIR)
syms = pygram.python_symbols syms: Final = pygram.python_symbols
# types # types
@ -128,16 +130,21 @@
"//=", "//=",
} }
IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist} IMPLICIT_TUPLE: Final = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE} BRACKET: Final = {
OPENING_BRACKETS = set(BRACKET.keys()) token.LPAR: token.RPAR,
CLOSING_BRACKETS = set(BRACKET.values()) token.LSQB: token.RSQB,
BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS token.LBRACE: token.RBRACE,
ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT} }
OPENING_BRACKETS: Final = set(BRACKET.keys())
CLOSING_BRACKETS: Final = set(BRACKET.values())
BRACKETS: Final = OPENING_BRACKETS | CLOSING_BRACKETS
ALWAYS_NO_SPACE: Final = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
RARROW = 55 RARROW = 55
@mypyc_attr(allow_interpreted_subclasses=True)
class Visitor(Generic[T]): class Visitor(Generic[T]):
"""Basic lib2to3 visitor that yields things of type `T` on `visit()`.""" """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
@ -178,9 +185,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa: C901
`complex_subscript` signals whether the given leaf is part of a subscription `complex_subscript` signals whether the given leaf is part of a subscription
which has non-trivial arguments, like arithmetic expressions or function calls. which has non-trivial arguments, like arithmetic expressions or function calls.
""" """
NO = "" NO: Final = ""
SPACE = " " SPACE: Final = " "
DOUBLESPACE = " " DOUBLESPACE: Final = " "
t = leaf.type t = leaf.type
p = leaf.parent p = leaf.parent
v = leaf.value v = leaf.value
@ -441,8 +448,8 @@ def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> b
def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]: def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
"""Return (penultimate, last) leaves skipping brackets in `omit` and contents.""" """Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
stop_after = None stop_after: Optional[Leaf] = None
last = None last: Optional[Leaf] = None
for leaf in reversed(leaves): for leaf in reversed(leaves):
if stop_after: if stop_after:
if leaf is stop_after: if leaf is stop_after:

View File

@ -11,6 +11,7 @@
from click import echo, style from click import echo, style
@mypyc_attr(patchable=True)
def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
if message is not None: if message is not None:
if "bold" not in styles: if "bold" not in styles:
@ -19,6 +20,7 @@ def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
echo(message, nl=nl, err=True) echo(message, nl=nl, err=True)
@mypyc_attr(patchable=True)
def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
if message is not None: if message is not None:
if "fg" not in styles: if "fg" not in styles:
@ -27,6 +29,7 @@ def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
echo(message, nl=nl, err=True) echo(message, nl=nl, err=True)
@mypyc_attr(patchable=True)
def out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: def out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
_out(message, nl=nl, **styles) _out(message, nl=nl, **styles)

View File

@ -4,11 +4,16 @@
import ast import ast
import platform import platform
import sys import sys
from typing import Iterable, Iterator, List, Set, Union, Tuple from typing import Any, Iterable, Iterator, List, Set, Tuple, Type, Union
if sys.version_info < (3, 8):
from typing_extensions import Final
else:
from typing import Final
# lib2to3 fork # lib2to3 fork
from blib2to3.pytree import Node, Leaf from blib2to3.pytree import Node, Leaf
from blib2to3 import pygram, pytree from blib2to3 import pygram
from blib2to3.pgen2 import driver from blib2to3.pgen2 import driver
from blib2to3.pgen2.grammar import Grammar from blib2to3.pgen2.grammar import Grammar
from blib2to3.pgen2.parse import ParseError from blib2to3.pgen2.parse import ParseError
@ -16,6 +21,9 @@
from black.mode import TargetVersion, Feature, supports_feature from black.mode import TargetVersion, Feature, supports_feature
from black.nodes import syms from black.nodes import syms
ast3: Any
ast27: Any
_IS_PYPY = platform.python_implementation() == "PyPy" _IS_PYPY = platform.python_implementation() == "PyPy"
try: try:
@ -86,7 +94,7 @@ def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -
src_txt += "\n" src_txt += "\n"
for grammar in get_grammars(set(target_versions)): for grammar in get_grammars(set(target_versions)):
drv = driver.Driver(grammar, pytree.convert) drv = driver.Driver(grammar)
try: try:
result = drv.parse_string(src_txt, True) result = drv.parse_string(src_txt, True)
break break
@ -148,6 +156,10 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
raise SyntaxError(first_error) raise SyntaxError(first_error)
ast3_AST: Final[Type[ast3.AST]] = ast3.AST
ast27_AST: Final[Type[ast27.AST]] = ast27.AST
def stringify_ast( def stringify_ast(
node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0 node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
) -> Iterator[str]: ) -> Iterator[str]:
@ -189,7 +201,13 @@ def stringify_ast(
elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)): elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
yield from stringify_ast(item, depth + 2) yield from stringify_ast(item, depth + 2)
elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)): # Note that we are referencing the typed-ast ASTs via global variables and not
# direct module attribute accesses because that breaks mypyc. It's probably
# something to do with the ast3 / ast27 variables being marked as Any leading
# mypy to think this branch is always taken, leaving the rest of the code
# unanalyzed. Tighting up the types for the typed-ast AST types avoids the
# mypyc crash.
elif isinstance(value, (ast.AST, ast3_AST, ast27_AST)):
yield from stringify_ast(value, depth + 2) yield from stringify_ast(value, depth + 2)
else: else:

View File

@ -4,10 +4,20 @@
import regex as re import regex as re
import sys import sys
from functools import lru_cache
from typing import List, Pattern from typing import List, Pattern
if sys.version_info < (3, 8):
from typing_extensions import Final
else:
from typing import Final
STRING_PREFIX_CHARS = "furbFURB" # All possible string prefix characters.
STRING_PREFIX_CHARS: Final = "furbFURB" # All possible string prefix characters.
STRING_PREFIX_RE: Final = re.compile(
r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", re.DOTALL
)
FIRST_NON_WHITESPACE_RE: Final = re.compile(r"\s*\t+\s*(\S)")
def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str: def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
@ -37,7 +47,7 @@ def lines_with_leading_tabs_expanded(s: str) -> List[str]:
for line in s.splitlines(): for line in s.splitlines():
# Find the index of the first non-whitespace character after a string of # Find the index of the first non-whitespace character after a string of
# whitespace that includes at least one tab # whitespace that includes at least one tab
match = re.match(r"\s*\t+\s*(\S)", line) match = FIRST_NON_WHITESPACE_RE.match(line)
if match: if match:
first_non_whitespace_idx = match.start(1) first_non_whitespace_idx = match.start(1)
@ -133,7 +143,7 @@ def normalize_string_prefix(s: str, remove_u_prefix: bool = False) -> str:
If remove_u_prefix is given, also removes any u prefix from the string. If remove_u_prefix is given, also removes any u prefix from the string.
""" """
match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", s, re.DOTALL) match = STRING_PREFIX_RE.match(s)
assert match is not None, f"failed to match string {s!r}" assert match is not None, f"failed to match string {s!r}"
orig_prefix = match.group(1) orig_prefix = match.group(1)
new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u") new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
@ -142,6 +152,14 @@ def normalize_string_prefix(s: str, remove_u_prefix: bool = False) -> str:
return f"{new_prefix}{match.group(2)}" return f"{new_prefix}{match.group(2)}"
# Re(gex) does actually cache patterns internally but this still improves
# performance on a long list literal of strings by 5-9% since lru_cache's
# caching overhead is much lower.
@lru_cache(maxsize=64)
def _cached_compile(pattern: str) -> re.Pattern:
return re.compile(pattern)
def normalize_string_quotes(s: str) -> str: def normalize_string_quotes(s: str) -> str:
"""Prefer double quotes but only if it doesn't cause more escaping. """Prefer double quotes but only if it doesn't cause more escaping.
@ -166,9 +184,9 @@ def normalize_string_quotes(s: str) -> str:
return s # There's an internal error return s # There's an internal error
prefix = s[:first_quote_pos] prefix = s[:first_quote_pos]
unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}") unescaped_new_quote = _cached_compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}") escaped_new_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}") escaped_orig_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
body = s[first_quote_pos + len(orig_quote) : -len(orig_quote)] body = s[first_quote_pos + len(orig_quote) : -len(orig_quote)]
if "r" in prefix.casefold(): if "r" in prefix.casefold():
if unescaped_new_quote.search(body): if unescaped_new_quote.search(body):

View File

@ -8,6 +8,7 @@
from typing import ( from typing import (
Any, Any,
Callable, Callable,
ClassVar,
Collection, Collection,
Dict, Dict,
Iterable, Iterable,
@ -20,6 +21,14 @@
TypeVar, TypeVar,
Union, Union,
) )
import sys
if sys.version_info < (3, 8):
from typing_extensions import Final
else:
from typing import Final
from mypy_extensions import trait
from black.rusty import Result, Ok, Err from black.rusty import Result, Ok, Err
@ -62,7 +71,6 @@ def TErr(err_msg: str) -> Err[CannotTransform]:
return Err(cant_transform) return Err(cant_transform)
@dataclass # type: ignore
class StringTransformer(ABC): class StringTransformer(ABC):
""" """
An implementation of the Transformer protocol that relies on its An implementation of the Transformer protocol that relies on its
@ -90,9 +98,13 @@ class StringTransformer(ABC):
as much as possible. as much as possible.
""" """
line_length: int __name__: Final = "StringTransformer"
normalize_strings: bool
__name__ = "StringTransformer" # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with
# `abc.ABC`.
def __init__(self, line_length: int, normalize_strings: bool) -> None:
self.line_length = line_length
self.normalize_strings = normalize_strings
@abstractmethod @abstractmethod
def do_match(self, line: Line) -> TMatchResult: def do_match(self, line: Line) -> TMatchResult:
@ -184,6 +196,7 @@ class CustomSplit:
break_idx: int break_idx: int
@trait
class CustomSplitMapMixin: class CustomSplitMapMixin:
""" """
This mixin class is used to map merged strings to a sequence of This mixin class is used to map merged strings to a sequence of
@ -191,8 +204,10 @@ class CustomSplitMapMixin:
the resultant substrings go over the configured max line length. the resultant substrings go over the configured max line length.
""" """
_Key = Tuple[StringID, str] _Key: ClassVar = Tuple[StringID, str]
_CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple) _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict(
tuple
)
@staticmethod @staticmethod
def _get_key(string: str) -> "CustomSplitMapMixin._Key": def _get_key(string: str) -> "CustomSplitMapMixin._Key":
@ -243,7 +258,7 @@ def has_custom_splits(self, string: str) -> bool:
return key in self._CUSTOM_SPLIT_MAP return key in self._CUSTOM_SPLIT_MAP
class StringMerger(CustomSplitMapMixin, StringTransformer): class StringMerger(StringTransformer, CustomSplitMapMixin):
"""StringTransformer that merges strings together. """StringTransformer that merges strings together.
Requirements: Requirements:
@ -739,7 +754,7 @@ class BaseStringSplitter(StringTransformer):
* The target string is not a multiline (i.e. triple-quote) string. * The target string is not a multiline (i.e. triple-quote) string.
""" """
STRING_OPERATORS = [ STRING_OPERATORS: Final = [
token.EQEQUAL, token.EQEQUAL,
token.GREATER, token.GREATER,
token.GREATEREQUAL, token.GREATEREQUAL,
@ -927,7 +942,7 @@ def _get_max_string_length(self, line: Line, string_idx: int) -> int:
return max_string_length return max_string_length
class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): class StringSplitter(BaseStringSplitter, CustomSplitMapMixin):
""" """
StringTransformer that splits "atom" strings (i.e. strings which exist on StringTransformer that splits "atom" strings (i.e. strings which exist on
lines by themselves). lines by themselves).
@ -965,9 +980,9 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
CustomSplit objects and add them to the custom split map. CustomSplit objects and add them to the custom split map.
""" """
MIN_SUBSTR_SIZE = 6 MIN_SUBSTR_SIZE: Final = 6
# Matches an "f-expression" (e.g. {var}) that might be found in an f-string. # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
RE_FEXPR = r""" RE_FEXPR: Final = r"""
(?<!\{) (?:\{\{)* \{ (?!\{) (?<!\{) (?:\{\{)* \{ (?!\{)
(?: (?:
[^\{\}] [^\{\}]
@ -1426,7 +1441,7 @@ def _get_string_operator_leaves(self, leaves: Iterable[Leaf]) -> List[Leaf]:
return string_op_leaves return string_op_leaves
class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin):
""" """
StringTransformer that splits non-"atom" strings (i.e. strings that do not StringTransformer that splits non-"atom" strings (i.e. strings that do not
exist on lines by themselves). exist on lines by themselves).
@ -1811,20 +1826,20 @@ class StringParser:
``` ```
""" """
DEFAULT_TOKEN = -1 DEFAULT_TOKEN: Final = 20210605
# String Parser States # String Parser States
START = 1 START: Final = 1
DOT = 2 DOT: Final = 2
NAME = 3 NAME: Final = 3
PERCENT = 4 PERCENT: Final = 4
SINGLE_FMT_ARG = 5 SINGLE_FMT_ARG: Final = 5
LPAR = 6 LPAR: Final = 6
RPAR = 7 RPAR: Final = 7
DONE = 8 DONE: Final = 8
# Lookup Table for Next State # Lookup Table for Next State
_goto: Dict[Tuple[ParserState, NodeType], ParserState] = { _goto: Final[Dict[Tuple[ParserState, NodeType], ParserState]] = {
# A string trailer may start with '.' OR '%'. # A string trailer may start with '.' OR '%'.
(START, token.DOT): DOT, (START, token.DOT): DOT,
(START, token.PERCENT): PERCENT, (START, token.PERCENT): PERCENT,

View File

@ -104,13 +104,12 @@ async def async_main(
no_diff, no_diff,
) )
return int(ret_val) return int(ret_val)
finally: finally:
if not keep and work_path.exists(): if not keep and work_path.exists():
LOG.debug(f"Removing {work_path}") LOG.debug(f"Removing {work_path}")
rmtree(work_path, onerror=lib.handle_PermissionError) rmtree(work_path, onerror=lib.handle_PermissionError)
return -2
@click.command(context_settings={"help_option_names": ["-h", "--help"]}) @click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option( @click.option(

View File

@ -19,3 +19,5 @@ Change Log:
https://github.com/python/cpython/commit/cae60187cf7a7b26281d012e1952fafe4e2e97e9 https://github.com/python/cpython/commit/cae60187cf7a7b26281d012e1952fafe4e2e97e9
- "bpo-42316: Allow unparenthesized walrus operator in indexes (GH-23317)" - "bpo-42316: Allow unparenthesized walrus operator in indexes (GH-23317)"
https://github.com/python/cpython/commit/b0aba1fcdc3da952698d99aec2334faa79a8b68c https://github.com/python/cpython/commit/b0aba1fcdc3da952698d99aec2334faa79a8b68c
- Tweaks to help mypyc compile faster code (including inlining type information,
"Final-ing", etc.)

View File

@ -23,6 +23,7 @@
import sys import sys
from typing import ( from typing import (
Any, Any,
cast,
IO, IO,
Iterable, Iterable,
List, List,
@ -34,14 +35,15 @@
Generic, Generic,
Union, Union,
) )
from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
# Pgen imports # Pgen imports
from . import grammar, parse, token, tokenize, pgen from . import grammar, parse, token, tokenize, pgen
from logging import Logger from logging import Logger
from blib2to3.pytree import _Convert, NL from blib2to3.pytree import NL
from blib2to3.pgen2.grammar import Grammar from blib2to3.pgen2.grammar import Grammar
from contextlib import contextmanager from blib2to3.pgen2.tokenize import GoodTokenInfo
Path = Union[str, "os.PathLike[str]"] Path = Union[str, "os.PathLike[str]"]
@ -115,29 +117,23 @@ def can_advance(self, to: int) -> bool:
class Driver(object): class Driver(object):
def __init__( def __init__(self, grammar: Grammar, logger: Optional[Logger] = None) -> None:
self,
grammar: Grammar,
convert: Optional[_Convert] = None,
logger: Optional[Logger] = None,
) -> None:
self.grammar = grammar self.grammar = grammar
if logger is None: if logger is None:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
self.logger = logger self.logger = logger
self.convert = convert
def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL: def parse_tokens(self, tokens: Iterable[GoodTokenInfo], debug: bool = False) -> NL:
"""Parse a series of tokens and return the syntax tree.""" """Parse a series of tokens and return the syntax tree."""
# XXX Move the prefix computation into a wrapper around tokenize. # XXX Move the prefix computation into a wrapper around tokenize.
proxy = TokenProxy(tokens) proxy = TokenProxy(tokens)
p = parse.Parser(self.grammar, self.convert) p = parse.Parser(self.grammar)
p.setup(proxy=proxy) p.setup(proxy=proxy)
lineno = 1 lineno = 1
column = 0 column = 0
indent_columns = [] indent_columns: List[int] = []
type = value = start = end = line_text = None type = value = start = end = line_text = None
prefix = "" prefix = ""
@ -163,6 +159,7 @@ def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL:
if type == token.OP: if type == token.OP:
type = grammar.opmap[value] type = grammar.opmap[value]
if debug: if debug:
assert type is not None
self.logger.debug( self.logger.debug(
"%s %r (prefix=%r)", token.tok_name[type], value, prefix "%s %r (prefix=%r)", token.tok_name[type], value, prefix
) )
@ -174,7 +171,7 @@ def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL:
elif type == token.DEDENT: elif type == token.DEDENT:
_indent_col = indent_columns.pop() _indent_col = indent_columns.pop()
prefix, _prefix = self._partially_consume_prefix(prefix, _indent_col) prefix, _prefix = self._partially_consume_prefix(prefix, _indent_col)
if p.addtoken(type, value, (prefix, start)): if p.addtoken(cast(int, type), value, (prefix, start)):
if debug: if debug:
self.logger.debug("Stop.") self.logger.debug("Stop.")
break break

View File

@ -29,7 +29,7 @@
TYPE_CHECKING, TYPE_CHECKING,
) )
from blib2to3.pgen2.grammar import Grammar from blib2to3.pgen2.grammar import Grammar
from blib2to3.pytree import NL, Context, RawNode, Leaf, Node from blib2to3.pytree import convert, NL, Context, RawNode, Leaf, Node
if TYPE_CHECKING: if TYPE_CHECKING:
from blib2to3.driver import TokenProxy from blib2to3.driver import TokenProxy
@ -70,9 +70,7 @@ def switch_to(self, ilabel: int) -> Iterator[None]:
finally: finally:
self.parser.stack = self._start_point self.parser.stack = self._start_point
def add_token( def add_token(self, tok_type: int, tok_val: Text, raw: bool = False) -> None:
self, tok_type: int, tok_val: Optional[Text], raw: bool = False
) -> None:
func: Callable[..., Any] func: Callable[..., Any]
if raw: if raw:
func = self.parser._addtoken func = self.parser._addtoken
@ -86,9 +84,7 @@ def add_token(
args.insert(0, ilabel) args.insert(0, ilabel)
func(*args) func(*args)
def determine_route( def determine_route(self, value: Text = None, force: bool = False) -> Optional[int]:
self, value: Optional[Text] = None, force: bool = False
) -> Optional[int]:
alive_ilabels = self.ilabels alive_ilabels = self.ilabels
if len(alive_ilabels) == 0: if len(alive_ilabels) == 0:
*_, most_successful_ilabel = self._dead_ilabels *_, most_successful_ilabel = self._dead_ilabels
@ -164,6 +160,11 @@ def __init__(self, grammar: Grammar, convert: Optional[Convert] = None) -> None:
to be converted. The syntax tree is converted from the bottom to be converted. The syntax tree is converted from the bottom
up. up.
**post-note: the convert argument is ignored since for Black's
usage, convert will always be blib2to3.pytree.convert. Allowing
this to be dynamic hurts mypyc's ability to use early binding.
These docs are left for historical and informational value.
A concrete syntax tree node is a (type, value, context, nodes) A concrete syntax tree node is a (type, value, context, nodes)
tuple, where type is the node type (a token or symbol number), tuple, where type is the node type (a token or symbol number),
value is None for symbols and a string for tokens, context is value is None for symbols and a string for tokens, context is
@ -176,6 +177,7 @@ def __init__(self, grammar: Grammar, convert: Optional[Convert] = None) -> None:
""" """
self.grammar = grammar self.grammar = grammar
# See note in docstring above. TL;DR this is ignored.
self.convert = convert or lam_sub self.convert = convert or lam_sub
def setup(self, proxy: "TokenProxy", start: Optional[int] = None) -> None: def setup(self, proxy: "TokenProxy", start: Optional[int] = None) -> None:
@ -203,7 +205,7 @@ def setup(self, proxy: "TokenProxy", start: Optional[int] = None) -> None:
self.used_names: Set[str] = set() self.used_names: Set[str] = set()
self.proxy = proxy self.proxy = proxy
def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool: def addtoken(self, type: int, value: Text, context: Context) -> bool:
"""Add a token; return True iff this is the end of the program.""" """Add a token; return True iff this is the end of the program."""
# Map from token to label # Map from token to label
ilabels = self.classify(type, value, context) ilabels = self.classify(type, value, context)
@ -237,7 +239,7 @@ def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool:
next_token_type, next_token_value, *_ = proxy.eat(counter) next_token_type, next_token_value, *_ = proxy.eat(counter)
if next_token_type == tokenize.OP: if next_token_type == tokenize.OP:
next_token_type = grammar.opmap[cast(str, next_token_value)] next_token_type = grammar.opmap[next_token_value]
recorder.add_token(next_token_type, next_token_value) recorder.add_token(next_token_type, next_token_value)
counter += 1 counter += 1
@ -247,9 +249,7 @@ def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool:
return self._addtoken(ilabel, type, value, context) return self._addtoken(ilabel, type, value, context)
def _addtoken( def _addtoken(self, ilabel: int, type: int, value: Text, context: Context) -> bool:
self, ilabel: int, type: int, value: Optional[Text], context: Context
) -> bool:
# Loop until the token is shifted; may raise exceptions # Loop until the token is shifted; may raise exceptions
while True: while True:
dfa, state, node = self.stack[-1] dfa, state, node = self.stack[-1]
@ -257,10 +257,18 @@ def _addtoken(
arcs = states[state] arcs = states[state]
# Look for a state with this label # Look for a state with this label
for i, newstate in arcs: for i, newstate in arcs:
t, v = self.grammar.labels[i] t = self.grammar.labels[i][0]
if ilabel == i: if t >= 256:
# See if it's a symbol and if we're in its first set
itsdfa = self.grammar.dfas[t]
itsstates, itsfirst = itsdfa
if ilabel in itsfirst:
# Push a symbol
self.push(t, itsdfa, newstate, context)
break # To continue the outer while loop
elif ilabel == i:
# Look it up in the list of labels # Look it up in the list of labels
assert t < 256
# Shift a token; we're done with it # Shift a token; we're done with it
self.shift(type, value, newstate, context) self.shift(type, value, newstate, context)
# Pop while we are in an accept-only state # Pop while we are in an accept-only state
@ -274,14 +282,7 @@ def _addtoken(
states, first = dfa states, first = dfa
# Done with this token # Done with this token
return False return False
elif t >= 256:
# See if it's a symbol and if we're in its first set
itsdfa = self.grammar.dfas[t]
itsstates, itsfirst = itsdfa
if ilabel in itsfirst:
# Push a symbol
self.push(t, self.grammar.dfas[t], newstate, context)
break # To continue the outer while loop
else: else:
if (0, state) in arcs: if (0, state) in arcs:
# An accepting state, pop it and try something else # An accepting state, pop it and try something else
@ -293,14 +294,13 @@ def _addtoken(
# No success finding a transition # No success finding a transition
raise ParseError("bad input", type, value, context) raise ParseError("bad input", type, value, context)
def classify(self, type: int, value: Optional[Text], context: Context) -> List[int]: def classify(self, type: int, value: Text, context: Context) -> List[int]:
"""Turn a token into a label. (Internal) """Turn a token into a label. (Internal)
Depending on whether the value is a soft-keyword or not, Depending on whether the value is a soft-keyword or not,
this function may return multiple labels to choose from.""" this function may return multiple labels to choose from."""
if type == token.NAME: if type == token.NAME:
# Keep a listing of all used names # Keep a listing of all used names
assert value is not None
self.used_names.add(value) self.used_names.add(value)
# Check for reserved words # Check for reserved words
if value in self.grammar.keywords: if value in self.grammar.keywords:
@ -317,18 +317,13 @@ def classify(self, type: int, value: Optional[Text], context: Context) -> List[i
raise ParseError("bad token", type, value, context) raise ParseError("bad token", type, value, context)
return [ilabel] return [ilabel]
def shift( def shift(self, type: int, value: Text, newstate: int, context: Context) -> None:
self, type: int, value: Optional[Text], newstate: int, context: Context
) -> None:
"""Shift a token. (Internal)""" """Shift a token. (Internal)"""
dfa, state, node = self.stack[-1] dfa, state, node = self.stack[-1]
assert value is not None
assert context is not None
rawnode: RawNode = (type, value, context, None) rawnode: RawNode = (type, value, context, None)
newnode = self.convert(self.grammar, rawnode) newnode = convert(self.grammar, rawnode)
if newnode is not None: assert node[-1] is not None
assert node[-1] is not None node[-1].append(newnode)
node[-1].append(newnode)
self.stack[-1] = (dfa, newstate, node) self.stack[-1] = (dfa, newstate, node)
def push(self, type: int, newdfa: DFAS, newstate: int, context: Context) -> None: def push(self, type: int, newdfa: DFAS, newstate: int, context: Context) -> None:
@ -341,12 +336,11 @@ def push(self, type: int, newdfa: DFAS, newstate: int, context: Context) -> None
def pop(self) -> None: def pop(self) -> None:
"""Pop a nonterminal. (Internal)""" """Pop a nonterminal. (Internal)"""
popdfa, popstate, popnode = self.stack.pop() popdfa, popstate, popnode = self.stack.pop()
newnode = self.convert(self.grammar, popnode) newnode = convert(self.grammar, popnode)
if newnode is not None: if self.stack:
if self.stack: dfa, state, node = self.stack[-1]
dfa, state, node = self.stack[-1] assert node[-1] is not None
assert node[-1] is not None node[-1].append(newnode)
node[-1].append(newnode) else:
else: self.rootnode = newnode
self.rootnode = newnode self.rootnode.used_names = self.used_names
self.rootnode.used_names = self.used_names

View File

@ -27,6 +27,7 @@
function to which the 5 fields described above are passed as 5 arguments, function to which the 5 fields described above are passed as 5 arguments,
each time a new token is found.""" each time a new token is found."""
import sys
from typing import ( from typing import (
Callable, Callable,
Iterable, Iterable,
@ -39,6 +40,12 @@
Union, Union,
cast, cast,
) )
if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final
from blib2to3.pgen2.token import * from blib2to3.pgen2.token import *
from blib2to3.pgen2.grammar import Grammar from blib2to3.pgen2.grammar import Grammar
@ -139,7 +146,7 @@ def _combinations(*l):
PseudoExtras = group(r"\\\r?\n", Comment, Triple) PseudoExtras = group(r"\\\r?\n", Comment, Triple)
PseudoToken = Whitespace + group(PseudoExtras, Number, Funny, ContStr, Name) PseudoToken = Whitespace + group(PseudoExtras, Number, Funny, ContStr, Name)
pseudoprog = re.compile(PseudoToken, re.UNICODE) pseudoprog: Final = re.compile(PseudoToken, re.UNICODE)
single3prog = re.compile(Single3) single3prog = re.compile(Single3)
double3prog = re.compile(Double3) double3prog = re.compile(Double3)
@ -149,7 +156,7 @@ def _combinations(*l):
| {"u", "U", "ur", "uR", "Ur", "UR"} | {"u", "U", "ur", "uR", "Ur", "UR"}
) )
endprogs = { endprogs: Final = {
"'": re.compile(Single), "'": re.compile(Single),
'"': re.compile(Double), '"': re.compile(Double),
"'''": single3prog, "'''": single3prog,
@ -159,12 +166,12 @@ def _combinations(*l):
**{prefix: None for prefix in _strprefixes}, **{prefix: None for prefix in _strprefixes},
} }
triple_quoted = ( triple_quoted: Final = (
{"'''", '"""'} {"'''", '"""'}
| {f"{prefix}'''" for prefix in _strprefixes} | {f"{prefix}'''" for prefix in _strprefixes}
| {f'{prefix}"""' for prefix in _strprefixes} | {f'{prefix}"""' for prefix in _strprefixes}
) )
single_quoted = ( single_quoted: Final = (
{"'", '"'} {"'", '"'}
| {f"{prefix}'" for prefix in _strprefixes} | {f"{prefix}'" for prefix in _strprefixes}
| {f'{prefix}"' for prefix in _strprefixes} | {f'{prefix}"' for prefix in _strprefixes}
@ -418,7 +425,7 @@ def generate_tokens(
logical line; continuation lines are included. logical line; continuation lines are included.
""" """
lnum = parenlev = continued = 0 lnum = parenlev = continued = 0
numchars = "0123456789" numchars: Final = "0123456789"
contstr, needcont = "", 0 contstr, needcont = "", 0
contline: Optional[str] = None contline: Optional[str] = None
indents = [0] indents = [0]
@ -427,7 +434,7 @@ def generate_tokens(
# `await` as keywords. # `await` as keywords.
async_keywords = False if grammar is None else grammar.async_keywords async_keywords = False if grammar is None else grammar.async_keywords
# 'stashed' and 'async_*' are used for async/await parsing # 'stashed' and 'async_*' are used for async/await parsing
stashed = None stashed: Optional[GoodTokenInfo] = None
async_def = False async_def = False
async_def_indent = 0 async_def_indent = 0
async_def_nl = False async_def_nl = False
@ -440,7 +447,7 @@ def generate_tokens(
line = readline() line = readline()
except StopIteration: except StopIteration:
line = "" line = ""
lnum = lnum + 1 lnum += 1
pos, max = 0, len(line) pos, max = 0, len(line)
if contstr: # continued string if contstr: # continued string
@ -481,14 +488,14 @@ def generate_tokens(
column = 0 column = 0
while pos < max: # measure leading whitespace while pos < max: # measure leading whitespace
if line[pos] == " ": if line[pos] == " ":
column = column + 1 column += 1
elif line[pos] == "\t": elif line[pos] == "\t":
column = (column // tabsize + 1) * tabsize column = (column // tabsize + 1) * tabsize
elif line[pos] == "\f": elif line[pos] == "\f":
column = 0 column = 0
else: else:
break break
pos = pos + 1 pos += 1
if pos == max: if pos == max:
break break
@ -507,7 +514,7 @@ def generate_tokens(
COMMENT, COMMENT,
comment_token, comment_token,
(lnum, pos), (lnum, pos),
(lnum, pos + len(comment_token)), (lnum, nl_pos),
line, line,
) )
yield (NL, line[nl_pos:], (lnum, nl_pos), (lnum, len(line)), line) yield (NL, line[nl_pos:], (lnum, nl_pos), (lnum, len(line)), line)
@ -652,16 +659,16 @@ def generate_tokens(
continued = 1 continued = 1
else: else:
if initial in "([{": if initial in "([{":
parenlev = parenlev + 1 parenlev += 1
elif initial in ")]}": elif initial in ")]}":
parenlev = parenlev - 1 parenlev -= 1
if stashed: if stashed:
yield stashed yield stashed
stashed = None stashed = None
yield (OP, token, spos, epos, line) yield (OP, token, spos, epos, line)
else: else:
yield (ERRORTOKEN, line[pos], (lnum, pos), (lnum, pos + 1), line) yield (ERRORTOKEN, line[pos], (lnum, pos), (lnum, pos + 1), line)
pos = pos + 1 pos += 1
if stashed: if stashed:
yield stashed yield stashed

View File

@ -14,7 +14,6 @@
from typing import ( from typing import (
Any, Any,
Callable,
Dict, Dict,
Iterator, Iterator,
List, List,
@ -92,8 +91,6 @@ def __eq__(self, other: Any) -> bool:
return NotImplemented return NotImplemented
return self._eq(other) return self._eq(other)
__hash__ = None # type: Any # For Py3 compatibility.
@property @property
def prefix(self) -> Text: def prefix(self) -> Text:
raise NotImplementedError raise NotImplementedError
@ -437,7 +434,7 @@ def __str__(self) -> Text:
This reproduces the input source exactly. This reproduces the input source exactly.
""" """
return self.prefix + str(self.value) return self._prefix + str(self.value)
def _eq(self, other) -> bool: def _eq(self, other) -> bool:
"""Compare two nodes for equality.""" """Compare two nodes for equality."""
@ -672,8 +669,11 @@ def __init__(
newcontent = list(content) newcontent = list(content)
for i, item in enumerate(newcontent): for i, item in enumerate(newcontent):
assert isinstance(item, BasePattern), (i, item) assert isinstance(item, BasePattern), (i, item)
if isinstance(item, WildcardPattern): # I don't even think this code is used anywhere, but it does cause
self.wildcards = True # unreachable errors from mypy. This function's signature does look
# odd though *shrug*.
if isinstance(item, WildcardPattern): # type: ignore[unreachable]
self.wildcards = True # type: ignore[unreachable]
self.type = type self.type = type
self.content = newcontent self.content = newcontent
self.name = name self.name = name
@ -978,6 +978,3 @@ def generate_matches(
r.update(r0) r.update(r0)
r.update(r1) r.update(r1)
yield c0 + c1, r yield c0 + c1, r
_Convert = Callable[[Grammar, RawNode], Any]

View File

@ -122,7 +122,7 @@ def invokeBlack(
runner = BlackRunner() runner = BlackRunner()
if ignore_config: if ignore_config:
args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args] args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
result = runner.invoke(black.main, args) result = runner.invoke(black.main, args, catch_exceptions=False)
assert result.stdout_bytes is not None assert result.stdout_bytes is not None
assert result.stderr_bytes is not None assert result.stderr_bytes is not None
msg = ( msg = (
@ -841,6 +841,7 @@ def test_get_future_imports(self) -> None:
) )
self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node)) self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
@pytest.mark.incompatible_with_mypyc
def test_debug_visitor(self) -> None: def test_debug_visitor(self) -> None:
source, _ = read_data("debug_visitor.py") source, _ = read_data("debug_visitor.py")
expected, _ = read_data("debug_visitor.out") expected, _ = read_data("debug_visitor.out")
@ -891,6 +892,7 @@ def test_endmarker(self) -> None:
self.assertEqual(len(n.children), 1) self.assertEqual(len(n.children), 1)
self.assertEqual(n.children[0].type, black.token.ENDMARKER) self.assertEqual(n.children[0].type, black.token.ENDMARKER)
@pytest.mark.incompatible_with_mypyc
@unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT") @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
def test_assertFormatEqual(self) -> None: def test_assertFormatEqual(self) -> None:
out_lines = [] out_lines = []
@ -1055,6 +1057,7 @@ def test_pipe_force_py36(self) -> None:
actual = result.output actual = result.output
self.assertFormatEqual(actual, expected) self.assertFormatEqual(actual, expected)
@pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin(self) -> None: def test_reformat_one_with_stdin(self) -> None:
with patch( with patch(
"black.format_stdin_to_stdout", "black.format_stdin_to_stdout",
@ -1072,6 +1075,7 @@ def test_reformat_one_with_stdin(self) -> None:
fsts.assert_called_once() fsts.assert_called_once()
report.done.assert_called_with(path, black.Changed.YES) report.done.assert_called_with(path, black.Changed.YES)
@pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_filename(self) -> None: def test_reformat_one_with_stdin_filename(self) -> None:
with patch( with patch(
"black.format_stdin_to_stdout", "black.format_stdin_to_stdout",
@ -1094,6 +1098,7 @@ def test_reformat_one_with_stdin_filename(self) -> None:
# __BLACK_STDIN_FILENAME__ should have been stripped # __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES) report.done.assert_called_with(expected, black.Changed.YES)
@pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_filename_pyi(self) -> None: def test_reformat_one_with_stdin_filename_pyi(self) -> None:
with patch( with patch(
"black.format_stdin_to_stdout", "black.format_stdin_to_stdout",
@ -1118,6 +1123,7 @@ def test_reformat_one_with_stdin_filename_pyi(self) -> None:
# __BLACK_STDIN_FILENAME__ should have been stripped # __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES) report.done.assert_called_with(expected, black.Changed.YES)
@pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_filename_ipynb(self) -> None: def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
with patch( with patch(
"black.format_stdin_to_stdout", "black.format_stdin_to_stdout",
@ -1142,6 +1148,7 @@ def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
# __BLACK_STDIN_FILENAME__ should have been stripped # __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES) report.done.assert_called_with(expected, black.Changed.YES)
@pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_and_existing_path(self) -> None: def test_reformat_one_with_stdin_and_existing_path(self) -> None:
with patch( with patch(
"black.format_stdin_to_stdout", "black.format_stdin_to_stdout",
@ -1296,6 +1303,7 @@ def test_read_pyproject_toml(self) -> None:
self.assertEqual(config["exclude"], r"\.pyi?$") self.assertEqual(config["exclude"], r"\.pyi?$")
self.assertEqual(config["include"], r"\.py?$") self.assertEqual(config["include"], r"\.py?$")
@pytest.mark.incompatible_with_mypyc
def test_find_project_root(self) -> None: def test_find_project_root(self) -> None:
with TemporaryDirectory() as workspace: with TemporaryDirectory() as workspace:
root = Path(workspace) root = Path(workspace)
@ -1483,6 +1491,7 @@ def test_code_option_color_diff(self) -> None:
assert output == result_diff, "The output did not match the expected value." assert output == result_diff, "The output did not match the expected value."
assert result.exit_code == 0, "The exit code is incorrect." assert result.exit_code == 0, "The exit code is incorrect."
@pytest.mark.incompatible_with_mypyc
def test_code_option_safe(self) -> None: def test_code_option_safe(self) -> None:
"""Test that the code option throws an error when the sanity checks fail.""" """Test that the code option throws an error when the sanity checks fail."""
# Patch black.assert_equivalent to ensure the sanity checks fail # Patch black.assert_equivalent to ensure the sanity checks fail
@ -1507,6 +1516,7 @@ def test_code_option_fast(self) -> None:
self.compare_results(result, formatted, 0) self.compare_results(result, formatted, 0)
@pytest.mark.incompatible_with_mypyc
def test_code_option_config(self) -> None: def test_code_option_config(self) -> None:
""" """
Test that the code option finds the pyproject.toml in the current directory. Test that the code option finds the pyproject.toml in the current directory.
@ -1527,6 +1537,7 @@ def test_code_option_config(self) -> None:
call_args[0].lower() == str(pyproject_path).lower() call_args[0].lower() == str(pyproject_path).lower()
), "Incorrect config loaded." ), "Incorrect config loaded."
@pytest.mark.incompatible_with_mypyc
def test_code_option_parent_config(self) -> None: def test_code_option_parent_config(self) -> None:
""" """
Test that the code option finds the pyproject.toml in the parent directory. Test that the code option finds the pyproject.toml in the parent directory.
@ -1894,6 +1905,7 @@ def test_extend_exclude(self) -> None:
src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude" src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
) )
@pytest.mark.incompatible_with_mypyc
def test_symlink_out_of_root_directory(self) -> None: def test_symlink_out_of_root_directory(self) -> None:
path = MagicMock() path = MagicMock()
root = THIS_DIR.resolve() root = THIS_DIR.resolve()
@ -2047,8 +2059,12 @@ def test_python_2_deprecation_autodetection_extended() -> None:
}, non_python2_case }, non_python2_case
with open(black.__file__, "r", encoding="utf-8") as _bf: try:
black_source_lines = _bf.readlines() with open(black.__file__, "r", encoding="utf-8") as _bf:
black_source_lines = _bf.readlines()
except UnicodeDecodeError:
if not black.COMPILED:
raise
def tracefunc( def tracefunc(