Improve get_future_imports implementation.

Closes #389.
This commit is contained in:
Zsolt Dollenstein 2018-07-02 17:48:48 +01:00
parent 3bdd423891
commit dd8bde6d2f
3 changed files with 33 additions and 12 deletions

View File

@ -20,6 +20,7 @@
Callable, Callable,
Collection, Collection,
Dict, Dict,
Generator,
Generic, Generic,
Iterable, Iterable,
Iterator, Iterator,
@ -2910,7 +2911,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
def get_future_imports(node: Node) -> Set[str]: def get_future_imports(node: Node) -> Set[str]:
"""Return a set of __future__ imports in the file.""" """Return a set of __future__ imports in the file."""
imports = set() imports: Set[str] = set()
def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
for child in children:
if isinstance(child, Leaf):
if child.type == token.NAME:
yield child.value
elif child.type == syms.import_as_name:
orig_name = child.children[0]
assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
yield orig_name.value
elif child.type == syms.import_as_names:
yield from get_imports_from_children(child.children)
else:
assert False, "Invalid syntax parsing imports"
for child in node.children: for child in node.children:
if child.type != syms.simple_stmt: if child.type != syms.simple_stmt:
break break
@ -2929,15 +2946,7 @@ def get_future_imports(node: Node) -> Set[str]:
module_name = first_child.children[1] module_name = first_child.children[1]
if not isinstance(module_name, Leaf) or module_name.value != "__future__": if not isinstance(module_name, Leaf) or module_name.value != "__future__":
break break
for import_from_child in first_child.children[3:]: imports |= set(get_imports_from_children(first_child.children[3:]))
if isinstance(import_from_child, Leaf):
if import_from_child.type == token.NAME:
imports.add(import_from_child.value)
else:
assert import_from_child.type == syms.import_as_names
for leaf in import_from_child.children:
if isinstance(leaf, Leaf) and leaf.type == token.NAME:
imports.add(leaf.value)
else: else:
break break
return imports return imports

View File

@ -1,5 +1,7 @@
#!/usr/bin/env python2 #!/usr/bin/env python2
from __future__ import unicode_literals from __future__ import unicode_literals as _unicode_literals
from __future__ import absolute_import
from __future__ import print_function as lol, with_function
u'hello' u'hello'
U"hello" U"hello"
@ -9,7 +11,9 @@
#!/usr/bin/env python2 #!/usr/bin/env python2
from __future__ import unicode_literals from __future__ import unicode_literals as _unicode_literals
from __future__ import absolute_import
from __future__ import print_function as lol, with_function
"hello" "hello"
"hello" "hello"

View File

@ -735,6 +735,14 @@ def test_get_future_imports(self) -> None:
self.assertEqual(set(), black.get_future_imports(node)) self.assertEqual(set(), black.get_future_imports(node))
node = black.lib2to3_parse("from some.module import black\n") node = black.lib2to3_parse("from some.module import black\n")
self.assertEqual(set(), black.get_future_imports(node)) self.assertEqual(set(), black.get_future_imports(node))
node = black.lib2to3_parse(
"from __future__ import unicode_literals as _unicode_literals"
)
self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
node = black.lib2to3_parse(
"from __future__ import unicode_literals as _lol, print"
)
self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
def test_debug_visitor(self) -> None: def test_debug_visitor(self) -> None:
source, _ = read_data("debug_visitor.py") source, _ = read_data("debug_visitor.py")