Normalize string quotes (#75)

* Normalize string quotes

Convert single-quoted strings to double-quoted. Convert triple single-quoted strings to triple double-quoted. Do not touch any strings where conversion would increase the number of backslashes.

Fixes #51.

* reformat Black itself
This commit is contained in:
Zsolt Dollenstein 2018-03-31 19:21:25 +01:00 committed by Łukasz Langa
parent 4dfec562ed
commit 80bd2b3134
14 changed files with 350 additions and 266 deletions

253
black.py
View File

@ -47,9 +47,9 @@
Priority = int Priority = int
Index = int Index = int
LN = Union[Leaf, Node] LN = Union[Leaf, Node]
SplitFunc = Callable[['Line', bool], Iterator['Line']] SplitFunc = Callable[["Line", bool], Iterator["Line"]]
out = partial(click.secho, bold=True, err=True) out = partial(click.secho, bold=True, err=True)
err = partial(click.secho, fg='red', err=True) err = partial(click.secho, fg="red", err=True)
class NothingChanged(UserWarning): class NothingChanged(UserWarning):
@ -94,15 +94,15 @@ class FormatOff(FormatError):
@click.command() @click.command()
@click.option( @click.option(
'-l', "-l",
'--line-length', "--line-length",
type=int, type=int,
default=DEFAULT_LINE_LENGTH, default=DEFAULT_LINE_LENGTH,
help='How many character per line to allow.', help="How many character per line to allow.",
show_default=True, show_default=True,
) )
@click.option( @click.option(
'--check', "--check",
is_flag=True, is_flag=True,
help=( help=(
"Don't write back the files, just return the status. Return code 0 " "Don't write back the files, just return the status. Return code 0 "
@ -111,13 +111,13 @@ class FormatOff(FormatError):
), ),
) )
@click.option( @click.option(
'--fast/--safe', "--fast/--safe",
is_flag=True, is_flag=True,
help='If --fast given, skip temporary sanity checks. [default: --safe]', help="If --fast given, skip temporary sanity checks. [default: --safe]",
) )
@click.version_option(version=__version__) @click.version_option(version=__version__)
@click.argument( @click.argument(
'src', "src",
nargs=-1, nargs=-1,
type=click.Path( type=click.Path(
exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
@ -136,17 +136,17 @@ def main(
elif p.is_file(): elif p.is_file():
# if a file was explicitly given, we don't care about its extension # if a file was explicitly given, we don't care about its extension
sources.append(p) sources.append(p)
elif s == '-': elif s == "-":
sources.append(Path('-')) sources.append(Path("-"))
else: else:
err(f'invalid path: {s}') err(f"invalid path: {s}")
if len(sources) == 0: if len(sources) == 0:
ctx.exit(0) ctx.exit(0)
elif len(sources) == 1: elif len(sources) == 1:
p = sources[0] p = sources[0]
report = Report(check=check) report = Report(check=check)
try: try:
if not p.is_file() and str(p) == '-': if not p.is_file() and str(p) == "-":
changed = format_stdin_to_stdout( changed = format_stdin_to_stdout(
line_length=line_length, fast=fast, write_back=not check line_length=line_length, fast=fast, write_back=not check
) )
@ -202,7 +202,7 @@ async def schedule_formatting(
report = Report(check=not write_back) report = Report(check=not write_back)
for src, task in tasks.items(): for src, task in tasks.items():
if not task.done(): if not task.done():
report.failed(src, 'timed out, cancelling') report.failed(src, "timed out, cancelling")
task.cancel() task.cancel()
cancelled.append(task) cancelled.append(task)
elif task.cancelled(): elif task.cancelled():
@ -214,7 +214,7 @@ async def schedule_formatting(
if cancelled: if cancelled:
await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
else: else:
out('All done! ✨ 🍰 ✨') out("All done! ✨ 🍰 ✨")
click.echo(str(report)) click.echo(str(report))
return report.return_code return report.return_code
@ -272,7 +272,7 @@ def format_file_contents(
valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
`line_length` is passed to :func:`format_str`. `line_length` is passed to :func:`format_str`.
""" """
if src_contents.strip() == '': if src_contents.strip() == "":
raise NothingChanged raise NothingChanged
dst_contents = format_str(src_contents, line_length=line_length) dst_contents = format_str(src_contents, line_length=line_length)
@ -319,8 +319,8 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
def lib2to3_parse(src_txt: str) -> Node: def lib2to3_parse(src_txt: str) -> Node:
"""Given a string with source, return the lib2to3 Node.""" """Given a string with source, return the lib2to3 Node."""
grammar = pygram.python_grammar_no_print_statement grammar = pygram.python_grammar_no_print_statement
if src_txt[-1] != '\n': if src_txt[-1] != "\n":
nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n' nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
src_txt += nl src_txt += nl
for grammar in GRAMMARS: for grammar in GRAMMARS:
drv = driver.Driver(grammar, pytree.convert) drv = driver.Driver(grammar, pytree.convert)
@ -350,7 +350,7 @@ def lib2to3_unparse(node: Node) -> str:
return code return code
T = TypeVar('T') T = TypeVar("T")
class Visitor(Generic[T]): class Visitor(Generic[T]):
@ -370,7 +370,7 @@ def visit(self, node: LN) -> Iterator[T]:
name = token.tok_name[node.type] name = token.tok_name[node.type]
else: else:
name = type_repr(node.type) name = type_repr(node.type)
yield from getattr(self, f'visit_{name}', self.visit_default)(node) yield from getattr(self, f"visit_{name}", self.visit_default)(node)
def visit_default(self, node: LN) -> Iterator[T]: def visit_default(self, node: LN) -> Iterator[T]:
"""Default `visit_*()` implementation. Recurses to children of `node`.""" """Default `visit_*()` implementation. Recurses to children of `node`."""
@ -384,24 +384,24 @@ class DebugVisitor(Visitor[T]):
tree_depth: int = 0 tree_depth: int = 0
def visit_default(self, node: LN) -> Iterator[T]: def visit_default(self, node: LN) -> Iterator[T]:
indent = ' ' * (2 * self.tree_depth) indent = " " * (2 * self.tree_depth)
if isinstance(node, Node): if isinstance(node, Node):
_type = type_repr(node.type) _type = type_repr(node.type)
out(f'{indent}{_type}', fg='yellow') out(f"{indent}{_type}", fg="yellow")
self.tree_depth += 1 self.tree_depth += 1
for child in node.children: for child in node.children:
yield from self.visit(child) yield from self.visit(child)
self.tree_depth -= 1 self.tree_depth -= 1
out(f'{indent}/{_type}', fg='yellow', bold=False) out(f"{indent}/{_type}", fg="yellow", bold=False)
else: else:
_type = token.tok_name.get(node.type, str(node.type)) _type = token.tok_name.get(node.type, str(node.type))
out(f'{indent}{_type}', fg='blue', nl=False) out(f"{indent}{_type}", fg="blue", nl=False)
if node.prefix: if node.prefix:
# We don't have to handle prefixes for `Node` objects since # We don't have to handle prefixes for `Node` objects since
# that delegates to the first child anyway. # that delegates to the first child anyway.
out(f' {node.prefix!r}', fg='green', bold=False, nl=False) out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
out(f' {node.value!r}', fg='blue', bold=False) out(f" {node.value!r}", fg="blue", bold=False)
@classmethod @classmethod
def show(cls, code: str) -> None: def show(cls, code: str) -> None:
@ -415,7 +415,7 @@ def show(cls, code: str) -> None:
KEYWORDS = set(keyword.kwlist) KEYWORDS = set(keyword.kwlist)
WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE} WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
FLOW_CONTROL = {'return', 'raise', 'break', 'continue'} FLOW_CONTROL = {"return", "raise", "break", "continue"}
STATEMENT = { STATEMENT = {
syms.if_stmt, syms.if_stmt,
syms.while_stmt, syms.while_stmt,
@ -427,7 +427,7 @@ def show(cls, code: str) -> None:
syms.classdef, syms.classdef,
} }
STANDALONE_COMMENT = 153 STANDALONE_COMMENT = 153
LOGIC_OPERATORS = {'and', 'or'} LOGIC_OPERATORS = {"and", "or"}
COMPARATORS = { COMPARATORS = {
token.LESS, token.LESS,
token.GREATER, token.GREATER,
@ -500,14 +500,14 @@ def mark(self, leaf: Leaf) -> None:
self.delimiters[id(self.previous)] = STRING_PRIORITY self.delimiters[id(self.previous)] = STRING_PRIORITY
elif ( elif (
leaf.type == token.NAME leaf.type == token.NAME
and leaf.value == 'for' and leaf.value == "for"
and leaf.parent and leaf.parent
and leaf.parent.type in {syms.comp_for, syms.old_comp_for} and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
): ):
self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
elif ( elif (
leaf.type == token.NAME leaf.type == token.NAME
and leaf.value == 'if' and leaf.value == "if"
and leaf.parent and leaf.parent
and leaf.parent.type in {syms.comp_if, syms.old_comp_if} and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
): ):
@ -612,7 +612,7 @@ def is_class(self) -> bool:
return ( return (
bool(self) bool(self)
and self.leaves[0].type == token.NAME and self.leaves[0].type == token.NAME
and self.leaves[0].value == 'class' and self.leaves[0].value == "class"
) )
@property @property
@ -628,12 +628,12 @@ def is_def(self) -> bool:
except IndexError: except IndexError:
second_leaf = None second_leaf = None
return ( return (
(first_leaf.type == token.NAME and first_leaf.value == 'def') (first_leaf.type == token.NAME and first_leaf.value == "def")
or ( or (
first_leaf.type == token.ASYNC first_leaf.type == token.ASYNC
and second_leaf is not None and second_leaf is not None
and second_leaf.type == token.NAME and second_leaf.type == token.NAME
and second_leaf.value == 'def' and second_leaf.value == "def"
) )
) )
@ -655,7 +655,7 @@ def is_yield(self) -> bool:
return ( return (
bool(self) bool(self)
and self.leaves[0].type == token.NAME and self.leaves[0].type == token.NAME
and self.leaves[0].value == 'yield' and self.leaves[0].value == "yield"
) )
@property @property
@ -722,7 +722,7 @@ def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
To avoid splitting on the comma in this situation, increase the depth of To avoid splitting on the comma in this situation, increase the depth of
tokens between `for` and `in`. tokens between `for` and `in`.
""" """
if leaf.type == token.NAME and leaf.value == 'for': if leaf.type == token.NAME and leaf.value == "for":
self.has_for = True self.has_for = True
self.bracket_tracker.depth += 1 self.bracket_tracker.depth += 1
self._for_loop_variable = True self._for_loop_variable = True
@ -732,7 +732,7 @@ def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
"""See `maybe_increment_for_loop_variable` above for explanation.""" """See `maybe_increment_for_loop_variable` above for explanation."""
if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in': if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
self.bracket_tracker.depth -= 1 self.bracket_tracker.depth -= 1
self._for_loop_variable = False self._for_loop_variable = False
return True return True
@ -745,7 +745,7 @@ def append_comment(self, comment: Leaf) -> bool:
comment.type == STANDALONE_COMMENT comment.type == STANDALONE_COMMENT
and self.bracket_tracker.any_open_brackets() and self.bracket_tracker.any_open_brackets()
): ):
comment.prefix = '' comment.prefix = ""
return False return False
if comment.type != token.COMMENT: if comment.type != token.COMMENT:
@ -754,7 +754,7 @@ def append_comment(self, comment: Leaf) -> bool:
after = len(self.leaves) - 1 after = len(self.leaves) - 1
if after == -1: if after == -1:
comment.type = STANDALONE_COMMENT comment.type = STANDALONE_COMMENT
comment.prefix = '' comment.prefix = ""
return False return False
else: else:
@ -786,17 +786,17 @@ def remove_trailing_comma(self) -> None:
def __str__(self) -> str: def __str__(self) -> str:
"""Render the line.""" """Render the line."""
if not self: if not self:
return '\n' return "\n"
indent = ' ' * self.depth indent = " " * self.depth
leaves = iter(self.leaves) leaves = iter(self.leaves)
first = next(leaves) first = next(leaves)
res = f'{first.prefix}{indent}{first.value}' res = f"{first.prefix}{indent}{first.value}"
for leaf in leaves: for leaf in leaves:
res += str(leaf) res += str(leaf)
for _, comment in self.comments: for _, comment in self.comments:
res += str(comment) res += str(comment)
return res + '\n' return res + "\n"
def __bool__(self) -> bool: def __bool__(self) -> bool:
"""Return True if the line has leaves or comments.""" """Return True if the line has leaves or comments."""
@ -832,9 +832,9 @@ def __str__(self) -> str:
`depth` is not used for indentation in this case. `depth` is not used for indentation in this case.
""" """
if not self: if not self:
return '\n' return "\n"
res = '' res = ""
for leaf in self.leaves: for leaf in self.leaves:
res += str(leaf) res += str(leaf)
return res return res
@ -888,9 +888,9 @@ def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
if current_line.leaves: if current_line.leaves:
# Consume the first leaf's extra newlines. # Consume the first leaf's extra newlines.
first_leaf = current_line.leaves[0] first_leaf = current_line.leaves[0]
before = first_leaf.prefix.count('\n') before = first_leaf.prefix.count("\n")
before = min(before, max_allowed) before = min(before, max_allowed)
first_leaf.prefix = '' first_leaf.prefix = ""
else: else:
before = 0 before = 0
depth = current_line.depth depth = current_line.depth
@ -1009,6 +1009,8 @@ def visit_default(self, node: LN) -> Iterator[Line]:
else: else:
normalize_prefix(node, inside_brackets=any_open_brackets) normalize_prefix(node, inside_brackets=any_open_brackets)
if node.type == token.STRING:
normalize_string_quotes(node)
if node.type not in WHITESPACE: if node.type not in WHITESPACE:
self.current_line.append(node) self.current_line.append(node)
yield from super().visit_default(node) yield from super().visit_default(node)
@ -1098,14 +1100,14 @@ def visit_unformatted(self, node: LN) -> Iterator[Line]:
def __attrs_post_init__(self) -> None: def __attrs_post_init__(self) -> None:
"""You are in a twisty little maze of passages.""" """You are in a twisty little maze of passages."""
v = self.visit_stmt v = self.visit_stmt
self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'}) self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"})
self.visit_while_stmt = partial(v, keywords={'while', 'else'}) self.visit_while_stmt = partial(v, keywords={"while", "else"})
self.visit_for_stmt = partial(v, keywords={'for', 'else'}) self.visit_for_stmt = partial(v, keywords={"for", "else"})
self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'}) self.visit_try_stmt = partial(v, keywords={"try", "except", "else", "finally"})
self.visit_except_clause = partial(v, keywords={'except'}) self.visit_except_clause = partial(v, keywords={"except"})
self.visit_funcdef = partial(v, keywords={'def'}) self.visit_funcdef = partial(v, keywords={"def"})
self.visit_with_stmt = partial(v, keywords={'with'}) self.visit_with_stmt = partial(v, keywords={"with"})
self.visit_classdef = partial(v, keywords={'class'}) self.visit_classdef = partial(v, keywords={"class"})
self.visit_async_funcdef = self.visit_async_stmt self.visit_async_funcdef = self.visit_async_stmt
self.visit_decorated = self.visit_decorators self.visit_decorated = self.visit_decorators
@ -1119,9 +1121,9 @@ def __attrs_post_init__(self) -> None:
def whitespace(leaf: Leaf) -> str: # noqa C901 def whitespace(leaf: Leaf) -> str: # noqa C901
"""Return whitespace prefix if needed for the given `leaf`.""" """Return whitespace prefix if needed for the given `leaf`."""
NO = '' NO = ""
SPACE = ' ' SPACE = " "
DOUBLESPACE = ' ' DOUBLESPACE = " "
t = leaf.type t = leaf.type
p = leaf.parent p = leaf.parent
v = leaf.value v = leaf.value
@ -1185,7 +1187,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901
and prevp.parent.type == syms.shift_expr and prevp.parent.type == syms.shift_expr
and prevp.prev_sibling and prevp.prev_sibling
and prevp.prev_sibling.type == token.NAME and prevp.prev_sibling.type == token.NAME
and prevp.prev_sibling.value == 'print' # type: ignore and prevp.prev_sibling.value == "print" # type: ignore
): ):
# Python 2 print chevron # Python 2 print chevron
return NO return NO
@ -1342,7 +1344,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901
return NO return NO
elif t == token.NAME: elif t == token.NAME:
if v == 'import': if v == "import":
return SPACE return SPACE
if prev and prev.type == token.DOT: if prev and prev.type == token.DOT:
@ -1416,17 +1418,17 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
if not p: if not p:
return return
if '#' not in p: if "#" not in p:
return return
consumed = 0 consumed = 0
nlines = 0 nlines = 0
for index, line in enumerate(p.split('\n')): for index, line in enumerate(p.split("\n")):
consumed += len(line) + 1 # adding the length of the split '\n' consumed += len(line) + 1 # adding the length of the split '\n'
line = line.lstrip() line = line.lstrip()
if not line: if not line:
nlines += 1 nlines += 1
if not line.startswith('#'): if not line.startswith("#"):
continue continue
if index == 0 and leaf.type != token.ENDMARKER: if index == 0 and leaf.type != token.ENDMARKER:
@ -1434,12 +1436,12 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
else: else:
comment_type = STANDALONE_COMMENT comment_type = STANDALONE_COMMENT
comment = make_comment(line) comment = make_comment(line)
yield Leaf(comment_type, comment, prefix='\n' * nlines) yield Leaf(comment_type, comment, prefix="\n" * nlines)
if comment in {'# fmt: on', '# yapf: enable'}: if comment in {"# fmt: on", "# yapf: enable"}:
raise FormatOn(consumed) raise FormatOn(consumed)
if comment in {'# fmt: off', '# yapf: disable'}: if comment in {"# fmt: off", "# yapf: disable"}:
raise FormatOff(consumed) raise FormatOff(consumed)
nlines = 0 nlines = 0
@ -1455,13 +1457,13 @@ def make_comment(content: str) -> str:
""" """
content = content.rstrip() content = content.rstrip()
if not content: if not content:
return '#' return "#"
if content[0] == '#': if content[0] == "#":
content = content[1:] content = content[1:]
if content and content[0] not in ' !:#': if content and content[0] not in " !:#":
content = ' ' + content content = " " + content
return '#' + content return "#" + content
def split_line( def split_line(
@ -1481,10 +1483,10 @@ def split_line(
yield line yield line
return return
line_str = str(line).strip('\n') line_str = str(line).strip("\n")
if ( if (
len(line_str) <= line_length len(line_str) <= line_length
and '\n' not in line_str # multiline strings and "\n" not in line_str # multiline strings
and not line.contains_standalone_comments and not line.contains_standalone_comments
): ):
yield line yield line
@ -1504,7 +1506,7 @@ def split_line(
result: List[Line] = [] result: List[Line] = []
try: try:
for l in split_func(line, py36): for l in split_func(line, py36):
if str(l).strip('\n') == line_str: if str(l).strip("\n") == line_str:
raise CannotSplit("Split function returned an unchanged result") raise CannotSplit("Split function returned an unchanged result")
result.extend( result.extend(
@ -1703,7 +1705,7 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:
and current_line.leaves[-1].type != token.COMMA and current_line.leaves[-1].type != token.COMMA
and trailing_comma_safe and trailing_comma_safe
): ):
current_line.append(Leaf(token.COMMA, ',')) current_line.append(Leaf(token.COMMA, ","))
yield current_line yield current_line
@ -1749,8 +1751,8 @@ def is_import(leaf: Leaf) -> bool:
return bool( return bool(
t == token.NAME t == token.NAME
and ( and (
(v == 'import' and p and p.type == syms.import_name) (v == "import" and p and p.type == syms.import_name)
or (v == 'from' and p and p.type == syms.import_from) or (v == "from" and p and p.type == syms.import_from)
) )
) )
@ -1762,15 +1764,52 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
Note: don't use backslashes for formatting or you'll lose your voting rights. Note: don't use backslashes for formatting or you'll lose your voting rights.
""" """
if not inside_brackets: if not inside_brackets:
spl = leaf.prefix.split('#') spl = leaf.prefix.split("#")
if '\\' not in spl[0]: if "\\" not in spl[0]:
nl_count = spl[-1].count('\n') nl_count = spl[-1].count("\n")
if len(spl) > 1: if len(spl) > 1:
nl_count -= 1 nl_count -= 1
leaf.prefix = '\n' * nl_count leaf.prefix = "\n" * nl_count
return return
leaf.prefix = '' leaf.prefix = ""
def normalize_string_quotes(leaf: Leaf) -> None:
value = leaf.value.lstrip("furbFURB")
if value[:3] == '"""':
return
elif value[:3] == "'''":
orig_quote = "'''"
new_quote = '"""'
elif value[0] == '"':
orig_quote = '"'
new_quote = "'"
else:
orig_quote = "'"
new_quote = '"'
first_quote_pos = leaf.value.find(orig_quote)
if first_quote_pos == -1:
return # There's an internal error
body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
new_body = body.replace(f"\\{orig_quote}", orig_quote).replace(
new_quote, f"\\{new_quote}"
)
if new_quote == '"""' and new_body[-1] == '"':
# edge case:
new_body = new_body[:-1] + '\\"'
orig_escape_count = body.count("\\")
new_escape_count = new_body.count("\\")
if new_escape_count > orig_escape_count:
return # Do not introduce more escaping
if new_escape_count == orig_escape_count and orig_quote == '"':
return # Prefer double quotes
prefix = leaf.value[:first_quote_pos]
leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
def is_python36(node: Node) -> bool: def is_python36(node: Node) -> bool:
@ -1783,7 +1822,7 @@ def is_python36(node: Node) -> bool:
for n in node.pre_order(): for n in node.pre_order():
if n.type == token.STRING: if n.type == token.STRING:
value_head = n.value[:2] # type: ignore value_head = n.value[:2] # type: ignore
if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}: if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
return True return True
elif ( elif (
@ -1798,9 +1837,9 @@ def is_python36(node: Node) -> bool:
return False return False
PYTHON_EXTENSIONS = {'.py'} PYTHON_EXTENSIONS = {".py"}
BLACKLISTED_DIRECTORIES = { BLACKLISTED_DIRECTORIES = {
'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv' "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
} }
@ -1830,16 +1869,16 @@ class Report:
def done(self, src: Path, changed: bool) -> None: def done(self, src: Path, changed: bool) -> None:
"""Increment the counter for successful reformatting. Write out a message.""" """Increment the counter for successful reformatting. Write out a message."""
if changed: if changed:
reformatted = 'would reformat' if self.check else 'reformatted' reformatted = "would reformat" if self.check else "reformatted"
out(f'{reformatted} {src}') out(f"{reformatted} {src}")
self.change_count += 1 self.change_count += 1
else: else:
out(f'{src} already well formatted, good job.', bold=False) out(f"{src} already well formatted, good job.", bold=False)
self.same_count += 1 self.same_count += 1
def failed(self, src: Path, message: str) -> None: def failed(self, src: Path, message: str) -> None:
"""Increment the counter for failed reformatting. Write out a message.""" """Increment the counter for failed reformatting. Write out a message."""
err(f'error: cannot format {src}: {message}') err(f"error: cannot format {src}: {message}")
self.failure_count += 1 self.failure_count += 1
@property @property
@ -1876,19 +1915,19 @@ def __str__(self) -> str:
failed = "failed to reformat" failed = "failed to reformat"
report = [] report = []
if self.change_count: if self.change_count:
s = 's' if self.change_count > 1 else '' s = "s" if self.change_count > 1 else ""
report.append( report.append(
click.style(f'{self.change_count} file{s} {reformatted}', bold=True) click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
) )
if self.same_count: if self.same_count:
s = 's' if self.same_count > 1 else '' s = "s" if self.same_count > 1 else ""
report.append(f'{self.same_count} file{s} {unchanged}') report.append(f"{self.same_count} file{s} {unchanged}")
if self.failure_count: if self.failure_count:
s = 's' if self.failure_count > 1 else '' s = "s" if self.failure_count > 1 else ""
report.append( report.append(
click.style(f'{self.failure_count} file{s} {failed}', fg='red') click.style(f"{self.failure_count} file{s} {failed}", fg="red")
) )
return ', '.join(report) + '.' return ", ".join(report) + "."
def assert_equivalent(src: str, dst: str) -> None: def assert_equivalent(src: str, dst: str) -> None:
@ -1935,17 +1974,17 @@ def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
try: try:
dst_ast = ast.parse(dst) dst_ast = ast.parse(dst)
except Exception as exc: except Exception as exc:
log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst) log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError( raise AssertionError(
f"INTERNAL ERROR: Black produced invalid code: {exc}. " f"INTERNAL ERROR: Black produced invalid code: {exc}. "
f"Please report a bug on https://github.com/ambv/black/issues. " f"Please report a bug on https://github.com/ambv/black/issues. "
f"This invalid output might be helpful: {log}" f"This invalid output might be helpful: {log}"
) from None ) from None
src_ast_str = '\n'.join(_v(src_ast)) src_ast_str = "\n".join(_v(src_ast))
dst_ast_str = '\n'.join(_v(dst_ast)) dst_ast_str = "\n".join(_v(dst_ast))
if src_ast_str != dst_ast_str: if src_ast_str != dst_ast_str:
log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst')) log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError( raise AssertionError(
f"INTERNAL ERROR: Black produced code that is not equivalent to " f"INTERNAL ERROR: Black produced code that is not equivalent to "
f"the source. " f"the source. "
@ -1959,8 +1998,8 @@ def assert_stable(src: str, dst: str, line_length: int) -> None:
newdst = format_str(dst, line_length=line_length) newdst = format_str(dst, line_length=line_length)
if dst != newdst: if dst != newdst:
log = dump_to_file( log = dump_to_file(
diff(src, dst, 'source', 'first pass'), diff(src, dst, "source", "first pass"),
diff(dst, newdst, 'first pass', 'second pass'), diff(dst, newdst, "first pass", "second pass"),
) )
raise AssertionError( raise AssertionError(
f"INTERNAL ERROR: Black produced different code on the second pass " f"INTERNAL ERROR: Black produced different code on the second pass "
@ -1975,11 +2014,11 @@ def dump_to_file(*output: str) -> str:
import tempfile import tempfile
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(
mode='w', prefix='blk_', suffix='.log', delete=False mode="w", prefix="blk_", suffix=".log", delete=False
) as f: ) as f:
for lines in output: for lines in output:
f.write(lines) f.write(lines)
f.write('\n') f.write("\n")
return f.name return f.name
@ -1987,9 +2026,9 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str:
"""Return a unified diff string between strings `a` and `b`.""" """Return a unified diff string between strings `a` and `b`."""
import difflib import difflib
a_lines = [line + '\n' for line in a.split('\n')] a_lines = [line + "\n" for line in a.split("\n")]
b_lines = [line + '\n' for line in b.split('\n')] b_lines = [line + "\n" for line in b.split("\n")]
return ''.join( return "".join(
difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5) difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
) )
@ -2023,5 +2062,5 @@ def shutdown(loop: BaseEventLoop) -> None:
loop.close() loop.close()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -11,48 +11,48 @@
def get_long_description(): def get_long_description():
readme_md = CURRENT_DIR / 'README.md' readme_md = CURRENT_DIR / "README.md"
with open(readme_md, encoding='utf8') as ld_file: with open(readme_md, encoding="utf8") as ld_file:
return ld_file.read() return ld_file.read()
def get_version(): def get_version():
black_py = CURRENT_DIR / 'black.py' black_py = CURRENT_DIR / "black.py"
_version_re = re.compile(r'__version__\s+=\s+(?P<version>.*)') _version_re = re.compile(r"__version__\s+=\s+(?P<version>.*)")
with open(black_py, 'r', encoding='utf8') as f: with open(black_py, "r", encoding="utf8") as f:
version = _version_re.search(f.read()).group('version') version = _version_re.search(f.read()).group("version")
return str(ast.literal_eval(version)) return str(ast.literal_eval(version))
setup( setup(
name='black', name="black",
version=get_version(), version=get_version(),
description="The uncompromising code formatter.", description="The uncompromising code formatter.",
long_description=get_long_description(), long_description=get_long_description(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords='automation formatter yapf autopep8 pyfmt gofmt rustfmt', keywords="automation formatter yapf autopep8 pyfmt gofmt rustfmt",
author='Łukasz Langa', author="Łukasz Langa",
author_email='lukasz@langa.pl', author_email="lukasz@langa.pl",
url='https://github.com/ambv/black', url="https://github.com/ambv/black",
license='MIT', license="MIT",
py_modules=['black'], py_modules=["black"],
packages=['blib2to3', 'blib2to3.pgen2'], packages=["blib2to3", "blib2to3.pgen2"],
package_data={'blib2to3': ['*.txt']}, package_data={"blib2to3": ["*.txt"]},
python_requires=">=3.6", python_requires=">=3.6",
zip_safe=False, zip_safe=False,
install_requires=['click', 'attrs>=17.4.0'], install_requires=["click", "attrs>=17.4.0"],
test_suite='tests.test_black', test_suite="tests.test_black",
classifiers=[ classifiers=[
'Development Status :: 3 - Alpha', "Development Status :: 3 - Alpha",
'Environment :: Console', "Environment :: Console",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'License :: OSI Approved :: MIT License', "License :: OSI Approved :: MIT License",
'Operating System :: OS Independent', "Operating System :: OS Independent",
'Programming Language :: Python', "Programming Language :: Python",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
'Programming Language :: Python :: 3 :: Only', "Programming Language :: Python :: 3 :: Only",
'Topic :: Software Development :: Libraries :: Python Modules', "Topic :: Software Development :: Libraries :: Python Modules",
'Topic :: Software Development :: Quality Assurance', "Topic :: Software Development :: Quality Assurance",
], ],
entry_points={'console_scripts': ['black=black:main']}, entry_points={"console_scripts": ["black=black:main"]},
) )

View File

@ -43,7 +43,7 @@ def function(default=None):
# Explains why we use global state. # Explains why we use global state.
GLOBAL_STATE = {'a': a(1), 'b': a(2), 'c': a(3)} GLOBAL_STATE = {"a": a(1), "b": a(2), "c": a(3)}
# Another comment! # Another comment!
@ -76,7 +76,7 @@ async def wat():
result = await x.method1() result = await x.method1()
# Comment after ending a block. # Comment after ending a block.
if result: if result:
print('A OK', file=sys.stdout) print("A OK", file=sys.stdout)
# Comment between things. # Comment between things.
print() print()

View File

@ -125,23 +125,23 @@ def inline_comments_in_brackets_ruin_everything():
__all__ = [ __all__ = [
# Super-special typing primitives. # Super-special typing primitives.
'Any', "Any",
'Callable', "Callable",
'ClassVar', "ClassVar",
# ABCs (from collections.abc). # ABCs (from collections.abc).
'AbstractSet', # collections.abc.Set. "AbstractSet", # collections.abc.Set.
'ByteString', "ByteString",
'Container', "Container",
# Concrete collection types. # Concrete collection types.
'Counter', "Counter",
'Deque', "Deque",
'Dict', "Dict",
'DefaultDict', "DefaultDict",
'List', "List",
'Set', "Set",
'FrozenSet', "FrozenSet",
'NamedTuple', # Not really a type. "NamedTuple", # Not really a type.
'Generator', "Generator",
] ]
# Comment before function. # Comment before function.
@ -212,7 +212,7 @@ def inline_comments_in_brackets_ruin_everything():
] ]
lcomp3 = [ lcomp3 = [
# This one is actually too long to fit in a single line. # This one is actually too long to fit in a single line.
element.split('\n', 1)[0] element.split("\n", 1)[0]
# yup # yup
for element in collection.select_elements() for element in collection.select_elements()
# right # right
@ -228,7 +228,7 @@ def inline_comments_in_brackets_ruin_everything():
# let's return # let's return
return Node( return Node(
syms.simple_stmt, syms.simple_stmt,
[Node(statement, result), Leaf(token.NEWLINE, '\n')], # FIXME: \r\n? [Node(statement, result), Leaf(token.NEWLINE, "\n")], # FIXME: \r\n?
) )

View File

@ -1,7 +1,7 @@
def func(): def func():
lcomp3 = [ lcomp3 = [
# This one is actually too long to fit in a single line. # This one is actually too long to fit in a single line.
element.split('\n', 1)[0] element.split("\n", 1)[0]
# yup # yup
for element in collection.select_elements() for element in collection.select_elements()
# right # right

View File

@ -61,7 +61,7 @@ def test_fails_invalid_post_data(
def foo(list_a, list_b): def foo(list_a, list_b):
results = ( results = (
User.query.filter(User.foo == 'bar').filter( # Because foo. User.query.filter(User.foo == "bar").filter( # Because foo.
db.or_(User.field_a.astext.in_(list_a), User.field_b.astext.in_(list_b)) db.or_(User.field_a.astext.in_(list_a), User.field_b.astext.in_(list_b))
).filter( ).filter(
User.xyz.is_(None) User.xyz.is_(None)

View File

@ -3,19 +3,19 @@ class C:
def test(self) -> None: def test(self) -> None:
with patch("black.out", print): with patch("black.out", print):
self.assertEqual( self.assertEqual(
unstyle(str(report)), '1 file reformatted, 1 file failed to reformat.' unstyle(str(report)), "1 file reformatted, 1 file failed to reformat."
) )
self.assertEqual( self.assertEqual(
unstyle(str(report)), unstyle(str(report)),
'1 file reformatted, 1 file left unchanged, 1 file failed to reformat.', "1 file reformatted, 1 file left unchanged, 1 file failed to reformat.",
) )
self.assertEqual( self.assertEqual(
unstyle(str(report)), unstyle(str(report)),
'2 files reformatted, 1 file left unchanged, ' "2 files reformatted, 1 file left unchanged, "
'1 file failed to reformat.', "1 file failed to reformat.",
) )
self.assertEqual( self.assertEqual(
unstyle(str(report)), unstyle(str(report)),
'2 files reformatted, 2 files left unchanged, ' "2 files reformatted, 2 files left unchanged, "
'2 files failed to reformat.', "2 files failed to reformat.",
) )

View File

@ -64,7 +64,7 @@ def g():
return DOUBLESPACE return DOUBLESPACE
# Another comment because more comments # Another comment because more comments
assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" assert p is not None, f'INTERNAL ERROR: hand-made leaf without parent: {leaf!r}'
prev = leaf.prev_sibling prev = leaf.prev_sibling
if not prev: if not prev:
@ -90,9 +90,9 @@ def g():
def f(): def f():
NO = '' NO = ""
SPACE = ' ' SPACE = " "
DOUBLESPACE = ' ' DOUBLESPACE = " "
t = leaf.type t = leaf.type
p = leaf.parent # trailing comment p = leaf.parent # trailing comment
@ -139,9 +139,9 @@ def f():
def g(): def g():
NO = '' NO = ""
SPACE = ' ' SPACE = " "
DOUBLESPACE = ' ' DOUBLESPACE = " "
t = leaf.type t = leaf.type
p = leaf.parent p = leaf.parent

View File

@ -157,8 +157,8 @@ async def f():
... ...
'some_string' "some_string"
b'\\xa3' b"\\xa3"
Name Name
None None
True True
@ -193,18 +193,18 @@ async def f():
lambda arg: None lambda arg: None
lambda a=True: a lambda a=True: a
lambda a, b, c=True: a lambda a, b, c=True: a
lambda a, b, c=True, *, d=(1 << v2), e='str': a lambda a, b, c=True, *, d=(1 << v2), e="str": a
lambda a, b, c=True, *vararg, d=(v1 << 2), e='str', **kwargs: a + b lambda a, b, c=True, *vararg, d=(v1 << 2), e="str", **kwargs: a + b
1 if True else 2 1 if True else 2
str or None if True else str or bytes or None str or None if True else str or bytes or None
(str or None) if True else (str or bytes or None) (str or None) if True else (str or bytes or None)
str or None if (1 if True else 2) else str or bytes or None str or None if (1 if True else 2) else str or bytes or None
(str or None) if (1 if True else 2) else (str or bytes or None) (str or None) if (1 if True else 2) else (str or bytes or None)
{'2.7': dead, '3.7': (long_live or die_hard)} {"2.7": dead, "3.7": (long_live or die_hard)}
{'2.7': dead, '3.7': (long_live or die_hard), **{'3.6': verygood}} {"2.7": dead, "3.7": (long_live or die_hard), **{"3.6": verygood}}
{**a, **b, **c} {**a, **b, **c}
{'2.7', '3.6', '3.7', '3.8', '3.9', ('4.0' if gilectomy else '3.10')} {"2.7", "3.6", "3.7", "3.8", "3.9", ("4.0" if gilectomy else "3.10")}
({'a': 'b'}, (True or False), (+value), 'string', b'bytes') or None ({"a": "b"}, (True or False), (+value), "string", b"bytes") or None
() ()
(1,) (1,)
(1, 2) (1, 2)
@ -214,14 +214,14 @@ async def f():
[1, 2, 3] [1, 2, 3]
{i for i in (1, 2, 3)} {i for i in (1, 2, 3)}
{(i ** 2) for i in (1, 2, 3)} {(i ** 2) for i in (1, 2, 3)}
{(i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))} {(i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))}
{((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)} {((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)}
[i for i in (1, 2, 3)] [i for i in (1, 2, 3)]
[(i ** 2) for i in (1, 2, 3)] [(i ** 2) for i in (1, 2, 3)]
[(i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))] [(i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))]
[((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)] [((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)]
{i: 0 for i in (1, 2, 3)} {i: 0 for i in (1, 2, 3)}
{i: j for i, j in ((1, 'a'), (2, 'b'), (3, 'c'))} {i: j for i, j in ((1, "a"), (2, "b"), (3, "c"))}
{a: b * 2 for a, b in dictionary.items()} {a: b * 2 for a, b in dictionary.items()}
{a: b * -2 for a, b in dictionary.items()} {a: b * -2 for a, b in dictionary.items()}
{ {
@ -232,14 +232,14 @@ async def f():
Life is Life Life is Life
call() call()
call(arg) call(arg)
call(kwarg='hey') call(kwarg="hey")
call(arg, kwarg='hey') call(arg, kwarg="hey")
call(arg, another, kwarg='hey', **kwargs) call(arg, another, kwarg="hey", **kwargs)
call( call(
this_is_a_very_long_variable_which_will_force_a_delimiter_split, this_is_a_very_long_variable_which_will_force_a_delimiter_split,
arg, arg,
another, another,
kwarg='hey', kwarg="hey",
**kwargs **kwargs
) # note: no trailing comma pre-3.6 ) # note: no trailing comma pre-3.6
call(*gidgets[:2]) call(*gidgets[:2])
@ -283,15 +283,15 @@ async def f():
numpy[:, ::-1] numpy[:, ::-1]
numpy[np.newaxis, :] numpy[np.newaxis, :]
(str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None) (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None)
{'2.7': dead, '3.7': long_live or die_hard} {"2.7": dead, "3.7": long_live or die_hard}
{'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'} {"2.7", "3.6", "3.7", "3.8", "3.9", "4.0" if gilectomy else "3.10"}
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C] [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C]
(SomeName) (SomeName)
SomeName SomeName
(Good, Bad, Ugly) (Good, Bad, Ugly)
(i for i in (1, 2, 3)) (i for i in (1, 2, 3))
((i ** 2) for i in (1, 2, 3)) ((i ** 2) for i in (1, 2, 3))
((i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))) ((i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c")))
(((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)) (((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3))
(*starred) (*starred)
a = (1,) a = (1,)

View File

@ -15,10 +15,10 @@ def func_no_args():
for i in range(10): for i in range(10):
print(i) print(i)
continue continue
exec("new-style exec", {}, {}) exec('new-style exec', {}, {})
return None return None
async def coroutine(arg, exec=False): async def coroutine(arg, exec=False):
"Single-line docstring. Multiline is harder to reformat." 'Single-line docstring. Multiline is harder to reformat.'
async with some_connection() as conn: async with some_connection() as conn:
await conn.do_what_i_mean('SELECT bobby, tables FROM xkcd', timeout=2) await conn.do_what_i_mean('SELECT bobby, tables FROM xkcd', timeout=2)
await asyncio.sleep(1) await asyncio.sleep(1)
@ -27,7 +27,7 @@ async def coroutine(arg, exec=False):
with_args=True, with_args=True,
many_args=[1,2,3] many_args=[1,2,3]
) )
def function_signature_stress_test(number:int,no_annotation=None,text:str="default",* ,debug:bool=False,**kwargs) -> str: def function_signature_stress_test(number:int,no_annotation=None,text:str='default',* ,debug:bool=False,**kwargs) -> str:
return text[number:-1] return text[number:-1]
# fmt: on # fmt: on
def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r''): def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r''):
@ -83,7 +83,7 @@ def long_lines():
from library import some_connection, some_decorator from library import some_connection, some_decorator
f'trigger 3.6 mode' f"trigger 3.6 mode"
# fmt: off # fmt: off
def func_no_args(): def func_no_args():
a; b; c a; b; c
@ -92,10 +92,10 @@ def func_no_args():
for i in range(10): for i in range(10):
print(i) print(i)
continue continue
exec("new-style exec", {}, {}) exec('new-style exec', {}, {})
return None return None
async def coroutine(arg, exec=False): async def coroutine(arg, exec=False):
"Single-line docstring. Multiline is harder to reformat." 'Single-line docstring. Multiline is harder to reformat.'
async with some_connection() as conn: async with some_connection() as conn:
await conn.do_what_i_mean('SELECT bobby, tables FROM xkcd', timeout=2) await conn.do_what_i_mean('SELECT bobby, tables FROM xkcd', timeout=2)
await asyncio.sleep(1) await asyncio.sleep(1)
@ -104,12 +104,12 @@ async def coroutine(arg, exec=False):
with_args=True, with_args=True,
many_args=[1,2,3] many_args=[1,2,3]
) )
def function_signature_stress_test(number:int,no_annotation=None,text:str="default",* ,debug:bool=False,**kwargs) -> str: def function_signature_stress_test(number:int,no_annotation=None,text:str='default',* ,debug:bool=False,**kwargs) -> str:
return text[number:-1] return text[number:-1]
# fmt: on # fmt: on
def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r''): def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r""):
offset = attr.ib(default=attr.Factory(lambda: _r.uniform(10000, 200000))) offset = attr.ib(default=attr.Factory(lambda: _r.uniform(10000, 200000)))
assert task._cancel_stack[:len(old_stack)] == old_stack assert task._cancel_stack[:len(old_stack)] == old_stack
@ -123,7 +123,7 @@ def spaces_types(
f: int = -1, f: int = -1,
g: int = 1 if False else 2, g: int = 1 if False else 2,
h: str = "", h: str = "",
i: str = r'', i: str = r"",
): ):
... ...

View File

@ -1,5 +1,5 @@
f'f-string without formatted values is just a string' f"f-string without formatted values is just a string"
f'{{NOT a formatted value}}' f"{{NOT a formatted value}}"
f'some f-string with {a} {few():.2f} {formatted.values!r}' f"some f-string with {a} {few():.2f} {formatted.values!r}"
f"{f'{nested} inner'} outer" f"{f'{nested} inner'} outer"
f'space between opening braces: { {a for a in (1, 2, 3)}}' f"space between opening braces: { {a for a in (1, 2, 3)}}"

View File

@ -80,7 +80,7 @@ def long_lines():
from library import some_connection, some_decorator from library import some_connection, some_decorator
f'trigger 3.6 mode' f"trigger 3.6 mode"
def func_no_args(): def func_no_args():
@ -103,7 +103,7 @@ def func_no_args():
async def coroutine(arg, exec=False): async def coroutine(arg, exec=False):
"Single-line docstring. Multiline is harder to reformat." "Single-line docstring. Multiline is harder to reformat."
async with some_connection() as conn: async with some_connection() as conn:
await conn.do_what_i_mean('SELECT bobby, tables FROM xkcd', timeout=2) await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2)
await asyncio.sleep(1) await asyncio.sleep(1)
@ -120,7 +120,7 @@ def function_signature_stress_test(
return text[number:-1] return text[number:-1]
def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r''): def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r""):
offset = attr.ib(default=attr.Factory(lambda: _r.uniform(10000, 200000))) offset = attr.ib(default=attr.Factory(lambda: _r.uniform(10000, 200000)))
assert task._cancel_stack[:len(old_stack)] == old_stack assert task._cancel_stack[:len(old_stack)] == old_stack
@ -134,7 +134,7 @@ def spaces_types(
f: int = -1, f: int = -1,
g: int = 1 if False else 2, g: int = 1 if False else 2,
h: str = "", h: str = "",
i: str = r'', i: str = r"",
): ):
... ...

37
tests/string_quotes.py Normal file
View File

@ -0,0 +1,37 @@
"Hello"
"Don't do that"
'Here is a "'
'What\'s the deal here?'
"What's the deal \"here\"?"
"And \"here\"?"
"""Strings with "" in them"""
'''Strings with "" in them'''
'''Here's a "'''
'''Here's a " '''
'''Just a normal triple
quote'''
f"just a normal {f} string"
f'''This is a triple-quoted {f}-string'''
f'MOAR {" ".join([])}'
f"MOAR {' '.join([])}"
r"raw string ftw"
# output
"Hello"
"Don't do that"
'Here is a "'
"What's the deal here?"
'What\'s the deal "here"?'
'And "here"?'
"""Strings with "" in them"""
"""Strings with "" in them"""
'''Here's a "'''
"""Here's a " """
"""Just a normal triple
quote"""
f"just a normal {f} string"
f"""This is a triple-quoted {f}-string"""
f'MOAR {" ".join([])}'
f"MOAR {' '.join([])}"
r"raw string ftw"

View File

@ -17,25 +17,25 @@
fs = partial(black.format_str, line_length=ll) fs = partial(black.format_str, line_length=ll)
THIS_FILE = Path(__file__) THIS_FILE = Path(__file__)
THIS_DIR = THIS_FILE.parent THIS_DIR = THIS_FILE.parent
EMPTY_LINE = '# EMPTY LINE WITH WHITESPACE' + ' (this comment will be removed)' EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
def dump_to_stderr(*output: str) -> str: def dump_to_stderr(*output: str) -> str:
return '\n' + '\n'.join(output) + '\n' return "\n" + "\n".join(output) + "\n"
def read_data(name: str) -> Tuple[str, str]: def read_data(name: str) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'""" """read_data('test_name') -> 'input', 'output'"""
if not name.endswith(('.py', '.out')): if not name.endswith((".py", ".out")):
name += '.py' name += ".py"
_input: List[str] = [] _input: List[str] = []
_output: List[str] = [] _output: List[str] = []
with open(THIS_DIR / name, 'r', encoding='utf8') as test: with open(THIS_DIR / name, "r", encoding="utf8") as test:
lines = test.readlines() lines = test.readlines()
result = _input result = _input
for line in lines: for line in lines:
line = line.replace(EMPTY_LINE, '') line = line.replace(EMPTY_LINE, "")
if line.rstrip() == '# output': if line.rstrip() == "# output":
result = _output result = _output
continue continue
@ -43,23 +43,23 @@ def read_data(name: str) -> Tuple[str, str]:
if _input and not _output: if _input and not _output:
# If there's no output marker, treat the entire file as already pre-formatted. # If there's no output marker, treat the entire file as already pre-formatted.
_output = _input[:] _output = _input[:]
return ''.join(_input).strip() + '\n', ''.join(_output).strip() + '\n' return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
class BlackTestCase(unittest.TestCase): class BlackTestCase(unittest.TestCase):
maxDiff = None maxDiff = None
def assertFormatEqual(self, expected: str, actual: str) -> None: def assertFormatEqual(self, expected: str, actual: str) -> None:
if actual != expected and not os.environ.get('SKIP_AST_PRINT'): if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
bdv: black.DebugVisitor[Any] bdv: black.DebugVisitor[Any]
black.out('Expected tree:', fg='green') black.out("Expected tree:", fg="green")
try: try:
exp_node = black.lib2to3_parse(expected) exp_node = black.lib2to3_parse(expected)
bdv = black.DebugVisitor() bdv = black.DebugVisitor()
list(bdv.visit(exp_node)) list(bdv.visit(exp_node))
except Exception as ve: except Exception as ve:
black.err(str(ve)) black.err(str(ve))
black.out('Actual tree:', fg='red') black.out("Actual tree:", fg="red")
try: try:
exp_node = black.lib2to3_parse(actual) exp_node = black.lib2to3_parse(actual)
bdv = black.DebugVisitor() bdv = black.DebugVisitor()
@ -70,7 +70,7 @@ def assertFormatEqual(self, expected: str, actual: str) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_self(self) -> None: def test_self(self) -> None:
source, expected = read_data('test_black') source, expected = read_data("test_black")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -79,19 +79,19 @@ def test_self(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_black(self) -> None: def test_black(self) -> None:
source, expected = read_data('../black') source, expected = read_data("../black")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, line_length=ll)
self.assertFalse(ff(THIS_DIR / '..' / 'black.py')) self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
def test_piping(self) -> None: def test_piping(self) -> None:
source, expected = read_data('../black') source, expected = read_data("../black")
hold_stdin, hold_stdout = sys.stdin, sys.stdout hold_stdin, hold_stdout = sys.stdin, sys.stdout
try: try:
sys.stdin, sys.stdout = StringIO(source), StringIO() sys.stdin, sys.stdout = StringIO(source), StringIO()
sys.stdin.name = '<stdin>' sys.stdin.name = "<stdin>"
black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True) black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True)
sys.stdout.seek(0) sys.stdout.seek(0)
actual = sys.stdout.read() actual = sys.stdout.read()
@ -103,16 +103,16 @@ def test_piping(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_setup(self) -> None: def test_setup(self) -> None:
source, expected = read_data('../setup') source, expected = read_data("../setup")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll) black.assert_stable(source, actual, line_length=ll)
self.assertFalse(ff(THIS_DIR / '..' / 'setup.py')) self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_function(self) -> None: def test_function(self) -> None:
source, expected = read_data('function') source, expected = read_data("function")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -120,7 +120,7 @@ def test_function(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_expression(self) -> None: def test_expression(self) -> None:
source, expected = read_data('expression') source, expected = read_data("expression")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -128,7 +128,15 @@ def test_expression(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_fstring(self) -> None: def test_fstring(self) -> None:
source, expected = read_data('fstring') source, expected = read_data("fstring")
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
@patch("black.dump_to_file", dump_to_stderr)
def test_string_quotes(self) -> None:
source, expected = read_data("string_quotes")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -136,7 +144,7 @@ def test_fstring(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_comments(self) -> None: def test_comments(self) -> None:
source, expected = read_data('comments') source, expected = read_data("comments")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -144,7 +152,7 @@ def test_comments(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_comments2(self) -> None: def test_comments2(self) -> None:
source, expected = read_data('comments2') source, expected = read_data("comments2")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -152,7 +160,7 @@ def test_comments2(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_comments3(self) -> None: def test_comments3(self) -> None:
source, expected = read_data('comments3') source, expected = read_data("comments3")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -160,7 +168,7 @@ def test_comments3(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_comments4(self) -> None: def test_comments4(self) -> None:
source, expected = read_data('comments4') source, expected = read_data("comments4")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -168,7 +176,7 @@ def test_comments4(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_cantfit(self) -> None: def test_cantfit(self) -> None:
source, expected = read_data('cantfit') source, expected = read_data("cantfit")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -176,7 +184,7 @@ def test_cantfit(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_import_spacing(self) -> None: def test_import_spacing(self) -> None:
source, expected = read_data('import_spacing') source, expected = read_data("import_spacing")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -184,7 +192,7 @@ def test_import_spacing(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_composition(self) -> None: def test_composition(self) -> None:
source, expected = read_data('composition') source, expected = read_data("composition")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -192,7 +200,7 @@ def test_composition(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_empty_lines(self) -> None: def test_empty_lines(self) -> None:
source, expected = read_data('empty_lines') source, expected = read_data("empty_lines")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -200,7 +208,7 @@ def test_empty_lines(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_python2(self) -> None: def test_python2(self) -> None:
source, expected = read_data('python2') source, expected = read_data("python2")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
# black.assert_equivalent(source, actual) # black.assert_equivalent(source, actual)
@ -208,7 +216,7 @@ def test_python2(self) -> None:
@patch("black.dump_to_file", dump_to_stderr) @patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff(self) -> None: def test_fmtonoff(self) -> None:
source, expected = read_data('fmtonoff') source, expected = read_data("fmtonoff")
actual = fs(source) actual = fs(source)
self.assertFormatEqual(expected, actual) self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual) black.assert_equivalent(source, actual)
@ -226,68 +234,68 @@ def err(msg: str, **kwargs: Any) -> None:
err_lines.append(msg) err_lines.append(msg)
with patch("black.out", out), patch("black.err", err): with patch("black.out", out), patch("black.err", err):
report.done(Path('f1'), changed=False) report.done(Path("f1"), changed=False)
self.assertEqual(len(out_lines), 1) self.assertEqual(len(out_lines), 1)
self.assertEqual(len(err_lines), 0) self.assertEqual(len(err_lines), 0)
self.assertEqual(out_lines[-1], 'f1 already well formatted, good job.') self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
self.assertEqual(unstyle(str(report)), '1 file left unchanged.') self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
self.assertEqual(report.return_code, 0) self.assertEqual(report.return_code, 0)
report.done(Path('f2'), changed=True) report.done(Path("f2"), changed=True)
self.assertEqual(len(out_lines), 2) self.assertEqual(len(out_lines), 2)
self.assertEqual(len(err_lines), 0) self.assertEqual(len(err_lines), 0)
self.assertEqual(out_lines[-1], 'reformatted f2') self.assertEqual(out_lines[-1], "reformatted f2")
self.assertEqual( self.assertEqual(
unstyle(str(report)), '1 file reformatted, 1 file left unchanged.' unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
) )
self.assertEqual(report.return_code, 0) self.assertEqual(report.return_code, 0)
report.check = True report.check = True
self.assertEqual(report.return_code, 1) self.assertEqual(report.return_code, 1)
report.check = False report.check = False
report.failed(Path('e1'), 'boom') report.failed(Path("e1"), "boom")
self.assertEqual(len(out_lines), 2) self.assertEqual(len(out_lines), 2)
self.assertEqual(len(err_lines), 1) self.assertEqual(len(err_lines), 1)
self.assertEqual(err_lines[-1], 'error: cannot format e1: boom') self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
self.assertEqual( self.assertEqual(
unstyle(str(report)), unstyle(str(report)),
'1 file reformatted, 1 file left unchanged, ' "1 file reformatted, 1 file left unchanged, "
'1 file failed to reformat.', "1 file failed to reformat.",
) )
self.assertEqual(report.return_code, 123) self.assertEqual(report.return_code, 123)
report.done(Path('f3'), changed=True) report.done(Path("f3"), changed=True)
self.assertEqual(len(out_lines), 3) self.assertEqual(len(out_lines), 3)
self.assertEqual(len(err_lines), 1) self.assertEqual(len(err_lines), 1)
self.assertEqual(out_lines[-1], 'reformatted f3') self.assertEqual(out_lines[-1], "reformatted f3")
self.assertEqual( self.assertEqual(
unstyle(str(report)), unstyle(str(report)),
'2 files reformatted, 1 file left unchanged, ' "2 files reformatted, 1 file left unchanged, "
'1 file failed to reformat.', "1 file failed to reformat.",
) )
self.assertEqual(report.return_code, 123) self.assertEqual(report.return_code, 123)
report.failed(Path('e2'), 'boom') report.failed(Path("e2"), "boom")
self.assertEqual(len(out_lines), 3) self.assertEqual(len(out_lines), 3)
self.assertEqual(len(err_lines), 2) self.assertEqual(len(err_lines), 2)
self.assertEqual(err_lines[-1], 'error: cannot format e2: boom') self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
self.assertEqual( self.assertEqual(
unstyle(str(report)), unstyle(str(report)),
'2 files reformatted, 1 file left unchanged, ' "2 files reformatted, 1 file left unchanged, "
'2 files failed to reformat.', "2 files failed to reformat.",
) )
self.assertEqual(report.return_code, 123) self.assertEqual(report.return_code, 123)
report.done(Path('f4'), changed=False) report.done(Path("f4"), changed=False)
self.assertEqual(len(out_lines), 4) self.assertEqual(len(out_lines), 4)
self.assertEqual(len(err_lines), 2) self.assertEqual(len(err_lines), 2)
self.assertEqual(out_lines[-1], 'f4 already well formatted, good job.') self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
self.assertEqual( self.assertEqual(
unstyle(str(report)), unstyle(str(report)),
'2 files reformatted, 2 files left unchanged, ' "2 files reformatted, 2 files left unchanged, "
'2 files failed to reformat.', "2 files failed to reformat.",
) )
self.assertEqual(report.return_code, 123) self.assertEqual(report.return_code, 123)
report.check = True report.check = True
self.assertEqual( self.assertEqual(
unstyle(str(report)), unstyle(str(report)),
'2 files would be reformatted, 2 files would be left unchanged, ' "2 files would be reformatted, 2 files would be left unchanged, "
'2 files would fail to reformat.', "2 files would fail to reformat.",
) )
def test_is_python36(self) -> None: def test_is_python36(self) -> None:
@ -297,20 +305,20 @@ def test_is_python36(self) -> None:
self.assertTrue(black.is_python36(node)) self.assertTrue(black.is_python36(node))
node = black.lib2to3_parse("def f(*, arg): f'string'\n") node = black.lib2to3_parse("def f(*, arg): f'string'\n")
self.assertTrue(black.is_python36(node)) self.assertTrue(black.is_python36(node))
source, expected = read_data('function') source, expected = read_data("function")
node = black.lib2to3_parse(source) node = black.lib2to3_parse(source)
self.assertTrue(black.is_python36(node)) self.assertTrue(black.is_python36(node))
node = black.lib2to3_parse(expected) node = black.lib2to3_parse(expected)
self.assertTrue(black.is_python36(node)) self.assertTrue(black.is_python36(node))
source, expected = read_data('expression') source, expected = read_data("expression")
node = black.lib2to3_parse(source) node = black.lib2to3_parse(source)
self.assertFalse(black.is_python36(node)) self.assertFalse(black.is_python36(node))
node = black.lib2to3_parse(expected) node = black.lib2to3_parse(expected)
self.assertFalse(black.is_python36(node)) self.assertFalse(black.is_python36(node))
def test_debug_visitor(self) -> None: def test_debug_visitor(self) -> None:
source, _ = read_data('debug_visitor.py') source, _ = read_data("debug_visitor.py")
expected, _ = read_data('debug_visitor.out') expected, _ = read_data("debug_visitor.out")
out_lines = [] out_lines = []
err_lines = [] err_lines = []
@ -322,8 +330,8 @@ def err(msg: str, **kwargs: Any) -> None:
with patch("black.out", out), patch("black.err", err): with patch("black.out", out), patch("black.err", err):
black.DebugVisitor.show(source) black.DebugVisitor.show(source)
actual = '\n'.join(out_lines) + '\n' actual = "\n".join(out_lines) + "\n"
log_name = '' log_name = ""
if expected != actual: if expected != actual:
log_name = black.dump_to_file(*out_lines) log_name = black.dump_to_file(*out_lines)
self.assertEqual( self.assertEqual(
@ -333,5 +341,5 @@ def err(msg: str, **kwargs: Any) -> None:
) )
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()