Update write_badge to support pathlib.Path

This commit is contained in:
Jon Grace-Cox
2022-08-13 09:42:47 -07:00
parent ea0ad53a37
commit 74b5e13067
2 changed files with 27 additions and 11 deletions

View File

@@ -746,21 +746,28 @@ class Badge:
(color, ", ".join(list(Color.__members__.keys()))),
)
def write_badge(self, file_path, overwrite=False) -> None:
def write_badge(self, file_path: Union[str, Path], overwrite=False) -> None:
"""Write badge to file."""
if isinstance(file_path, str):
if file_path.endswith("/"):
raise ValueError("File location may not be a directory.")
file: Path = Path(file_path)
else:
file = file_path
# Validate path (part 1)
if file_path.endswith("/"):
if file.is_dir():
raise ValueError("File location may not be a directory.")
# Get absolute filepath
path = os.path.abspath(file_path)
if not path.lower().endswith(".svg"):
path += ".svg"
# Ensure we're using a .svg extension
file = file.with_suffix(".svg")
# Validate path (part 2)
if not overwrite and os.path.exists(path):
raise RuntimeError('File "{}" already exists.'.format(path))
if not overwrite and file.exists():
raise RuntimeError('File "{}" already exists.'.format(file))
with open(path, mode="w") as file_handle:
with open(file, mode="w") as file_handle:
file_handle.write(self.badge_svg_text)

View File

@@ -4,6 +4,9 @@ from anybadge import Badge
from anybadge.cli import main, parse_args
TESTS_DIR = Path(__file__).parent
class TestAnybadge(TestCase):
"""Test case class for anybadge package."""
@@ -267,8 +270,14 @@ class TestAnybadge(TestCase):
with self.assertRaisesRegex(
RuntimeError, r'File ".*tests\/exists\.svg" already exists\.'
):
badge.write_badge("tests/exists")
badge.write_badge("tests/exists")
badge.write_badge(TESTS_DIR / Path("exists"))
badge.write_badge(TESTS_DIR / Path("exists"))
with self.assertRaisesRegex(
RuntimeError, r'File ".*tests\/exists\.svg" already exists\.'
):
badge.write_badge(str(TESTS_DIR / Path("exists")))
badge.write_badge(str(TESTS_DIR / Path("exists")))
def test_arg_parsing(self):
args = parse_args(["-l", "label", "-v", "value"])