Support formatting specified lines (#4020)

This commit is contained in:
Yilei Yang 2023-11-06 18:05:25 -08:00 committed by GitHub
parent ecbd9e8cf7
commit 46be1f8e54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1358 additions and 28 deletions

View File

@ -6,6 +6,9 @@
<!-- Include any especially major or disruptive changes here -->
- Support formatting ranges of lines with the new `--line-ranges` command-line option
(#4020).
### Stable style
- Fix crash on formatting bytes strings that look like docstrings (#4003)

View File

@ -175,6 +175,23 @@ All done! ✨ 🍰 ✨
1 file would be reformatted.
```
### `--line-ranges`
When specified, _Black_ will try its best to only format these lines.
This option can be specified multiple times, and a union of the lines will be formatted.
Each range must be specified as two integers connected by a `-`: `<START>-<END>`. The
`<START>` and `<END>` integer indices are 1-based and inclusive on both ends.
_Black_ may still format lines outside of the ranges for multi-line statements.
Formatting more than one file or any ipynb files with this option is not supported. This
option cannot be specified in the `pyproject.toml` config.
Example: `black --line-ranges=1-10 --line-ranges=21-30 test.py` will format lines from
`1` to `10` and `21` to `30`.
This option is mainly for editor integrations, such as "Format Selection".
#### `--color` / `--no-color`
Show (or do not show) colored diff. Only applies when `--diff` is given.

View File

@ -13,6 +13,7 @@
from pathlib import Path
from typing import (
Any,
Collection,
Dict,
Generator,
Iterator,
@ -77,6 +78,7 @@
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
from black.parsing import InvalidInput # noqa F401
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
from black.report import Changed, NothingChanged, Report
from black.trans import iter_fexpr_spans
from blib2to3.pgen2 import token
@ -163,6 +165,12 @@ def read_pyproject_toml(
"extend-exclude", "Config key extend-exclude must be a string"
)
line_ranges = config.get("line_ranges")
if line_ranges is not None:
raise click.BadOptionUsage(
"line-ranges", "Cannot use line-ranges in the pyproject.toml file."
)
default_map: Dict[str, Any] = {}
if ctx.default_map:
default_map.update(ctx.default_map)
@ -304,6 +312,19 @@ def validate_regex(
is_flag=True,
help="Don't write the files back, just output a diff for each file on stdout.",
)
@click.option(
"--line-ranges",
multiple=True,
metavar="START-END",
help=(
"When specified, _Black_ will try its best to only format these lines. This"
" option can be specified multiple times, and a union of the lines will be"
" formatted. Each range must be specified as two integers connected by a `-`:"
" `<START>-<END>`. The `<START>` and `<END>` integer indices are 1-based and"
" inclusive on both ends."
),
default=(),
)
@click.option(
"--color/--no-color",
is_flag=True,
@ -443,6 +464,7 @@ def main( # noqa: C901
target_version: List[TargetVersion],
check: bool,
diff: bool,
line_ranges: Sequence[str],
color: bool,
fast: bool,
pyi: bool,
@ -544,6 +566,18 @@ def main( # noqa: C901
python_cell_magics=set(python_cell_magics),
)
lines: List[Tuple[int, int]] = []
if line_ranges:
if ipynb:
err("Cannot use --line-ranges with ipynb files.")
ctx.exit(1)
try:
lines = parse_line_ranges(line_ranges)
except ValueError as e:
err(str(e))
ctx.exit(1)
if code is not None:
# Run in quiet mode by default with -c; the extra output isn't useful.
# You can still pass -v to get verbose output.
@ -553,7 +587,12 @@ def main( # noqa: C901
if code is not None:
reformat_code(
content=code, fast=fast, write_back=write_back, mode=mode, report=report
content=code,
fast=fast,
write_back=write_back,
mode=mode,
report=report,
lines=lines,
)
else:
assert root is not None # root is only None if code is not None
@ -588,10 +627,14 @@ def main( # noqa: C901
write_back=write_back,
mode=mode,
report=report,
lines=lines,
)
else:
from black.concurrency import reformat_many
if lines:
err("Cannot use --line-ranges to format multiple files.")
ctx.exit(1)
reformat_many(
sources=sources,
fast=fast,
@ -714,7 +757,13 @@ def path_empty(
def reformat_code(
content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
content: str,
fast: bool,
write_back: WriteBack,
mode: Mode,
report: Report,
*,
lines: Collection[Tuple[int, int]] = (),
) -> None:
"""
Reformat and print out `content` without spawning child processes.
@ -727,7 +776,7 @@ def reformat_code(
try:
changed = Changed.NO
if format_stdin_to_stdout(
content=content, fast=fast, write_back=write_back, mode=mode
content=content, fast=fast, write_back=write_back, mode=mode, lines=lines
):
changed = Changed.YES
report.done(path, changed)
@ -741,7 +790,13 @@ def reformat_code(
# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
@mypyc_attr(patchable=True)
def reformat_one(
src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
src: Path,
fast: bool,
write_back: WriteBack,
mode: Mode,
report: "Report",
*,
lines: Collection[Tuple[int, int]] = (),
) -> None:
"""Reformat a single file under `src` without spawning child processes.
@ -766,7 +821,9 @@ def reformat_one(
mode = replace(mode, is_pyi=True)
elif src.suffix == ".ipynb":
mode = replace(mode, is_ipynb=True)
if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
if format_stdin_to_stdout(
fast=fast, write_back=write_back, mode=mode, lines=lines
):
changed = Changed.YES
else:
cache = Cache.read(mode)
@ -774,7 +831,7 @@ def reformat_one(
if not cache.is_changed(src):
changed = Changed.CACHED
if changed is not Changed.CACHED and format_file_in_place(
src, fast=fast, write_back=write_back, mode=mode
src, fast=fast, write_back=write_back, mode=mode, lines=lines
):
changed = Changed.YES
if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
@ -794,6 +851,8 @@ def format_file_in_place(
mode: Mode,
write_back: WriteBack = WriteBack.NO,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
*,
lines: Collection[Tuple[int, int]] = (),
) -> bool:
"""Format file under `src` path. Return True if changed.
@ -813,7 +872,9 @@ def format_file_in_place(
header = buf.readline()
src_contents, encoding, newline = decode_bytes(buf.read())
try:
dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
dst_contents = format_file_contents(
src_contents, fast=fast, mode=mode, lines=lines
)
except NothingChanged:
return False
except JSONDecodeError:
@ -858,6 +919,7 @@ def format_stdin_to_stdout(
content: Optional[str] = None,
write_back: WriteBack = WriteBack.NO,
mode: Mode,
lines: Collection[Tuple[int, int]] = (),
) -> bool:
"""Format file on stdin. Return True if changed.
@ -876,7 +938,7 @@ def format_stdin_to_stdout(
dst = src
try:
dst = format_file_contents(src, fast=fast, mode=mode)
dst = format_file_contents(src, fast=fast, mode=mode, lines=lines)
return True
except NothingChanged:
@ -904,7 +966,11 @@ def format_stdin_to_stdout(
def check_stability_and_equivalence(
src_contents: str, dst_contents: str, *, mode: Mode
src_contents: str,
dst_contents: str,
*,
mode: Mode,
lines: Collection[Tuple[int, int]] = (),
) -> None:
"""Perform stability and equivalence checks.
@ -913,10 +979,16 @@ def check_stability_and_equivalence(
content differently.
"""
assert_equivalent(src_contents, dst_contents)
assert_stable(src_contents, dst_contents, mode=mode)
assert_stable(src_contents, dst_contents, mode=mode, lines=lines)
def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
def format_file_contents(
src_contents: str,
*,
fast: bool,
mode: Mode,
lines: Collection[Tuple[int, int]] = (),
) -> FileContent:
"""Reformat contents of a file and return new contents.
If `fast` is False, additionally confirm that the reformatted code is
@ -926,13 +998,15 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
if mode.is_ipynb:
dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
else:
dst_contents = format_str(src_contents, mode=mode)
dst_contents = format_str(src_contents, mode=mode, lines=lines)
if src_contents == dst_contents:
raise NothingChanged
if not fast and not mode.is_ipynb:
# Jupyter notebooks will already have been checked above.
check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
check_stability_and_equivalence(
src_contents, dst_contents, mode=mode, lines=lines
)
return dst_contents
@ -1043,7 +1117,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
raise NothingChanged
def format_str(src_contents: str, *, mode: Mode) -> str:
def format_str(
src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = ()
) -> str:
"""Reformat a string and return new contents.
`mode` determines formatting options, such as how many characters per line are
@ -1073,16 +1149,20 @@ def f(
hey
"""
dst_contents = _format_str_once(src_contents, mode=mode)
dst_contents = _format_str_once(src_contents, mode=mode, lines=lines)
# Forced second pass to work around optional trailing commas (becoming
# forced trailing commas on pass 2) interacting differently with optional
# parentheses. Admittedly ugly.
if src_contents != dst_contents:
return _format_str_once(dst_contents, mode=mode)
if lines:
lines = adjusted_lines(lines, src_contents, dst_contents)
return _format_str_once(dst_contents, mode=mode, lines=lines)
return dst_contents
def _format_str_once(src_contents: str, *, mode: Mode) -> str:
def _format_str_once(
src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = ()
) -> str:
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_blocks: List[LinesBlock] = []
if mode.target_versions:
@ -1097,7 +1177,11 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
if supports_feature(versions, feature)
}
normalize_fmt_off(src_node, mode)
lines = LineGenerator(mode=mode, features=context_manager_features)
if lines:
# This should be called after normalize_fmt_off.
convert_unchanged_lines(src_node, lines)
line_generator = LineGenerator(mode=mode, features=context_manager_features)
elt = EmptyLineTracker(mode=mode)
split_line_features = {
feature
@ -1105,7 +1189,7 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
if supports_feature(versions, feature)
}
block: Optional[LinesBlock] = None
for current_line in lines.visit(src_node):
for current_line in line_generator.visit(src_node):
block = elt.maybe_empty_lines(current_line)
dst_blocks.append(block)
for line in transform_line(
@ -1373,12 +1457,16 @@ def assert_equivalent(src: str, dst: str) -> None:
) from None
def assert_stable(src: str, dst: str, mode: Mode) -> None:
def assert_stable(
src: str, dst: str, mode: Mode, *, lines: Collection[Tuple[int, int]] = ()
) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
# We shouldn't call format_str() here, because that formats the string
# twice and may hide a bug where we bounce back and forth between two
# versions.
newdst = _format_str_once(dst, mode=mode)
if lines:
lines = adjusted_lines(lines, src, dst)
newdst = _format_str_once(dst, mode=mode, lines=lines)
if dst != newdst:
log = dump_to_file(
str(mode),

View File

@ -935,3 +935,31 @@ def is_part_of_annotation(leaf: Leaf) -> bool:
return True
ancestor = ancestor.parent
return False
def first_leaf(node: LN) -> Optional[Leaf]:
"""Returns the first leaf of the ancestor node."""
if isinstance(node, Leaf):
return node
elif not node.children:
return None
else:
return first_leaf(node.children[0])
def last_leaf(node: LN) -> Optional[Leaf]:
"""Returns the last leaf of the ancestor node."""
if isinstance(node, Leaf):
return node
elif not node.children:
return None
else:
return last_leaf(node.children[-1])
def furthest_ancestor_with_last_leaf(leaf: Leaf) -> LN:
"""Returns the furthest ancestor that has this leaf node as the last leaf."""
node: LN = leaf
while node.parent and node.parent.children and node is node.parent.children[-1]:
node = node.parent
return node

496
src/black/ranges.py Normal file
View File

@ -0,0 +1,496 @@
"""Functions related to Black's formatting by line ranges feature."""
import difflib
from dataclasses import dataclass
from typing import Collection, Iterator, List, Sequence, Set, Tuple, Union
from black.nodes import (
LN,
STANDALONE_COMMENT,
Leaf,
Node,
Visitor,
first_leaf,
furthest_ancestor_with_last_leaf,
last_leaf,
syms,
)
from blib2to3.pgen2.token import ASYNC, NEWLINE
def parse_line_ranges(line_ranges: Sequence[str]) -> List[Tuple[int, int]]:
lines: List[Tuple[int, int]] = []
for lines_str in line_ranges:
parts = lines_str.split("-")
if len(parts) != 2:
raise ValueError(
"Incorrect --line-ranges format, expect 'START-END', found"
f" {lines_str!r}"
)
try:
start = int(parts[0])
end = int(parts[1])
except ValueError:
raise ValueError(
"Incorrect --line-ranges value, expect integer ranges, found"
f" {lines_str!r}"
) from None
else:
lines.append((start, end))
return lines
def is_valid_line_range(lines: Tuple[int, int]) -> bool:
"""Returns whether the line range is valid."""
return not lines or lines[0] <= lines[1]
def adjusted_lines(
lines: Collection[Tuple[int, int]],
original_source: str,
modified_source: str,
) -> List[Tuple[int, int]]:
"""Returns the adjusted line ranges based on edits from the original code.
This computes the new line ranges by diffing original_source and
modified_source, and adjust each range based on how the range overlaps with
the diffs.
Note the diff can contain lines outside of the original line ranges. This can
happen when the formatting has to be done in adjacent to maintain consistent
local results. For example:
1. def my_func(arg1, arg2,
2. arg3,):
3. pass
If it restricts to line 2-2, it can't simply reformat line 2, it also has
to reformat line 1:
1. def my_func(
2. arg1,
3. arg2,
4. arg3,
5. ):
6. pass
In this case, we will expand the line ranges to also include the whole diff
block.
Args:
lines: a collection of line ranges.
original_source: the original source.
modified_source: the modified source.
"""
lines_mappings = _calculate_lines_mappings(original_source, modified_source)
new_lines = []
# Keep an index of the current search. Since the lines and lines_mappings are
# sorted, this makes the search complexity linear.
current_mapping_index = 0
for start, end in sorted(lines):
start_mapping_index = _find_lines_mapping_index(
start,
lines_mappings,
current_mapping_index,
)
end_mapping_index = _find_lines_mapping_index(
end,
lines_mappings,
start_mapping_index,
)
current_mapping_index = start_mapping_index
if start_mapping_index >= len(lines_mappings) or end_mapping_index >= len(
lines_mappings
):
# Protect against invalid inputs.
continue
start_mapping = lines_mappings[start_mapping_index]
end_mapping = lines_mappings[end_mapping_index]
if start_mapping.is_changed_block:
# When the line falls into a changed block, expands to the whole block.
new_start = start_mapping.modified_start
else:
new_start = (
start - start_mapping.original_start + start_mapping.modified_start
)
if end_mapping.is_changed_block:
# When the line falls into a changed block, expands to the whole block.
new_end = end_mapping.modified_end
else:
new_end = end - end_mapping.original_start + end_mapping.modified_start
new_range = (new_start, new_end)
if is_valid_line_range(new_range):
new_lines.append(new_range)
return new_lines
def convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]) -> None:
"""Converts unchanged lines to STANDALONE_COMMENT.
The idea is similar to how `# fmt: on/off` is implemented. It also converts the
nodes between those markers as a single `STANDALONE_COMMENT` leaf node with
the unformatted code as its value. `STANDALONE_COMMENT` is a "fake" token
that will be formatted as-is with its prefix normalized.
Here we perform two passes:
1. Visit the top-level statements, and convert them to a single
`STANDALONE_COMMENT` when unchanged. This speeds up formatting when some
of the top-level statements aren't changed.
2. Convert unchanged "unwrapped lines" to `STANDALONE_COMMENT` nodes line by
line. "unwrapped lines" are divided by the `NEWLINE` token. e.g. a
multi-line statement is *one* "unwrapped line" that ends with `NEWLINE`,
even though this statement itself can span multiple lines, and the
tokenizer only sees the last '\n' as the `NEWLINE` token.
NOTE: During pass (2), comment prefixes and indentations are ALWAYS
normalized even when the lines aren't changed. This is fixable by moving
more formatting to pass (1). However, it's hard to get it correct when
incorrect indentations are used. So we defer this to future optimizations.
"""
lines_set: Set[int] = set()
for start, end in lines:
lines_set.update(range(start, end + 1))
visitor = _TopLevelStatementsVisitor(lines_set)
_ = list(visitor.visit(src_node)) # Consume all results.
_convert_unchanged_line_by_line(src_node, lines_set)
def _contains_standalone_comment(node: LN) -> bool:
if isinstance(node, Leaf):
return node.type == STANDALONE_COMMENT
else:
for child in node.children:
if _contains_standalone_comment(child):
return True
return False
class _TopLevelStatementsVisitor(Visitor[None]):
"""
A node visitor that converts unchanged top-level statements to
STANDALONE_COMMENT.
This is used in addition to _convert_unchanged_lines_by_flatterning, to
speed up formatting when there are unchanged top-level
classes/functions/statements.
"""
def __init__(self, lines_set: Set[int]):
self._lines_set = lines_set
def visit_simple_stmt(self, node: Node) -> Iterator[None]:
# This is only called for top-level statements, since `visit_suite`
# won't visit its children nodes.
yield from []
newline_leaf = last_leaf(node)
if not newline_leaf:
return
assert (
newline_leaf.type == NEWLINE
), f"Unexpectedly found leaf.type={newline_leaf.type}"
# We need to find the furthest ancestor with the NEWLINE as the last
# leaf, since a `suite` can simply be a `simple_stmt` when it puts
# its body on the same line. Example: `if cond: pass`.
ancestor = furthest_ancestor_with_last_leaf(newline_leaf)
if not _get_line_range(ancestor).intersection(self._lines_set):
_convert_node_to_standalone_comment(ancestor)
def visit_suite(self, node: Node) -> Iterator[None]:
yield from []
# If there is a STANDALONE_COMMENT node, it means parts of the node tree
# have fmt on/off/skip markers. Those STANDALONE_COMMENT nodes can't
# be simply converted by calling str(node). So we just don't convert
# here.
if _contains_standalone_comment(node):
return
# Find the semantic parent of this suite. For `async_stmt` and
# `async_funcdef`, the ASYNC token is defined on a separate level by the
# grammar.
semantic_parent = node.parent
if semantic_parent is not None:
if (
semantic_parent.prev_sibling is not None
and semantic_parent.prev_sibling.type == ASYNC
):
semantic_parent = semantic_parent.parent
if semantic_parent is not None and not _get_line_range(
semantic_parent
).intersection(self._lines_set):
_convert_node_to_standalone_comment(semantic_parent)
def _convert_unchanged_line_by_line(node: Node, lines_set: Set[int]) -> None:
"""Converts unchanged to STANDALONE_COMMENT line by line."""
for leaf in node.leaves():
if leaf.type != NEWLINE:
# We only consider "unwrapped lines", which are divided by the NEWLINE
# token.
continue
if leaf.parent and leaf.parent.type == syms.match_stmt:
# The `suite` node is defined as:
# match_stmt: "match" subject_expr ':' NEWLINE INDENT case_block+ DEDENT
# Here we need to check `subject_expr`. The `case_block+` will be
# checked by their own NEWLINEs.
nodes_to_ignore: List[LN] = []
prev_sibling = leaf.prev_sibling
while prev_sibling:
nodes_to_ignore.insert(0, prev_sibling)
prev_sibling = prev_sibling.prev_sibling
if not _get_line_range(nodes_to_ignore).intersection(lines_set):
_convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf)
elif leaf.parent and leaf.parent.type == syms.suite:
# The `suite` node is defined as:
# suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT
# We will check `simple_stmt` and `stmt+` separately against the lines set
parent_sibling = leaf.parent.prev_sibling
nodes_to_ignore = []
while parent_sibling and not parent_sibling.type == syms.suite:
# NOTE: Multiple suite nodes can exist as siblings in e.g. `if_stmt`.
nodes_to_ignore.insert(0, parent_sibling)
parent_sibling = parent_sibling.prev_sibling
# Special case for `async_stmt` and `async_funcdef` where the ASYNC
# token is on the grandparent node.
grandparent = leaf.parent.parent
if (
grandparent is not None
and grandparent.prev_sibling is not None
and grandparent.prev_sibling.type == ASYNC
):
nodes_to_ignore.insert(0, grandparent.prev_sibling)
if not _get_line_range(nodes_to_ignore).intersection(lines_set):
_convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf)
else:
ancestor = furthest_ancestor_with_last_leaf(leaf)
# Consider multiple decorators as a whole block, as their
# newlines have different behaviors than the rest of the grammar.
if (
ancestor.type == syms.decorator
and ancestor.parent
and ancestor.parent.type == syms.decorators
):
ancestor = ancestor.parent
if not _get_line_range(ancestor).intersection(lines_set):
_convert_node_to_standalone_comment(ancestor)
def _convert_node_to_standalone_comment(node: LN) -> None:
"""Convert node to STANDALONE_COMMENT by modifying the tree inline."""
parent = node.parent
if not parent:
return
first = first_leaf(node)
last = last_leaf(node)
if not first or not last:
return
if first is last:
# This can happen on the following edge cases:
# 1. A block of `# fmt: off/on` code except the `# fmt: on` is placed
# on the end of the last line instead of on a new line.
# 2. A single backslash on its own line followed by a comment line.
# Ideally we don't want to format them when not requested, but fixing
# isn't easy. These cases are also badly formatted code, so it isn't
# too bad we reformat them.
return
# The prefix contains comments and indentation whitespaces. They are
# reformatted accordingly to the correct indentation level.
# This also means the indentation will be changed on the unchanged lines, and
# this is actually required to not break incremental reformatting.
prefix = first.prefix
first.prefix = ""
index = node.remove()
if index is not None:
# Remove the '\n', as STANDALONE_COMMENT will have '\n' appended when
# genearting the formatted code.
value = str(node)[:-1]
parent.insert_child(
index,
Leaf(
STANDALONE_COMMENT,
value,
prefix=prefix,
fmt_pass_converted_first_leaf=first,
),
)
def _convert_nodes_to_standalone_comment(nodes: Sequence[LN], *, newline: Leaf) -> None:
"""Convert nodes to STANDALONE_COMMENT by modifying the tree inline."""
if not nodes:
return
parent = nodes[0].parent
first = first_leaf(nodes[0])
if not parent or not first:
return
prefix = first.prefix
first.prefix = ""
value = "".join(str(node) for node in nodes)
# The prefix comment on the NEWLINE leaf is the trailing comment of the statement.
if newline.prefix:
value += newline.prefix
newline.prefix = ""
index = nodes[0].remove()
for node in nodes[1:]:
node.remove()
if index is not None:
parent.insert_child(
index,
Leaf(
STANDALONE_COMMENT,
value,
prefix=prefix,
fmt_pass_converted_first_leaf=first,
),
)
def _leaf_line_end(leaf: Leaf) -> int:
"""Returns the line number of the leaf node's last line."""
if leaf.type == NEWLINE:
return leaf.lineno
else:
# Leaf nodes like multiline strings can occupy multiple lines.
return leaf.lineno + str(leaf).count("\n")
def _get_line_range(node_or_nodes: Union[LN, List[LN]]) -> Set[int]:
"""Returns the line range of this node or list of nodes."""
if isinstance(node_or_nodes, list):
nodes = node_or_nodes
if not nodes:
return set()
first = first_leaf(nodes[0])
last = last_leaf(nodes[-1])
if first and last:
line_start = first.lineno
line_end = _leaf_line_end(last)
return set(range(line_start, line_end + 1))
else:
return set()
else:
node = node_or_nodes
if isinstance(node, Leaf):
return set(range(node.lineno, _leaf_line_end(node) + 1))
else:
first = first_leaf(node)
last = last_leaf(node)
if first and last:
return set(range(first.lineno, _leaf_line_end(last) + 1))
else:
return set()
@dataclass
class _LinesMapping:
"""1-based lines mapping from original source to modified source.
Lines [original_start, original_end] from original source
are mapped to [modified_start, modified_end].
The ranges are inclusive on both ends.
"""
original_start: int
original_end: int
modified_start: int
modified_end: int
# Whether this range corresponds to a changed block, or an unchanged block.
is_changed_block: bool
def _calculate_lines_mappings(
original_source: str,
modified_source: str,
) -> Sequence[_LinesMapping]:
"""Returns a sequence of _LinesMapping by diffing the sources.
For example, given the following diff:
import re
- def func(arg1,
- arg2, arg3):
+ def func(arg1, arg2, arg3):
pass
It returns the following mappings:
original -> modified
(1, 1) -> (1, 1), is_changed_block=False (the "import re" line)
(2, 3) -> (2, 2), is_changed_block=True (the diff)
(4, 4) -> (3, 3), is_changed_block=False (the "pass" line)
You can think of this visually as if it brings up a side-by-side diff, and tries
to map the line ranges from the left side to the right side:
(1, 1)->(1, 1) 1. import re 1. import re
(2, 3)->(2, 2) 2. def func(arg1, 2. def func(arg1, arg2, arg3):
3. arg2, arg3):
(4, 4)->(3, 3) 4. pass 3. pass
Args:
original_source: the original source.
modified_source: the modified source.
"""
matcher = difflib.SequenceMatcher(
None,
original_source.splitlines(keepends=True),
modified_source.splitlines(keepends=True),
)
matching_blocks = matcher.get_matching_blocks()
lines_mappings: List[_LinesMapping] = []
# matching_blocks is a sequence of "same block of code ranges", see
# https://docs.python.org/3/library/difflib.html#difflib.SequenceMatcher.get_matching_blocks
# Each block corresponds to a _LinesMapping with is_changed_block=False,
# and the ranges between two blocks corresponds to a _LinesMapping with
# is_changed_block=True,
# NOTE: matching_blocks is 0-based, but _LinesMapping is 1-based.
for i, block in enumerate(matching_blocks):
if i == 0:
if block.a != 0 or block.b != 0:
lines_mappings.append(
_LinesMapping(
original_start=1,
original_end=block.a,
modified_start=1,
modified_end=block.b,
is_changed_block=False,
)
)
else:
previous_block = matching_blocks[i - 1]
lines_mappings.append(
_LinesMapping(
original_start=previous_block.a + previous_block.size + 1,
original_end=block.a,
modified_start=previous_block.b + previous_block.size + 1,
modified_end=block.b,
is_changed_block=True,
)
)
if i < len(matching_blocks) - 1:
lines_mappings.append(
_LinesMapping(
original_start=block.a + 1,
original_end=block.a + block.size,
modified_start=block.b + 1,
modified_end=block.b + block.size,
is_changed_block=False,
)
)
return lines_mappings
def _find_lines_mapping_index(
original_line: int,
lines_mappings: Sequence[_LinesMapping],
start_index: int,
) -> int:
"""Returns the original index of the lines mappings for the original line."""
index = start_index
while index < len(lines_mappings):
mapping = lines_mappings[index]
if (
mapping.original_start <= original_line
and original_line <= mapping.original_end
):
return index
index += 1
return index

View File

@ -0,0 +1,107 @@
# flags: --line-ranges=5-6
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
def foo1(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
def foo2(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
def foo3(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
def foo4(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
# Adding some unformated code covering a wide range of syntaxes.
if True:
# Incorrectly indented prefix comments.
pass
import typing
from typing import (
Any ,
)
class MyClass( object): # Trailing comment with extra leading space.
#NOTE: The following indentation is incorrect:
@decor( 1 * 3 )
def my_func( arg):
pass
try: # Trailing comment with extra leading space.
for i in range(10): # Trailing comment with extra leading space.
while condition:
if something:
then_something( )
elif something_else:
then_something_else( )
except ValueError as e:
unformatted( )
finally:
unformatted( )
async def test_async_unformatted( ): # Trailing comment with extra leading space.
async for i in some_iter( unformatted ): # Trailing comment with extra leading space.
await asyncio.sleep( 1 )
async with some_context( unformatted ):
print( "unformatted" )
# output
# flags: --line-ranges=5-6
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
def foo1(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
def foo2(
parameter_1,
parameter_2,
parameter_3,
parameter_4,
parameter_5,
parameter_6,
parameter_7,
):
pass
def foo3(
parameter_1,
parameter_2,
parameter_3,
parameter_4,
parameter_5,
parameter_6,
parameter_7,
):
pass
def foo4(parameter_1, parameter_2, parameter_3, parameter_4, parameter_5, parameter_6, parameter_7): pass
# Adding some unformated code covering a wide range of syntaxes.
if True:
# Incorrectly indented prefix comments.
pass
import typing
from typing import (
Any ,
)
class MyClass( object): # Trailing comment with extra leading space.
#NOTE: The following indentation is incorrect:
@decor( 1 * 3 )
def my_func( arg):
pass
try: # Trailing comment with extra leading space.
for i in range(10): # Trailing comment with extra leading space.
while condition:
if something:
then_something( )
elif something_else:
then_something_else( )
except ValueError as e:
unformatted( )
finally:
unformatted( )
async def test_async_unformatted( ): # Trailing comment with extra leading space.
async for i in some_iter( unformatted ): # Trailing comment with extra leading space.
await asyncio.sleep( 1 )
async with some_context( unformatted ):
print( "unformatted" )

View File

@ -0,0 +1,49 @@
# flags: --line-ranges=7-7 --line-ranges=17-23
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
# fmt: off
import os
def myfunc( ): # Intentionally unformatted.
pass
# fmt: on
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
# fmt: off
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
# fmt: on
def myfunc( ): # This will be reformatted.
print( {"this will be reformatted"} )
# output
# flags: --line-ranges=7-7 --line-ranges=17-23
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
# fmt: off
import os
def myfunc( ): # Intentionally unformatted.
pass
# fmt: on
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
# fmt: off
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
# fmt: on
def myfunc(): # This will be reformatted.
print({"this will be reformatted"})

View File

@ -0,0 +1,27 @@
# flags: --line-ranges=12-12
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
# Regression test for an edge case involving decorators and fmt: off/on.
class MyClass:
# fmt: off
@decorator ( )
# fmt: on
def method():
print ( "str" )
# output
# flags: --line-ranges=12-12
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
# Regression test for an edge case involving decorators and fmt: off/on.
class MyClass:
# fmt: off
@decorator ( )
# fmt: on
def method():
print("str")

View File

@ -0,0 +1,37 @@
# flags: --line-ranges=11-17
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
# fmt: off
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
# fmt: on
def myfunc( ): # This will be reformatted.
print( {"this will be reformatted"} )
# output
# flags: --line-ranges=11-17
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
# fmt: off
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
def myfunc( ): # This will not be reformatted.
print( {"also won't be reformatted"} )
# fmt: on
def myfunc(): # This will be reformatted.
print({"this will be reformatted"})

View File

@ -0,0 +1,9 @@
# flags: --line-ranges=8-8
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
# This test ensures no empty lines are added around import lines.
# It caused an issue before https://github.com/psf/black/pull/3610 is merged.
import os
import re
import sys

View File

@ -0,0 +1,27 @@
# flags: --line-ranges=5-5
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
if cond1:
print("first")
if cond2:
print("second")
else:
print("else")
if another_cond:
print("will not be changed")
# output
# flags: --line-ranges=5-5
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
if cond1:
print("first")
if cond2:
print("second")
else:
print("else")
if another_cond:
print("will not be changed")

View File

@ -0,0 +1,27 @@
# flags: --line-ranges=9-11
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
# This is a specific case for Black's two-pass formatting behavior in `format_str`.
# The second pass must respect the line ranges before the first pass.
def restrict_to_this_line(arg1,
arg2,
arg3):
print ( "This should not be formatted." )
print ( "Note that in the second pass, the original line range 9-11 will cover these print lines.")
# output
# flags: --line-ranges=9-11
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
# This is a specific case for Black's two-pass formatting behavior in `format_str`.
# The second pass must respect the line ranges before the first pass.
def restrict_to_this_line(arg1, arg2, arg3):
print ( "This should not be formatted." )
print ( "Note that in the second pass, the original line range 9-11 will cover these print lines.")

View File

@ -0,0 +1,25 @@
# flags: --line-ranges=5-5 --line-ranges=9-9 --line-ranges=13-13
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
alist = [
1, 2
]
adict = {
"key" : "value"
}
func_call (
arg = value
)
# output
# flags: --line-ranges=5-5 --line-ranges=9-9 --line-ranges=13-13
# NOTE: If you need to modify this file, pay special attention to the --line-ranges=
# flag above as it's formatting specifically these lines.
alist = [1, 2]
adict = {"key": "value"}
func_call(arg=value)

View File

@ -0,0 +1,2 @@
[tool.black]
line-ranges = "1-1"

View File

@ -0,0 +1,50 @@
"""Module doc."""
from typing import (
Callable,
Literal,
)
# fmt: off
class Unformatted:
def should_also_work(self):
pass
# fmt: on
a = [1, 2] # fmt: skip
# This should cover as many syntaxes as possible.
class Foo:
"""Class doc."""
def __init__(self) -> None:
pass
@add_logging
@memoize.memoize(max_items=2)
def plus_one(
self,
number: int,
) -> int:
return number + 1
async def async_plus_one(self, number: int) -> int:
await asyncio.sleep(1)
async with some_context():
return number + 1
try:
for i in range(10):
while condition:
if something:
then_something()
elif something_else:
then_something_else()
except ValueError as e:
handle(e)
finally:
done()

View File

@ -0,0 +1,25 @@
# flags: --minimum-version=3.10
def pattern_matching():
match status:
case 1:
return "1"
case [single]:
return "single"
case [
action,
obj,
]:
return "act on obj"
case Point(x=0):
return "class pattern"
case {"text": message}:
return "mapping"
case {
"text": message,
"format": _,
}:
return "mapping"
case _:
return "fallback"

View File

@ -8,6 +8,7 @@
import os
import re
import sys
import textwrap
import types
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager, redirect_stderr
@ -1269,7 +1270,7 @@ def test_reformat_one_with_stdin_filename(self) -> None:
report=report,
)
fsts.assert_called_once_with(
fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE, lines=()
)
# __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES)
@ -1295,6 +1296,7 @@ def test_reformat_one_with_stdin_filename_pyi(self) -> None:
fast=True,
write_back=black.WriteBack.YES,
mode=replace(DEFAULT_MODE, is_pyi=True),
lines=(),
)
# __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES)
@ -1320,6 +1322,7 @@ def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
fast=True,
write_back=black.WriteBack.YES,
mode=replace(DEFAULT_MODE, is_ipynb=True),
lines=(),
)
# __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES)
@ -1941,6 +1944,88 @@ def test_equivalency_ast_parse_failure_includes_error(self) -> None:
err.match("invalid character")
err.match(r"\(<unknown>, line 1\)")
def test_line_ranges_with_code_option(self) -> None:
code = textwrap.dedent("""\
if a == b:
print ( "OK" )
""")
args = ["--line-ranges=1-1", "--code", code]
result = CliRunner().invoke(black.main, args)
expected = textwrap.dedent("""\
if a == b:
print ( "OK" )
""")
self.compare_results(result, expected, expected_exit_code=0)
def test_line_ranges_with_stdin(self) -> None:
code = textwrap.dedent("""\
if a == b:
print ( "OK" )
""")
runner = BlackRunner()
result = runner.invoke(
black.main, ["--line-ranges=1-1", "-"], input=BytesIO(code.encode("utf-8"))
)
expected = textwrap.dedent("""\
if a == b:
print ( "OK" )
""")
self.compare_results(result, expected, expected_exit_code=0)
def test_line_ranges_with_source(self) -> None:
with TemporaryDirectory() as workspace:
test_file = Path(workspace) / "test.py"
test_file.write_text(
textwrap.dedent("""\
if a == b:
print ( "OK" )
"""),
encoding="utf-8",
)
args = ["--line-ranges=1-1", str(test_file)]
result = CliRunner().invoke(black.main, args)
assert not result.exit_code
formatted = test_file.read_text(encoding="utf-8")
expected = textwrap.dedent("""\
if a == b:
print ( "OK" )
""")
assert expected == formatted
def test_line_ranges_with_multiple_sources(self) -> None:
with TemporaryDirectory() as workspace:
test1_file = Path(workspace) / "test1.py"
test1_file.write_text("", encoding="utf-8")
test2_file = Path(workspace) / "test2.py"
test2_file.write_text("", encoding="utf-8")
args = ["--line-ranges=1-1", str(test1_file), str(test2_file)]
result = CliRunner().invoke(black.main, args)
assert result.exit_code == 1
assert "Cannot use --line-ranges to format multiple files" in result.output
def test_line_ranges_with_ipynb(self) -> None:
with TemporaryDirectory() as workspace:
test_file = Path(workspace) / "test.ipynb"
test_file.write_text("{}", encoding="utf-8")
args = ["--line-ranges=1-1", "--ipynb", str(test_file)]
result = CliRunner().invoke(black.main, args)
assert "Cannot use --line-ranges with ipynb files" in result.output
assert result.exit_code == 1
def test_line_ranges_in_pyproject_toml(self) -> None:
config = THIS_DIR / "data" / "invalid_line_ranges.toml"
result = BlackRunner().invoke(
black.main, ["--code", "print()", "--config", str(config)]
)
assert result.exit_code == 2
assert result.stderr_bytes is not None
assert (
b"Cannot use line-ranges in the pyproject.toml file." in result.stderr_bytes
)
class TestCaching:
def test_get_cache_dir(

View File

@ -29,13 +29,19 @@ def check_file(subdir: str, filename: str, *, data: bool = True) -> None:
args.mode,
fast=args.fast,
minimum_version=args.minimum_version,
lines=args.lines,
)
if args.minimum_version is not None:
major, minor = args.minimum_version
target_version = TargetVersion[f"PY{major}{minor}"]
mode = replace(args.mode, target_versions={target_version})
assert_format(
source, expected, mode, fast=args.fast, minimum_version=args.minimum_version
source,
expected,
mode,
fast=args.fast,
minimum_version=args.minimum_version,
lines=args.lines,
)
@ -45,6 +51,24 @@ def test_simple_format(filename: str) -> None:
check_file("cases", filename)
@pytest.mark.parametrize("filename", all_data_cases("line_ranges_formatted"))
def test_line_ranges_line_by_line(filename: str) -> None:
args, source, expected = read_data_with_mode("line_ranges_formatted", filename)
assert (
source == expected
), "Test cases in line_ranges_formatted must already be formatted."
line_count = len(source.splitlines())
for line in range(1, line_count + 1):
assert_format(
source,
expected,
args.mode,
fast=args.fast,
minimum_version=args.minimum_version,
lines=[(line, line)],
)
# =============== #
# Unusual cases
# =============== #

185
tests/test_ranges.py Normal file
View File

@ -0,0 +1,185 @@
"""Test the black.ranges module."""
from typing import List, Tuple
import pytest
from black.ranges import adjusted_lines
@pytest.mark.parametrize(
"lines",
[[(1, 1)], [(1, 3)], [(1, 1), (3, 4)]],
)
def test_no_diff(lines: List[Tuple[int, int]]) -> None:
source = """\
import re
def func():
pass
"""
assert lines == adjusted_lines(lines, source, source)
@pytest.mark.parametrize(
"lines",
[
[(1, 0)],
[(-8, 0)],
[(-8, 8)],
[(1, 100)],
[(2, 1)],
[(0, 8), (3, 1)],
],
)
def test_invalid_lines(lines: List[Tuple[int, int]]) -> None:
original_source = """\
import re
def foo(arg):
'''This is the foo function.
This is foo function's
docstring with more descriptive texts.
'''
def func(arg1,
arg2, arg3):
pass
"""
modified_source = """\
import re
def foo(arg):
'''This is the foo function.
This is foo function's
docstring with more descriptive texts.
'''
def func(arg1, arg2, arg3):
pass
"""
assert not adjusted_lines(lines, original_source, modified_source)
@pytest.mark.parametrize(
"lines,adjusted",
[
(
[(1, 1)],
[(1, 1)],
),
(
[(1, 2)],
[(1, 1)],
),
(
[(1, 6)],
[(1, 2)],
),
(
[(6, 6)],
[],
),
],
)
def test_removals(
lines: List[Tuple[int, int]], adjusted: List[Tuple[int, int]]
) -> None:
original_source = """\
1. first line
2. second line
3. third line
4. fourth line
5. fifth line
6. sixth line
"""
modified_source = """\
2. second line
5. fifth line
"""
assert adjusted == adjusted_lines(lines, original_source, modified_source)
@pytest.mark.parametrize(
"lines,adjusted",
[
(
[(1, 1)],
[(2, 2)],
),
(
[(1, 2)],
[(2, 5)],
),
(
[(2, 2)],
[(5, 5)],
),
],
)
def test_additions(
lines: List[Tuple[int, int]], adjusted: List[Tuple[int, int]]
) -> None:
original_source = """\
1. first line
2. second line
"""
modified_source = """\
this is added
1. first line
this is added
this is added
2. second line
this is added
"""
assert adjusted == adjusted_lines(lines, original_source, modified_source)
@pytest.mark.parametrize(
"lines,adjusted",
[
(
[(1, 11)],
[(1, 10)],
),
(
[(1, 12)],
[(1, 11)],
),
(
[(10, 10)],
[(9, 9)],
),
([(1, 1), (9, 10)], [(1, 1), (9, 9)]),
([(9, 10), (1, 1)], [(1, 1), (9, 9)]),
],
)
def test_diffs(lines: List[Tuple[int, int]], adjusted: List[Tuple[int, int]]) -> None:
original_source = """\
1. import re
2. def foo(arg):
3. '''This is the foo function.
4.
5. This is foo function's
6. docstring with more descriptive texts.
7. '''
8.
9. def func(arg1,
10. arg2, arg3):
11. pass
12. # last line
"""
modified_source = """\
1. import re # changed
2. def foo(arg):
3. '''This is the foo function.
4.
5. This is foo function's
6. docstring with more descriptive texts.
7. '''
8.
9. def func(arg1, arg2, arg3):
11. pass
12. # last line changed
"""
assert adjusted == adjusted_lines(lines, original_source, modified_source)

View File

@ -8,13 +8,14 @@
from dataclasses import dataclass, field, replace
from functools import partial
from pathlib import Path
from typing import Any, Iterator, List, Optional, Tuple
from typing import Any, Collection, Iterator, List, Optional, Tuple
import black
from black.const import DEFAULT_LINE_LENGTH
from black.debug import DebugVisitor
from black.mode import TargetVersion
from black.output import diff, err, out
from black.ranges import parse_line_ranges
from . import conftest
@ -44,6 +45,7 @@ class TestCaseArgs:
mode: black.Mode = field(default_factory=black.Mode)
fast: bool = False
minimum_version: Optional[Tuple[int, int]] = None
lines: Collection[Tuple[int, int]] = ()
def _assert_format_equal(expected: str, actual: str) -> None:
@ -93,6 +95,7 @@ def assert_format(
*,
fast: bool = False,
minimum_version: Optional[Tuple[int, int]] = None,
lines: Collection[Tuple[int, int]] = (),
) -> None:
"""Convenience function to check that Black formats as expected.
@ -101,7 +104,7 @@ def assert_format(
separate from TargetVerson Mode configuration.
"""
_assert_format_inner(
source, expected, mode, fast=fast, minimum_version=minimum_version
source, expected, mode, fast=fast, minimum_version=minimum_version, lines=lines
)
# For both preview and non-preview tests, ensure that Black doesn't crash on
@ -113,6 +116,7 @@ def assert_format(
replace(mode, preview=not mode.preview),
fast=fast,
minimum_version=minimum_version,
lines=lines,
)
except Exception as e:
text = "non-preview" if mode.preview else "preview"
@ -129,6 +133,7 @@ def assert_format(
replace(mode, preview=False, line_length=1),
fast=fast,
minimum_version=minimum_version,
lines=lines,
)
except Exception as e:
raise FormatFailure(
@ -143,8 +148,9 @@ def _assert_format_inner(
*,
fast: bool = False,
minimum_version: Optional[Tuple[int, int]] = None,
lines: Collection[Tuple[int, int]] = (),
) -> None:
actual = black.format_str(source, mode=mode)
actual = black.format_str(source, mode=mode, lines=lines)
if expected is not None:
_assert_format_equal(expected, actual)
# It's not useful to run safety checks if we're expecting no changes anyway. The
@ -156,7 +162,7 @@ def _assert_format_inner(
# when checking modern code on older versions.
if minimum_version is None or sys.version_info >= minimum_version:
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, mode=mode)
black.assert_stable(source, actual, mode=mode, lines=lines)
def dump_to_stderr(*output: str) -> str:
@ -239,6 +245,7 @@ def get_flags_parser() -> argparse.ArgumentParser:
" version works correctly."
),
)
parser.add_argument("--line-ranges", action="append")
return parser
@ -254,7 +261,13 @@ def parse_mode(flags_line: str) -> TestCaseArgs:
magic_trailing_comma=not args.skip_magic_trailing_comma,
preview=args.preview,
)
return TestCaseArgs(mode=mode, fast=args.fast, minimum_version=args.minimum_version)
if args.line_ranges:
lines = parse_line_ranges(args.line_ranges)
else:
lines = []
return TestCaseArgs(
mode=mode, fast=args.fast, minimum_version=args.minimum_version, lines=lines
)
def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]:
@ -267,6 +280,12 @@ def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]:
for line in lines:
if not _input and line.startswith("# flags: "):
mode = parse_mode(line[len("# flags: ") :])
if mode.lines:
# Retain the `# flags: ` line when using --line-ranges=. This requires
# the `# output` section to also include this line, but retaining the
# line is important to make the line ranges match what you see in the
# test file.
result.append(line)
continue
line = line.replace(EMPTY_LINE, "")
if line.rstrip() == "# output":