diff --git a/anybadge/badge.py b/anybadge/badge.py index c9f08c7..3e992e1 100644 --- a/anybadge/badge.py +++ b/anybadge/badge.py @@ -746,21 +746,28 @@ class Badge: (color, ", ".join(list(Color.__members__.keys()))), ) - def write_badge(self, file_path, overwrite=False) -> None: + def write_badge(self, file_path: Union[str, Path], overwrite=False) -> None: """Write badge to file.""" + if isinstance(file_path, str): + + if file_path.endswith("/"): + raise ValueError("File location may not be a directory.") + + file: Path = Path(file_path) + else: + file = file_path + # Validate path (part 1) - if file_path.endswith("/"): + if file.is_dir(): raise ValueError("File location may not be a directory.") - # Get absolute filepath - path = os.path.abspath(file_path) - if not path.lower().endswith(".svg"): - path += ".svg" + # Ensure we're using a .svg extension + file = file.with_suffix(".svg") # Validate path (part 2) - if not overwrite and os.path.exists(path): - raise RuntimeError('File "{}" already exists.'.format(path)) + if not overwrite and file.exists(): + raise RuntimeError('File "{}" already exists.'.format(file)) - with open(path, mode="w") as file_handle: + with open(file, mode="w") as file_handle: file_handle.write(self.badge_svg_text) diff --git a/tests/test_anybadge.py b/tests/test_anybadge.py index d7f3b30..7ca4f25 100644 --- a/tests/test_anybadge.py +++ b/tests/test_anybadge.py @@ -4,6 +4,9 @@ from anybadge import Badge from anybadge.cli import main, parse_args +TESTS_DIR = Path(__file__).parent + + class TestAnybadge(TestCase): """Test case class for anybadge package.""" @@ -267,8 +270,14 @@ class TestAnybadge(TestCase): with self.assertRaisesRegex( RuntimeError, r'File ".*tests\/exists\.svg" already exists\.' ): - badge.write_badge("tests/exists") - badge.write_badge("tests/exists") + badge.write_badge(TESTS_DIR / Path("exists")) + badge.write_badge(TESTS_DIR / Path("exists")) + + with self.assertRaisesRegex( + RuntimeError, r'File ".*tests\/exists\.svg" already exists\.' + ): + badge.write_badge(str(TESTS_DIR / Path("exists"))) + badge.write_badge(str(TESTS_DIR / Path("exists"))) def test_arg_parsing(self): args = parse_args(["-l", "label", "-v", "value"])