Support formatting specified lines (#4020)
This commit is contained in:
parent
ecbd9e8cf7
commit
46be1f8e54
@ -6,6 +6,9 @@
|
||||
|
||||
<!-- Include any especially major or disruptive changes here -->
|
||||
|
||||
- Support formatting ranges of lines with the new `--line-ranges` command-line option
|
||||
(#4020).
|
||||
|
||||
### Stable style
|
||||
|
||||
- Fix crash on formatting bytes strings that look like docstrings (#4003)
|
||||
|
@ -175,6 +175,23 @@ All done! ✨ 🍰 ✨
|
||||
1 file would be reformatted.
|
||||
```
|
||||
|
||||
### `--line-ranges`
|
||||
|
||||
When specified, _Black_ will try its best to only format these lines.
|
||||
|
||||
This option can be specified multiple times, and a union of the lines will be formatted.
|
||||
Each range must be specified as two integers connected by a `-`: `<START>-<END>`. The
|
||||
`<START>` and `<END>` integer indices are 1-based and inclusive on both ends.
|
||||
|
||||
_Black_ may still format lines outside of the ranges for multi-line statements.
|
||||
Formatting more than one file or any ipynb files with this option is not supported. This
|
||||
option cannot be specified in the `pyproject.toml` config.
|
||||
|
||||
Example: `black --line-ranges=1-10 --line-ranges=21-30 test.py` will format lines from
|
||||
`1` to `10` and `21` to `30`.
|
||||
|
||||
This option is mainly for editor integrations, such as "Format Selection".
|
||||
|
||||
#### `--color` / `--no-color`
|
||||
|
||||
Show (or do not show) colored diff. Only applies when `--diff` is given.
|
||||
|
@ -13,6 +13,7 @@
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
@ -77,6 +78,7 @@
|
||||
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
|
||||
from black.parsing import InvalidInput # noqa F401
|
||||
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
|
||||
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
|
||||
from black.report import Changed, NothingChanged, Report
|
||||
from black.trans import iter_fexpr_spans
|
||||
from blib2to3.pgen2 import token
|
||||
@ -163,6 +165,12 @@ def read_pyproject_toml(
|
||||
"extend-exclude", "Config key extend-exclude must be a string"
|
||||
)
|
||||
|
||||
line_ranges = config.get("line_ranges")
|
||||
if line_ranges is not None:
|
||||
raise click.BadOptionUsage(
|
||||
"line-ranges", "Cannot use line-ranges in the pyproject.toml file."
|
||||
)
|
||||
|
||||
default_map: Dict[str, Any] = {}
|
||||
if ctx.default_map:
|
||||
default_map.update(ctx.default_map)
|
||||
@ -304,6 +312,19 @@ def validate_regex(
|
||||
is_flag=True,
|
||||
help="Don't write the files back, just output a diff for each file on stdout.",
|
||||
)
|
||||
@click.option(
|
||||
"--line-ranges",
|
||||
multiple=True,
|
||||
metavar="START-END",
|
||||
help=(
|
||||
"When specified, _Black_ will try its best to only format these lines. This"
|
||||
" option can be specified multiple times, and a union of the lines will be"
|
||||
" formatted. Each range must be specified as two integers connected by a `-`:"
|
||||
" `<START>-<END>`. The `<START>` and `<END>` integer indices are 1-based and"
|
||||
" inclusive on both ends."
|
||||
),
|
||||
default=(),
|
||||
)
|
||||
@click.option(
|
||||
"--color/--no-color",
|
||||
is_flag=True,
|
||||
@ -443,6 +464,7 @@ def main( # noqa: C901
|
||||
target_version: List[TargetVersion],
|
||||
check: bool,
|
||||
diff: bool,
|
||||
line_ranges: Sequence[str],
|
||||
color: bool,
|
||||
fast: bool,
|
||||
pyi: bool,
|
||||
@ -544,6 +566,18 @@ def main( # noqa: C901
|
||||
python_cell_magics=set(python_cell_magics),
|
||||
)
|
||||
|
||||
lines: List[Tuple[int, int]] = []
|
||||
if line_ranges:
|
||||
if ipynb:
|
||||
err("Cannot use --line-ranges with ipynb files.")
|
||||
ctx.exit(1)
|
||||
|
||||
try:
|
||||
lines = parse_line_ranges(line_ranges)
|
||||
except ValueError as e:
|
||||
err(str(e))
|
||||
ctx.exit(1)
|
||||
|
||||
if code is not None:
|
||||
# Run in quiet mode by default with -c; the extra output isn't useful.
|
||||
# You can still pass -v to get verbose output.
|
||||
@ -553,7 +587,12 @@ def main( # noqa: C901
|
||||
|
||||
if code is not None:
|
||||
reformat_code(
|
||||
content=code, fast=fast, write_back=write_back, mode=mode, report=report
|
||||
content=code,
|
||||
fast=fast,
|
||||
write_back=write_back,
|
||||
mode=mode,
|
||||
report=report,
|
||||
lines=lines,
|
||||
)
|
||||
else:
|
||||
assert root is not None # root is only None if code is not None
|
||||
@ -588,10 +627,14 @@ def main( # noqa: C901
|
||||
write_back=write_back,
|
||||
mode=mode,
|
||||
report=report,
|
||||
lines=lines,
|
||||
)
|
||||
else:
|
||||
from black.concurrency import reformat_many
|
||||
|
||||
if lines:
|
||||
err("Cannot use --line-ranges to format multiple files.")
|
||||
ctx.exit(1)
|
||||
reformat_many(
|
||||
sources=sources,
|
||||
fast=fast,
|
||||
@ -714,7 +757,13 @@ def path_empty(
|
||||
|
||||
|
||||
def reformat_code(
|
||||
content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
|
||||
content: str,
|
||||
fast: bool,
|
||||
write_back: WriteBack,
|
||||
mode: Mode,
|
||||
report: Report,
|
||||
*,
|
||||
lines: Collection[Tuple[int, int]] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Reformat and print out `content` without spawning child processes.
|
||||
@ -727,7 +776,7 @@ def reformat_code(
|
||||
try:
|
||||
changed = Changed.NO
|
||||
if format_stdin_to_stdout(
|
||||
content=content, fast=fast, write_back=write_back, mode=mode
|
||||
content=content, fast=fast, write_back=write_back, mode=mode, lines=lines
|
||||
):
|
||||
changed = Changed.YES
|
||||
report.done(path, changed)
|
||||
@ -741,7 +790,13 @@ def reformat_code(
|
||||
# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
|
||||
@mypyc_attr(patchable=True)
|
||||
def reformat_one(
|
||||
src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
|
||||
src: Path,
|
||||
fast: bool,
|
||||
write_back: WriteBack,
|
||||
mode: Mode,
|
||||
report: "Report",
|
||||
*,
|
||||
lines: Collection[Tuple[int, int]] = (),
|
||||
) -> None:
|
||||
"""Reformat a single file under `src` without spawning child processes.
|
||||
|
||||
@ -766,7 +821,9 @@ def reformat_one(
|
||||
mode = replace(mode, is_pyi=True)
|
||||
elif src.suffix == ".ipynb":
|
||||
mode = replace(mode, is_ipynb=True)
|
||||
if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
|
||||
if format_stdin_to_stdout(
|
||||
fast=fast, write_back=write_back, mode=mode, lines=lines
|
||||
):
|
||||
changed = Changed.YES
|
||||
else:
|
||||
cache = Cache.read(mode)
|
||||
@ -774,7 +831,7 @@ def reformat_one(
|
||||
if not cache.is_changed(src):
|
||||
changed = Changed.CACHED
|
||||
if changed is not Changed.CACHED and format_file_in_place(
|
||||
src, fast=fast, write_back=write_back, mode=mode
|
||||
src, fast=fast, write_back=write_back, mode=mode, lines=lines
|
||||
):
|
||||
changed = Changed.YES
|
||||
if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
|
||||
@ -794,6 +851,8 @@ def format_file_in_place(
|
||||
mode: Mode,
|
||||
write_back: WriteBack = WriteBack.NO,
|
||||
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
|
||||
*,
|
||||
lines: Collection[Tuple[int, int]] = (),
|
||||
) -> bool:
|
||||
"""Format file under `src` path. Return True if changed.
|
||||
|
||||
@ -813,7 +872,9 @@ def format_file_in_place(
|
||||
header = buf.readline()
|
||||
src_contents, encoding, newline = decode_bytes(buf.read())
|
||||
try:
|
||||
dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
|
||||
dst_contents = format_file_contents(
|
||||
src_contents, fast=fast, mode=mode, lines=lines
|
||||
)
|
||||
except NothingChanged:
|
||||
return False
|
||||
except JSONDecodeError:
|
||||
@ -858,6 +919,7 @@ def format_stdin_to_stdout(
|
||||
content: Optional[str] = None,
|
||||
write_back: WriteBack = WriteBack.NO,
|
||||
mode: Mode,
|
||||
lines: Collection[Tuple[int, int]] = (),
|
||||
) -> bool:
|
||||
"""Format file on stdin. Return True if changed.
|
||||
|
||||
@ -876,7 +938,7 @@ def format_stdin_to_stdout(
|
||||
|
||||
dst = src
|
||||
try:
|
||||
dst = format_file_contents(src, fast=fast, mode=mode)
|
||||
dst = format_file_contents(src, fast=fast, mode=mode, lines=lines)
|
||||
return True
|
||||
|
||||
except NothingChanged:
|
||||
@ -904,7 +966,11 @@ def format_stdin_to_stdout(
|
||||
|
||||
|
||||
def check_stability_and_equivalence(
|
||||
src_contents: str, dst_contents: str, *, mode: Mode
|
||||
src_contents: str,
|
||||
dst_contents: str,
|
||||
*,
|
||||
mode: Mode,
|
||||
lines: Collection[Tuple[int, int]] = (),
|
||||
) -> None:
|
||||
"""Perform stability and equivalence checks.
|
||||
|
||||
@ -913,10 +979,16 @@ def check_stability_and_equivalence(
|
||||
content differently.
|
||||
"""
|
||||
assert_equivalent(src_contents, dst_contents)
|
||||
assert_stable(src_contents, dst_contents, mode=mode)
|
||||
assert_stable(src_contents, dst_contents, mode=mode, lines=lines)
|
||||
|
||||
|
||||
def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
|
||||
def format_file_contents(
|
||||
src_contents: str,
|
||||
*,
|
||||
fast: bool,
|
||||
mode: Mode,
|
||||
lines: Collection[Tuple[int, int]] = (),
|
||||
) -> FileContent:
|
||||
"""Reformat contents of a file and return new contents.
|
||||
|
||||
If `fast` is False, additionally confirm that the reformatted code is
|
||||
@ -926,13 +998,15 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
|
||||
if mode.is_ipynb:
|
||||
dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
|
||||
else:
|
||||
dst_contents = format_str(src_contents, mode=mode)
|
||||
dst_contents = format_str(src_contents, mode=mode, lines=lines)
|
||||
if src_contents == dst_contents:
|
||||
raise NothingChanged
|
||||
|
||||
if not fast and not mode.is_ipynb:
|
||||
# Jupyter notebooks will already have been checked above.
|
||||
check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
|
||||
check_stability_and_equivalence(
|
||||
src_contents, dst_contents, mode=mode, lines=lines
|
||||
)
|
||||
return dst_contents
|
||||
|
||||
|
||||
@ -1043,7 +1117,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
|
||||
raise NothingChanged
|
||||
|
||||
|
||||
def format_str(src_contents: str, *, mode: Mode) -> str:
|
||||
def format_str(
|
||||
src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = ()
|
||||
) -> str:
|
||||
"""Reformat a string and return new contents.
|
||||
|
||||
`mode` determines formatting options, such as how many characters per line are
|
||||
@ -1073,16 +1149,20 @@ def f(
|
||||
hey
|
||||
|
||||
"""
|
||||
dst_contents = _format_str_once(src_contents, mode=mode)
|
||||
dst_contents = _format_str_once(src_contents, mode=mode, lines=lines)
|
||||
# Forced second pass to work around optional trailing commas (becoming
|
||||
# forced trailing commas on pass 2) interacting differently with optional
|
||||
# parentheses. Admittedly ugly.
|
||||
if src_contents != dst_contents:
|
||||
return _format_str_once(dst_contents, mode=mode)
|
||||
if lines:
|
||||
lines = adjusted_lines(lines, src_contents, dst_contents)
|
||||
return _format_str_once(dst_contents, mode=mode, lines=lines)
|
||||
return dst_contents
|
||||
|
||||
|
||||
def _format_str_once(src_contents: str, *, mode: Mode) -> str:
|
||||
def _format_str_once(
|
||||
src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = ()
|
||||
) -> str:
|
||||
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
|
||||
dst_blocks: List[LinesBlock] = []
|
||||
if mode.target_versions:
|
||||
@ -1097,7 +1177,11 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
|
||||
if supports_feature(versions, feature)
|
||||
}
|
||||
normalize_fmt_off(src_node, mode)
|
||||
lines = LineGenerator(mode=mode, features=context_manager_features)
|
||||
if lines:
|
||||
# This should be called after normalize_fmt_off.
|
||||
convert_unchanged_lines(src_node, lines)
|
||||
|
||||
line_generator = LineGenerator(mode=mode, features=context_manager_features)
|
||||
elt = EmptyLineTracker(mode=mode)
|
||||
split_line_features = {
|
||||
feature
|
||||
@ -1105,7 +1189,7 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
|
||||
if supports_feature(versions, feature)
|
||||
}
|
||||
block: Optional[LinesBlock] = None
|
||||
for current_line in lines.visit(src_node):
|
||||
for current_line in line_generator.visit(src_node):
|
||||
block = elt.maybe_empty_lines(current_line)
|
||||
dst_blocks.append(block)
|
||||
for line in transform_line(
|
||||
@ -1373,12 +1457,16 @@ def assert_equivalent(src: str, dst: str) -> None:
|
||||
) from None
|
||||
|
||||
|
||||
def assert_stable(src: str, dst: str, mode: Mode) -> None:
|
||||
def assert_stable(
|
||||
src: str, dst: str, mode: Mode, *, lines: Collection[Tuple[int, int]] = ()
|
||||
) -> None:
|
||||
"""Raise AssertionError if `dst` reformats differently the second time."""
|
||||
# We shouldn't call format_str() here, because that formats the string
|
||||
# twice and may hide a bug where we bounce back and forth between two
|
||||
# versions.
|
||||
newdst = _format_str_once(dst, mode=mode)
|
||||
if lines:
|
||||
lines = adjusted_lines(lines, src, dst)
|
||||
newdst = _format_str_once(dst, mode=mode, lines=lines)
|
||||
if dst != newdst:
|
||||
log = dump_to_file(
|
||||
str(mode),
|
||||
|
@ -935,3 +935,31 @@ def is_part_of_annotation(leaf: Leaf) -> bool:
|
||||
return True
|
||||
ancestor = ancestor.parent
|
||||
return False
|
||||
|
||||
|
||||
def first_leaf(node: LN) -> Optional[Leaf]:
|
||||
"""Returns the first leaf of the ancestor node."""
|
||||
if isinstance(node, Leaf):
|
||||
return node
|
||||
elif not node.children:
|
||||
return None
|
||||
else:
|
||||
return first_leaf(node.children[0])
|
||||
|
||||
|
||||
def last_leaf(node: LN) -> Optional[Leaf]:
|
||||
"""Returns the last leaf of the ancestor node."""
|
||||
if isinstance(node, Leaf):
|
||||
return node
|
||||
elif not node.children:
|
||||
return None
|
||||
else:
|
||||
return last_leaf(node.children[-1])
|
||||
|
||||
|
||||
def furthest_ancestor_with_last_leaf(leaf: Leaf) -> LN:
|
||||
"""Returns the furthest ancestor that has this leaf node as the last leaf."""
|
||||
node: LN = leaf
|
||||
while node.parent and node.parent.children and node is node.parent.children[-1]:
|
||||
node = node.parent
|
||||
return node
|
||||
|
496
src/black/ranges.py
Normal file
496
src/black/ranges.py
Normal file
@ -0,0 +1,496 @@
|
||||
"""Functions related to Black's formatting by line ranges feature."""
|
||||
|
||||
import difflib
|
||||
from dataclasses import dataclass
|
||||
from typing import Collection, Iterator, List, Sequence, Set, Tuple, Union
|
||||
|
||||
from black.nodes import (
|
||||
LN,
|
||||
STANDALONE_COMMENT,
|
||||
Leaf,
|
||||
Node,
|
||||
Visitor,
|
||||
first_leaf,
|
||||
furthest_ancestor_with_last_leaf,
|
||||
last_leaf,
|
||||
syms,
|
||||
)
|
||||
from blib2to3.pgen2.token import ASYNC, NEWLINE
|
||||
|
||||
|
||||
def parse_line_ranges(line_ranges: Sequence[str]) -> List[Tuple[int, int]]:
|
||||
lines: List[Tuple[int, int]] = []
|
||||
for lines_str in line_ranges:
|
||||
parts = lines_str.split("-")
|
||||
if len(parts) != 2:
|
||||
raise ValueError(
|
||||
"Incorrect --line-ranges format, expect 'START-END', found"
|
||||
f" {lines_str!r}"
|
||||
)
|
||||
try:
|
||||
start = int(parts[0])
|
||||
end = int(parts[1])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Incorrect --line-ranges value, expect integer ranges, found"
|
||||
f" {lines_str!r}"
|
||||
) from None
|
||||
else:
|
||||
lines.append((start, end))
|
||||
return lines
|
||||
|
||||
|
||||
def is_valid_line_range(lines: Tuple[int, int]) -> bool:
|
||||
"""Returns whether the line range is valid."""
|
||||
return not lines or lines[0] <= lines[1]
|
||||
|
||||
|
||||
def adjusted_lines(
|
||||
lines: Collection[Tuple[int, int]],
|
||||
original_source: str,
|
||||
modified_source: str,
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""Returns the adjusted line ranges based on edits from the original code.
|
||||
|
||||
This computes the new line ranges by diffing original_source and
|
||||
modified_source, and adjust each range based on how the range overlaps with
|
||||
the diffs.
|
||||
|
||||
Note the diff can contain lines outside of the original line ranges. This can
|
||||
happen when the formatting has to be done in adjacent to maintain consistent
|
||||
local results. For example:
|
||||
|
||||
1. def my_func(arg1, arg2,
|
||||
2. arg3,):
|
||||
3. pass
|
||||
|
||||
If it restricts to line 2-2, it can't simply reformat line 2, it also has
|
||||
to reformat line 1:
|
||||
|
||||
1. def my_func(
|
||||
2. arg1,
|
||||
3. arg2,
|
||||
4. arg3,
|
||||
5. ):
|
||||
6. pass
|
||||
|
||||
In this case, we will expand the line ranges to also include the whole diff
|
||||
block.
|
||||
|
||||
Args:
|
||||
lines: a collection of line ranges.
|
||||
original_source: the original source.
|
||||
modified_source: the modified source.
|
||||
"""
|
||||
lines_mappings = _calculate_lines_mappings(original_source, modified_source)
|
||||
|
||||
new_lines = []
|
||||
# Keep an index of the current search. Since the lines and lines_mappings are
|
||||
# sorted, this makes the search complexity linear.
|
||||
current_mapping_index = 0
|
||||
for start, end in sorted(lines):
|
||||
start_mapping_index = _find_lines_mapping_index(
|
||||
start,
|
||||
lines_mappings,
|
||||
current_mapping_index,
|
||||
)
|
||||
end_mapping_index = _find_lines_mapping_index(
|
||||
end,
|
||||
lines_mappings,
|
||||
start_mapping_index,
|
||||
)
|
||||
current_mapping_index = start_mapping_index
|
||||
if start_mapping_index >= len(lines_mappings) or end_mapping_index >= len(
|
||||
lines_mappings
|
||||
):
|
||||
# Protect against invalid inputs.
|
||||
continue
|
||||
start_mapping = lines_mappings[start_mapping_index]
|
||||
end_mapping = lines_mappings[end_mapping_index]
|
||||
if start_mapping.is_changed_block:
|
||||
# When the line falls into a changed block, expands to the whole block.
|
||||
new_start = start_mapping.modified_start
|
||||
else:
|
||||
new_start = (
|
||||
start - start_mapping.original_start + start_mapping.modified_start
|
||||
)
|
||||
if end_mapping.is_changed_block:
|
||||
# When the line falls into a changed block, expands to the whole block.
|
||||
new_end = end_mapping.modified_end
|
||||
else:
|
||||
new_end = end - end_mapping.original_start + end_mapping.modified_start
|
||||
new_range = (new_start, new_end)
|
||||
if is_valid_line_range(new_range):
|
||||
new_lines.append(new_range)
|
||||
return new_lines
|
||||
|
||||
|
||||
def convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]) -> None:
|
||||
"""Converts unchanged lines to STANDALONE_COMMENT.
|
||||
|
||||
The idea is similar to how `# fmt: on/off` is implemented. It also converts the
|
||||
nodes between those markers as a single `STANDALONE_COMMENT` leaf node with
|
||||
the unformatted code as its value. `STANDALONE_COMMENT` is a "fake" token
|
||||
that will be formatted as-is with its prefix normalized.
|
||||
|
||||
Here we perform two passes:
|
||||
|
||||
1. Visit the top-level statements, and convert them to a single
|
||||
`STANDALONE_COMMENT` when unchanged. This speeds up formatting when some
|
||||
of the top-level statements aren't changed.
|
||||
2. Convert unchanged "unwrapped lines" to `STANDALONE_COMMENT` nodes line by
|
||||
line. "unwrapped lines" are divided by the `NEWLINE` token. e.g. a
|
||||
multi-line statement is *one* "unwrapped line" that ends with `NEWLINE`,
|
||||
even though this statement itself can span multiple lines, and the
|
||||
tokenizer only sees the last '\n' as the `NEWLINE` token.
|
||||
|
||||
NOTE: During pass (2), comment prefixes and indentations are ALWAYS
|
||||
normalized even when the lines aren't changed. This is fixable by moving
|
||||
more formatting to pass (1). However, it's hard to get it correct when
|
||||
incorrect indentations are used. So we defer this to future optimizations.
|
||||
"""
|
||||
lines_set: Set[int] = set()
|
||||
for start, end in lines:
|
||||
lines_set.update(range(start, end + 1))
|
||||
visitor = _TopLevelStatementsVisitor(lines_set)
|
||||
_ = list(visitor.visit(src_node)) # Consume all results.
|
||||
_convert_unchanged_line_by_line(src_node, lines_set)
|
||||
|
||||
|
||||
def _contains_standalone_comment(node: LN) -> bool:
|
||||
if isinstance(node, Leaf):
|
||||
return node.type == STANDALONE_COMMENT
|
||||
else:
|
||||
for child in node.children:
|
||||
if _contains_standalone_comment(child):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class _TopLevelStatementsVisitor(Visitor[None]):
|
||||
"""
|
||||
A node visitor that converts unchanged top-level statements to
|
||||
STANDALONE_COMMENT.
|
||||
|
||||
This is used in addition to _convert_unchanged_lines_by_flatterning, to
|
||||
speed up formatting when there are unchanged top-level
|
||||
classes/functions/statements.
|
||||
"""
|
||||
|
||||
def __init__(self, lines_set: Set[int]):
|
||||
self._lines_set = lines_set
|
||||
|
||||
def visit_simple_stmt(self, node: Node) -> Iterator[None]:
|
||||
# This is only called for top-level statements, since `visit_suite`
|
||||
# won't visit its children nodes.
|
||||
yield from []
|
||||
newline_leaf = last_leaf(node)
|
||||
if not newline_leaf:
|
||||
return
|
||||
assert (
|
||||
newline_leaf.type == NEWLINE
|
||||
), f"Unexpectedly found leaf.type={newline_leaf.type}"
|
||||
# We need to find the furthest ancestor with the NEWLINE as the last
|
||||
# leaf, since a `suite` can simply be a `simple_stmt` when it puts
|
||||
# its body on the same line. Example: `if cond: pass`.
|
||||
ancestor = furthest_ancestor_with_last_leaf(newline_leaf)
|
||||
if not _get_line_range(ancestor).intersection(self._lines_set):
|
||||
_convert_node_to_standalone_comment(ancestor)
|
||||
|
||||
def visit_suite(self, node: Node) -> Iterator[None]:
|
||||
yield from []
|
||||
# If there is a STANDALONE_COMMENT node, it means parts of the node tree
|
||||
# have fmt on/off/skip markers. Those STANDALONE_COMMENT nodes can't
|
||||
# be simply converted by calling str(node). So we just don't convert
|
||||
# here.
|
||||
if _contains_standalone_comment(node):
|
||||
return
|
||||
# Find the semantic parent of this suite. For `async_stmt` and
|
||||
# `async_funcdef`, the ASYNC token is defined on a separate level by the
|
||||
# grammar.
|
||||
semantic_parent = node.parent
|
||||
if semantic_parent is not None:
|
||||
if (
|
||||
semantic_parent.prev_sibling is not None
|
||||
and semantic_parent.prev_sibling.type == ASYNC
|
||||
):
|
||||
semantic_parent = semantic_parent.parent
|
||||
if semantic_parent is not None and not _get_line_range(
|
||||
semantic_parent
|
||||
).intersection(self._lines_set):
|
||||
_convert_node_to_standalone_comment(semantic_parent)
|
||||
|
||||
|
||||
def _convert_unchanged_line_by_line(node: Node, lines_set: Set[int]) -> None:
|
||||
"""Converts unchanged to STANDALONE_COMMENT line by line."""
|
||||
for leaf in node.leaves():
|
||||
if leaf.type != NEWLINE:
|
||||
# We only consider "unwrapped lines", which are divided by the NEWLINE
|
||||
# token.
|
||||
continue
|
||||
if leaf.parent and leaf.parent.type == syms.match_stmt:
|
||||
# The `suite` node is defined as:
|
||||
# match_stmt: "match" subject_expr ':' NEWLINE INDENT case_block+ DEDENT
|
||||
# Here we need to check `subject_expr`. The `case_block+` will be
|
||||
# checked by their own NEWLINEs.
|
||||
nodes_to_ignore: List[LN] = []
|
||||
prev_sibling = leaf.prev_sibling
|
||||
while prev_sibling:
|
||||
nodes_to_ignore.insert(0, prev_sibling)
|
||||
prev_sibling = prev_sibling.prev_sibling
|
||||
if not _get_line_range(nodes_to_ignore).intersection(lines_set):
|
||||
_convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf)
|
||||
elif leaf.parent and leaf.parent.type == syms.suite:
|
||||
# The `suite` node is defined as:
|
||||
# suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT
|
||||
# We will check `simple_stmt` and `stmt+` separately against the lines set
|
||||
parent_sibling = leaf.parent.prev_sibling
|
||||
nodes_to_ignore = []
|
||||
while parent_sibling and not parent_sibling.type == syms.suite:
|
||||
# NOTE: Multiple suite nodes can exist as siblings in e.g. `if_stmt`.
|
||||
nodes_to_ignore.insert(0, parent_sibling)
|
||||
parent_sibling = parent_sibling.prev_sibling
|
||||
# Special case for `async_stmt` and `async_funcdef` where the ASYNC
|
||||
# token is on the grandparent node.
|
||||
grandparent = leaf.parent.parent
|
||||
if (
|
||||
grandparent is not None
|
||||
and grandparent.prev_sibling is not None
|
||||
and grandparent.prev_sibling.type == ASYNC
|
||||
):
|
||||
nodes_to_ignore.insert(0, grandparent.prev_sibling)
|
||||
if not _get_line_range(nodes_to_ignore).intersection(lines_set):
|
||||
_convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf)
|
||||
else:
|
||||
ancestor = furthest_ancestor_with_last_leaf(leaf)
|
||||
# Consider multiple decorators as a whole block, as their
|
||||
# newlines have different behaviors than the rest of the grammar.
|
||||
if (
|
||||
ancestor.type == syms.decorator
|
||||
and ancestor.parent
|
||||
and ancestor.parent.type == syms.decorators
|
||||
):
|
||||
ancestor = ancestor.parent
|
||||
if not _get_line_range(ancestor).intersection(lines_set):
|
||||
_convert_node_to_standalone_comment(ancestor)
|
||||
|
||||
|
||||
def _convert_node_to_standalone_comment(node: LN) -> None:
|
||||
"""Convert node to STANDALONE_COMMENT by modifying the tree inline."""
|
||||
parent = node.parent
|
||||
if not parent:
|
||||
return
|
||||
first = first_leaf(node)
|
||||
last = last_leaf(node)
|
||||
if not first or not last:
|
||||
return
|
||||
if first is last:
|
||||
# This can happen on the following edge cases:
|
||||
# 1. A block of `# fmt: off/on` code except the `# fmt: on` is placed
|
||||
# on the end of the last line instead of on a new line.
|
||||
# 2. A single backslash on its own line followed by a comment line.
|
||||
# Ideally we don't want to format them when not requested, but fixing
|
||||
# isn't easy. These cases are also badly formatted code, so it isn't
|
||||
# too bad we reformat them.
|
||||
return
|
||||
# The prefix contains comments and indentation whitespaces. They are
|
||||
# reformatted accordingly to the correct indentation level.
|
||||
# This also means the indentation will be changed on the unchanged lines, and
|
||||
# this is actually required to not break incremental reformatting.
|
||||
prefix = first.prefix
|
||||
first.prefix = ""
|
||||
index = node.remove()
|
||||
if index is not None:
|
||||
# Remove the '\n', as STANDALONE_COMMENT will have '\n' appended when
|
||||
# genearting the formatted code.
|
||||
value = str(node)[:-1]
|
||||
parent.insert_child(
|
||||
index,
|
||||
Leaf(
|
||||
STANDALONE_COMMENT,
|
||||
value,
|
||||
prefix=prefix,
|
||||
fmt_pass_converted_first_leaf=first,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _convert_nodes_to_standalone_comment(nodes: Sequence[LN], *, newline: Leaf) -> None:
|
||||
"""Convert nodes to STANDALONE_COMMENT by modifying the tree inline."""
|
||||
if not nodes:
|
||||
return
|
||||
parent = nodes[0].parent
|
||||
first = first_leaf(nodes[0])
|
||||
if not parent or not first:
|
||||
return
|
||||
prefix = first.prefix
|
||||
first.prefix = ""
|
||||
value = "".join(str(node) for node in nodes)
|
||||
# The prefix comment on the NEWLINE leaf is the trailing comment of the statement.
|
||||
if newline.prefix:
|
||||
value += newline.prefix
|
||||
newline.prefix = ""
|
||||
index = nodes[0].remove()
|
||||
for node in nodes[1:]:
|
||||
node.remove()
|
||||
if index is not None:
|
||||
parent.insert_child(
|
||||
index,
|
||||
Leaf(
|
||||
STANDALONE_COMMENT,
|
||||
value,
|
||||
prefix=prefix,
|
||||
fmt_pass_converted_first_leaf=first,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _leaf_line_end(leaf: Leaf) -> int:
|
||||
"""Returns the line number of the leaf node's last line."""
|
||||
if leaf.type == NEWLINE:
|
||||
return leaf.lineno
|
||||
else:
|
||||
# Leaf nodes like multiline strings can occupy multiple lines.
|
||||
return leaf.lineno + str(leaf).count("\n")
|
||||
|
||||
|
||||
def _get_line_range(node_or_nodes: Union[LN, List[LN]]) -> Set[int]:
|
||||
"""Returns the line range of this node or list of nodes."""
|
||||
if isinstance(node_or_nodes, list):
|
||||
nodes = node_or_nodes
|
||||
if not nodes:
|
||||
return set()
|
||||
first = first_leaf(nodes[0])
|
||||
last = last_leaf(nodes[-1])
|
||||
if first and last:
|
||||
line_start = first.lineno
|
||||
line_end = _leaf_line_end(last)
|
||||
return set(range(line_start, line_end + 1))
|
||||
else:
|
||||
return set()
|
||||
else:
|
||||
node = node_or_nodes
|
||||
if isinstance(node, Leaf):
|
||||
return set(range(node.lineno, _leaf_line_end(node) + 1))
|
||||
else:
|
||||
first = first_leaf(node)
|
||||
last = last_leaf(node)
|
||||
if first and last:
|
||||
return set(range(first.lineno, _leaf_line_end(last) + 1))
|
||||
else:
|
||||
return set()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LinesMapping:
|
||||
"""1-based lines mapping from original source to modified source.
|
||||
|
||||
Lines [original_start, original_end] from original source
|
||||
are mapped to [modified_start, modified_end].
|
||||
|
||||
The ranges are inclusive on both ends.
|
||||
"""
|
||||
|
||||
original_start: int
|
||||
original_end: int
|
||||
modified_start: int
|
||||
modified_end: int
|
||||
# Whether this range corresponds to a changed block, or an unchanged block.
|
||||
is_changed_block: bool
|
||||
|
||||
|
||||
def _calculate_lines_mappings(
|
||||
original_source: str,
|
||||
modified_source: str,
|
||||
) -> Sequence[_LinesMapping]:
|
||||
"""Returns a sequence of _LinesMapping by diffing the sources.
|
||||
|
||||
For example, given the following diff:
|
||||
import re
|
||||
- def func(arg1,
|
||||
- arg2, arg3):
|
||||
+ def func(arg1, arg2, arg3):
|
||||
pass
|
||||
It returns the following mappings:
|
||||
original -> modified
|
||||
(1, 1) -> (1, 1), is_changed_block=False (the "import re" line)
|
||||
(2, 3) -> (2, 2), is_changed_block=True (the diff)
|
||||
(4, 4) -> (3, 3), is_changed_block=False (the "pass" line)
|
||||
|
||||
You can think of this visually as if it brings up a side-by-side diff, and tries
|
||||
to map the line ranges from the left side to the right side:
|
||||
|
||||
(1, 1)->(1, 1) 1. import re 1. import re
|
||||
(2, 3)->(2, 2) 2. def func(arg1, 2. def func(arg1, arg2, arg3):
|
||||
3. arg2, arg3):
|
||||
(4, 4)->(3, 3) 4. pass 3. pass
|
||||
|
||||
Args:
|
||||
original_source: the original source.
|
||||
modified_source: the modified source.
|
||||
"""
|
||||
matcher = difflib.SequenceMatcher(
|
||||
None,
|
||||
original_source.splitlines(keepends=True),
|
||||
modified_source.splitlines(keepends=True),
|
||||
)
|
||||
matching_blocks = matcher.get_matching_blocks()
|
||||
lines_mappings: List[_LinesMapping] = []
|
||||
# matching_blocks is a sequence of "same block of code ranges", see
|
||||
# https://docs.python.org/3/library/difflib.html#difflib.SequenceMatcher.get_matching_blocks
|
||||
# Each block corresponds to a _LinesMapping with is_changed_block=False,
|
||||
# and the ranges between two blocks corresponds to a _LinesMapping with
|
||||
# is_changed_block=True,
|
||||
# NOTE: matching_blocks is 0-based, but _LinesMapping is 1-based.
|
||||
for i, block in enumerate(matching_blocks):
|
||||
if i == 0:
|
||||
if block.a != 0 or block.b != 0:
|
||||
lines_mappings.append(
|
||||
_LinesMapping(
|
||||
original_start=1,
|
||||
original_end=block.a,
|
||||
modified_start=1,
|
||||
modified_end=block.b,
|
||||
is_changed_block=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
previous_block = matching_blocks[i - 1]
|
||||
lines_mappings.append(
|
||||
_LinesMapping(
|
||||
original_start=previous_block.a + previous_block.size + 1,
|
||||
original_end=block.a,
|
||||
modified_start=previous_block.b + previous_block.size + 1,
|
||||
modified_end=block.b,
|
||||
is_changed_block=True,
|
||||
)
|
||||
)
|
||||
if i < len(matching_blocks) - 1:
|
||||
lines_mappings.append(
|
||||
_LinesMapping(
|
||||
original_start=block.a + 1,
|
||||
original_end=block.a + block.size,
|
||||
modified_start=block.b + 1,
|
||||
modified_end=block.b + block.size,
|
||||
is_changed_block=False,
|
||||
)
|
||||
)
|
||||
return lines_mappings
|
||||
|
||||
|
||||
def _find_lines_mapping_index(
|
||||
original_line: int,
|
||||
lines_mappings: Sequence[_LinesMapping],
|
||||
start_index: int,
|
||||
) -> int:
|
||||
"""Returns the original index of the lines mappings for the original line."""
|
||||
index = start_index
|
||||
while index < len(lines_mappings):
|
||||
mapping = lines_mappings[index]
|
||||
if (
|
||||
mapping.original_start <= original_line
|
||||
and original_line <= mapping.original_end
|
||||
):
|
||||
return index
|
||||
index += 1
|
||||
return index
|
107
tests/data/cases/line_ranges_basic.py
Normal file
107
tests/data/cases/line_ranges_basic.py
Normal file
@ -0,0 +1,107 @@
|
||||
# flags: --line-ranges=5-6
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
def foo1(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
|
||||
def foo2(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
|
||||
def foo3(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
|
||||
def foo4(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
|
||||
|
||||
# Adding some unformated code covering a wide range of syntaxes.
|
||||
|
||||
if True:
|
||||
# Incorrectly indented prefix comments.
|
||||
pass
|
||||
|
||||
import typing
|
||||
from typing import (
|
||||
Any ,
|
||||
)
|
||||
class MyClass( object): # Trailing comment with extra leading space.
|
||||
#NOTE: The following indentation is incorrect:
|
||||
@decor( 1 * 3 )
|
||||
def my_func( arg):
|
||||
pass
|
||||
|
||||
try: # Trailing comment with extra leading space.
|
||||
for i in range(10): # Trailing comment with extra leading space.
|
||||
while condition:
|
||||
if something:
|
||||
then_something( )
|
||||
elif something_else:
|
||||
then_something_else( )
|
||||
except ValueError as e:
|
||||
unformatted( )
|
||||
finally:
|
||||
unformatted( )
|
||||
|
||||
async def test_async_unformatted( ): # Trailing comment with extra leading space.
|
||||
async for i in some_iter( unformatted ): # Trailing comment with extra leading space.
|
||||
await asyncio.sleep( 1 )
|
||||
async with some_context( unformatted ):
|
||||
print( "unformatted" )
|
||||
|
||||
|
||||
# output
|
||||
# flags: --line-ranges=5-6
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
def foo1(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
|
||||
def foo2(
|
||||
parameter_1,
|
||||
parameter_2,
|
||||
parameter_3,
|
||||
parameter_4,
|
||||
parameter_5,
|
||||
parameter_6,
|
||||
parameter_7,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def foo3(
|
||||
parameter_1,
|
||||
parameter_2,
|
||||
parameter_3,
|
||||
parameter_4,
|
||||
parameter_5,
|
||||
parameter_6,
|
||||
parameter_7,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def foo4(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
|
||||
|
||||
# Adding some unformated code covering a wide range of syntaxes.
|
||||
|
||||
if True:
|
||||
# Incorrectly indented prefix comments.
|
||||
pass
|
||||
|
||||
import typing
|
||||
from typing import (
|
||||
Any ,
|
||||
)
|
||||
class MyClass( object): # Trailing comment with extra leading space.
|
||||
#NOTE: The following indentation is incorrect:
|
||||
@decor( 1 * 3 )
|
||||
def my_func( arg):
|
||||
pass
|
||||
|
||||
try: # Trailing comment with extra leading space.
|
||||
for i in range(10): # Trailing comment with extra leading space.
|
||||
while condition:
|
||||
if something:
|
||||
then_something( )
|
||||
elif something_else:
|
||||
then_something_else( )
|
||||
except ValueError as e:
|
||||
unformatted( )
|
||||
finally:
|
||||
unformatted( )
|
||||
|
||||
async def test_async_unformatted( ): # Trailing comment with extra leading space.
|
||||
async for i in some_iter( unformatted ): # Trailing comment with extra leading space.
|
||||
await asyncio.sleep( 1 )
|
||||
async with some_context( unformatted ):
|
||||
print( "unformatted" )
|
49
tests/data/cases/line_ranges_fmt_off.py
Normal file
49
tests/data/cases/line_ranges_fmt_off.py
Normal file
@ -0,0 +1,49 @@
|
||||
# flags: --line-ranges=7-7 --line-ranges=17-23
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
|
||||
# fmt: off
|
||||
import os
|
||||
def myfunc( ): # Intentionally unformatted.
|
||||
pass
|
||||
# fmt: on
|
||||
|
||||
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
# fmt: off
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
# fmt: on
|
||||
|
||||
|
||||
def myfunc( ): # This will be reformatted.
|
||||
print( {"this will be reformatted"} )
|
||||
|
||||
# output
|
||||
|
||||
# flags: --line-ranges=7-7 --line-ranges=17-23
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
|
||||
# fmt: off
|
||||
import os
|
||||
def myfunc( ): # Intentionally unformatted.
|
||||
pass
|
||||
# fmt: on
|
||||
|
||||
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
# fmt: off
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
# fmt: on
|
||||
|
||||
|
||||
def myfunc(): # This will be reformatted.
|
||||
print({"this will be reformatted"})
|
27
tests/data/cases/line_ranges_fmt_off_decorator.py
Normal file
27
tests/data/cases/line_ranges_fmt_off_decorator.py
Normal file
@ -0,0 +1,27 @@
|
||||
# flags: --line-ranges=12-12
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
|
||||
# Regression test for an edge case involving decorators and fmt: off/on.
|
||||
class MyClass:
|
||||
|
||||
# fmt: off
|
||||
@decorator ( )
|
||||
# fmt: on
|
||||
def method():
|
||||
print ( "str" )
|
||||
|
||||
# output
|
||||
|
||||
# flags: --line-ranges=12-12
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
|
||||
# Regression test for an edge case involving decorators and fmt: off/on.
|
||||
class MyClass:
|
||||
|
||||
# fmt: off
|
||||
@decorator ( )
|
||||
# fmt: on
|
||||
def method():
|
||||
print("str")
|
37
tests/data/cases/line_ranges_fmt_off_overlap.py
Normal file
37
tests/data/cases/line_ranges_fmt_off_overlap.py
Normal file
@ -0,0 +1,37 @@
|
||||
# flags: --line-ranges=11-17
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
|
||||
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
# fmt: off
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
# fmt: on
|
||||
|
||||
|
||||
def myfunc( ): # This will be reformatted.
|
||||
print( {"this will be reformatted"} )
|
||||
|
||||
# output
|
||||
|
||||
# flags: --line-ranges=11-17
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
|
||||
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
# fmt: off
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
def myfunc( ): # This will not be reformatted.
|
||||
print( {"also won't be reformatted"} )
|
||||
# fmt: on
|
||||
|
||||
|
||||
def myfunc(): # This will be reformatted.
|
||||
print({"this will be reformatted"})
|
9
tests/data/cases/line_ranges_imports.py
Normal file
9
tests/data/cases/line_ranges_imports.py
Normal file
@ -0,0 +1,9 @@
|
||||
# flags: --line-ranges=8-8
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
|
||||
# This test ensures no empty lines are added around import lines.
|
||||
# It caused an issue before https://github.com/psf/black/pull/3610 is merged.
|
||||
import os
|
||||
import re
|
||||
import sys
|
27
tests/data/cases/line_ranges_indentation.py
Normal file
27
tests/data/cases/line_ranges_indentation.py
Normal file
@ -0,0 +1,27 @@
|
||||
# flags: --line-ranges=5-5
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
if cond1:
|
||||
print("first")
|
||||
if cond2:
|
||||
print("second")
|
||||
else:
|
||||
print("else")
|
||||
|
||||
if another_cond:
|
||||
print("will not be changed")
|
||||
|
||||
# output
|
||||
|
||||
# flags: --line-ranges=5-5
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
if cond1:
|
||||
print("first")
|
||||
if cond2:
|
||||
print("second")
|
||||
else:
|
||||
print("else")
|
||||
|
||||
if another_cond:
|
||||
print("will not be changed")
|
27
tests/data/cases/line_ranges_two_passes.py
Normal file
27
tests/data/cases/line_ranges_two_passes.py
Normal file
@ -0,0 +1,27 @@
|
||||
# flags: --line-ranges=9-11
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
|
||||
# This is a specific case for Black's two-pass formatting behavior in `format_str`.
|
||||
# The second pass must respect the line ranges before the first pass.
|
||||
|
||||
|
||||
def restrict_to_this_line(arg1,
|
||||
arg2,
|
||||
arg3):
|
||||
print ( "This should not be formatted." )
|
||||
print ( "Note that in the second pass, the original line range 9-11 will cover these print lines.")
|
||||
|
||||
# output
|
||||
|
||||
# flags: --line-ranges=9-11
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
|
||||
# This is a specific case for Black's two-pass formatting behavior in `format_str`.
|
||||
# The second pass must respect the line ranges before the first pass.
|
||||
|
||||
|
||||
def restrict_to_this_line(arg1, arg2, arg3):
|
||||
print ( "This should not be formatted." )
|
||||
print ( "Note that in the second pass, the original line range 9-11 will cover these print lines.")
|
25
tests/data/cases/line_ranges_unwrapping.py
Normal file
25
tests/data/cases/line_ranges_unwrapping.py
Normal file
@ -0,0 +1,25 @@
|
||||
# flags: --line-ranges=5-5 --line-ranges=9-9 --line-ranges=13-13
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
alist = [
|
||||
1, 2
|
||||
]
|
||||
|
||||
adict = {
|
||||
"key" : "value"
|
||||
}
|
||||
|
||||
func_call (
|
||||
arg = value
|
||||
)
|
||||
|
||||
# output
|
||||
|
||||
# flags: --line-ranges=5-5 --line-ranges=9-9 --line-ranges=13-13
|
||||
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
|
||||
# flag above as it's formatting specifically these lines.
|
||||
alist = [1, 2]
|
||||
|
||||
adict = {"key": "value"}
|
||||
|
||||
func_call(arg=value)
|
2
tests/data/invalid_line_ranges.toml
Normal file
2
tests/data/invalid_line_ranges.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[tool.black]
|
||||
line-ranges = "1-1"
|
50
tests/data/line_ranges_formatted/basic.py
Normal file
50
tests/data/line_ranges_formatted/basic.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""Module doc."""
|
||||
|
||||
from typing import (
|
||||
Callable,
|
||||
Literal,
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
class Unformatted:
|
||||
def should_also_work(self):
|
||||
pass
|
||||
# fmt: on
|
||||
|
||||
|
||||
a = [1, 2] # fmt: skip
|
||||
|
||||
|
||||
# This should cover as many syntaxes as possible.
|
||||
class Foo:
|
||||
"""Class doc."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@add_logging
|
||||
@memoize.memoize(max_items=2)
|
||||
def plus_one(
|
||||
self,
|
||||
number: int,
|
||||
) -> int:
|
||||
return number + 1
|
||||
|
||||
async def async_plus_one(self, number: int) -> int:
|
||||
await asyncio.sleep(1)
|
||||
async with some_context():
|
||||
return number + 1
|
||||
|
||||
|
||||
try:
|
||||
for i in range(10):
|
||||
while condition:
|
||||
if something:
|
||||
then_something()
|
||||
elif something_else:
|
||||
then_something_else()
|
||||
except ValueError as e:
|
||||
handle(e)
|
||||
finally:
|
||||
done()
|
25
tests/data/line_ranges_formatted/pattern_matching.py
Normal file
25
tests/data/line_ranges_formatted/pattern_matching.py
Normal file
@ -0,0 +1,25 @@
|
||||
# flags: --minimum-version=3.10
|
||||
|
||||
|
||||
def pattern_matching():
|
||||
match status:
|
||||
case 1:
|
||||
return "1"
|
||||
case [single]:
|
||||
return "single"
|
||||
case [
|
||||
action,
|
||||
obj,
|
||||
]:
|
||||
return "act on obj"
|
||||
case Point(x=0):
|
||||
return "class pattern"
|
||||
case {"text": message}:
|
||||
return "mapping"
|
||||
case {
|
||||
"text": message,
|
||||
"format": _,
|
||||
}:
|
||||
return "mapping"
|
||||
case _:
|
||||
return "fallback"
|
@ -8,6 +8,7 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import types
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager, redirect_stderr
|
||||
@ -1269,7 +1270,7 @@ def test_reformat_one_with_stdin_filename(self) -> None:
|
||||
report=report,
|
||||
)
|
||||
fsts.assert_called_once_with(
|
||||
fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
|
||||
fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE, lines=()
|
||||
)
|
||||
# __BLACK_STDIN_FILENAME__ should have been stripped
|
||||
report.done.assert_called_with(expected, black.Changed.YES)
|
||||
@ -1295,6 +1296,7 @@ def test_reformat_one_with_stdin_filename_pyi(self) -> None:
|
||||
fast=True,
|
||||
write_back=black.WriteBack.YES,
|
||||
mode=replace(DEFAULT_MODE, is_pyi=True),
|
||||
lines=(),
|
||||
)
|
||||
# __BLACK_STDIN_FILENAME__ should have been stripped
|
||||
report.done.assert_called_with(expected, black.Changed.YES)
|
||||
@ -1320,6 +1322,7 @@ def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
|
||||
fast=True,
|
||||
write_back=black.WriteBack.YES,
|
||||
mode=replace(DEFAULT_MODE, is_ipynb=True),
|
||||
lines=(),
|
||||
)
|
||||
# __BLACK_STDIN_FILENAME__ should have been stripped
|
||||
report.done.assert_called_with(expected, black.Changed.YES)
|
||||
@ -1941,6 +1944,88 @@ def test_equivalency_ast_parse_failure_includes_error(self) -> None:
|
||||
err.match("invalid character")
|
||||
err.match(r"\(<unknown>, line 1\)")
|
||||
|
||||
def test_line_ranges_with_code_option(self) -> None:
|
||||
code = textwrap.dedent("""\
|
||||
if a == b:
|
||||
print ( "OK" )
|
||||
""")
|
||||
args = ["--line-ranges=1-1", "--code", code]
|
||||
result = CliRunner().invoke(black.main, args)
|
||||
|
||||
expected = textwrap.dedent("""\
|
||||
if a == b:
|
||||
print ( "OK" )
|
||||
""")
|
||||
self.compare_results(result, expected, expected_exit_code=0)
|
||||
|
||||
def test_line_ranges_with_stdin(self) -> None:
|
||||
code = textwrap.dedent("""\
|
||||
if a == b:
|
||||
print ( "OK" )
|
||||
""")
|
||||
runner = BlackRunner()
|
||||
result = runner.invoke(
|
||||
black.main, ["--line-ranges=1-1", "-"], input=BytesIO(code.encode("utf-8"))
|
||||
)
|
||||
|
||||
expected = textwrap.dedent("""\
|
||||
if a == b:
|
||||
print ( "OK" )
|
||||
""")
|
||||
self.compare_results(result, expected, expected_exit_code=0)
|
||||
|
||||
def test_line_ranges_with_source(self) -> None:
|
||||
with TemporaryDirectory() as workspace:
|
||||
test_file = Path(workspace) / "test.py"
|
||||
test_file.write_text(
|
||||
textwrap.dedent("""\
|
||||
if a == b:
|
||||
print ( "OK" )
|
||||
"""),
|
||||
encoding="utf-8",
|
||||
)
|
||||
args = ["--line-ranges=1-1", str(test_file)]
|
||||
result = CliRunner().invoke(black.main, args)
|
||||
assert not result.exit_code
|
||||
|
||||
formatted = test_file.read_text(encoding="utf-8")
|
||||
expected = textwrap.dedent("""\
|
||||
if a == b:
|
||||
print ( "OK" )
|
||||
""")
|
||||
assert expected == formatted
|
||||
|
||||
def test_line_ranges_with_multiple_sources(self) -> None:
|
||||
with TemporaryDirectory() as workspace:
|
||||
test1_file = Path(workspace) / "test1.py"
|
||||
test1_file.write_text("", encoding="utf-8")
|
||||
test2_file = Path(workspace) / "test2.py"
|
||||
test2_file.write_text("", encoding="utf-8")
|
||||
args = ["--line-ranges=1-1", str(test1_file), str(test2_file)]
|
||||
result = CliRunner().invoke(black.main, args)
|
||||
assert result.exit_code == 1
|
||||
assert "Cannot use --line-ranges to format multiple files" in result.output
|
||||
|
||||
def test_line_ranges_with_ipynb(self) -> None:
|
||||
with TemporaryDirectory() as workspace:
|
||||
test_file = Path(workspace) / "test.ipynb"
|
||||
test_file.write_text("{}", encoding="utf-8")
|
||||
args = ["--line-ranges=1-1", "--ipynb", str(test_file)]
|
||||
result = CliRunner().invoke(black.main, args)
|
||||
assert "Cannot use --line-ranges with ipynb files" in result.output
|
||||
assert result.exit_code == 1
|
||||
|
||||
def test_line_ranges_in_pyproject_toml(self) -> None:
|
||||
config = THIS_DIR / "data" / "invalid_line_ranges.toml"
|
||||
result = BlackRunner().invoke(
|
||||
black.main, ["--code", "print()", "--config", str(config)]
|
||||
)
|
||||
assert result.exit_code == 2
|
||||
assert result.stderr_bytes is not None
|
||||
assert (
|
||||
b"Cannot use line-ranges in the pyproject.toml file." in result.stderr_bytes
|
||||
)
|
||||
|
||||
|
||||
class TestCaching:
|
||||
def test_get_cache_dir(
|
||||
|
@ -29,13 +29,19 @@ def check_file(subdir: str, filename: str, *, data: bool = True) -> None:
|
||||
args.mode,
|
||||
fast=args.fast,
|
||||
minimum_version=args.minimum_version,
|
||||
lines=args.lines,
|
||||
)
|
||||
if args.minimum_version is not None:
|
||||
major, minor = args.minimum_version
|
||||
target_version = TargetVersion[f"PY{major}{minor}"]
|
||||
mode = replace(args.mode, target_versions={target_version})
|
||||
assert_format(
|
||||
source, expected, mode, fast=args.fast, minimum_version=args.minimum_version
|
||||
source,
|
||||
expected,
|
||||
mode,
|
||||
fast=args.fast,
|
||||
minimum_version=args.minimum_version,
|
||||
lines=args.lines,
|
||||
)
|
||||
|
||||
|
||||
@ -45,6 +51,24 @@ def test_simple_format(filename: str) -> None:
|
||||
check_file("cases", filename)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("filename", all_data_cases("line_ranges_formatted"))
|
||||
def test_line_ranges_line_by_line(filename: str) -> None:
|
||||
args, source, expected = read_data_with_mode("line_ranges_formatted", filename)
|
||||
assert (
|
||||
source == expected
|
||||
), "Test cases in line_ranges_formatted must already be formatted."
|
||||
line_count = len(source.splitlines())
|
||||
for line in range(1, line_count + 1):
|
||||
assert_format(
|
||||
source,
|
||||
expected,
|
||||
args.mode,
|
||||
fast=args.fast,
|
||||
minimum_version=args.minimum_version,
|
||||
lines=[(line, line)],
|
||||
)
|
||||
|
||||
|
||||
# =============== #
|
||||
# Unusual cases
|
||||
# =============== #
|
||||
|
185
tests/test_ranges.py
Normal file
185
tests/test_ranges.py
Normal file
@ -0,0 +1,185 @@
|
||||
"""Test the black.ranges module."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from black.ranges import adjusted_lines
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lines",
|
||||
[[(1, 1)], [(1, 3)], [(1, 1), (3, 4)]],
|
||||
)
|
||||
def test_no_diff(lines: List[Tuple[int, int]]) -> None:
|
||||
source = """\
|
||||
import re
|
||||
|
||||
def func():
|
||||
pass
|
||||
"""
|
||||
assert lines == adjusted_lines(lines, source, source)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lines",
|
||||
[
|
||||
[(1, 0)],
|
||||
[(-8, 0)],
|
||||
[(-8, 8)],
|
||||
[(1, 100)],
|
||||
[(2, 1)],
|
||||
[(0, 8), (3, 1)],
|
||||
],
|
||||
)
|
||||
def test_invalid_lines(lines: List[Tuple[int, int]]) -> None:
|
||||
original_source = """\
|
||||
import re
|
||||
def foo(arg):
|
||||
'''This is the foo function.
|
||||
|
||||
This is foo function's
|
||||
docstring with more descriptive texts.
|
||||
'''
|
||||
|
||||
def func(arg1,
|
||||
arg2, arg3):
|
||||
pass
|
||||
"""
|
||||
modified_source = """\
|
||||
import re
|
||||
def foo(arg):
|
||||
'''This is the foo function.
|
||||
|
||||
This is foo function's
|
||||
docstring with more descriptive texts.
|
||||
'''
|
||||
|
||||
def func(arg1, arg2, arg3):
|
||||
pass
|
||||
"""
|
||||
assert not adjusted_lines(lines, original_source, modified_source)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lines,adjusted",
|
||||
[
|
||||
(
|
||||
[(1, 1)],
|
||||
[(1, 1)],
|
||||
),
|
||||
(
|
||||
[(1, 2)],
|
||||
[(1, 1)],
|
||||
),
|
||||
(
|
||||
[(1, 6)],
|
||||
[(1, 2)],
|
||||
),
|
||||
(
|
||||
[(6, 6)],
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_removals(
|
||||
lines: List[Tuple[int, int]], adjusted: List[Tuple[int, int]]
|
||||
) -> None:
|
||||
original_source = """\
|
||||
1. first line
|
||||
2. second line
|
||||
3. third line
|
||||
4. fourth line
|
||||
5. fifth line
|
||||
6. sixth line
|
||||
"""
|
||||
modified_source = """\
|
||||
2. second line
|
||||
5. fifth line
|
||||
"""
|
||||
assert adjusted == adjusted_lines(lines, original_source, modified_source)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lines,adjusted",
|
||||
[
|
||||
(
|
||||
[(1, 1)],
|
||||
[(2, 2)],
|
||||
),
|
||||
(
|
||||
[(1, 2)],
|
||||
[(2, 5)],
|
||||
),
|
||||
(
|
||||
[(2, 2)],
|
||||
[(5, 5)],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_additions(
|
||||
lines: List[Tuple[int, int]], adjusted: List[Tuple[int, int]]
|
||||
) -> None:
|
||||
original_source = """\
|
||||
1. first line
|
||||
2. second line
|
||||
"""
|
||||
modified_source = """\
|
||||
this is added
|
||||
1. first line
|
||||
this is added
|
||||
this is added
|
||||
2. second line
|
||||
this is added
|
||||
"""
|
||||
assert adjusted == adjusted_lines(lines, original_source, modified_source)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lines,adjusted",
|
||||
[
|
||||
(
|
||||
[(1, 11)],
|
||||
[(1, 10)],
|
||||
),
|
||||
(
|
||||
[(1, 12)],
|
||||
[(1, 11)],
|
||||
),
|
||||
(
|
||||
[(10, 10)],
|
||||
[(9, 9)],
|
||||
),
|
||||
([(1, 1), (9, 10)], [(1, 1), (9, 9)]),
|
||||
([(9, 10), (1, 1)], [(1, 1), (9, 9)]),
|
||||
],
|
||||
)
|
||||
def test_diffs(lines: List[Tuple[int, int]], adjusted: List[Tuple[int, int]]) -> None:
|
||||
original_source = """\
|
||||
1. import re
|
||||
2. def foo(arg):
|
||||
3. '''This is the foo function.
|
||||
4.
|
||||
5. This is foo function's
|
||||
6. docstring with more descriptive texts.
|
||||
7. '''
|
||||
8.
|
||||
9. def func(arg1,
|
||||
10. arg2, arg3):
|
||||
11. pass
|
||||
12. # last line
|
||||
"""
|
||||
modified_source = """\
|
||||
1. import re # changed
|
||||
2. def foo(arg):
|
||||
3. '''This is the foo function.
|
||||
4.
|
||||
5. This is foo function's
|
||||
6. docstring with more descriptive texts.
|
||||
7. '''
|
||||
8.
|
||||
9. def func(arg1, arg2, arg3):
|
||||
11. pass
|
||||
12. # last line changed
|
||||
"""
|
||||
assert adjusted == adjusted_lines(lines, original_source, modified_source)
|
@ -8,13 +8,14 @@
|
||||
from dataclasses import dataclass, field, replace
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator, List, Optional, Tuple
|
||||
from typing import Any, Collection, Iterator, List, Optional, Tuple
|
||||
|
||||
import black
|
||||
from black.const import DEFAULT_LINE_LENGTH
|
||||
from black.debug import DebugVisitor
|
||||
from black.mode import TargetVersion
|
||||
from black.output import diff, err, out
|
||||
from black.ranges import parse_line_ranges
|
||||
|
||||
from . import conftest
|
||||
|
||||
@ -44,6 +45,7 @@ class TestCaseArgs:
|
||||
mode: black.Mode = field(default_factory=black.Mode)
|
||||
fast: bool = False
|
||||
minimum_version: Optional[Tuple[int, int]] = None
|
||||
lines: Collection[Tuple[int, int]] = ()
|
||||
|
||||
|
||||
def _assert_format_equal(expected: str, actual: str) -> None:
|
||||
@ -93,6 +95,7 @@ def assert_format(
|
||||
*,
|
||||
fast: bool = False,
|
||||
minimum_version: Optional[Tuple[int, int]] = None,
|
||||
lines: Collection[Tuple[int, int]] = (),
|
||||
) -> None:
|
||||
"""Convenience function to check that Black formats as expected.
|
||||
|
||||
@ -101,7 +104,7 @@ def assert_format(
|
||||
separate from TargetVerson Mode configuration.
|
||||
"""
|
||||
_assert_format_inner(
|
||||
source, expected, mode, fast=fast, minimum_version=minimum_version
|
||||
source, expected, mode, fast=fast, minimum_version=minimum_version, lines=lines
|
||||
)
|
||||
|
||||
# For both preview and non-preview tests, ensure that Black doesn't crash on
|
||||
@ -113,6 +116,7 @@ def assert_format(
|
||||
replace(mode, preview=not mode.preview),
|
||||
fast=fast,
|
||||
minimum_version=minimum_version,
|
||||
lines=lines,
|
||||
)
|
||||
except Exception as e:
|
||||
text = "non-preview" if mode.preview else "preview"
|
||||
@ -129,6 +133,7 @@ def assert_format(
|
||||
replace(mode, preview=False, line_length=1),
|
||||
fast=fast,
|
||||
minimum_version=minimum_version,
|
||||
lines=lines,
|
||||
)
|
||||
except Exception as e:
|
||||
raise FormatFailure(
|
||||
@ -143,8 +148,9 @@ def _assert_format_inner(
|
||||
*,
|
||||
fast: bool = False,
|
||||
minimum_version: Optional[Tuple[int, int]] = None,
|
||||
lines: Collection[Tuple[int, int]] = (),
|
||||
) -> None:
|
||||
actual = black.format_str(source, mode=mode)
|
||||
actual = black.format_str(source, mode=mode, lines=lines)
|
||||
if expected is not None:
|
||||
_assert_format_equal(expected, actual)
|
||||
# It's not useful to run safety checks if we're expecting no changes anyway. The
|
||||
@ -156,7 +162,7 @@ def _assert_format_inner(
|
||||
# when checking modern code on older versions.
|
||||
if minimum_version is None or sys.version_info >= minimum_version:
|
||||
black.assert_equivalent(source, actual)
|
||||
black.assert_stable(source, actual, mode=mode)
|
||||
black.assert_stable(source, actual, mode=mode, lines=lines)
|
||||
|
||||
|
||||
def dump_to_stderr(*output: str) -> str:
|
||||
@ -239,6 +245,7 @@ def get_flags_parser() -> argparse.ArgumentParser:
|
||||
" version works correctly."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--line-ranges", action="append")
|
||||
return parser
|
||||
|
||||
|
||||
@ -254,7 +261,13 @@ def parse_mode(flags_line: str) -> TestCaseArgs:
|
||||
magic_trailing_comma=not args.skip_magic_trailing_comma,
|
||||
preview=args.preview,
|
||||
)
|
||||
return TestCaseArgs(mode=mode, fast=args.fast, minimum_version=args.minimum_version)
|
||||
if args.line_ranges:
|
||||
lines = parse_line_ranges(args.line_ranges)
|
||||
else:
|
||||
lines = []
|
||||
return TestCaseArgs(
|
||||
mode=mode, fast=args.fast, minimum_version=args.minimum_version, lines=lines
|
||||
)
|
||||
|
||||
|
||||
def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]:
|
||||
@ -267,6 +280,12 @@ def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]:
|
||||
for line in lines:
|
||||
if not _input and line.startswith("# flags: "):
|
||||
mode = parse_mode(line[len("# flags: ") :])
|
||||
if mode.lines:
|
||||
# Retain the `# flags: ` line when using --line-ranges=. This requires
|
||||
# the `# output` section to also include this line, but retaining the
|
||||
# line is important to make the line ranges match what you see in the
|
||||
# test file.
|
||||
result.append(line)
|
||||
continue
|
||||
line = line.replace(EMPTY_LINE, "")
|
||||
if line.rstrip() == "# output":
|
||||
|
Loading…
Reference in New Issue
Block a user