Add --target-version option to allow users to choose targeted Python versions (#618)

This commit is contained in:
Jelle Zijlstra 2019-02-06 18:43:50 -08:00 committed by GitHub
parent a9d8af466a
commit 36d3c516d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 545 additions and 387 deletions

100
README.md
View File

@ -71,46 +71,60 @@ black {source_file_or_directory}
black [OPTIONS] [SRC]...
Options:
-l, --line-length INTEGER Where to wrap around. [default: 88]
--py36 Allow using Python 3.6-only syntax on all input
files. This will put trailing commas in function
signatures and calls also after *args and
**kwargs. [default: per-file auto-detection]
--pyi Format all input files like typing stubs
regardless of file extension (useful when piping
source on standard input).
-l, --line-length INTEGER How many characters per line to allow.
[default: 88]
-t, --target-version [pypy35|cpy27|cpy33|cpy34|cpy35|cpy36|cpy37|cpy38]
Python versions that should be supported by
Black's output. [default: per-file auto-
detection]
--py36 Allow using Python 3.6-only syntax on all
input files. This will put trailing commas
in function signatures and calls also after
*args and **kwargs. [default: per-file
auto-detection]
--pyi Format all input files like typing stubs
regardless of file extension (useful when
piping source on standard input).
-S, --skip-string-normalization
Don't normalize string quotes or prefixes.
Don't normalize string quotes or prefixes.
-N, --skip-numeric-underscore-normalization
Don't normalize underscores in numeric literals.
--check Don't write the files back, just return the
status. Return code 0 means nothing would
change. Return code 1 means some files would be
reformatted. Return code 123 means there was an
internal error.
--diff Don't write the files back, just output a diff
for each file on stdout.
--fast / --safe If --fast given, skip temporary sanity checks.
[default: --safe]
--include TEXT A regular expression that matches files and
directories that should be included on
recursive searches. On Windows, use forward
slashes for directories. [default: \.pyi?$]
--exclude TEXT A regular expression that matches files and
directories that should be excluded on
recursive searches. On Windows, use forward
slashes for directories. [default:
build/|buck-out/|dist/|_build/|\.eggs/|\.git/|
\.hg/|\.mypy_cache/|\.nox/|\.tox/|\.venv/]
-q, --quiet Don't emit non-error messages to stderr. Errors
are still emitted, silence those with
2>/dev/null.
-v, --verbose Also emit messages to stderr about files
that were not changed or were ignored due to
--exclude=.
--version Show the version and exit.
--config PATH Read configuration from PATH.
--help Show this message and exit.
Don't normalize underscores in numeric
literals.
--check Don't write the files back, just return the
status. Return code 0 means nothing would
change. Return code 1 means some files
would be reformatted. Return code 123 means
there was an internal error.
--diff Don't write the files back, just output a
diff for each file on stdout.
--fast / --safe If --fast given, skip temporary sanity
checks. [default: --safe]
--include TEXT A regular expression that matches files and
directories that should be included on
recursive searches. An empty value means
all files are included regardless of the
name. Use forward slashes for directories
on all platforms (Windows, too). Exclusions
are calculated first, inclusions later.
[default: \.pyi?$]
--exclude TEXT A regular expression that matches files and
directories that should be excluded on
recursive searches. An empty value means no
paths are excluded. Use forward slashes for
directories on all platforms (Windows, too).
Exclusions are calculated first, inclusions
later. [default: /(\.eggs|\.git|\.hg|\.mypy
_cache|\.nox|\.tox|\.venv|_build|buck-
out|build|dist)/]
-q, --quiet Don't emit non-error messages to stderr.
Errors are still emitted, silence those with
2>/dev/null.
-v, --verbose Also emit messages to stderr about files
that were not changed or were ignored due to
--exclude=.
--version Show the version and exit.
--config PATH Read configuration from PATH.
-h, --help Show this message and exit.
```
*Black* is a well-behaved Unix-style command-line tool:
@ -815,8 +829,9 @@ The headers controlling how code is formatted are:
passed the `--fast` command line flag.
- `X-Python-Variant`: if set to `pyi`, `blackd` will act as *Black* does when
passed the `--pyi` command line flag. Otherwise, its value must correspond to
a Python version. If this value represents at least Python 3.6, `blackd` will
act as *Black* does when passed the `--py36` command line flag.
a Python version or a set of comma-separated Python versions, optionally
prefixed with `cpy` or `pypy`. For example, to request code that is compatible
with PyPy 3.5 and CPython 3.5, set the header to `pypy3.5,cpy3.5`.
If any of these headers are set to invalid values, `blackd` returns a `HTTP 400`
error response, mentioning the name of the problematic header in the message body.
@ -935,6 +950,11 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md).
## Change Log
### 18.11b0
* new option `--target-version` to control which Python versions
*Black*-formatted code should target
### 18.9b0
* numeric literals are now formatted by *Black* (#452, #461, #464, #469):

361
black.py
View File

@ -2,7 +2,7 @@
from asyncio.base_events import BaseEventLoop
from concurrent.futures import Executor, ProcessPoolExecutor
from datetime import datetime
from enum import Enum, Flag
from enum import Enum
from functools import lru_cache, partial, wraps
import io
import itertools
@ -37,7 +37,7 @@
)
from appdirs import user_cache_dir
from attr import dataclass, Factory
from attr import dataclass, evolve, Factory
import click
import toml
@ -45,6 +45,7 @@
from blib2to3.pytree import Node, Leaf, type_repr
from blib2to3 import pygram, pytree
from blib2to3.pgen2 import driver, token
from blib2to3.pgen2.grammar import Grammar
from blib2to3.pgen2.parse import ParseError
@ -111,32 +112,86 @@ class Changed(Enum):
YES = 2
class FileMode(Flag):
AUTO_DETECT = 0
PYTHON36 = 1
PYI = 2
NO_STRING_NORMALIZATION = 4
NO_NUMERIC_UNDERSCORE_NORMALIZATION = 8
class TargetVersion(Enum):
PYPY35 = 1
CPY27 = 2
CPY33 = 3
CPY34 = 4
CPY35 = 5
CPY36 = 6
CPY37 = 7
CPY38 = 8
@classmethod
def from_configuration(
cls,
*,
py36: bool,
pyi: bool,
skip_string_normalization: bool,
skip_numeric_underscore_normalization: bool,
) -> "FileMode":
mode = cls.AUTO_DETECT
if py36:
mode |= cls.PYTHON36
if pyi:
mode |= cls.PYI
if skip_string_normalization:
mode |= cls.NO_STRING_NORMALIZATION
if skip_numeric_underscore_normalization:
mode |= cls.NO_NUMERIC_UNDERSCORE_NORMALIZATION
return mode
def is_python2(self) -> bool:
return self is TargetVersion.CPY27
PY36_VERSIONS = {TargetVersion.CPY36, TargetVersion.CPY37, TargetVersion.CPY38}
class Feature(Enum):
# All string literals are unicode
UNICODE_LITERALS = 1
F_STRINGS = 2
NUMERIC_UNDERSCORES = 3
TRAILING_COMMA = 4
VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
TargetVersion.CPY27: set(),
TargetVersion.PYPY35: {Feature.UNICODE_LITERALS, Feature.F_STRINGS},
TargetVersion.CPY33: {Feature.UNICODE_LITERALS},
TargetVersion.CPY34: {Feature.UNICODE_LITERALS},
TargetVersion.CPY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
TargetVersion.CPY36: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
},
TargetVersion.CPY37: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
},
TargetVersion.CPY38: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
},
}
@dataclass
class FileMode:
target_versions: Set[TargetVersion] = Factory(set)
line_length: int = DEFAULT_LINE_LENGTH
numeric_underscore_normalization: bool = True
string_normalization: bool = True
is_pyi: bool = False
def get_cache_key(self) -> str:
if self.target_versions:
version_str = ",".join(
str(version.value)
for version in sorted(self.target_versions, key=lambda v: v.value)
)
else:
version_str = "-"
parts = [
version_str,
str(self.line_length),
str(int(self.numeric_underscore_normalization)),
str(int(self.string_normalization)),
str(int(self.is_pyi)),
]
return ".".join(parts)
def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
def read_pyproject_toml(
@ -184,6 +239,17 @@ def read_pyproject_toml(
help="How many characters per line to allow.",
show_default=True,
)
@click.option(
"-t",
"--target-version",
type=click.Choice([v.name.lower() for v in TargetVersion]),
callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
multiple=True,
help=(
"Python versions that should be supported by Black's output. [default: "
"per-file auto-detection]"
),
)
@click.option(
"--py36",
is_flag=True,
@ -297,6 +363,7 @@ def read_pyproject_toml(
def main(
ctx: click.Context,
line_length: int,
target_version: List[TargetVersion],
check: bool,
diff: bool,
fast: bool,
@ -313,11 +380,23 @@ def main(
) -> None:
"""The uncompromising code formatter."""
write_back = WriteBack.from_configuration(check=check, diff=diff)
mode = FileMode.from_configuration(
py36=py36,
pyi=pyi,
skip_string_normalization=skip_string_normalization,
skip_numeric_underscore_normalization=skip_numeric_underscore_normalization,
if target_version:
if py36:
err(f"Cannot use both --target-version and --py36")
ctx.exit(2)
else:
versions = set(target_version)
elif py36:
versions = PY36_VERSIONS
else:
# We'll autodetect later.
versions = set()
mode = FileMode(
target_versions=versions,
line_length=line_length,
is_pyi=pyi,
string_normalization=not skip_string_normalization,
numeric_underscore_normalization=not skip_numeric_underscore_normalization,
)
if config and verbose:
out(f"Using configuration from {config}.", bold=False, fg="blue")
@ -353,7 +432,6 @@ def main(
if len(sources) == 1:
reformat_one(
src=sources.pop(),
line_length=line_length,
fast=fast,
write_back=write_back,
mode=mode,
@ -366,7 +444,6 @@ def main(
loop.run_until_complete(
schedule_formatting(
sources=sources,
line_length=line_length,
fast=fast,
write_back=write_back,
mode=mode,
@ -385,12 +462,7 @@ def main(
def reformat_one(
src: Path,
line_length: int,
fast: bool,
write_back: WriteBack,
mode: FileMode,
report: "Report",
src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
) -> None:
"""Reformat a single file under `src` without spawning child processes.
@ -401,29 +473,23 @@ def reformat_one(
try:
changed = Changed.NO
if not src.is_file() and str(src) == "-":
if format_stdin_to_stdout(
line_length=line_length, fast=fast, write_back=write_back, mode=mode
):
if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
changed = Changed.YES
else:
cache: Cache = {}
if write_back != WriteBack.DIFF:
cache = read_cache(line_length, mode)
cache = read_cache(mode)
res_src = src.resolve()
if res_src in cache and cache[res_src] == get_cache_info(res_src):
changed = Changed.CACHED
if changed is not Changed.CACHED and format_file_in_place(
src,
line_length=line_length,
fast=fast,
write_back=write_back,
mode=mode,
src, fast=fast, write_back=write_back, mode=mode
):
changed = Changed.YES
if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
write_back is WriteBack.CHECK and changed is Changed.NO
):
write_cache(cache, [src], line_length, mode)
write_cache(cache, [src], mode)
report.done(src, changed)
except Exception as exc:
report.failed(src, str(exc))
@ -431,7 +497,6 @@ def reformat_one(
async def schedule_formatting(
sources: Set[Path],
line_length: int,
fast: bool,
write_back: WriteBack,
mode: FileMode,
@ -448,7 +513,7 @@ async def schedule_formatting(
"""
cache: Cache = {}
if write_back != WriteBack.DIFF:
cache = read_cache(line_length, mode)
cache = read_cache(mode)
sources, cached = filter_cached(cache, sources)
for src in sorted(cached):
report.done(src, Changed.CACHED)
@ -465,14 +530,7 @@ async def schedule_formatting(
lock = manager.Lock()
tasks = {
loop.run_in_executor(
executor,
format_file_in_place,
src,
line_length,
fast,
write_back,
mode,
lock,
executor, format_file_in_place, src, fast, mode, write_back, lock
): src
for src in sorted(sources)
}
@ -503,15 +561,14 @@ async def schedule_formatting(
if cancelled:
await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
if sources_to_cache:
write_cache(cache, sources_to_cache, line_length, mode)
write_cache(cache, sources_to_cache, mode)
def format_file_in_place(
src: Path,
line_length: int,
fast: bool,
mode: FileMode,
write_back: WriteBack = WriteBack.NO,
mode: FileMode = FileMode.AUTO_DETECT,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
) -> bool:
"""Format file under `src` path. Return True if changed.
@ -521,15 +578,13 @@ def format_file_in_place(
`line_length` and `fast` options are passed to :func:`format_file_contents`.
"""
if src.suffix == ".pyi":
mode |= FileMode.PYI
mode = evolve(mode, is_pyi=True)
then = datetime.utcfromtimestamp(src.stat().st_mtime)
with open(src, "rb") as buf:
src_contents, encoding, newline = decode_bytes(buf.read())
try:
dst_contents = format_file_contents(
src_contents, line_length=line_length, fast=fast, mode=mode
)
dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
except NothingChanged:
return False
@ -559,23 +614,19 @@ def format_file_in_place(
def format_stdin_to_stdout(
line_length: int,
fast: bool,
write_back: WriteBack = WriteBack.NO,
mode: FileMode = FileMode.AUTO_DETECT,
fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
) -> bool:
"""Format file on stdin. Return True if changed.
If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
write a diff to stdout.
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
write a diff to stdout. The `mode` argument is passed to
:func:`format_file_contents`.
"""
then = datetime.utcnow()
src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
dst = src
try:
dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
dst = format_file_contents(src, fast=fast, mode=mode)
return True
except NothingChanged:
@ -596,11 +647,7 @@ def format_stdin_to_stdout(
def format_file_contents(
src_contents: str,
*,
line_length: int,
fast: bool,
mode: FileMode = FileMode.AUTO_DETECT,
src_contents: str, *, fast: bool, mode: FileMode
) -> FileContent:
"""Reformat contents a file and return new contents.
@ -611,38 +658,38 @@ def format_file_contents(
if src_contents.strip() == "":
raise NothingChanged
dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
dst_contents = format_str(src_contents, mode=mode)
if src_contents == dst_contents:
raise NothingChanged
if not fast:
assert_equivalent(src_contents, dst_contents)
assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
assert_stable(src_contents, dst_contents, mode=mode)
return dst_contents
def format_str(
src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
) -> FileContent:
def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
"""Reformat a string and return new contents.
`line_length` determines how many characters per line are allowed.
"""
src_node = lib2to3_parse(src_contents.lstrip())
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_contents = ""
future_imports = get_future_imports(src_node)
is_pyi = bool(mode & FileMode.PYI)
py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
if mode.target_versions:
versions = mode.target_versions
else:
versions = detect_target_versions(src_node)
normalize_fmt_off(src_node)
lines = LineGenerator(
remove_u_prefix=py36 or "unicode_literals" in future_imports,
is_pyi=is_pyi,
normalize_strings=normalize_strings,
allow_underscores=py36
and not bool(mode & FileMode.NO_NUMERIC_UNDERSCORE_NORMALIZATION),
remove_u_prefix="unicode_literals" in future_imports
or supports_feature(versions, Feature.UNICODE_LITERALS),
is_pyi=mode.is_pyi,
normalize_strings=mode.string_normalization,
allow_underscores=mode.numeric_underscore_normalization
and supports_feature(versions, Feature.NUMERIC_UNDERSCORES),
)
elt = EmptyLineTracker(is_pyi=is_pyi)
elt = EmptyLineTracker(is_pyi=mode.is_pyi)
empty_line = Line()
after = 0
for current_line in lines.visit(src_node):
@ -651,7 +698,11 @@ def format_str(
before, after = elt.maybe_empty_lines(current_line)
for _ in range(before):
dst_contents += str(empty_line)
for line in split_line(current_line, line_length=line_length, py36=py36):
for line in split_line(
current_line,
line_length=mode.line_length,
supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
):
dst_contents += str(line)
return dst_contents
@ -680,11 +731,25 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
]
def lib2to3_parse(src_txt: str) -> Node:
def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
if not target_versions:
return GRAMMARS
elif all(not version.is_python2() for version in target_versions):
# Python 2-compatible code, so don't try Python 3 grammar.
return [
pygram.python_grammar_no_print_statement_no_exec_statement,
pygram.python_grammar_no_print_statement,
]
else:
return [pygram.python_grammar]
def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
"""Given a string with source, return the lib2to3 Node."""
if src_txt[-1:] != "\n":
src_txt += "\n"
for grammar in GRAMMARS:
for grammar in get_grammars(set(target_versions)):
drv = driver.Driver(grammar, pytree.convert)
try:
result = drv.parse_string(src_txt, True)
@ -2093,7 +2158,10 @@ def make_comment(content: str) -> str:
def split_line(
line: Line, line_length: int, inner: bool = False, py36: bool = False
line: Line,
line_length: int,
inner: bool = False,
supports_trailing_commas: bool = False,
) -> Iterator[Line]:
"""Split a `line` into potentially many lines.
@ -2102,8 +2170,7 @@ def split_line(
current `line`, possibly transitively. This means we can fallback to splitting
by delimiters if the LHS/RHS don't yield any results.
If `py36` is True, splitting may generate syntax that is only compatible
with Python 3.6 and later.
If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
"""
if line.is_comment:
yield line
@ -2132,9 +2199,13 @@ def split_line(
split_funcs = [left_hand_split]
else:
def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
for omit in generate_trailers_to_omit(line, line_length):
lines = list(right_hand_split(line, line_length, py36, omit=omit))
lines = list(
right_hand_split(
line, line_length, supports_trailing_commas, omit=omit
)
)
if is_line_short_enough(lines[0], line_length=line_length):
yield from lines
return
@ -2142,7 +2213,7 @@ def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
# All splits failed, best effort split with no omits.
# This mostly happens to multiline strings that are by definition
# reported as not fitting a single line.
yield from right_hand_split(line, py36)
yield from right_hand_split(line, supports_trailing_commas)
if line.inside_brackets:
split_funcs = [delimiter_split, standalone_comment_split, rhs]
@ -2154,12 +2225,17 @@ def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
# split altogether.
result: List[Line] = []
try:
for l in split_func(line, py36):
for l in split_func(line, supports_trailing_commas):
if str(l).strip("\n") == line_str:
raise CannotSplit("Split function returned an unchanged result")
result.extend(
split_line(l, line_length=line_length, inner=True, py36=py36)
split_line(
l,
line_length=line_length,
inner=True,
supports_trailing_commas=supports_trailing_commas,
)
)
except CannotSplit:
continue
@ -2172,7 +2248,9 @@ def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
yield line
def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
def left_hand_split(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
"""Split line into many lines, starting with the first matching bracket pair.
Note: this usually looks weird, only use this for function definitions.
@ -2209,7 +2287,10 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
def right_hand_split(
line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
line: Line,
line_length: int,
supports_trailing_commas: bool = False,
omit: Collection[LeafID] = (),
) -> Iterator[Line]:
"""Split line into many lines, starting with the last matching bracket pair.
@ -2267,7 +2348,12 @@ def right_hand_split(
):
omit = {id(closing_bracket), *omit}
try:
yield from right_hand_split(line, line_length, py36=py36, omit=omit)
yield from right_hand_split(
line,
line_length,
supports_trailing_commas=supports_trailing_commas,
omit=omit,
)
return
except CannotSplit:
@ -2356,8 +2442,10 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
"""
@wraps(split_func)
def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
for l in split_func(line, py36):
def split_wrapper(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
for l in split_func(line, supports_trailing_commas):
normalize_prefix(l.leaves[0], inside_brackets=True)
yield l
@ -2365,7 +2453,9 @@ def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
@dont_increase_indentation
def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
def delimiter_split(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
"""Split according to delimiters of the highest priority.
If `py36` is True, the split will add trailing commas also in function
@ -2411,7 +2501,7 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:
if leaf.bracket_depth == lowest_depth and is_vararg(
leaf, within=VARARGS_PARENTS
):
trailing_comma_safe = trailing_comma_safe and py36
trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
leaf_priority = bt.delimiters.get(id(leaf))
if leaf_priority == delimiter_priority:
yield current_line
@ -2429,7 +2519,9 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:
@dont_increase_indentation
def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
def standalone_comment_split(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
"""Split standalone comments from the rest of the line."""
if not line.contains_standalone_comments(0):
raise CannotSplit("Line does not have any standalone comments")
@ -2988,23 +3080,24 @@ def should_explode(line: Line, opening_bracket: Leaf) -> bool:
return max_priority == COMMA_PRIORITY
def is_python36(node: Node) -> bool:
"""Return True if the current file is using Python 3.6+ features.
def get_features_used(node: Node) -> Set[Feature]:
"""Return a set of (relatively) new Python features used in this file.
Currently looking for:
- f-strings;
- underscores in numeric literals; and
- trailing commas after * or ** in function signatures and calls.
"""
features: Set[Feature] = set()
for n in node.pre_order():
if n.type == token.STRING:
value_head = n.value[:2] # type: ignore
if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
return True
features.add(Feature.F_STRINGS)
elif n.type == token.NUMBER:
if "_" in n.value: # type: ignore
return True
features.add(Feature.NUMERIC_UNDERSCORES)
elif (
n.type in {syms.typedargslist, syms.arglist}
@ -3013,14 +3106,22 @@ def is_python36(node: Node) -> bool:
):
for ch in n.children:
if ch.type in STARS:
return True
features.add(Feature.TRAILING_COMMA)
if ch.type == syms.argument:
for argch in ch.children:
if argch.type in STARS:
return True
features.add(Feature.TRAILING_COMMA)
return False
return features
def detect_target_versions(node: Node) -> Set[TargetVersion]:
"""Detect the version to target based on the nodes used."""
features = get_features_used(node)
return {
version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
}
def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
@ -3337,11 +3438,9 @@ def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
) from None
def assert_stable(
src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
) -> None:
def assert_stable(src: str, dst: str, mode: FileMode) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
newdst = format_str(dst, line_length=line_length, mode=mode)
newdst = format_str(dst, mode=mode)
if dst != newdst:
log = dump_to_file(
diff(src, dst, "source", "first pass"),
@ -3598,16 +3697,16 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
return False
def get_cache_file(line_length: int, mode: FileMode) -> Path:
return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
def get_cache_file(mode: FileMode) -> Path:
return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
def read_cache(line_length: int, mode: FileMode) -> Cache:
def read_cache(mode: FileMode) -> Cache:
"""Read the cache if it exists and is well formed.
If it is not well formed, the call to write_cache later should resolve the issue.
"""
cache_file = get_cache_file(line_length, mode)
cache_file = get_cache_file(mode)
if not cache_file.exists():
return {}
@ -3642,11 +3741,9 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set
return todo, done
def write_cache(
cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
) -> None:
def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
"""Update the cache file."""
cache_file = get_cache_file(line_length, mode)
cache_file = get_cache_file(mode)
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}

View File

@ -3,6 +3,7 @@
from functools import partial
import logging
from multiprocessing import freeze_support
from typing import Set, Tuple
from aiohttp import web
import aiohttp_cors
@ -29,6 +30,10 @@
]
class InvalidVariantHeader(Exception):
pass
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option(
"--bind-host", type=str, help="Address to bind the server to.", default="localhost"
@ -73,22 +78,20 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
)
except ValueError:
return web.Response(status=400, text="Invalid line length header value")
py36 = False
pyi = False
if PYTHON_VARIANT_HEADER in request.headers:
value = request.headers[PYTHON_VARIANT_HEADER]
if value == "pyi":
pyi = True
else:
try:
major, *rest = value.split(".")
if int(major) == 3 and len(rest) > 0:
if int(rest[0]) >= 6:
py36 = True
except ValueError:
return web.Response(
status=400, text=f"Invalid value for {PYTHON_VARIANT_HEADER}"
)
try:
pyi, versions = parse_python_variant_header(value)
except InvalidVariantHeader as e:
return web.Response(
status=400,
text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
)
else:
pyi = False
versions = set()
skip_string_normalization = bool(
request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
)
@ -98,25 +101,19 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
fast = False
if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
fast = True
mode = black.FileMode.from_configuration(
py36=py36,
pyi=pyi,
skip_string_normalization=skip_string_normalization,
skip_numeric_underscore_normalization=skip_numeric_underscore_normalization,
mode = black.FileMode(
target_versions=versions,
is_pyi=pyi,
line_length=line_length,
string_normalization=not skip_string_normalization,
numeric_underscore_normalization=not skip_numeric_underscore_normalization,
)
req_bytes = await request.content.read()
charset = request.charset if request.charset is not None else "utf8"
req_str = req_bytes.decode(charset)
loop = asyncio.get_event_loop()
formatted_str = await loop.run_in_executor(
executor,
partial(
black.format_file_contents,
req_str,
line_length=line_length,
fast=fast,
mode=mode,
),
executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
)
return web.Response(
content_type=request.content_type, charset=charset, text=formatted_str
@ -130,6 +127,45 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
return web.Response(status=500, text=str(e))
def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
if value == "pyi":
return True, set()
else:
versions = set()
for version in value.split(","):
tag = "cpy"
if version.startswith("cpy"):
version = version[len("cpy") :]
elif version.startswith("pypy"):
tag = "pypy"
version = version[len("pypy") :]
major_str, *rest = version.split(".")
try:
major = int(major_str)
if major not in (2, 3):
raise InvalidVariantHeader("major version must be 2 or 3")
if len(rest) > 0:
minor = int(rest[0])
if major == 2 and minor != 7:
raise InvalidVariantHeader(
"minor version must be 7 for Python 2"
)
else:
# Default to lowest supported minor version.
minor = 7 if major == 2 else 3
version_str = f"{tag.upper()}{major}{minor}"
# If PyPY is the same as CPython in some version, use
# the corresponding CPython version.
if tag == "pypy" and not hasattr(black.TargetVersion, version_str):
version_str = f"CPY{major}{minor}"
if major == 3 and not hasattr(black.TargetVersion, version_str):
raise InvalidVariantHeader(f"3.{minor} is not supported")
versions.add(black.TargetVersion[version_str])
except (KeyError, ValueError):
raise InvalidVariantHeader("expected e.g. '3.7', 'pypy3.5'")
return False, versions
def patched_main() -> None:
freeze_support()
black.patch_click()

View File

@ -39,7 +39,7 @@ def function_signature_stress_test(number:int,no_annotation=None,text:str='defau
return text[number:-1]
# fmt: on
def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r''):
offset = attr.ib(default=attr.Factory( lambda: _r.uniform(10000, 200000)))
offset = attr.ib(default=attr.Factory( lambda: _r.uniform(1, 2)))
assert task._cancel_stack[:len(old_stack)] == old_stack
def spaces_types(a: int = 1, b: tuple = (), c: list = [], d: dict = {}, e: bool = True, f: int = -1, g: int = 1 if False else 2, h: str = "", i: str = r''): ...
def spaces2(result= _core.Value(None)):
@ -225,7 +225,7 @@ def function_signature_stress_test(number:int,no_annotation=None,text:str='defau
return text[number:-1]
# fmt: on
def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r""):
offset = attr.ib(default=attr.Factory(lambda: _r.uniform(10000, 200_000)))
offset = attr.ib(default=attr.Factory(lambda: _r.uniform(1, 2)))
assert task._cancel_stack[: len(old_stack)] == old_stack

1
tests/empty.toml Normal file
View File

@ -0,0 +1 @@
# Empty configuration file; used in tests to avoid interference from Black's own config.

View File

@ -27,6 +27,7 @@
from click.testing import CliRunner
import black
from black import Feature
try:
import blackd
@ -37,9 +38,8 @@
has_blackd_deps = True
ll = 88
ff = partial(black.format_file_in_place, line_length=ll, fast=True)
fs = partial(black.format_str, line_length=ll)
ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
fs = partial(black.format_str, mode=black.FileMode())
THIS_FILE = Path(__file__)
THIS_DIR = THIS_FILE.parent
EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
@ -155,13 +155,22 @@ def assertFormatEqual(self, expected: str, actual: str) -> None:
black.err(str(ve))
self.assertEqual(expected, actual)
def invokeBlack(
self, args: List[str], exit_code: int = 0, ignore_config: bool = True
) -> None:
runner = BlackRunner()
if ignore_config:
args = ["--config", str(THIS_DIR / "empty.toml"), *args]
result = runner.invoke(black.main, args)
self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
@patch("black.dump_to_file", dump_to_stderr)
def test_empty(self) -> None:
source = expected = ""
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
def test_empty_ff(self) -> None:
expected = ""
@ -180,7 +189,7 @@ def test_self(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
self.assertFalse(ff(THIS_FILE))
@patch("black.dump_to_file", dump_to_stderr)
@ -189,20 +198,20 @@ def test_black(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
def test_piping(self) -> None:
source, expected = read_data("../black", data=False)
result = BlackRunner().invoke(
black.main,
["-", "--fast", f"--line-length={ll}"],
["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
input=BytesIO(source.encode("utf8")),
)
self.assertEqual(result.exit_code, 0)
self.assertFormatEqual(expected, result.output)
black.assert_equivalent(source, result.output)
black.assert_stable(source, result.output, line_length=ll)
black.assert_stable(source, result.output, black.FileMode())
def test_piping_diff(self) -> None:
diff_header = re.compile(
@ -212,7 +221,13 @@ def test_piping_diff(self) -> None:
source, _ = read_data("expression.py")
expected, _ = read_data("expression.diff")
config = THIS_DIR / "data" / "empty_pyproject.toml"
args = ["-", "--fast", f"--line-length={ll}", "--diff", f"--config={config}"]
args = [
"-",
"--fast",
f"--line-length={black.DEFAULT_LINE_LENGTH}",
"--diff",
f"--config={config}",
]
result = BlackRunner().invoke(
black.main, args, input=BytesIO(source.encode("utf8"))
)
@ -227,7 +242,7 @@ def test_setup(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
@patch("black.dump_to_file", dump_to_stderr)
@ -236,7 +251,7 @@ def test_function(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_function2(self) -> None:
@ -244,7 +259,7 @@ def test_function2(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_expression(self) -> None:
@ -252,7 +267,7 @@ def test_expression(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
def test_expression_ff(self) -> None:
source, expected = read_data("expression")
@ -266,7 +281,7 @@ def test_expression_ff(self) -> None:
self.assertFormatEqual(expected, actual)
with patch("black.dump_to_file", dump_to_stderr):
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
def test_expression_diff(self) -> None:
source, _ = read_data("expression.py")
@ -299,7 +314,7 @@ def test_fstring(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_string_quotes(self) -> None:
@ -307,12 +322,12 @@ def test_string_quotes(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
mode = black.FileMode.NO_STRING_NORMALIZATION
black.assert_stable(source, actual, black.FileMode())
mode = black.FileMode(string_normalization=False)
not_normalized = fs(source, mode=mode)
self.assertFormatEqual(source, not_normalized)
black.assert_equivalent(source, not_normalized)
black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
black.assert_stable(source, not_normalized, mode=mode)
@patch("black.dump_to_file", dump_to_stderr)
def test_slices(self) -> None:
@ -320,7 +335,7 @@ def test_slices(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments(self) -> None:
@ -328,7 +343,7 @@ def test_comments(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments2(self) -> None:
@ -336,7 +351,7 @@ def test_comments2(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments3(self) -> None:
@ -344,7 +359,7 @@ def test_comments3(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments4(self) -> None:
@ -352,7 +367,7 @@ def test_comments4(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments5(self) -> None:
@ -360,7 +375,7 @@ def test_comments5(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments6(self) -> None:
@ -368,7 +383,7 @@ def test_comments6(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_cantfit(self) -> None:
@ -376,7 +391,7 @@ def test_cantfit(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_import_spacing(self) -> None:
@ -384,7 +399,7 @@ def test_import_spacing(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_composition(self) -> None:
@ -392,7 +407,7 @@ def test_composition(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_empty_lines(self) -> None:
@ -400,7 +415,7 @@ def test_empty_lines(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_string_prefixes(self) -> None:
@ -408,33 +423,34 @@ def test_string_prefixes(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals(self) -> None:
source, expected = read_data("numeric_literals")
actual = fs(source, mode=black.FileMode.PYTHON36)
mode = black.FileMode(target_versions=black.PY36_VERSIONS)
actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, mode)
@patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals_ignoring_underscores(self) -> None:
source, expected = read_data("numeric_literals_skip_underscores")
mode = (
black.FileMode.PYTHON36 | black.FileMode.NO_NUMERIC_UNDERSCORE_NORMALIZATION
mode = black.FileMode(
numeric_underscore_normalization=False, target_versions=black.PY36_VERSIONS
)
actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll, mode=mode)
black.assert_stable(source, actual, mode)
@patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals_py2(self) -> None:
source, expected = read_data("numeric_literals_py2")
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_python2(self) -> None:
@ -442,22 +458,22 @@ def test_python2(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
# black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_python2_unicode_literals(self) -> None:
source, expected = read_data("python2_unicode_literals")
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_stub(self) -> None:
mode = black.FileMode.PYI
mode = black.FileMode(is_pyi=True)
source, expected = read_data("stub.pyi")
actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll, mode=mode)
black.assert_stable(source, actual, mode)
@patch("black.dump_to_file", dump_to_stderr)
def test_python37(self) -> None:
@ -467,7 +483,7 @@ def test_python37(self) -> None:
major, minor = sys.version_info[:2]
if major > 3 or (major == 3 and minor >= 7):
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff(self) -> None:
@ -475,7 +491,7 @@ def test_fmtonoff(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff2(self) -> None:
@ -483,7 +499,7 @@ def test_fmtonoff2(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_remove_empty_parentheses_after_class(self) -> None:
@ -491,7 +507,7 @@ def test_remove_empty_parentheses_after_class(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_new_line_between_class_and_code(self) -> None:
@ -499,7 +515,7 @@ def test_new_line_between_class_and_code(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_bracket_match(self) -> None:
@ -507,7 +523,7 @@ def test_bracket_match(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
black.assert_stable(source, actual, black.FileMode())
def test_comment_indentation(self) -> None:
contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
@ -794,27 +810,32 @@ def err(msg: str, **kwargs: Any) -> None:
"2 files would fail to reformat.",
)
def test_is_python36(self) -> None:
def test_get_features_used(self) -> None:
node = black.lib2to3_parse("def f(*, arg): ...\n")
self.assertFalse(black.is_python36(node))
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("def f(*, arg,): ...\n")
self.assertTrue(black.is_python36(node))
self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
node = black.lib2to3_parse("def f(*, arg): f'string'\n")
self.assertTrue(black.is_python36(node))
self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
node = black.lib2to3_parse("123_456\n")
self.assertTrue(black.is_python36(node))
self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
node = black.lib2to3_parse("123456\n")
self.assertFalse(black.is_python36(node))
self.assertEqual(black.get_features_used(node), set())
source, expected = read_data("function")
node = black.lib2to3_parse(source)
self.assertTrue(black.is_python36(node))
self.assertEqual(
black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
)
node = black.lib2to3_parse(expected)
self.assertTrue(black.is_python36(node))
self.assertEqual(
black.get_features_used(node),
{Feature.TRAILING_COMMA, Feature.F_STRINGS, Feature.NUMERIC_UNDERSCORES},
)
source, expected = read_data("expression")
node = black.lib2to3_parse(source)
self.assertFalse(black.is_python36(node))
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse(expected)
self.assertFalse(black.is_python36(node))
self.assertEqual(black.get_features_used(node), set())
def test_get_future_imports(self) -> None:
node = black.lib2to3_parse("\n")
@ -872,21 +893,22 @@ def err(msg: str, **kwargs: Any) -> None:
def test_format_file_contents(self) -> None:
empty = ""
mode = black.FileMode()
with self.assertRaises(black.NothingChanged):
black.format_file_contents(empty, line_length=ll, fast=False)
black.format_file_contents(empty, mode=mode, fast=False)
just_nl = "\n"
with self.assertRaises(black.NothingChanged):
black.format_file_contents(just_nl, line_length=ll, fast=False)
black.format_file_contents(just_nl, mode=mode, fast=False)
same = "l = [1, 2, 3]\n"
with self.assertRaises(black.NothingChanged):
black.format_file_contents(same, line_length=ll, fast=False)
black.format_file_contents(same, mode=mode, fast=False)
different = "l = [1,2,3]"
expected = same
actual = black.format_file_contents(different, line_length=ll, fast=False)
actual = black.format_file_contents(different, mode=mode, fast=False)
self.assertEqual(expected, actual)
invalid = "return if you can"
with self.assertRaises(black.InvalidInput) as e:
black.format_file_contents(invalid, line_length=ll, fast=False)
black.format_file_contents(invalid, mode=mode, fast=False)
self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
def test_endmarker(self) -> None:
@ -916,35 +938,33 @@ def err(msg: str, **kwargs: Any) -> None:
self.assertEqual("".join(err_lines), "")
def test_cache_broken_file(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
with cache_dir() as workspace:
cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
cache_file = black.get_cache_file(mode)
with cache_file.open("w") as fobj:
fobj.write("this is not a pickle")
self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
self.assertEqual(black.read_cache(mode), {})
src = (workspace / "test.py").resolve()
with src.open("w") as fobj:
fobj.write("print('hello')")
result = CliRunner().invoke(black.main, [str(src)])
self.assertEqual(result.exit_code, 0)
cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
self.invokeBlack([str(src)])
cache = black.read_cache(mode)
self.assertIn(src, cache)
def test_cache_single_file_already_cached(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
with cache_dir() as workspace:
src = (workspace / "test.py").resolve()
with src.open("w") as fobj:
fobj.write("print('hello')")
black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
result = CliRunner().invoke(black.main, [str(src)])
self.assertEqual(result.exit_code, 0)
black.write_cache({}, [src], mode)
self.invokeBlack([str(src)])
with src.open("r") as fobj:
self.assertEqual(fobj.read(), "print('hello')")
@event_loop(close=False)
def test_cache_multiple_files(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
with cache_dir() as workspace, patch(
"black.ProcessPoolExecutor", new=ThreadPoolExecutor
):
@ -954,50 +974,48 @@ def test_cache_multiple_files(self) -> None:
two = (workspace / "two.py").resolve()
with two.open("w") as fobj:
fobj.write("print('hello')")
black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
result = CliRunner().invoke(black.main, [str(workspace)])
self.assertEqual(result.exit_code, 0)
black.write_cache({}, [one], mode)
self.invokeBlack([str(workspace)])
with one.open("r") as fobj:
self.assertEqual(fobj.read(), "print('hello')")
with two.open("r") as fobj:
self.assertEqual(fobj.read(), 'print("hello")\n')
cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
cache = black.read_cache(mode)
self.assertIn(one, cache)
self.assertIn(two, cache)
def test_no_cache_when_writeback_diff(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
with cache_dir() as workspace:
src = (workspace / "test.py").resolve()
with src.open("w") as fobj:
fobj.write("print('hello')")
result = CliRunner().invoke(black.main, [str(src), "--diff"])
self.assertEqual(result.exit_code, 0)
cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
self.invokeBlack([str(src), "--diff"])
cache_file = black.get_cache_file(mode)
self.assertFalse(cache_file.exists())
def test_no_cache_when_stdin(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
with cache_dir():
result = CliRunner().invoke(
black.main, ["-"], input=BytesIO(b"print('hello')")
)
self.assertEqual(result.exit_code, 0)
cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
cache_file = black.get_cache_file(mode)
self.assertFalse(cache_file.exists())
def test_read_cache_no_cachefile(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
with cache_dir():
self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
self.assertEqual(black.read_cache(mode), {})
def test_write_cache_read_cache(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
with cache_dir() as workspace:
src = (workspace / "test.py").resolve()
src.touch()
black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
black.write_cache({}, [src], mode)
cache = black.read_cache(mode)
self.assertIn(src, cache)
self.assertEqual(cache[src], black.get_cache_info(src))
@ -1018,15 +1036,15 @@ def test_filter_cached(self) -> None:
self.assertEqual(done, {cached})
def test_write_cache_creates_directory_if_needed(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
with cache_dir(exists=False) as workspace:
self.assertFalse(workspace.exists())
black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
black.write_cache({}, [], mode)
self.assertTrue(workspace.exists())
@event_loop(close=False)
def test_failed_formatting_does_not_get_cached(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
with cache_dir() as workspace, patch(
"black.ProcessPoolExecutor", new=ThreadPoolExecutor
):
@ -1036,40 +1054,33 @@ def test_failed_formatting_does_not_get_cached(self) -> None:
clean = (workspace / "clean.py").resolve()
with clean.open("w") as fobj:
fobj.write('print("hello")\n')
result = CliRunner().invoke(black.main, [str(workspace)])
self.assertEqual(result.exit_code, 123)
cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
self.invokeBlack([str(workspace)], exit_code=123)
cache = black.read_cache(mode)
self.assertNotIn(failing, cache)
self.assertIn(clean, cache)
def test_write_cache_write_fail(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
with cache_dir(), patch.object(Path, "open") as mock:
mock.side_effect = OSError
black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
black.write_cache({}, [], mode)
@event_loop(close=False)
def test_check_diff_use_together(self) -> None:
with cache_dir():
# Files which will be reformatted.
src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
self.assertEqual(result.exit_code, 1, result.output)
self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
# Files which will not be reformatted.
src2 = (THIS_DIR / "data" / "composition.py").resolve()
result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
self.assertEqual(result.exit_code, 0, result.output)
self.invokeBlack([str(src2), "--diff", "--check"])
# Multi file command.
result = CliRunner().invoke(
black.main, [str(src1), str(src2), "--diff", "--check"]
)
self.assertEqual(result.exit_code, 1, result.output)
self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
def test_no_files(self) -> None:
with cache_dir():
# Without an argument, black exits with error code 0.
result = CliRunner().invoke(black.main, [])
self.assertEqual(result.exit_code, 0)
self.invokeBlack([])
def test_broken_symlink(self) -> None:
with cache_dir() as workspace:
@ -1078,43 +1089,42 @@ def test_broken_symlink(self) -> None:
symlink.symlink_to("nonexistent.py")
except OSError as e:
self.skipTest(f"Can't create symlinks: {e}")
result = CliRunner().invoke(black.main, [str(workspace.resolve())])
self.assertEqual(result.exit_code, 0)
self.invokeBlack([str(workspace.resolve())])
def test_read_cache_line_lengths(self) -> None:
mode = black.FileMode.AUTO_DETECT
mode = black.FileMode()
short_mode = black.FileMode(line_length=1)
with cache_dir() as workspace:
path = (workspace / "file.py").resolve()
path.touch()
black.write_cache({}, [path], 1, mode)
one = black.read_cache(1, mode)
black.write_cache({}, [path], mode)
one = black.read_cache(mode)
self.assertIn(path, one)
two = black.read_cache(2, mode)
two = black.read_cache(short_mode)
self.assertNotIn(path, two)
def test_single_file_force_pyi(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT
pyi_mode = black.FileMode.PYI
reg_mode = black.FileMode()
pyi_mode = black.FileMode(is_pyi=True)
contents, expected = read_data("force_pyi")
with cache_dir() as workspace:
path = (workspace / "file.py").resolve()
with open(path, "w") as fh:
fh.write(contents)
result = CliRunner().invoke(black.main, [str(path), "--pyi"])
self.assertEqual(result.exit_code, 0)
self.invokeBlack([str(path), "--pyi"])
with open(path, "r") as fh:
actual = fh.read()
# verify cache with --pyi is separate
pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
pyi_cache = black.read_cache(pyi_mode)
self.assertIn(path, pyi_cache)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
normal_cache = black.read_cache(reg_mode)
self.assertNotIn(path, normal_cache)
self.assertEqual(actual, expected)
@event_loop(close=False)
def test_multi_file_force_pyi(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT
pyi_mode = black.FileMode.PYI
reg_mode = black.FileMode()
pyi_mode = black.FileMode(is_pyi=True)
contents, expected = read_data("force_pyi")
with cache_dir() as workspace:
paths = [
@ -1124,15 +1134,14 @@ def test_multi_file_force_pyi(self) -> None:
for path in paths:
with open(path, "w") as fh:
fh.write(contents)
result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
self.assertEqual(result.exit_code, 0)
self.invokeBlack([str(p) for p in paths] + ["--pyi"])
for path in paths:
with open(path, "r") as fh:
actual = fh.read()
self.assertEqual(actual, expected)
# verify cache with --pyi is separate
pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
pyi_cache = black.read_cache(pyi_mode)
normal_cache = black.read_cache(reg_mode)
for path in paths:
self.assertIn(path, pyi_cache)
self.assertNotIn(path, normal_cache)
@ -1147,28 +1156,27 @@ def test_pipe_force_pyi(self) -> None:
self.assertFormatEqual(actual, expected)
def test_single_file_force_py36(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT
py36_mode = black.FileMode.PYTHON36
reg_mode = black.FileMode()
py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
source, expected = read_data("force_py36")
with cache_dir() as workspace:
path = (workspace / "file.py").resolve()
with open(path, "w") as fh:
fh.write(source)
result = CliRunner().invoke(black.main, [str(path), "--py36"])
self.assertEqual(result.exit_code, 0)
self.invokeBlack([str(path), "--py36"])
with open(path, "r") as fh:
actual = fh.read()
# verify cache with --py36 is separate
py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
py36_cache = black.read_cache(py36_mode)
self.assertIn(path, py36_cache)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
normal_cache = black.read_cache(reg_mode)
self.assertNotIn(path, normal_cache)
self.assertEqual(actual, expected)
@event_loop(close=False)
def test_multi_file_force_py36(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT
py36_mode = black.FileMode.PYTHON36
reg_mode = black.FileMode()
py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
source, expected = read_data("force_py36")
with cache_dir() as workspace:
paths = [
@ -1178,17 +1186,14 @@ def test_multi_file_force_py36(self) -> None:
for path in paths:
with open(path, "w") as fh:
fh.write(source)
result = CliRunner().invoke(
black.main, [str(p) for p in paths] + ["--py36"]
)
self.assertEqual(result.exit_code, 0)
self.invokeBlack([str(p) for p in paths] + ["--py36"])
for path in paths:
with open(path, "r") as fh:
actual = fh.read()
self.assertEqual(actual, expected)
# verify cache with --py36 is separate
pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
pyi_cache = black.read_cache(py36_mode)
normal_cache = black.read_cache(reg_mode)
for path in paths:
self.assertIn(path, pyi_cache)
self.assertNotIn(path, normal_cache)
@ -1265,8 +1270,7 @@ def test_empty_exclude(self) -> None:
def test_invalid_include_exclude(self) -> None:
for option in ["--include", "--exclude"]:
result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
self.assertEqual(result.exit_code, 2)
self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
def test_preserves_line_endings(self) -> None:
with TemporaryDirectory() as workspace:
@ -1407,10 +1411,24 @@ async def test_blackd_supported_version(self) -> None:
async def test_blackd_invalid_python_variant(self) -> None:
app = blackd.make_app()
async with TestClient(TestServer(app)) as client:
response = await client.post(
"/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: "lol"}
)
self.assertEqual(response.status, 400)
async def check(header_value: str, expected_status: int = 400) -> None:
response = await client.post(
"/",
data=b"what",
headers={blackd.PYTHON_VARIANT_HEADER: header_value},
)
self.assertEqual(response.status, expected_status)
await check("lol")
await check("ruby3.5")
await check("pyi3.6")
await check("cpy1.5")
await check("2.8")
await check("cpy2.8")
await check("3.0")
await check("pypy3.0")
await check("jython3.4")
@unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
@async_test
@ -1426,51 +1444,37 @@ async def test_blackd_pyi(self) -> None:
@unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
@async_test
async def test_blackd_py36(self) -> None:
async def test_blackd_python_variant(self) -> None:
app = blackd.make_app()
code = (
"def f(\n"
" and_has_a_bunch_of,\n"
" very_long_arguments_too,\n"
" and_lots_of_them_as_well_lol,\n"
" **and_very_long_keyword_arguments\n"
"):\n"
" pass\n"
)
async with TestClient(TestServer(app)) as client:
response = await client.post(
"/",
data=(
"def f(\n"
" and_has_a_bunch_of,\n"
" very_long_arguments_too,\n"
" and_lots_of_them_as_well_lol,\n"
" **and_very_long_keyword_arguments\n"
"):\n"
" pass\n"
),
headers={blackd.PYTHON_VARIANT_HEADER: "3.6"},
)
self.assertEqual(response.status, 200)
response = await client.post(
"/",
data=(
"def f(\n"
" and_has_a_bunch_of,\n"
" very_long_arguments_too,\n"
" and_lots_of_them_as_well_lol,\n"
" **and_very_long_keyword_arguments\n"
"):\n"
" pass\n"
),
headers={blackd.PYTHON_VARIANT_HEADER: "3.5"},
)
self.assertEqual(response.status, 204)
response = await client.post(
"/",
data=(
"def f(\n"
" and_has_a_bunch_of,\n"
" very_long_arguments_too,\n"
" and_lots_of_them_as_well_lol,\n"
" **and_very_long_keyword_arguments\n"
"):\n"
" pass\n"
),
headers={blackd.PYTHON_VARIANT_HEADER: "2"},
)
self.assertEqual(response.status, 204)
async def check(header_value: str, expected_status: int) -> None:
response = await client.post(
"/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
)
self.assertEqual(response.status, expected_status)
await check("3.6", 200)
await check("cpy3.6", 200)
await check("3.5,3.7", 200)
await check("3.5,cpy3.7", 200)
await check("2", 204)
await check("2.7", 204)
await check("cpy2.7", 204)
await check("pypy2.7", 204)
await check("3.4", 204)
await check("cpy3.4", 204)
await check("pypy3.4", 204)
@unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
@async_test