Add --target-version option to allow users to choose targeted Python versions (#618)

This commit is contained in:
Jelle Zijlstra 2019-02-06 18:43:50 -08:00 committed by GitHub
parent a9d8af466a
commit 36d3c516d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 545 additions and 387 deletions

View File

@ -71,46 +71,60 @@ black {source_file_or_directory}
black [OPTIONS] [SRC]... black [OPTIONS] [SRC]...
Options: Options:
-l, --line-length INTEGER Where to wrap around. [default: 88] -l, --line-length INTEGER How many characters per line to allow.
--py36 Allow using Python 3.6-only syntax on all input [default: 88]
files. This will put trailing commas in function -t, --target-version [pypy35|cpy27|cpy33|cpy34|cpy35|cpy36|cpy37|cpy38]
signatures and calls also after *args and Python versions that should be supported by
**kwargs. [default: per-file auto-detection] Black's output. [default: per-file auto-
detection]
--py36 Allow using Python 3.6-only syntax on all
input files. This will put trailing commas
in function signatures and calls also after
*args and **kwargs. [default: per-file
auto-detection]
--pyi Format all input files like typing stubs --pyi Format all input files like typing stubs
regardless of file extension (useful when piping regardless of file extension (useful when
source on standard input). piping source on standard input).
-S, --skip-string-normalization -S, --skip-string-normalization
Don't normalize string quotes or prefixes. Don't normalize string quotes or prefixes.
-N, --skip-numeric-underscore-normalization -N, --skip-numeric-underscore-normalization
Don't normalize underscores in numeric literals. Don't normalize underscores in numeric
literals.
--check Don't write the files back, just return the --check Don't write the files back, just return the
status. Return code 0 means nothing would status. Return code 0 means nothing would
change. Return code 1 means some files would be change. Return code 1 means some files
reformatted. Return code 123 means there was an would be reformatted. Return code 123 means
internal error. there was an internal error.
--diff Don't write the files back, just output a diff --diff Don't write the files back, just output a
for each file on stdout. diff for each file on stdout.
--fast / --safe If --fast given, skip temporary sanity checks. --fast / --safe If --fast given, skip temporary sanity
[default: --safe] checks. [default: --safe]
--include TEXT A regular expression that matches files and --include TEXT A regular expression that matches files and
directories that should be included on directories that should be included on
recursive searches. On Windows, use forward recursive searches. An empty value means
slashes for directories. [default: \.pyi?$] all files are included regardless of the
name. Use forward slashes for directories
on all platforms (Windows, too). Exclusions
are calculated first, inclusions later.
[default: \.pyi?$]
--exclude TEXT A regular expression that matches files and --exclude TEXT A regular expression that matches files and
directories that should be excluded on directories that should be excluded on
recursive searches. On Windows, use forward recursive searches. An empty value means no
slashes for directories. [default: paths are excluded. Use forward slashes for
build/|buck-out/|dist/|_build/|\.eggs/|\.git/| directories on all platforms (Windows, too).
\.hg/|\.mypy_cache/|\.nox/|\.tox/|\.venv/] Exclusions are calculated first, inclusions
-q, --quiet Don't emit non-error messages to stderr. Errors later. [default: /(\.eggs|\.git|\.hg|\.mypy
are still emitted, silence those with _cache|\.nox|\.tox|\.venv|_build|buck-
out|build|dist)/]
-q, --quiet Don't emit non-error messages to stderr.
Errors are still emitted, silence those with
2>/dev/null. 2>/dev/null.
-v, --verbose Also emit messages to stderr about files -v, --verbose Also emit messages to stderr about files
that were not changed or were ignored due to that were not changed or were ignored due to
--exclude=. --exclude=.
--version Show the version and exit. --version Show the version and exit.
--config PATH Read configuration from PATH. --config PATH Read configuration from PATH.
--help Show this message and exit. -h, --help Show this message and exit.
``` ```
*Black* is a well-behaved Unix-style command-line tool: *Black* is a well-behaved Unix-style command-line tool:
@ -815,8 +829,9 @@ The headers controlling how code is formatted are:
passed the `--fast` command line flag. passed the `--fast` command line flag.
- `X-Python-Variant`: if set to `pyi`, `blackd` will act as *Black* does when - `X-Python-Variant`: if set to `pyi`, `blackd` will act as *Black* does when
passed the `--pyi` command line flag. Otherwise, its value must correspond to passed the `--pyi` command line flag. Otherwise, its value must correspond to
a Python version. If this value represents at least Python 3.6, `blackd` will a Python version or a set of comma-separated Python versions, optionally
act as *Black* does when passed the `--py36` command line flag. prefixed with `cpy` or `pypy`. For example, to request code that is compatible
with PyPy 3.5 and CPython 3.5, set the header to `pypy3.5,cpy3.5`.
If any of these headers are set to invalid values, `blackd` returns a `HTTP 400` If any of these headers are set to invalid values, `blackd` returns a `HTTP 400`
error response, mentioning the name of the problematic header in the message body. error response, mentioning the name of the problematic header in the message body.
@ -935,6 +950,11 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md).
## Change Log ## Change Log
### 18.11b0
* new option `--target-version` to control which Python versions
*Black*-formatted code should target
### 18.9b0 ### 18.9b0
* numeric literals are now formatted by *Black* (#452, #461, #464, #469): * numeric literals are now formatted by *Black* (#452, #461, #464, #469):

361
black.py
View File

@ -2,7 +2,7 @@
from asyncio.base_events import BaseEventLoop from asyncio.base_events import BaseEventLoop
from concurrent.futures import Executor, ProcessPoolExecutor from concurrent.futures import Executor, ProcessPoolExecutor
from datetime import datetime from datetime import datetime
from enum import Enum, Flag from enum import Enum
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
import io import io
import itertools import itertools
@ -37,7 +37,7 @@
) )
from appdirs import user_cache_dir from appdirs import user_cache_dir
from attr import dataclass, Factory from attr import dataclass, evolve, Factory
import click import click
import toml import toml
@ -45,6 +45,7 @@
from blib2to3.pytree import Node, Leaf, type_repr from blib2to3.pytree import Node, Leaf, type_repr
from blib2to3 import pygram, pytree from blib2to3 import pygram, pytree
from blib2to3.pgen2 import driver, token from blib2to3.pgen2 import driver, token
from blib2to3.pgen2.grammar import Grammar
from blib2to3.pgen2.parse import ParseError from blib2to3.pgen2.parse import ParseError
@ -111,32 +112,86 @@ class Changed(Enum):
YES = 2 YES = 2
class FileMode(Flag): class TargetVersion(Enum):
AUTO_DETECT = 0 PYPY35 = 1
PYTHON36 = 1 CPY27 = 2
PYI = 2 CPY33 = 3
NO_STRING_NORMALIZATION = 4 CPY34 = 4
NO_NUMERIC_UNDERSCORE_NORMALIZATION = 8 CPY35 = 5
CPY36 = 6
CPY37 = 7
CPY38 = 8
@classmethod def is_python2(self) -> bool:
def from_configuration( return self is TargetVersion.CPY27
cls,
*,
py36: bool, PY36_VERSIONS = {TargetVersion.CPY36, TargetVersion.CPY37, TargetVersion.CPY38}
pyi: bool,
skip_string_normalization: bool,
skip_numeric_underscore_normalization: bool, class Feature(Enum):
) -> "FileMode": # All string literals are unicode
mode = cls.AUTO_DETECT UNICODE_LITERALS = 1
if py36: F_STRINGS = 2
mode |= cls.PYTHON36 NUMERIC_UNDERSCORES = 3
if pyi: TRAILING_COMMA = 4
mode |= cls.PYI
if skip_string_normalization:
mode |= cls.NO_STRING_NORMALIZATION VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
if skip_numeric_underscore_normalization: TargetVersion.CPY27: set(),
mode |= cls.NO_NUMERIC_UNDERSCORE_NORMALIZATION TargetVersion.PYPY35: {Feature.UNICODE_LITERALS, Feature.F_STRINGS},
return mode TargetVersion.CPY33: {Feature.UNICODE_LITERALS},
TargetVersion.CPY34: {Feature.UNICODE_LITERALS},
TargetVersion.CPY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
TargetVersion.CPY36: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
},
TargetVersion.CPY37: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
},
TargetVersion.CPY38: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
},
}
@dataclass
class FileMode:
target_versions: Set[TargetVersion] = Factory(set)
line_length: int = DEFAULT_LINE_LENGTH
numeric_underscore_normalization: bool = True
string_normalization: bool = True
is_pyi: bool = False
def get_cache_key(self) -> str:
if self.target_versions:
version_str = ",".join(
str(version.value)
for version in sorted(self.target_versions, key=lambda v: v.value)
)
else:
version_str = "-"
parts = [
version_str,
str(self.line_length),
str(int(self.numeric_underscore_normalization)),
str(int(self.string_normalization)),
str(int(self.is_pyi)),
]
return ".".join(parts)
def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
def read_pyproject_toml( def read_pyproject_toml(
@ -184,6 +239,17 @@ def read_pyproject_toml(
help="How many characters per line to allow.", help="How many characters per line to allow.",
show_default=True, show_default=True,
) )
@click.option(
"-t",
"--target-version",
type=click.Choice([v.name.lower() for v in TargetVersion]),
callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
multiple=True,
help=(
"Python versions that should be supported by Black's output. [default: "
"per-file auto-detection]"
),
)
@click.option( @click.option(
"--py36", "--py36",
is_flag=True, is_flag=True,
@ -297,6 +363,7 @@ def read_pyproject_toml(
def main( def main(
ctx: click.Context, ctx: click.Context,
line_length: int, line_length: int,
target_version: List[TargetVersion],
check: bool, check: bool,
diff: bool, diff: bool,
fast: bool, fast: bool,
@ -313,11 +380,23 @@ def main(
) -> None: ) -> None:
"""The uncompromising code formatter.""" """The uncompromising code formatter."""
write_back = WriteBack.from_configuration(check=check, diff=diff) write_back = WriteBack.from_configuration(check=check, diff=diff)
mode = FileMode.from_configuration( if target_version:
py36=py36, if py36:
pyi=pyi, err(f"Cannot use both --target-version and --py36")
skip_string_normalization=skip_string_normalization, ctx.exit(2)
skip_numeric_underscore_normalization=skip_numeric_underscore_normalization, else:
versions = set(target_version)
elif py36:
versions = PY36_VERSIONS
else:
# We'll autodetect later.
versions = set()
mode = FileMode(
target_versions=versions,
line_length=line_length,
is_pyi=pyi,
string_normalization=not skip_string_normalization,
numeric_underscore_normalization=not skip_numeric_underscore_normalization,
) )
if config and verbose: if config and verbose:
out(f"Using configuration from {config}.", bold=False, fg="blue") out(f"Using configuration from {config}.", bold=False, fg="blue")
@ -353,7 +432,6 @@ def main(
if len(sources) == 1: if len(sources) == 1:
reformat_one( reformat_one(
src=sources.pop(), src=sources.pop(),
line_length=line_length,
fast=fast, fast=fast,
write_back=write_back, write_back=write_back,
mode=mode, mode=mode,
@ -366,7 +444,6 @@ def main(
loop.run_until_complete( loop.run_until_complete(
schedule_formatting( schedule_formatting(
sources=sources, sources=sources,
line_length=line_length,
fast=fast, fast=fast,
write_back=write_back, write_back=write_back,
mode=mode, mode=mode,
@ -385,12 +462,7 @@ def main(
def reformat_one( def reformat_one(
src: Path, src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
line_length: int,
fast: bool,
write_back: WriteBack,
mode: FileMode,
report: "Report",
) -> None: ) -> None:
"""Reformat a single file under `src` without spawning child processes. """Reformat a single file under `src` without spawning child processes.
@ -401,29 +473,23 @@ def reformat_one(
try: try:
changed = Changed.NO changed = Changed.NO
if not src.is_file() and str(src) == "-": if not src.is_file() and str(src) == "-":
if format_stdin_to_stdout( if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
line_length=line_length, fast=fast, write_back=write_back, mode=mode
):
changed = Changed.YES changed = Changed.YES
else: else:
cache: Cache = {} cache: Cache = {}
if write_back != WriteBack.DIFF: if write_back != WriteBack.DIFF:
cache = read_cache(line_length, mode) cache = read_cache(mode)
res_src = src.resolve() res_src = src.resolve()
if res_src in cache and cache[res_src] == get_cache_info(res_src): if res_src in cache and cache[res_src] == get_cache_info(res_src):
changed = Changed.CACHED changed = Changed.CACHED
if changed is not Changed.CACHED and format_file_in_place( if changed is not Changed.CACHED and format_file_in_place(
src, src, fast=fast, write_back=write_back, mode=mode
line_length=line_length,
fast=fast,
write_back=write_back,
mode=mode,
): ):
changed = Changed.YES changed = Changed.YES
if (write_back is WriteBack.YES and changed is not Changed.CACHED) or ( if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
write_back is WriteBack.CHECK and changed is Changed.NO write_back is WriteBack.CHECK and changed is Changed.NO
): ):
write_cache(cache, [src], line_length, mode) write_cache(cache, [src], mode)
report.done(src, changed) report.done(src, changed)
except Exception as exc: except Exception as exc:
report.failed(src, str(exc)) report.failed(src, str(exc))
@ -431,7 +497,6 @@ def reformat_one(
async def schedule_formatting( async def schedule_formatting(
sources: Set[Path], sources: Set[Path],
line_length: int,
fast: bool, fast: bool,
write_back: WriteBack, write_back: WriteBack,
mode: FileMode, mode: FileMode,
@ -448,7 +513,7 @@ async def schedule_formatting(
""" """
cache: Cache = {} cache: Cache = {}
if write_back != WriteBack.DIFF: if write_back != WriteBack.DIFF:
cache = read_cache(line_length, mode) cache = read_cache(mode)
sources, cached = filter_cached(cache, sources) sources, cached = filter_cached(cache, sources)
for src in sorted(cached): for src in sorted(cached):
report.done(src, Changed.CACHED) report.done(src, Changed.CACHED)
@ -465,14 +530,7 @@ async def schedule_formatting(
lock = manager.Lock() lock = manager.Lock()
tasks = { tasks = {
loop.run_in_executor( loop.run_in_executor(
executor, executor, format_file_in_place, src, fast, mode, write_back, lock
format_file_in_place,
src,
line_length,
fast,
write_back,
mode,
lock,
): src ): src
for src in sorted(sources) for src in sorted(sources)
} }
@ -503,15 +561,14 @@ async def schedule_formatting(
if cancelled: if cancelled:
await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
if sources_to_cache: if sources_to_cache:
write_cache(cache, sources_to_cache, line_length, mode) write_cache(cache, sources_to_cache, mode)
def format_file_in_place( def format_file_in_place(
src: Path, src: Path,
line_length: int,
fast: bool, fast: bool,
mode: FileMode,
write_back: WriteBack = WriteBack.NO, write_back: WriteBack = WriteBack.NO,
mode: FileMode = FileMode.AUTO_DETECT,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
) -> bool: ) -> bool:
"""Format file under `src` path. Return True if changed. """Format file under `src` path. Return True if changed.
@ -521,15 +578,13 @@ def format_file_in_place(
`line_length` and `fast` options are passed to :func:`format_file_contents`. `line_length` and `fast` options are passed to :func:`format_file_contents`.
""" """
if src.suffix == ".pyi": if src.suffix == ".pyi":
mode |= FileMode.PYI mode = evolve(mode, is_pyi=True)
then = datetime.utcfromtimestamp(src.stat().st_mtime) then = datetime.utcfromtimestamp(src.stat().st_mtime)
with open(src, "rb") as buf: with open(src, "rb") as buf:
src_contents, encoding, newline = decode_bytes(buf.read()) src_contents, encoding, newline = decode_bytes(buf.read())
try: try:
dst_contents = format_file_contents( dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
src_contents, line_length=line_length, fast=fast, mode=mode
)
except NothingChanged: except NothingChanged:
return False return False
@ -559,23 +614,19 @@ def format_file_in_place(
def format_stdin_to_stdout( def format_stdin_to_stdout(
line_length: int, fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
fast: bool,
write_back: WriteBack = WriteBack.NO,
mode: FileMode = FileMode.AUTO_DETECT,
) -> bool: ) -> bool:
"""Format file on stdin. Return True if changed. """Format file on stdin. Return True if changed.
If `write_back` is YES, write reformatted code back to stdout. If it is DIFF, If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
write a diff to stdout. write a diff to stdout. The `mode` argument is passed to
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
:func:`format_file_contents`. :func:`format_file_contents`.
""" """
then = datetime.utcnow() then = datetime.utcnow()
src, encoding, newline = decode_bytes(sys.stdin.buffer.read()) src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
dst = src dst = src
try: try:
dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode) dst = format_file_contents(src, fast=fast, mode=mode)
return True return True
except NothingChanged: except NothingChanged:
@ -596,11 +647,7 @@ def format_stdin_to_stdout(
def format_file_contents( def format_file_contents(
src_contents: str, src_contents: str, *, fast: bool, mode: FileMode
*,
line_length: int,
fast: bool,
mode: FileMode = FileMode.AUTO_DETECT,
) -> FileContent: ) -> FileContent:
"""Reformat contents a file and return new contents. """Reformat contents a file and return new contents.
@ -611,38 +658,38 @@ def format_file_contents(
if src_contents.strip() == "": if src_contents.strip() == "":
raise NothingChanged raise NothingChanged
dst_contents = format_str(src_contents, line_length=line_length, mode=mode) dst_contents = format_str(src_contents, mode=mode)
if src_contents == dst_contents: if src_contents == dst_contents:
raise NothingChanged raise NothingChanged
if not fast: if not fast:
assert_equivalent(src_contents, dst_contents) assert_equivalent(src_contents, dst_contents)
assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode) assert_stable(src_contents, dst_contents, mode=mode)
return dst_contents return dst_contents
def format_str( def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
) -> FileContent:
"""Reformat a string and return new contents. """Reformat a string and return new contents.
`line_length` determines how many characters per line are allowed. `line_length` determines how many characters per line are allowed.
""" """
src_node = lib2to3_parse(src_contents.lstrip()) src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_contents = "" dst_contents = ""
future_imports = get_future_imports(src_node) future_imports = get_future_imports(src_node)
is_pyi = bool(mode & FileMode.PYI) if mode.target_versions:
py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node) versions = mode.target_versions
normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION) else:
versions = detect_target_versions(src_node)
normalize_fmt_off(src_node) normalize_fmt_off(src_node)
lines = LineGenerator( lines = LineGenerator(
remove_u_prefix=py36 or "unicode_literals" in future_imports, remove_u_prefix="unicode_literals" in future_imports
is_pyi=is_pyi, or supports_feature(versions, Feature.UNICODE_LITERALS),
normalize_strings=normalize_strings, is_pyi=mode.is_pyi,
allow_underscores=py36 normalize_strings=mode.string_normalization,
and not bool(mode & FileMode.NO_NUMERIC_UNDERSCORE_NORMALIZATION), allow_underscores=mode.numeric_underscore_normalization
and supports_feature(versions, Feature.NUMERIC_UNDERSCORES),
) )
elt = EmptyLineTracker(is_pyi=is_pyi) elt = EmptyLineTracker(is_pyi=mode.is_pyi)
empty_line = Line() empty_line = Line()
after = 0 after = 0
for current_line in lines.visit(src_node): for current_line in lines.visit(src_node):
@ -651,7 +698,11 @@ def format_str(
before, after = elt.maybe_empty_lines(current_line) before, after = elt.maybe_empty_lines(current_line)
for _ in range(before): for _ in range(before):
dst_contents += str(empty_line) dst_contents += str(empty_line)
for line in split_line(current_line, line_length=line_length, py36=py36): for line in split_line(
current_line,
line_length=mode.line_length,
supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
):
dst_contents += str(line) dst_contents += str(line)
return dst_contents return dst_contents
@ -680,11 +731,25 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
] ]
def lib2to3_parse(src_txt: str) -> Node: def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
if not target_versions:
return GRAMMARS
elif all(not version.is_python2() for version in target_versions):
# Python 2-compatible code, so don't try Python 3 grammar.
return [
pygram.python_grammar_no_print_statement_no_exec_statement,
pygram.python_grammar_no_print_statement,
]
else:
return [pygram.python_grammar]
def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
"""Given a string with source, return the lib2to3 Node.""" """Given a string with source, return the lib2to3 Node."""
if src_txt[-1:] != "\n": if src_txt[-1:] != "\n":
src_txt += "\n" src_txt += "\n"
for grammar in GRAMMARS:
for grammar in get_grammars(set(target_versions)):
drv = driver.Driver(grammar, pytree.convert) drv = driver.Driver(grammar, pytree.convert)
try: try:
result = drv.parse_string(src_txt, True) result = drv.parse_string(src_txt, True)
@ -2093,7 +2158,10 @@ def make_comment(content: str) -> str:
def split_line( def split_line(
line: Line, line_length: int, inner: bool = False, py36: bool = False line: Line,
line_length: int,
inner: bool = False,
supports_trailing_commas: bool = False,
) -> Iterator[Line]: ) -> Iterator[Line]:
"""Split a `line` into potentially many lines. """Split a `line` into potentially many lines.
@ -2102,8 +2170,7 @@ def split_line(
current `line`, possibly transitively. This means we can fallback to splitting current `line`, possibly transitively. This means we can fallback to splitting
by delimiters if the LHS/RHS don't yield any results. by delimiters if the LHS/RHS don't yield any results.
If `py36` is True, splitting may generate syntax that is only compatible If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
with Python 3.6 and later.
""" """
if line.is_comment: if line.is_comment:
yield line yield line
@ -2132,9 +2199,13 @@ def split_line(
split_funcs = [left_hand_split] split_funcs = [left_hand_split]
else: else:
def rhs(line: Line, py36: bool = False) -> Iterator[Line]: def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
for omit in generate_trailers_to_omit(line, line_length): for omit in generate_trailers_to_omit(line, line_length):
lines = list(right_hand_split(line, line_length, py36, omit=omit)) lines = list(
right_hand_split(
line, line_length, supports_trailing_commas, omit=omit
)
)
if is_line_short_enough(lines[0], line_length=line_length): if is_line_short_enough(lines[0], line_length=line_length):
yield from lines yield from lines
return return
@ -2142,7 +2213,7 @@ def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
# All splits failed, best effort split with no omits. # All splits failed, best effort split with no omits.
# This mostly happens to multiline strings that are by definition # This mostly happens to multiline strings that are by definition
# reported as not fitting a single line. # reported as not fitting a single line.
yield from right_hand_split(line, py36) yield from right_hand_split(line, supports_trailing_commas)
if line.inside_brackets: if line.inside_brackets:
split_funcs = [delimiter_split, standalone_comment_split, rhs] split_funcs = [delimiter_split, standalone_comment_split, rhs]
@ -2154,12 +2225,17 @@ def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
# split altogether. # split altogether.
result: List[Line] = [] result: List[Line] = []
try: try:
for l in split_func(line, py36): for l in split_func(line, supports_trailing_commas):
if str(l).strip("\n") == line_str: if str(l).strip("\n") == line_str:
raise CannotSplit("Split function returned an unchanged result") raise CannotSplit("Split function returned an unchanged result")
result.extend( result.extend(
split_line(l, line_length=line_length, inner=True, py36=py36) split_line(
l,
line_length=line_length,
inner=True,
supports_trailing_commas=supports_trailing_commas,
)
) )
except CannotSplit: except CannotSplit:
continue continue
@ -2172,7 +2248,9 @@ def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
yield line yield line
def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: def left_hand_split(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
"""Split line into many lines, starting with the first matching bracket pair. """Split line into many lines, starting with the first matching bracket pair.
Note: this usually looks weird, only use this for function definitions. Note: this usually looks weird, only use this for function definitions.
@ -2209,7 +2287,10 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
def right_hand_split( def right_hand_split(
line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = () line: Line,
line_length: int,
supports_trailing_commas: bool = False,
omit: Collection[LeafID] = (),
) -> Iterator[Line]: ) -> Iterator[Line]:
"""Split line into many lines, starting with the last matching bracket pair. """Split line into many lines, starting with the last matching bracket pair.
@ -2267,7 +2348,12 @@ def right_hand_split(
): ):
omit = {id(closing_bracket), *omit} omit = {id(closing_bracket), *omit}
try: try:
yield from right_hand_split(line, line_length, py36=py36, omit=omit) yield from right_hand_split(
line,
line_length,
supports_trailing_commas=supports_trailing_commas,
omit=omit,
)
return return
except CannotSplit: except CannotSplit:
@ -2356,8 +2442,10 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
""" """
@wraps(split_func) @wraps(split_func)
def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]: def split_wrapper(
for l in split_func(line, py36): line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
for l in split_func(line, supports_trailing_commas):
normalize_prefix(l.leaves[0], inside_brackets=True) normalize_prefix(l.leaves[0], inside_brackets=True)
yield l yield l
@ -2365,7 +2453,9 @@ def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
@dont_increase_indentation @dont_increase_indentation
def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: def delimiter_split(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
"""Split according to delimiters of the highest priority. """Split according to delimiters of the highest priority.
If `py36` is True, the split will add trailing commas also in function If `py36` is True, the split will add trailing commas also in function
@ -2411,7 +2501,7 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:
if leaf.bracket_depth == lowest_depth and is_vararg( if leaf.bracket_depth == lowest_depth and is_vararg(
leaf, within=VARARGS_PARENTS leaf, within=VARARGS_PARENTS
): ):
trailing_comma_safe = trailing_comma_safe and py36 trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
leaf_priority = bt.delimiters.get(id(leaf)) leaf_priority = bt.delimiters.get(id(leaf))
if leaf_priority == delimiter_priority: if leaf_priority == delimiter_priority:
yield current_line yield current_line
@ -2429,7 +2519,9 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:
@dont_increase_indentation @dont_increase_indentation
def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: def standalone_comment_split(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
"""Split standalone comments from the rest of the line.""" """Split standalone comments from the rest of the line."""
if not line.contains_standalone_comments(0): if not line.contains_standalone_comments(0):
raise CannotSplit("Line does not have any standalone comments") raise CannotSplit("Line does not have any standalone comments")
@ -2988,23 +3080,24 @@ def should_explode(line: Line, opening_bracket: Leaf) -> bool:
return max_priority == COMMA_PRIORITY return max_priority == COMMA_PRIORITY
def is_python36(node: Node) -> bool: def get_features_used(node: Node) -> Set[Feature]:
"""Return True if the current file is using Python 3.6+ features. """Return a set of (relatively) new Python features used in this file.
Currently looking for: Currently looking for:
- f-strings; - f-strings;
- underscores in numeric literals; and - underscores in numeric literals; and
- trailing commas after * or ** in function signatures and calls. - trailing commas after * or ** in function signatures and calls.
""" """
features: Set[Feature] = set()
for n in node.pre_order(): for n in node.pre_order():
if n.type == token.STRING: if n.type == token.STRING:
value_head = n.value[:2] # type: ignore value_head = n.value[:2] # type: ignore
if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}: if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
return True features.add(Feature.F_STRINGS)
elif n.type == token.NUMBER: elif n.type == token.NUMBER:
if "_" in n.value: # type: ignore if "_" in n.value: # type: ignore
return True features.add(Feature.NUMERIC_UNDERSCORES)
elif ( elif (
n.type in {syms.typedargslist, syms.arglist} n.type in {syms.typedargslist, syms.arglist}
@ -3013,14 +3106,22 @@ def is_python36(node: Node) -> bool:
): ):
for ch in n.children: for ch in n.children:
if ch.type in STARS: if ch.type in STARS:
return True features.add(Feature.TRAILING_COMMA)
if ch.type == syms.argument: if ch.type == syms.argument:
for argch in ch.children: for argch in ch.children:
if argch.type in STARS: if argch.type in STARS:
return True features.add(Feature.TRAILING_COMMA)
return False return features
def detect_target_versions(node: Node) -> Set[TargetVersion]:
"""Detect the version to target based on the nodes used."""
features = get_features_used(node)
return {
version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
}
def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]: def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
@ -3337,11 +3438,9 @@ def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
) from None ) from None
def assert_stable( def assert_stable(src: str, dst: str, mode: FileMode) -> None:
src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
) -> None:
"""Raise AssertionError if `dst` reformats differently the second time.""" """Raise AssertionError if `dst` reformats differently the second time."""
newdst = format_str(dst, line_length=line_length, mode=mode) newdst = format_str(dst, mode=mode)
if dst != newdst: if dst != newdst:
log = dump_to_file( log = dump_to_file(
diff(src, dst, "source", "first pass"), diff(src, dst, "source", "first pass"),
@ -3598,16 +3697,16 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
return False return False
def get_cache_file(line_length: int, mode: FileMode) -> Path: def get_cache_file(mode: FileMode) -> Path:
return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle" return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
def read_cache(line_length: int, mode: FileMode) -> Cache: def read_cache(mode: FileMode) -> Cache:
"""Read the cache if it exists and is well formed. """Read the cache if it exists and is well formed.
If it is not well formed, the call to write_cache later should resolve the issue. If it is not well formed, the call to write_cache later should resolve the issue.
""" """
cache_file = get_cache_file(line_length, mode) cache_file = get_cache_file(mode)
if not cache_file.exists(): if not cache_file.exists():
return {} return {}
@ -3642,11 +3741,9 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set
return todo, done return todo, done
def write_cache( def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
) -> None:
"""Update the cache file.""" """Update the cache file."""
cache_file = get_cache_file(line_length, mode) cache_file = get_cache_file(mode)
try: try:
CACHE_DIR.mkdir(parents=True, exist_ok=True) CACHE_DIR.mkdir(parents=True, exist_ok=True)
new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}} new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}

View File

@ -3,6 +3,7 @@
from functools import partial from functools import partial
import logging import logging
from multiprocessing import freeze_support from multiprocessing import freeze_support
from typing import Set, Tuple
from aiohttp import web from aiohttp import web
import aiohttp_cors import aiohttp_cors
@ -29,6 +30,10 @@
] ]
class InvalidVariantHeader(Exception):
pass
@click.command(context_settings={"help_option_names": ["-h", "--help"]}) @click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option( @click.option(
"--bind-host", type=str, help="Address to bind the server to.", default="localhost" "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
@ -73,22 +78,20 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
) )
except ValueError: except ValueError:
return web.Response(status=400, text="Invalid line length header value") return web.Response(status=400, text="Invalid line length header value")
py36 = False
pyi = False
if PYTHON_VARIANT_HEADER in request.headers: if PYTHON_VARIANT_HEADER in request.headers:
value = request.headers[PYTHON_VARIANT_HEADER] value = request.headers[PYTHON_VARIANT_HEADER]
if value == "pyi":
pyi = True
else:
try: try:
major, *rest = value.split(".") pyi, versions = parse_python_variant_header(value)
if int(major) == 3 and len(rest) > 0: except InvalidVariantHeader as e:
if int(rest[0]) >= 6:
py36 = True
except ValueError:
return web.Response( return web.Response(
status=400, text=f"Invalid value for {PYTHON_VARIANT_HEADER}" status=400,
text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
) )
else:
pyi = False
versions = set()
skip_string_normalization = bool( skip_string_normalization = bool(
request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False) request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
) )
@ -98,25 +101,19 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
fast = False fast = False
if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast": if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
fast = True fast = True
mode = black.FileMode.from_configuration( mode = black.FileMode(
py36=py36, target_versions=versions,
pyi=pyi, is_pyi=pyi,
skip_string_normalization=skip_string_normalization, line_length=line_length,
skip_numeric_underscore_normalization=skip_numeric_underscore_normalization, string_normalization=not skip_string_normalization,
numeric_underscore_normalization=not skip_numeric_underscore_normalization,
) )
req_bytes = await request.content.read() req_bytes = await request.content.read()
charset = request.charset if request.charset is not None else "utf8" charset = request.charset if request.charset is not None else "utf8"
req_str = req_bytes.decode(charset) req_str = req_bytes.decode(charset)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
formatted_str = await loop.run_in_executor( formatted_str = await loop.run_in_executor(
executor, executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
partial(
black.format_file_contents,
req_str,
line_length=line_length,
fast=fast,
mode=mode,
),
) )
return web.Response( return web.Response(
content_type=request.content_type, charset=charset, text=formatted_str content_type=request.content_type, charset=charset, text=formatted_str
@ -130,6 +127,45 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
return web.Response(status=500, text=str(e)) return web.Response(status=500, text=str(e))
def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
if value == "pyi":
return True, set()
else:
versions = set()
for version in value.split(","):
tag = "cpy"
if version.startswith("cpy"):
version = version[len("cpy") :]
elif version.startswith("pypy"):
tag = "pypy"
version = version[len("pypy") :]
major_str, *rest = version.split(".")
try:
major = int(major_str)
if major not in (2, 3):
raise InvalidVariantHeader("major version must be 2 or 3")
if len(rest) > 0:
minor = int(rest[0])
if major == 2 and minor != 7:
raise InvalidVariantHeader(
"minor version must be 7 for Python 2"
)
else:
# Default to lowest supported minor version.
minor = 7 if major == 2 else 3
version_str = f"{tag.upper()}{major}{minor}"
# If PyPY is the same as CPython in some version, use
# the corresponding CPython version.
if tag == "pypy" and not hasattr(black.TargetVersion, version_str):
version_str = f"CPY{major}{minor}"
if major == 3 and not hasattr(black.TargetVersion, version_str):
raise InvalidVariantHeader(f"3.{minor} is not supported")
versions.add(black.TargetVersion[version_str])
except (KeyError, ValueError):
raise InvalidVariantHeader("expected e.g. '3.7', 'pypy3.5'")
return False, versions
def patched_main() -> None: def patched_main() -> None:
freeze_support() freeze_support()
black.patch_click() black.patch_click()

View File

@ -39,7 +39,7 @@ def function_signature_stress_test(number:int,no_annotation=None,text:str='defau
return text[number:-1] return text[number:-1]
# fmt: on # fmt: on
def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r''): def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r''):
offset = attr.ib(default=attr.Factory( lambda: _r.uniform(10000, 200000))) offset = attr.ib(default=attr.Factory( lambda: _r.uniform(1, 2)))
assert task._cancel_stack[:len(old_stack)] == old_stack assert task._cancel_stack[:len(old_stack)] == old_stack
def spaces_types(a: int = 1, b: tuple = (), c: list = [], d: dict = {}, e: bool = True, f: int = -1, g: int = 1 if False else 2, h: str = "", i: str = r''): ... def spaces_types(a: int = 1, b: tuple = (), c: list = [], d: dict = {}, e: bool = True, f: int = -1, g: int = 1 if False else 2, h: str = "", i: str = r''): ...
def spaces2(result= _core.Value(None)): def spaces2(result= _core.Value(None)):
@ -225,7 +225,7 @@ def function_signature_stress_test(number:int,no_annotation=None,text:str='defau
return text[number:-1] return text[number:-1]
# fmt: on # fmt: on
def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r""): def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r""):
offset = attr.ib(default=attr.Factory(lambda: _r.uniform(10000, 200_000))) offset = attr.ib(default=attr.Factory(lambda: _r.uniform(1, 2)))
assert task._cancel_stack[: len(old_stack)] == old_stack assert task._cancel_stack[: len(old_stack)] == old_stack

1
tests/empty.toml Normal file
View File

@ -0,0 +1 @@
# Empty configuration file; used in tests to avoid interference from Black's own config.

View File

@ -27,6 +27,7 @@
from click.testing import CliRunner from click.testing import CliRunner
import black import black
from black import Feature
try: try:
import blackd import blackd
@ -37,9 +38,8 @@
has_blackd_deps = True has_blackd_deps = True
ll = 88 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
ff = partial(black.format_file_in_place, line_length=ll, fast=True) fs = partial(black.format_str, mode=black.FileMode())
fs = partial(black.format_str, line_length=ll)
THIS_FILE = Path(__file__) THIS_FILE = Path(__file__)
THIS_DIR = THIS_FILE.parent THIS_DIR = THIS_FILE.parent
EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)" EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
@ -155,13 +155,22 @@ def assertFormatEqual(self, expected: str, actual: str) -> None:
black.err(str(ve)) black.err(str(ve))
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
def invokeBlack(
self, args: List[str], exit_code: int = 0, ignore_config: bool = True
) -> None:
runner = BlackRunner()
if ignore_config:
args = ["--config", str(THIS_DIR / "empty.toml"), *args]
result = runner.invoke(black.main, args)
self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_empty(self) -> None: def test_empty(self) -> None:
source = expected = "" source = expected = ""
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
def test_empty_ff(self) -> None: def test_empty_ff(self) -> None:
expected = "" expected = ""
@ -180,7 +189,7 @@ def test_self(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
self.assertFalse(ff(THIS_FILE)) self.assertFalse(ff(THIS_FILE))
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
@ -189,20 +198,20 @@ def test_black(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
self.assertFalse(ff(THIS_DIR / ".." / "black.py")) self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
def test_piping(self) -> None: def test_piping(self) -> None:
source, expected = read_data("../black", data=False) source, expected = read_data("../black", data=False)
result = BlackRunner().invoke( result = BlackRunner().invoke(
black.main, black.main,
["-", "--fast", f"--line-length={ll}"], ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
input=BytesIO(source.encode("utf8")), input=BytesIO(source.encode("utf8")),
) )
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
self.assertFormatEqual(expected, result.output) self.assertFormatEqual(expected, result.output)
black.assert_equivalent(source, result.output) black.assert_equivalent(source, result.output)
black.assert_stable(source, result.output, line_length=ll) black.assert_stable(source, result.output, black.FileMode())
def test_piping_diff(self) -> None: def test_piping_diff(self) -> None:
diff_header = re.compile( diff_header = re.compile(
@ -212,7 +221,13 @@ def test_piping_diff(self) -> None:
source, _ = read_data("expression.py") source, _ = read_data("expression.py")
expected, _ = read_data("expression.diff") expected, _ = read_data("expression.diff")
config = THIS_DIR / "data" / "empty_pyproject.toml" config = THIS_DIR / "data" / "empty_pyproject.toml"
args = ["-", "--fast", f"--line-length={ll}", "--diff", f"--config={config}"] args = [
"-",
"--fast",
f"--line-length={black.DEFAULT_LINE_LENGTH}",
"--diff",
f"--config={config}",
]
result = BlackRunner().invoke( result = BlackRunner().invoke(
black.main, args, input=BytesIO(source.encode("utf8")) black.main, args, input=BytesIO(source.encode("utf8"))
) )
@ -227,7 +242,7 @@ def test_setup(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
self.assertFalse(ff(THIS_DIR / ".." / "setup.py")) self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
@ -236,7 +251,7 @@ def test_function(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_function2(self) -> None: def test_function2(self) -> None:
@ -244,7 +259,7 @@ def test_function2(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_expression(self) -> None: def test_expression(self) -> None:
@ -252,7 +267,7 @@ def test_expression(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
def test_expression_ff(self) -> None: def test_expression_ff(self) -> None:
source, expected = read_data("expression") source, expected = read_data("expression")
@ -266,7 +281,7 @@ def test_expression_ff(self) -> None:
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
with patch("black.dump_to_file", dump_to_stderr): with patch("black.dump_to_file", dump_to_stderr):
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
def test_expression_diff(self) -> None: def test_expression_diff(self) -> None:
source, _ = read_data("expression.py") source, _ = read_data("expression.py")
@ -299,7 +314,7 @@ def test_fstring(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_string_quotes(self) -> None: def test_string_quotes(self) -> None:
@ -307,12 +322,12 @@ def test_string_quotes(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
mode = black.FileMode.NO_STRING_NORMALIZATION mode = black.FileMode(string_normalization=False)
not_normalized = fs(source, mode=mode) not_normalized = fs(source, mode=mode)
self.assertFormatEqual(source, not_normalized) self.assertFormatEqual(source, not_normalized)
black.assert_equivalent(source, not_normalized) black.assert_equivalent(source, not_normalized)
black.assert_stable(source, not_normalized, line_length=ll, mode=mode) black.assert_stable(source, not_normalized, mode=mode)
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_slices(self) -> None: def test_slices(self) -> None:
@ -320,7 +335,7 @@ def test_slices(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_comments(self) -> None: def test_comments(self) -> None:
@ -328,7 +343,7 @@ def test_comments(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_comments2(self) -> None: def test_comments2(self) -> None:
@ -336,7 +351,7 @@ def test_comments2(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_comments3(self) -> None: def test_comments3(self) -> None:
@ -344,7 +359,7 @@ def test_comments3(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_comments4(self) -> None: def test_comments4(self) -> None:
@ -352,7 +367,7 @@ def test_comments4(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_comments5(self) -> None: def test_comments5(self) -> None:
@ -360,7 +375,7 @@ def test_comments5(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_comments6(self) -> None: def test_comments6(self) -> None:
@ -368,7 +383,7 @@ def test_comments6(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_cantfit(self) -> None: def test_cantfit(self) -> None:
@ -376,7 +391,7 @@ def test_cantfit(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_import_spacing(self) -> None: def test_import_spacing(self) -> None:
@ -384,7 +399,7 @@ def test_import_spacing(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_composition(self) -> None: def test_composition(self) -> None:
@ -392,7 +407,7 @@ def test_composition(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_empty_lines(self) -> None: def test_empty_lines(self) -> None:
@ -400,7 +415,7 @@ def test_empty_lines(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_string_prefixes(self) -> None: def test_string_prefixes(self) -> None:
@ -408,33 +423,34 @@ def test_string_prefixes(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals(self) -> None: def test_numeric_literals(self) -> None:
source, expected = read_data("numeric_literals") source, expected = read_data("numeric_literals")
actual = fs(source, mode=black.FileMode.PYTHON36) mode = black.FileMode(target_versions=black.PY36_VERSIONS)
actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, mode)
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals_ignoring_underscores(self) -> None: def test_numeric_literals_ignoring_underscores(self) -> None:
source, expected = read_data("numeric_literals_skip_underscores") source, expected = read_data("numeric_literals_skip_underscores")
mode = ( mode = black.FileMode(
black.FileMode.PYTHON36 | black.FileMode.NO_NUMERIC_UNDERSCORE_NORMALIZATION numeric_underscore_normalization=False, target_versions=black.PY36_VERSIONS
) )
actual = fs(source, mode=mode) actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll, mode=mode) black.assert_stable(source, actual, mode)
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals_py2(self) -> None: def test_numeric_literals_py2(self) -> None:
source, expected = read_data("numeric_literals_py2") source, expected = read_data("numeric_literals_py2")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_python2(self) -> None: def test_python2(self) -> None:
@ -442,22 +458,22 @@ def test_python2(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
# black.assert_equivalent(source, actual) # black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_python2_unicode_literals(self) -> None: def test_python2_unicode_literals(self) -> None:
source, expected = read_data("python2_unicode_literals") source, expected = read_data("python2_unicode_literals")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_stub(self) -> None: def test_stub(self) -> None:
mode = black.FileMode.PYI mode = black.FileMode(is_pyi=True)
source, expected = read_data("stub.pyi") source, expected = read_data("stub.pyi")
actual = fs(source, mode=mode) actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_stable(source, actual, line_length=ll, mode=mode) black.assert_stable(source, actual, mode)
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_python37(self) -> None: def test_python37(self) -> None:
@ -467,7 +483,7 @@ def test_python37(self) -> None:
major, minor = sys.version_info[:2] major, minor = sys.version_info[:2]
if major > 3 or (major == 3 and minor >= 7): if major > 3 or (major == 3 and minor >= 7):
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff(self) -> None: def test_fmtonoff(self) -> None:
@ -475,7 +491,7 @@ def test_fmtonoff(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff2(self) -> None: def test_fmtonoff2(self) -> None:
@ -483,7 +499,7 @@ def test_fmtonoff2(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_remove_empty_parentheses_after_class(self) -> None: def test_remove_empty_parentheses_after_class(self) -> None:
@ -491,7 +507,7 @@ def test_remove_empty_parentheses_after_class(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_new_line_between_class_and_code(self) -> None: def test_new_line_between_class_and_code(self) -> None:
@ -499,7 +515,7 @@ def test_new_line_between_class_and_code(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_bracket_match(self) -> None: def test_bracket_match(self) -> None:
@ -507,7 +523,7 @@ def test_bracket_match(self) -> None:
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, black.FileMode())
def test_comment_indentation(self) -> None: def test_comment_indentation(self) -> None:
contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n" contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
@ -794,27 +810,32 @@ def err(msg: str, **kwargs: Any) -> None:
"2 files would fail to reformat.", "2 files would fail to reformat.",
) )
def test_is_python36(self) -> None: def test_get_features_used(self) -> None:
node = black.lib2to3_parse("def f(*, arg): ...\n") node = black.lib2to3_parse("def f(*, arg): ...\n")
self.assertFalse(black.is_python36(node)) self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("def f(*, arg,): ...\n") node = black.lib2to3_parse("def f(*, arg,): ...\n")
self.assertTrue(black.is_python36(node)) self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
node = black.lib2to3_parse("def f(*, arg): f'string'\n") node = black.lib2to3_parse("def f(*, arg): f'string'\n")
self.assertTrue(black.is_python36(node)) self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
node = black.lib2to3_parse("123_456\n") node = black.lib2to3_parse("123_456\n")
self.assertTrue(black.is_python36(node)) self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
node = black.lib2to3_parse("123456\n") node = black.lib2to3_parse("123456\n")
self.assertFalse(black.is_python36(node)) self.assertEqual(black.get_features_used(node), set())
source, expected = read_data("function") source, expected = read_data("function")
node = black.lib2to3_parse(source) node = black.lib2to3_parse(source)
self.assertTrue(black.is_python36(node)) self.assertEqual(
black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
)
node = black.lib2to3_parse(expected) node = black.lib2to3_parse(expected)
self.assertTrue(black.is_python36(node)) self.assertEqual(
black.get_features_used(node),
{Feature.TRAILING_COMMA, Feature.F_STRINGS, Feature.NUMERIC_UNDERSCORES},
)
source, expected = read_data("expression") source, expected = read_data("expression")
node = black.lib2to3_parse(source) node = black.lib2to3_parse(source)
self.assertFalse(black.is_python36(node)) self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse(expected) node = black.lib2to3_parse(expected)
self.assertFalse(black.is_python36(node)) self.assertEqual(black.get_features_used(node), set())
def test_get_future_imports(self) -> None: def test_get_future_imports(self) -> None:
node = black.lib2to3_parse("\n") node = black.lib2to3_parse("\n")
@ -872,21 +893,22 @@ def err(msg: str, **kwargs: Any) -> None:
def test_format_file_contents(self) -> None: def test_format_file_contents(self) -> None:
empty = "" empty = ""
mode = black.FileMode()
with self.assertRaises(black.NothingChanged): with self.assertRaises(black.NothingChanged):
black.format_file_contents(empty, line_length=ll, fast=False) black.format_file_contents(empty, mode=mode, fast=False)
just_nl = "\n" just_nl = "\n"
with self.assertRaises(black.NothingChanged): with self.assertRaises(black.NothingChanged):
black.format_file_contents(just_nl, line_length=ll, fast=False) black.format_file_contents(just_nl, mode=mode, fast=False)
same = "l = [1, 2, 3]\n" same = "l = [1, 2, 3]\n"
with self.assertRaises(black.NothingChanged): with self.assertRaises(black.NothingChanged):
black.format_file_contents(same, line_length=ll, fast=False) black.format_file_contents(same, mode=mode, fast=False)
different = "l = [1,2,3]" different = "l = [1,2,3]"
expected = same expected = same
actual = black.format_file_contents(different, line_length=ll, fast=False) actual = black.format_file_contents(different, mode=mode, fast=False)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
invalid = "return if you can" invalid = "return if you can"
with self.assertRaises(black.InvalidInput) as e: with self.assertRaises(black.InvalidInput) as e:
black.format_file_contents(invalid, line_length=ll, fast=False) black.format_file_contents(invalid, mode=mode, fast=False)
self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can") self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
def test_endmarker(self) -> None: def test_endmarker(self) -> None:
@ -916,35 +938,33 @@ def err(msg: str, **kwargs: Any) -> None:
self.assertEqual("".join(err_lines), "") self.assertEqual("".join(err_lines), "")
def test_cache_broken_file(self) -> None: def test_cache_broken_file(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
with cache_dir() as workspace: with cache_dir() as workspace:
cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode) cache_file = black.get_cache_file(mode)
with cache_file.open("w") as fobj: with cache_file.open("w") as fobj:
fobj.write("this is not a pickle") fobj.write("this is not a pickle")
self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {}) self.assertEqual(black.read_cache(mode), {})
src = (workspace / "test.py").resolve() src = (workspace / "test.py").resolve()
with src.open("w") as fobj: with src.open("w") as fobj:
fobj.write("print('hello')") fobj.write("print('hello')")
result = CliRunner().invoke(black.main, [str(src)]) self.invokeBlack([str(src)])
self.assertEqual(result.exit_code, 0) cache = black.read_cache(mode)
cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
self.assertIn(src, cache) self.assertIn(src, cache)
def test_cache_single_file_already_cached(self) -> None: def test_cache_single_file_already_cached(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
with cache_dir() as workspace: with cache_dir() as workspace:
src = (workspace / "test.py").resolve() src = (workspace / "test.py").resolve()
with src.open("w") as fobj: with src.open("w") as fobj:
fobj.write("print('hello')") fobj.write("print('hello')")
black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode) black.write_cache({}, [src], mode)
result = CliRunner().invoke(black.main, [str(src)]) self.invokeBlack([str(src)])
self.assertEqual(result.exit_code, 0)
with src.open("r") as fobj: with src.open("r") as fobj:
self.assertEqual(fobj.read(), "print('hello')") self.assertEqual(fobj.read(), "print('hello')")
@event_loop(close=False) @event_loop(close=False)
def test_cache_multiple_files(self) -> None: def test_cache_multiple_files(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
with cache_dir() as workspace, patch( with cache_dir() as workspace, patch(
"black.ProcessPoolExecutor", new=ThreadPoolExecutor "black.ProcessPoolExecutor", new=ThreadPoolExecutor
): ):
@ -954,50 +974,48 @@ def test_cache_multiple_files(self) -> None:
two = (workspace / "two.py").resolve() two = (workspace / "two.py").resolve()
with two.open("w") as fobj: with two.open("w") as fobj:
fobj.write("print('hello')") fobj.write("print('hello')")
black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode) black.write_cache({}, [one], mode)
result = CliRunner().invoke(black.main, [str(workspace)]) self.invokeBlack([str(workspace)])
self.assertEqual(result.exit_code, 0)
with one.open("r") as fobj: with one.open("r") as fobj:
self.assertEqual(fobj.read(), "print('hello')") self.assertEqual(fobj.read(), "print('hello')")
with two.open("r") as fobj: with two.open("r") as fobj:
self.assertEqual(fobj.read(), 'print("hello")\n') self.assertEqual(fobj.read(), 'print("hello")\n')
cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode) cache = black.read_cache(mode)
self.assertIn(one, cache) self.assertIn(one, cache)
self.assertIn(two, cache) self.assertIn(two, cache)
def test_no_cache_when_writeback_diff(self) -> None: def test_no_cache_when_writeback_diff(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
with cache_dir() as workspace: with cache_dir() as workspace:
src = (workspace / "test.py").resolve() src = (workspace / "test.py").resolve()
with src.open("w") as fobj: with src.open("w") as fobj:
fobj.write("print('hello')") fobj.write("print('hello')")
result = CliRunner().invoke(black.main, [str(src), "--diff"]) self.invokeBlack([str(src), "--diff"])
self.assertEqual(result.exit_code, 0) cache_file = black.get_cache_file(mode)
cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
self.assertFalse(cache_file.exists()) self.assertFalse(cache_file.exists())
def test_no_cache_when_stdin(self) -> None: def test_no_cache_when_stdin(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
with cache_dir(): with cache_dir():
result = CliRunner().invoke( result = CliRunner().invoke(
black.main, ["-"], input=BytesIO(b"print('hello')") black.main, ["-"], input=BytesIO(b"print('hello')")
) )
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode) cache_file = black.get_cache_file(mode)
self.assertFalse(cache_file.exists()) self.assertFalse(cache_file.exists())
def test_read_cache_no_cachefile(self) -> None: def test_read_cache_no_cachefile(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
with cache_dir(): with cache_dir():
self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {}) self.assertEqual(black.read_cache(mode), {})
def test_write_cache_read_cache(self) -> None: def test_write_cache_read_cache(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
with cache_dir() as workspace: with cache_dir() as workspace:
src = (workspace / "test.py").resolve() src = (workspace / "test.py").resolve()
src.touch() src.touch()
black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode) black.write_cache({}, [src], mode)
cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode) cache = black.read_cache(mode)
self.assertIn(src, cache) self.assertIn(src, cache)
self.assertEqual(cache[src], black.get_cache_info(src)) self.assertEqual(cache[src], black.get_cache_info(src))
@ -1018,15 +1036,15 @@ def test_filter_cached(self) -> None:
self.assertEqual(done, {cached}) self.assertEqual(done, {cached})
def test_write_cache_creates_directory_if_needed(self) -> None: def test_write_cache_creates_directory_if_needed(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
with cache_dir(exists=False) as workspace: with cache_dir(exists=False) as workspace:
self.assertFalse(workspace.exists()) self.assertFalse(workspace.exists())
black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode) black.write_cache({}, [], mode)
self.assertTrue(workspace.exists()) self.assertTrue(workspace.exists())
@event_loop(close=False) @event_loop(close=False)
def test_failed_formatting_does_not_get_cached(self) -> None: def test_failed_formatting_does_not_get_cached(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
with cache_dir() as workspace, patch( with cache_dir() as workspace, patch(
"black.ProcessPoolExecutor", new=ThreadPoolExecutor "black.ProcessPoolExecutor", new=ThreadPoolExecutor
): ):
@ -1036,40 +1054,33 @@ def test_failed_formatting_does_not_get_cached(self) -> None:
clean = (workspace / "clean.py").resolve() clean = (workspace / "clean.py").resolve()
with clean.open("w") as fobj: with clean.open("w") as fobj:
fobj.write('print("hello")\n') fobj.write('print("hello")\n')
result = CliRunner().invoke(black.main, [str(workspace)]) self.invokeBlack([str(workspace)], exit_code=123)
self.assertEqual(result.exit_code, 123) cache = black.read_cache(mode)
cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
self.assertNotIn(failing, cache) self.assertNotIn(failing, cache)
self.assertIn(clean, cache) self.assertIn(clean, cache)
def test_write_cache_write_fail(self) -> None: def test_write_cache_write_fail(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
with cache_dir(), patch.object(Path, "open") as mock: with cache_dir(), patch.object(Path, "open") as mock:
mock.side_effect = OSError mock.side_effect = OSError
black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode) black.write_cache({}, [], mode)
@event_loop(close=False) @event_loop(close=False)
def test_check_diff_use_together(self) -> None: def test_check_diff_use_together(self) -> None:
with cache_dir(): with cache_dir():
# Files which will be reformatted. # Files which will be reformatted.
src1 = (THIS_DIR / "data" / "string_quotes.py").resolve() src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"]) self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
self.assertEqual(result.exit_code, 1, result.output)
# Files which will not be reformatted. # Files which will not be reformatted.
src2 = (THIS_DIR / "data" / "composition.py").resolve() src2 = (THIS_DIR / "data" / "composition.py").resolve()
result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"]) self.invokeBlack([str(src2), "--diff", "--check"])
self.assertEqual(result.exit_code, 0, result.output)
# Multi file command. # Multi file command.
result = CliRunner().invoke( self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
black.main, [str(src1), str(src2), "--diff", "--check"]
)
self.assertEqual(result.exit_code, 1, result.output)
def test_no_files(self) -> None: def test_no_files(self) -> None:
with cache_dir(): with cache_dir():
# Without an argument, black exits with error code 0. # Without an argument, black exits with error code 0.
result = CliRunner().invoke(black.main, []) self.invokeBlack([])
self.assertEqual(result.exit_code, 0)
def test_broken_symlink(self) -> None: def test_broken_symlink(self) -> None:
with cache_dir() as workspace: with cache_dir() as workspace:
@ -1078,43 +1089,42 @@ def test_broken_symlink(self) -> None:
symlink.symlink_to("nonexistent.py") symlink.symlink_to("nonexistent.py")
except OSError as e: except OSError as e:
self.skipTest(f"Can't create symlinks: {e}") self.skipTest(f"Can't create symlinks: {e}")
result = CliRunner().invoke(black.main, [str(workspace.resolve())]) self.invokeBlack([str(workspace.resolve())])
self.assertEqual(result.exit_code, 0)
def test_read_cache_line_lengths(self) -> None: def test_read_cache_line_lengths(self) -> None:
mode = black.FileMode.AUTO_DETECT mode = black.FileMode()
short_mode = black.FileMode(line_length=1)
with cache_dir() as workspace: with cache_dir() as workspace:
path = (workspace / "file.py").resolve() path = (workspace / "file.py").resolve()
path.touch() path.touch()
black.write_cache({}, [path], 1, mode) black.write_cache({}, [path], mode)
one = black.read_cache(1, mode) one = black.read_cache(mode)
self.assertIn(path, one) self.assertIn(path, one)
two = black.read_cache(2, mode) two = black.read_cache(short_mode)
self.assertNotIn(path, two) self.assertNotIn(path, two)
def test_single_file_force_pyi(self) -> None: def test_single_file_force_pyi(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT reg_mode = black.FileMode()
pyi_mode = black.FileMode.PYI pyi_mode = black.FileMode(is_pyi=True)
contents, expected = read_data("force_pyi") contents, expected = read_data("force_pyi")
with cache_dir() as workspace: with cache_dir() as workspace:
path = (workspace / "file.py").resolve() path = (workspace / "file.py").resolve()
with open(path, "w") as fh: with open(path, "w") as fh:
fh.write(contents) fh.write(contents)
result = CliRunner().invoke(black.main, [str(path), "--pyi"]) self.invokeBlack([str(path), "--pyi"])
self.assertEqual(result.exit_code, 0)
with open(path, "r") as fh: with open(path, "r") as fh:
actual = fh.read() actual = fh.read()
# verify cache with --pyi is separate # verify cache with --pyi is separate
pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode) pyi_cache = black.read_cache(pyi_mode)
self.assertIn(path, pyi_cache) self.assertIn(path, pyi_cache)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode) normal_cache = black.read_cache(reg_mode)
self.assertNotIn(path, normal_cache) self.assertNotIn(path, normal_cache)
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
@event_loop(close=False) @event_loop(close=False)
def test_multi_file_force_pyi(self) -> None: def test_multi_file_force_pyi(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT reg_mode = black.FileMode()
pyi_mode = black.FileMode.PYI pyi_mode = black.FileMode(is_pyi=True)
contents, expected = read_data("force_pyi") contents, expected = read_data("force_pyi")
with cache_dir() as workspace: with cache_dir() as workspace:
paths = [ paths = [
@ -1124,15 +1134,14 @@ def test_multi_file_force_pyi(self) -> None:
for path in paths: for path in paths:
with open(path, "w") as fh: with open(path, "w") as fh:
fh.write(contents) fh.write(contents)
result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"]) self.invokeBlack([str(p) for p in paths] + ["--pyi"])
self.assertEqual(result.exit_code, 0)
for path in paths: for path in paths:
with open(path, "r") as fh: with open(path, "r") as fh:
actual = fh.read() actual = fh.read()
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
# verify cache with --pyi is separate # verify cache with --pyi is separate
pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode) pyi_cache = black.read_cache(pyi_mode)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode) normal_cache = black.read_cache(reg_mode)
for path in paths: for path in paths:
self.assertIn(path, pyi_cache) self.assertIn(path, pyi_cache)
self.assertNotIn(path, normal_cache) self.assertNotIn(path, normal_cache)
@ -1147,28 +1156,27 @@ def test_pipe_force_pyi(self) -> None:
self.assertFormatEqual(actual, expected) self.assertFormatEqual(actual, expected)
def test_single_file_force_py36(self) -> None: def test_single_file_force_py36(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT reg_mode = black.FileMode()
py36_mode = black.FileMode.PYTHON36 py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
source, expected = read_data("force_py36") source, expected = read_data("force_py36")
with cache_dir() as workspace: with cache_dir() as workspace:
path = (workspace / "file.py").resolve() path = (workspace / "file.py").resolve()
with open(path, "w") as fh: with open(path, "w") as fh:
fh.write(source) fh.write(source)
result = CliRunner().invoke(black.main, [str(path), "--py36"]) self.invokeBlack([str(path), "--py36"])
self.assertEqual(result.exit_code, 0)
with open(path, "r") as fh: with open(path, "r") as fh:
actual = fh.read() actual = fh.read()
# verify cache with --py36 is separate # verify cache with --py36 is separate
py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode) py36_cache = black.read_cache(py36_mode)
self.assertIn(path, py36_cache) self.assertIn(path, py36_cache)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode) normal_cache = black.read_cache(reg_mode)
self.assertNotIn(path, normal_cache) self.assertNotIn(path, normal_cache)
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
@event_loop(close=False) @event_loop(close=False)
def test_multi_file_force_py36(self) -> None: def test_multi_file_force_py36(self) -> None:
reg_mode = black.FileMode.AUTO_DETECT reg_mode = black.FileMode()
py36_mode = black.FileMode.PYTHON36 py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
source, expected = read_data("force_py36") source, expected = read_data("force_py36")
with cache_dir() as workspace: with cache_dir() as workspace:
paths = [ paths = [
@ -1178,17 +1186,14 @@ def test_multi_file_force_py36(self) -> None:
for path in paths: for path in paths:
with open(path, "w") as fh: with open(path, "w") as fh:
fh.write(source) fh.write(source)
result = CliRunner().invoke( self.invokeBlack([str(p) for p in paths] + ["--py36"])
black.main, [str(p) for p in paths] + ["--py36"]
)
self.assertEqual(result.exit_code, 0)
for path in paths: for path in paths:
with open(path, "r") as fh: with open(path, "r") as fh:
actual = fh.read() actual = fh.read()
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
# verify cache with --py36 is separate # verify cache with --py36 is separate
pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode) pyi_cache = black.read_cache(py36_mode)
normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode) normal_cache = black.read_cache(reg_mode)
for path in paths: for path in paths:
self.assertIn(path, pyi_cache) self.assertIn(path, pyi_cache)
self.assertNotIn(path, normal_cache) self.assertNotIn(path, normal_cache)
@ -1265,8 +1270,7 @@ def test_empty_exclude(self) -> None:
def test_invalid_include_exclude(self) -> None: def test_invalid_include_exclude(self) -> None:
for option in ["--include", "--exclude"]: for option in ["--include", "--exclude"]:
result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"]) self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
self.assertEqual(result.exit_code, 2)
def test_preserves_line_endings(self) -> None: def test_preserves_line_endings(self) -> None:
with TemporaryDirectory() as workspace: with TemporaryDirectory() as workspace:
@ -1407,10 +1411,24 @@ async def test_blackd_supported_version(self) -> None:
async def test_blackd_invalid_python_variant(self) -> None: async def test_blackd_invalid_python_variant(self) -> None:
app = blackd.make_app() app = blackd.make_app()
async with TestClient(TestServer(app)) as client: async with TestClient(TestServer(app)) as client:
async def check(header_value: str, expected_status: int = 400) -> None:
response = await client.post( response = await client.post(
"/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: "lol"} "/",
data=b"what",
headers={blackd.PYTHON_VARIANT_HEADER: header_value},
) )
self.assertEqual(response.status, 400) self.assertEqual(response.status, expected_status)
await check("lol")
await check("ruby3.5")
await check("pyi3.6")
await check("cpy1.5")
await check("2.8")
await check("cpy2.8")
await check("3.0")
await check("pypy3.0")
await check("jython3.4")
@unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
@async_test @async_test
@ -1426,51 +1444,37 @@ async def test_blackd_pyi(self) -> None:
@unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
@async_test @async_test
async def test_blackd_py36(self) -> None: async def test_blackd_python_variant(self) -> None:
app = blackd.make_app() app = blackd.make_app()
code = (
"def f(\n"
" and_has_a_bunch_of,\n"
" very_long_arguments_too,\n"
" and_lots_of_them_as_well_lol,\n"
" **and_very_long_keyword_arguments\n"
"):\n"
" pass\n"
)
async with TestClient(TestServer(app)) as client: async with TestClient(TestServer(app)) as client:
async def check(header_value: str, expected_status: int) -> None:
response = await client.post( response = await client.post(
"/", "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
data=(
"def f(\n"
" and_has_a_bunch_of,\n"
" very_long_arguments_too,\n"
" and_lots_of_them_as_well_lol,\n"
" **and_very_long_keyword_arguments\n"
"):\n"
" pass\n"
),
headers={blackd.PYTHON_VARIANT_HEADER: "3.6"},
) )
self.assertEqual(response.status, 200) self.assertEqual(response.status, expected_status)
response = await client.post(
"/", await check("3.6", 200)
data=( await check("cpy3.6", 200)
"def f(\n" await check("3.5,3.7", 200)
" and_has_a_bunch_of,\n" await check("3.5,cpy3.7", 200)
" very_long_arguments_too,\n"
" and_lots_of_them_as_well_lol,\n" await check("2", 204)
" **and_very_long_keyword_arguments\n" await check("2.7", 204)
"):\n" await check("cpy2.7", 204)
" pass\n" await check("pypy2.7", 204)
), await check("3.4", 204)
headers={blackd.PYTHON_VARIANT_HEADER: "3.5"}, await check("cpy3.4", 204)
) await check("pypy3.4", 204)
self.assertEqual(response.status, 204)
response = await client.post(
"/",
data=(
"def f(\n"
" and_has_a_bunch_of,\n"
" very_long_arguments_too,\n"
" and_lots_of_them_as_well_lol,\n"
" **and_very_long_keyword_arguments\n"
"):\n"
" pass\n"
),
headers={blackd.PYTHON_VARIANT_HEADER: "2"},
)
self.assertEqual(response.status, 204)
@unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
@async_test @async_test