from __future__ import annotations import argparse import re import tomllib from pathlib import Path ROOT = Path(__file__).resolve().parent.parent PYPROJECT_PATH = ROOT / "pyproject.toml" PROJECT_VERSION_PATTERN = re.compile( r'(?ms)(^\[project\]\s*$.*?^version\s*=\s*")([^"\r\n]+)(")' ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Update the project version in pyproject.toml.") parser.add_argument("version", nargs="?", help="New project version.") parser.add_argument("--print-current", action="store_true", help="Print the current project version.") args = parser.parse_args() if not args.print_current and not args.version: parser.error("Provide a version or use --print-current.") return args def validate_version(version: str) -> str: normalized = version.strip() if not normalized: raise ValueError("Version cannot be empty.") if normalized != version: raise ValueError("Version cannot start or end with whitespace.") if any(character.isspace() for character in version): raise ValueError("Version cannot contain whitespace.") if '"' in version or "'" in version: raise ValueError("Version cannot contain quote characters.") return version def extract_project_version(pyproject_text: str) -> str: match = PROJECT_VERSION_PATTERN.search(pyproject_text) if match is None: raise ValueError("Could not find [project].version in pyproject.toml.") return match.group(2) def set_project_version_in_text(pyproject_text: str, new_version: str) -> str: validate_version(new_version) def replace(match: re.Match[str]) -> str: return f"{match.group(1)}{new_version}{match.group(3)}" updated_text, replacements = PROJECT_VERSION_PATTERN.subn(replace, pyproject_text, count=1) if replacements != 1: raise ValueError("Could not update [project].version in pyproject.toml.") tomllib.loads(updated_text) return updated_text def write_project_version(pyproject_path: Path, new_version: str) -> tuple[str, str]: original_text = pyproject_path.read_text(encoding="utf-8") current_version = extract_project_version(original_text) updated_text = set_project_version_in_text(original_text, new_version) if updated_text != original_text: pyproject_path.write_text(updated_text, encoding="utf-8") return current_version, new_version def main() -> int: args = parse_args() pyproject_text = PYPROJECT_PATH.read_text(encoding="utf-8") if args.print_current: print(extract_project_version(pyproject_text)) return 0 previous_version, updated_version = write_project_version(PYPROJECT_PATH, args.version) if previous_version == updated_version: print(f"Version already set to {updated_version}") else: print(f"Updated version: {previous_version} -> {updated_version}") return 0 if __name__ == "__main__": raise SystemExit(main())