Find project root correctly (#1518)

Ensure root dir is a common parent of all inputs
Fixes #1493
This commit is contained in:
Lihu Ben-Ezri-Ravin 2020-06-24 05:09:07 -04:00 committed by GitHub
parent f90f50a743
commit 2471b9256d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 7 deletions

View File

@ -5825,8 +5825,8 @@ def gen_python_files(
def find_project_root(srcs: Iterable[str]) -> Path:
"""Return a directory containing .git, .hg, or pyproject.toml.
That directory can be one of the directories passed in `srcs` or their
common parent.
That directory will be a common parent of all files and directories
passed in `srcs`.
If no directory in the tree contains a marker that would specify it's the
project root, the root of the file system is returned.
@ -5834,11 +5834,20 @@ def find_project_root(srcs: Iterable[str]) -> Path:
if not srcs:
return Path("/").resolve()
common_base = min(Path(src).resolve() for src in srcs)
if common_base.is_dir():
# Append a fake file so `parents` below returns `common_base_dir`, too.
common_base /= "fake-file"
for directory in common_base.parents:
path_srcs = [Path(src).resolve() for src in srcs]
# A list of lists of parents for each 'src'. 'src' is included as a
# "parent" of itself if it is a directory
src_parents = [
list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
]
common_base = max(
set.intersection(*(set(parents) for parents in src_parents)),
key=lambda path: path.parts,
)
for directory in (common_base, *common_base.parents):
if (directory / ".git").exists():
return directory

View File

@ -1801,6 +1801,28 @@ def __init__(self) -> None:
self.assertEqual(config["exclude"], r"\.pyi?$")
self.assertEqual(config["include"], r"\.py?$")
def test_find_project_root(self) -> None:
with TemporaryDirectory() as workspace:
root = Path(workspace)
test_dir = root / "test"
test_dir.mkdir()
src_dir = root / "src"
src_dir.mkdir()
root_pyproject = root / "pyproject.toml"
root_pyproject.touch()
src_pyproject = src_dir / "pyproject.toml"
src_pyproject.touch()
src_python = src_dir / "foo.py"
src_python.touch()
self.assertEqual(
black.find_project_root((src_dir, test_dir)), root.resolve()
)
self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
class BlackDTestCase(AioHTTPTestCase):
async def get_application(self) -> web.Application: