Improve get_future_imports implementation.

Closes #389.
This commit is contained in:
Zsolt Dollenstein 2018-07-02 17:48:48 +01:00
parent 3bdd423891
commit dd8bde6d2f
3 changed files with 33 additions and 12 deletions

View File

@ -20,6 +20,7 @@
Callable,
Collection,
Dict,
Generator,
Generic,
Iterable,
Iterator,
@ -2910,7 +2911,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
def get_future_imports(node: Node) -> Set[str]:
"""Return a set of __future__ imports in the file."""
imports = set()
imports: Set[str] = set()
def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
for child in children:
if isinstance(child, Leaf):
if child.type == token.NAME:
yield child.value
elif child.type == syms.import_as_name:
orig_name = child.children[0]
assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
yield orig_name.value
elif child.type == syms.import_as_names:
yield from get_imports_from_children(child.children)
else:
assert False, "Invalid syntax parsing imports"
for child in node.children:
if child.type != syms.simple_stmt:
break
@ -2929,15 +2946,7 @@ def get_future_imports(node: Node) -> Set[str]:
module_name = first_child.children[1]
if not isinstance(module_name, Leaf) or module_name.value != "__future__":
break
for import_from_child in first_child.children[3:]:
if isinstance(import_from_child, Leaf):
if import_from_child.type == token.NAME:
imports.add(import_from_child.value)
else:
assert import_from_child.type == syms.import_as_names
for leaf in import_from_child.children:
if isinstance(leaf, Leaf) and leaf.type == token.NAME:
imports.add(leaf.value)
imports |= set(get_imports_from_children(first_child.children[3:]))
else:
break
return imports

View File

@ -1,5 +1,7 @@
#!/usr/bin/env python2
from __future__ import unicode_literals
from __future__ import unicode_literals as _unicode_literals
from __future__ import absolute_import
from __future__ import print_function as lol, with_function
u'hello'
U"hello"
@ -9,7 +11,9 @@
#!/usr/bin/env python2
from __future__ import unicode_literals
from __future__ import unicode_literals as _unicode_literals
from __future__ import absolute_import
from __future__ import print_function as lol, with_function
"hello"
"hello"

View File

@ -735,6 +735,14 @@ def test_get_future_imports(self) -> None:
self.assertEqual(set(), black.get_future_imports(node))
node = black.lib2to3_parse("from some.module import black\n")
self.assertEqual(set(), black.get_future_imports(node))
node = black.lib2to3_parse(
"from __future__ import unicode_literals as _unicode_literals"
)
self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
node = black.lib2to3_parse(
"from __future__ import unicode_literals as _lol, print"
)
self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
def test_debug_visitor(self) -> None:
source, _ = read_data("debug_visitor.py")