From b988532aaa97b7ff01cde776614da45490bcb1b3 Mon Sep 17 00:00:00 2001 From: claudi Date: Tue, 10 Mar 2026 16:02:24 +0100 Subject: [PATCH] feat: implement brand-specific configuration and update management for Agravity Bridge --- config.example.json | 8 +- src/webdrop_bridge/config.py | 119 ++++++++++++++-- src/webdrop_bridge/core/config_manager.py | 19 ++- src/webdrop_bridge/core/updater.py | 162 +++++++++++++++++++--- src/webdrop_bridge/main.py | 2 + src/webdrop_bridge/ui/main_window.py | 28 +++- src/webdrop_bridge/ui/settings_dialog.py | 4 +- tests/unit/test_config.py | 53 ++++++- tests/unit/test_updater.py | 114 +++++++++++++++ 9 files changed, 461 insertions(+), 48 deletions(-) diff --git a/config.example.json b/config.example.json index c97367e..d93d339 100644 --- a/config.example.json +++ b/config.example.json @@ -1,6 +1,12 @@ { - "app_name": "WebDrop Bridge", + "brand_id": "agravity", + "config_dir_name": "agravity_bridge", + "app_name": "Agravity Bridge", "webapp_url": "https://dev.agravity.io/", + "update_base_url": "https://git.him-tools.de", + "update_repo": "HIM-public/webdrop-bridge", + "update_channel": "stable", + "update_manifest_name": "release-manifest.json", "url_mappings": [ { "url_prefix": "https://devagravitystg.file.core.windows.net/devagravitysync/", diff --git a/src/webdrop_bridge/config.py b/src/webdrop_bridge/config.py index c20c93d..f4c035f 100644 --- a/src/webdrop_bridge/config.py +++ b/src/webdrop_bridge/config.py @@ -3,6 +3,7 @@ import json import logging import os +import sys from dataclasses import dataclass, field from pathlib import Path from typing import List @@ -11,6 +12,13 @@ from dotenv import load_dotenv logger = logging.getLogger(__name__) +DEFAULT_BRAND_ID = "webdrop_bridge" +DEFAULT_CONFIG_DIR_NAME = "webdrop_bridge" +DEFAULT_UPDATE_BASE_URL = "https://git.him-tools.de" +DEFAULT_UPDATE_REPO = "HIM-public/webdrop-bridge" +DEFAULT_UPDATE_CHANNEL = "stable" +DEFAULT_UPDATE_MANIFEST_NAME = "release-manifest.json" + class ConfigurationError(Exception): """Raised when configuration is invalid.""" @@ -60,6 +68,12 @@ class Config: enable_logging: Whether to write logs to file enable_checkout: Whether to check asset checkout status and show checkout dialog on drag. Disabled by default as checkout support is optional. + brand_id: Stable brand identifier used for packaging and update selection + config_dir_name: AppData/config directory name for this branded variant + update_base_url: Base Forgejo URL used for release checks + update_repo: Forgejo repository containing shared releases + update_channel: Update channel name used by release manifest selection + update_manifest_name: Asset name of the shared release manifest Raises: ConfigurationError: If configuration values are invalid @@ -82,6 +96,12 @@ class Config: enable_logging: bool = True enable_checkout: bool = False language: str = "auto" + brand_id: str = DEFAULT_BRAND_ID + config_dir_name: str = DEFAULT_CONFIG_DIR_NAME + update_base_url: str = DEFAULT_UPDATE_BASE_URL + update_repo: str = DEFAULT_UPDATE_REPO + update_channel: str = DEFAULT_UPDATE_CHANNEL + update_manifest_name: str = DEFAULT_UPDATE_MANIFEST_NAME @classmethod def from_file(cls, config_path: Path) -> "Config": @@ -124,6 +144,9 @@ class Config: elif not root.is_dir(): raise ConfigurationError(f"Allowed root is not a directory: {root}") + brand_id = data.get("brand_id", DEFAULT_BRAND_ID) + config_dir_name = data.get("config_dir_name", cls._slugify_config_dir_name(brand_id)) + # Get log file path log_file = None if data.get("enable_logging", True): @@ -132,10 +155,10 @@ class Config: log_file = Path(log_file_str) # If relative path, resolve relative to app data directory instead of cwd if not log_file.is_absolute(): - log_file = Config.get_default_log_dir() / log_file + log_file = Config.get_default_log_dir(config_dir_name) / log_file else: # Use default log path in app data - log_file = Config.get_default_log_path() + log_file = Config.get_default_log_path(config_dir_name) app_name = data.get("app_name", "WebDrop Bridge") stored_window_title = data.get("window_title", "") @@ -174,6 +197,12 @@ class Config: enable_logging=data.get("enable_logging", True), enable_checkout=data.get("enable_checkout", False), language=data.get("language", "auto"), + brand_id=brand_id, + config_dir_name=config_dir_name, + update_base_url=data.get("update_base_url", DEFAULT_UPDATE_BASE_URL), + update_repo=data.get("update_repo", DEFAULT_UPDATE_REPO), + update_channel=data.get("update_channel", DEFAULT_UPDATE_CHANNEL), + update_manifest_name=data.get("update_manifest_name", DEFAULT_UPDATE_MANIFEST_NAME), ) @classmethod @@ -201,6 +230,8 @@ class Config: from webdrop_bridge import __version__ app_version = __version__ + brand_id = os.getenv("BRAND_ID", DEFAULT_BRAND_ID) + config_dir_name = os.getenv("APP_CONFIG_DIR_NAME", cls._slugify_config_dir_name(brand_id)) log_level = os.getenv("LOG_LEVEL", "INFO").upper() log_file_str = os.getenv("LOG_FILE", None) @@ -215,6 +246,10 @@ class Config: enable_logging = os.getenv("ENABLE_LOGGING", "true").lower() == "true" enable_checkout = os.getenv("ENABLE_CHECKOUT", "false").lower() == "true" language = os.getenv("LANGUAGE", "auto") + update_base_url = os.getenv("UPDATE_BASE_URL", DEFAULT_UPDATE_BASE_URL) + update_repo = os.getenv("UPDATE_REPO", DEFAULT_UPDATE_REPO) + update_channel = os.getenv("UPDATE_CHANNEL", DEFAULT_UPDATE_CHANNEL) + update_manifest_name = os.getenv("UPDATE_MANIFEST_NAME", DEFAULT_UPDATE_MANIFEST_NAME) # Validate log level valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} @@ -254,10 +289,10 @@ class Config: log_file = Path(log_file_str) # If relative path, resolve relative to app data directory instead of cwd if not log_file.is_absolute(): - log_file = Config.get_default_log_dir() / log_file + log_file = Config.get_default_log_dir(config_dir_name) / log_file else: # Use default log path in app data - log_file = Config.get_default_log_path() + log_file = Config.get_default_log_path(config_dir_name) # Validate webapp URL is not empty if not webapp_url: @@ -308,6 +343,12 @@ class Config: enable_logging=enable_logging, enable_checkout=enable_checkout, language=language, + brand_id=brand_id, + config_dir_name=config_dir_name, + update_base_url=update_base_url, + update_repo=update_repo, + update_channel=update_channel, + update_manifest_name=update_manifest_name, ) def to_file(self, config_path: Path) -> None: @@ -337,6 +378,12 @@ class Config: "enable_logging": self.enable_logging, "enable_checkout": self.enable_checkout, "language": self.language, + "brand_id": self.brand_id, + "config_dir_name": self.config_dir_name, + "update_base_url": self.update_base_url, + "update_repo": self.update_repo, + "update_channel": self.update_channel, + "update_manifest_name": self.update_manifest_name, } config_path.parent.mkdir(parents=True, exist_ok=True) @@ -344,7 +391,49 @@ class Config: json.dump(data, f, indent=2) @staticmethod - def get_default_config_path() -> Path: + def load_bootstrap_env(env_file: str | None = None) -> Path | None: + """Load a bootstrap .env before configuration path lookup. + + This lets branded builds decide their config directory before the main + config file is loaded. + + Args: + env_file: Optional explicit .env path + + Returns: + Path to the loaded .env file, or None if nothing was loaded + """ + candidate_paths: list[Path] = [] + if env_file: + candidate_paths.append(Path(env_file).resolve()) + else: + if getattr(sys, "frozen", False): + candidate_paths.append(Path(sys.executable).resolve().parent / ".env") + + candidate_paths.append(Path.cwd() / ".env") + candidate_paths.append(Path(__file__).resolve().parents[2] / ".env") + + for path in candidate_paths: + if path.exists(): + load_dotenv(path, override=False) + logger.debug(f"Loaded bootstrap environment from {path}") + return path + + return None + + @staticmethod + def _slugify_config_dir_name(value: str) -> str: + """Convert brand-like identifiers into a filesystem-safe directory name.""" + sanitized = "".join(c.lower() if c.isalnum() else "_" for c in value).strip("_") + return sanitized or DEFAULT_CONFIG_DIR_NAME + + @staticmethod + def get_default_config_dir_name() -> str: + """Get the default config directory name from environment or fallback.""" + return os.getenv("APP_CONFIG_DIR_NAME", DEFAULT_CONFIG_DIR_NAME) + + @staticmethod + def get_default_config_path(config_dir_name: str | None = None) -> Path: """Get the default configuration file path. Returns: @@ -356,10 +445,10 @@ class Config: base = Path.home() / "AppData" / "Roaming" else: base = Path.home() / ".config" - return base / "webdrop_bridge" / "config.json" + return base / (config_dir_name or Config.get_default_config_dir_name()) / "config.json" @staticmethod - def get_default_log_dir() -> Path: + def get_default_log_dir(config_dir_name: str | None = None) -> Path: """Get the default directory for log files. Always uses user's AppData directory to ensure permissions work @@ -374,21 +463,31 @@ class Config: base = Path.home() / "AppData" / "Roaming" else: base = Path.home() / ".local" / "share" - return base / "webdrop_bridge" / "logs" + return base / (config_dir_name or Config.get_default_config_dir_name()) / "logs" @staticmethod - def get_default_log_path() -> Path: + def get_default_log_path(config_dir_name: str | None = None) -> Path: """Get the default log file path. Returns: Path to default log file in user's AppData/Roaming/webdrop_bridge/logs """ - return Config.get_default_log_dir() / "webdrop_bridge.log" + dir_name = config_dir_name or Config.get_default_config_dir_name() + return Config.get_default_log_dir(dir_name) / f"{dir_name}.log" + + def get_config_path(self) -> Path: + """Get the default config file path for this configured brand.""" + return self.get_default_config_path(self.config_dir_name) + + def get_cache_dir(self) -> Path: + """Get the update/cache directory for this configured brand.""" + return self.get_default_config_path(self.config_dir_name).parent / "cache" def __repr__(self) -> str: """Return developer-friendly representation.""" return ( f"Config(app={self.app_name} v{self.app_version}, " + f"brand={self.brand_id}, " f"log_level={self.log_level}, " f"allowed_roots={len(self.allowed_roots)} dirs, " f"window={self.window_width}x{self.window_height})" diff --git a/src/webdrop_bridge/core/config_manager.py b/src/webdrop_bridge/core/config_manager.py index 52798ee..4c4be27 100644 --- a/src/webdrop_bridge/core/config_manager.py +++ b/src/webdrop_bridge/core/config_manager.py @@ -101,14 +101,13 @@ class ConfigValidator: class ConfigProfile: """Manages named configuration profiles. - Profiles are stored in ~/.webdrop_bridge/profiles/ directory as JSON files. + Profiles are stored in the brand-specific app config directory. """ - PROFILES_DIR = Path.home() / ".webdrop_bridge" / "profiles" - - def __init__(self) -> None: + def __init__(self, config_dir_name: str = "webdrop_bridge") -> None: """Initialize profile manager.""" - self.PROFILES_DIR.mkdir(parents=True, exist_ok=True) + self.profiles_dir = Config.get_default_config_path(config_dir_name).parent / "profiles" + self.profiles_dir.mkdir(parents=True, exist_ok=True) def save_profile(self, profile_name: str, config: Config) -> Path: """Save configuration as a named profile. @@ -126,7 +125,7 @@ class ConfigProfile: if not profile_name or "/" in profile_name or "\\" in profile_name: raise ConfigurationError(f"Invalid profile name: {profile_name}") - profile_path = self.PROFILES_DIR / f"{profile_name}.json" + profile_path = self.profiles_dir / f"{profile_name}.json" config_data = { "app_name": config.app_name, @@ -160,7 +159,7 @@ class ConfigProfile: Raises: ConfigurationError: If profile not found or invalid """ - profile_path = self.PROFILES_DIR / f"{profile_name}.json" + profile_path = self.profiles_dir / f"{profile_name}.json" if not profile_path.exists(): raise ConfigurationError(f"Profile not found: {profile_name}") @@ -179,10 +178,10 @@ class ConfigProfile: Returns: List of profile names (without .json extension) """ - if not self.PROFILES_DIR.exists(): + if not self.profiles_dir.exists(): return [] - return sorted([p.stem for p in self.PROFILES_DIR.glob("*.json")]) + return sorted([p.stem for p in self.profiles_dir.glob("*.json")]) def delete_profile(self, profile_name: str) -> None: """Delete a profile. @@ -193,7 +192,7 @@ class ConfigProfile: Raises: ConfigurationError: If profile not found """ - profile_path = self.PROFILES_DIR / f"{profile_name}.json" + profile_path = self.profiles_dir / f"{profile_name}.json" if not profile_path.exists(): raise ConfigurationError(f"Profile not found: {profile_name}") diff --git a/src/webdrop_bridge/core/updater.py b/src/webdrop_bridge/core/updater.py index 2f2b3b6..92fe794 100644 --- a/src/webdrop_bridge/core/updater.py +++ b/src/webdrop_bridge/core/updater.py @@ -5,9 +5,11 @@ verifying checksums from Forgejo releases. """ import asyncio +import fnmatch import hashlib import json import logging +import platform import socket from dataclasses import dataclass from datetime import datetime, timedelta @@ -34,7 +36,16 @@ class Release: class UpdateManager: """Manages auto-updates via Forgejo releases API.""" - def __init__(self, current_version: str, config_dir: Optional[Path] = None): + def __init__( + self, + current_version: str, + config_dir: Optional[Path] = None, + brand_id: str = "webdrop_bridge", + forgejo_url: str = "https://git.him-tools.de", + repo: str = "HIM-public/webdrop-bridge", + update_channel: str = "stable", + manifest_name: str = "release-manifest.json", + ): """Initialize update manager. Args: @@ -42,8 +53,11 @@ class UpdateManager: config_dir: Directory for storing update cache. Defaults to temp. """ self.current_version = current_version - self.forgejo_url = "https://git.him-tools.de" - self.repo = "HIM-public/webdrop-bridge" + self.brand_id = brand_id + self.forgejo_url = forgejo_url.rstrip("/") + self.repo = repo + self.update_channel = update_channel + self.manifest_name = manifest_name self.api_endpoint = f"{self.forgejo_url}/api/v1/repos/{self.repo}/releases/latest" # Cache management @@ -52,6 +66,128 @@ class UpdateManager: self.cache_file = self.cache_dir / "update_check.json" self.cache_ttl = timedelta(hours=24) + def _get_platform_key(self) -> str: + """Return the release-manifest platform key for the current system.""" + system = platform.system() + machine = platform.machine().lower() + + if system == "Windows": + arch = "x64" if machine in {"amd64", "x86_64"} else machine + return f"windows-{arch}" + if system == "Darwin": + return "macos-universal" + return f"{system.lower()}-{machine}" + + def _find_asset(self, assets: list[dict], asset_name: str) -> Optional[dict]: + """Find an asset by exact name.""" + for asset in assets: + if asset.get("name") == asset_name: + return asset + return None + + def _find_manifest_asset(self, release: Release) -> Optional[dict]: + """Find the shared release manifest asset if present.""" + return self._find_asset(release.assets, self.manifest_name) + + def _download_json_asset(self, url: str) -> Optional[dict]: + """Download and parse a JSON asset from a release.""" + try: + with urlopen(url, timeout=10) as response: + return json.loads(response.read().decode("utf-8")) + except (URLError, json.JSONDecodeError) as e: + logger.error(f"Failed to download JSON asset: {e}") + return None + + async def _load_release_manifest(self, release: Release) -> Optional[dict]: + """Load the shared release manifest if present.""" + manifest_asset = self._find_manifest_asset(release) + if not manifest_asset: + return None + + loop = asyncio.get_event_loop() + return await asyncio.wait_for( + loop.run_in_executor( + None, self._download_json_asset, manifest_asset["browser_download_url"] + ), + timeout=15, + ) + + def _resolve_assets_from_manifest( + self, release: Release, manifest: dict + ) -> tuple[Optional[dict], Optional[dict]]: + """Resolve installer and checksum assets from a shared release manifest.""" + if manifest.get("channel") not in {None, "", self.update_channel}: + logger.info( + "Release manifest channel %s does not match configured channel %s", + manifest.get("channel"), + self.update_channel, + ) + return None, None + + brand_entry = manifest.get("brands", {}).get(self.brand_id, {}) + platform_entry = brand_entry.get(self._get_platform_key(), {}) + installer_name = platform_entry.get("installer") + checksum_name = platform_entry.get("checksum") + + if not installer_name: + logger.warning( + "No installer entry found for brand=%s platform=%s in release manifest", + self.brand_id, + self._get_platform_key(), + ) + return None, None + + return self._find_asset(release.assets, installer_name), self._find_asset( + release.assets, checksum_name + ) + + def _resolve_assets_legacy(self, release: Release) -> tuple[Optional[dict], Optional[dict]]: + """Resolve installer and checksum assets using legacy filename matching.""" + is_windows = platform.system() == "Windows" + extension = ".msi" if is_windows else ".dmg" + brand_prefix = f"{self.brand_id}-*" + + installer_asset = None + for asset in release.assets: + asset_name = asset.get("name", "") + if not asset_name.endswith(extension): + continue + + if self.brand_id != "webdrop_bridge" and fnmatch.fnmatch( + asset_name.lower(), brand_prefix.lower() + ): + installer_asset = asset + break + + if self.brand_id == "webdrop_bridge": + installer_asset = asset + break + + if not installer_asset: + return None, None + + checksum_asset = self._find_asset(release.assets, f"{installer_asset['name']}.sha256") + return installer_asset, checksum_asset + + async def _resolve_release_assets( + self, release: Release + ) -> tuple[Optional[dict], Optional[dict]]: + """Resolve installer and checksum assets for the configured brand.""" + try: + manifest = await self._load_release_manifest(release) + except asyncio.TimeoutError: + logger.warning( + "Timed out while loading release manifest, falling back to legacy lookup" + ) + manifest = None + + if manifest: + installer_asset, checksum_asset = self._resolve_assets_from_manifest(release, manifest) + if installer_asset: + return installer_asset, checksum_asset + + return self._resolve_assets_legacy(release) + def _parse_version(self, version_str: str) -> tuple[int, int, int]: """Parse semantic version string to tuple. @@ -253,12 +389,7 @@ class UpdateManager: logger.error("No assets found in release") return None - # Find .msi or .dmg file - installer_asset = None - for asset in release.assets: - if asset["name"].endswith((".msi", ".dmg")): - installer_asset = asset - break + installer_asset, _ = await self._resolve_release_assets(release) if not installer_asset: logger.error("No installer found in release assets") @@ -345,14 +476,11 @@ class UpdateManager: Returns: True if checksum matches, False otherwise """ - # Find .sha256 file matching the installer name (e.g. Setup.msi.sha256) - # Fall back to any .sha256 only if no specific match exists - installer_name = file_path.name - checksum_asset = None - for asset in release.assets: - if asset["name"] == f"{installer_name}.sha256": - checksum_asset = asset - break + installer_asset, checksum_asset = await self._resolve_release_assets(release) + installer_name = installer_asset["name"] if installer_asset else file_path.name + + if not checksum_asset: + checksum_asset = self._find_asset(release.assets, f"{installer_name}.sha256") if not checksum_asset: logger.warning("No checksum file found in release") diff --git a/src/webdrop_bridge/main.py b/src/webdrop_bridge/main.py index 4e90a7b..1194d69 100644 --- a/src/webdrop_bridge/main.py +++ b/src/webdrop_bridge/main.py @@ -30,6 +30,8 @@ def main() -> int: int: Exit code (0 for success, non-zero for error) """ try: + Config.load_bootstrap_env() + # Load configuration from file if it exists, otherwise from environment config_path = Config.get_default_config_path() if config_path.exists(): diff --git a/src/webdrop_bridge/ui/main_window.py b/src/webdrop_bridge/ui/main_window.py index 6462ca6..c4f9967 100644 --- a/src/webdrop_bridge/ui/main_window.py +++ b/src/webdrop_bridge/ui/main_window.py @@ -1872,8 +1872,16 @@ class MainWindow(QMainWindow): try: # Create update manager - cache_dir = Path.home() / ".webdrop_bridge" - manager = UpdateManager(current_version=self.config.app_version, config_dir=cache_dir) + cache_dir = self.config.get_cache_dir() + manager = UpdateManager( + current_version=self.config.app_version, + config_dir=cache_dir, + brand_id=self.config.brand_id, + forgejo_url=self.config.update_base_url, + repo=self.config.update_repo, + update_channel=self.config.update_channel, + manifest_name=self.config.update_manifest_name, + ) # Run async check in background self._run_async_check(manager) @@ -2090,7 +2098,13 @@ class MainWindow(QMainWindow): # Create update manager manager = UpdateManager( - current_version=self.config.app_version, config_dir=Path.home() / ".webdrop_bridge" + current_version=self.config.app_version, + config_dir=self.config.get_cache_dir(), + brand_id=self.config.brand_id, + forgejo_url=self.config.update_base_url, + repo=self.config.update_repo, + update_channel=self.config.update_channel, + manifest_name=self.config.update_manifest_name, ) # Create and start background thread @@ -2229,7 +2243,13 @@ class MainWindow(QMainWindow): from webdrop_bridge.core.updater import UpdateManager manager = UpdateManager( - current_version=self.config.app_version, config_dir=Path.home() / ".webdrop_bridge" + current_version=self.config.app_version, + config_dir=self.config.get_cache_dir(), + brand_id=self.config.brand_id, + forgejo_url=self.config.update_base_url, + repo=self.config.update_repo, + update_channel=self.config.update_channel, + manifest_name=self.config.update_manifest_name, ) if manager.install_update(installer_path): diff --git a/src/webdrop_bridge/ui/settings_dialog.py b/src/webdrop_bridge/ui/settings_dialog.py index 935aee1..99f5241 100644 --- a/src/webdrop_bridge/ui/settings_dialog.py +++ b/src/webdrop_bridge/ui/settings_dialog.py @@ -42,7 +42,7 @@ class SettingsDialog(QDialog): """ super().__init__(parent) self.config = config - self.profile_manager = ConfigProfile() + self.profile_manager = ConfigProfile(config.config_dir_name) self.setWindowTitle(tr("settings.title")) self.setGeometry(100, 100, 600, 500) @@ -96,7 +96,7 @@ class SettingsDialog(QDialog): self.config.window_width = config_data["window_width"] self.config.window_height = config_data["window_height"] - config_path = Config.get_default_config_path() + config_path = self.config.get_config_path() self.config.to_file(config_path) logger.info(f"Configuration saved to {config_path}") diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index fdeda3d..2c2e9ce 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -12,14 +12,26 @@ def clear_env(): """Clear environment variables before each test to avoid persistence.""" # Save current env saved_env = os.environ.copy() - + # Clear relevant variables for key in list(os.environ.keys()): - if key.startswith(('APP_', 'LOG_', 'ALLOWED_', 'WEBAPP_', 'WINDOW_', 'ENABLE_')): + if key.startswith( + ( + "APP_", + "LOG_", + "ALLOWED_", + "WEBAPP_", + "WINDOW_", + "ENABLE_", + "BRAND_", + "UPDATE_", + "LANGUAGE", + ) + ): del os.environ[key] - + yield - + # Restore env (cleanup) os.environ.clear() os.environ.update(saved_env) @@ -64,6 +76,28 @@ class TestConfigFromEnv: assert config.window_width == 1200 assert config.window_height == 800 + def test_from_env_with_branding_values(self, tmp_path): + """Test loading branding and update metadata from environment.""" + env_file = tmp_path / ".env" + root1 = tmp_path / "root1" + root1.mkdir() + env_file.write_text( + f"BRAND_ID=agravity\n" + f"APP_CONFIG_DIR_NAME=agravity_bridge\n" + f"UPDATE_REPO=HIM-public/webdrop-bridge\n" + f"UPDATE_CHANNEL=stable\n" + f"UPDATE_MANIFEST_NAME=release-manifest.json\n" + f"ALLOWED_ROOTS={root1}\n" + ) + + config = Config.from_env(str(env_file)) + + assert config.brand_id == "agravity" + assert config.config_dir_name == "agravity_bridge" + assert config.update_repo == "HIM-public/webdrop-bridge" + assert config.update_channel == "stable" + assert config.update_manifest_name == "release-manifest.json" + def test_from_env_with_defaults(self, tmp_path): """Test loading config uses defaults when env vars not set.""" # Create empty .env file @@ -73,8 +107,11 @@ class TestConfigFromEnv: config = Config.from_env(str(env_file)) assert config.app_name == "WebDrop Bridge" + assert config.brand_id == "webdrop_bridge" + assert config.config_dir_name == "webdrop_bridge" # Version should come from __init__.py (dynamic, not hardcoded) from webdrop_bridge import __version__ + assert config.app_version == __version__ assert config.log_level == "INFO" assert config.window_width == 1024 @@ -187,3 +224,11 @@ class TestConfigValidation: config = Config.from_env(str(env_file)) assert config.allowed_urls == ["example.com", "test.org"] + + def test_brand_specific_default_paths(self): + """Test brand-specific config and log directories.""" + config_path = Config.get_default_config_path("agravity_bridge") + log_path = Config.get_default_log_path("agravity_bridge") + + assert config_path.parts[-2:] == ("agravity_bridge", "config.json") + assert log_path.parts[-2:] == ("logs", "agravity_bridge.log") diff --git a/tests/unit/test_updater.py b/tests/unit/test_updater.py index 1685f20..f3f09a4 100644 --- a/tests/unit/test_updater.py +++ b/tests/unit/test_updater.py @@ -16,6 +16,17 @@ def update_manager(tmp_path): return UpdateManager(current_version="0.0.1", config_dir=tmp_path) +@pytest.fixture +def agravity_update_manager(tmp_path): + """Create a brand-aware UpdateManager instance for Agravity Bridge.""" + return UpdateManager( + current_version="0.0.1", + config_dir=tmp_path, + brand_id="agravity", + update_channel="stable", + ) + + @pytest.fixture def sample_release(): """Sample release data from API.""" @@ -252,6 +263,109 @@ class TestDownloading: assert result is None + @pytest.mark.asyncio + async def test_download_update_uses_release_manifest(self, agravity_update_manager, tmp_path): + """Test branded download selection from a shared release manifest.""" + release = Release( + tag_name="v0.0.2", + name="WebDropBridge v0.0.2", + version="0.0.2", + body="Release notes", + assets=[ + { + "name": "AgravityBridge-0.0.2-win-x64.msi", + "browser_download_url": "https://example.com/AgravityBridge-0.0.2-win-x64.msi", + }, + { + "name": "AgravityBridge-0.0.2-win-x64.msi.sha256", + "browser_download_url": "https://example.com/AgravityBridge-0.0.2-win-x64.msi.sha256", + }, + { + "name": "OtherBridge-0.0.2-win-x64.msi", + "browser_download_url": "https://example.com/OtherBridge-0.0.2-win-x64.msi", + }, + { + "name": "release-manifest.json", + "browser_download_url": "https://example.com/release-manifest.json", + }, + ], + published_at="2026-01-29T10:00:00Z", + ) + + manifest = { + "version": "0.0.2", + "channel": "stable", + "brands": { + "agravity": { + "windows-x64": { + "installer": "AgravityBridge-0.0.2-win-x64.msi", + "checksum": "AgravityBridge-0.0.2-win-x64.msi.sha256", + } + } + }, + } + + with ( + patch.object(UpdateManager, "_download_json_asset", return_value=manifest), + patch.object(UpdateManager, "_download_file", return_value=True) as mock_download, + ): + result = await agravity_update_manager.download_update(release, tmp_path) + + assert result is not None + assert result.name == "AgravityBridge-0.0.2-win-x64.msi" + mock_download.assert_called_once() + + @pytest.mark.asyncio + async def test_verify_checksum_uses_release_manifest(self, agravity_update_manager, tmp_path): + """Test branded checksum selection from a shared release manifest.""" + test_file = tmp_path / "AgravityBridge-0.0.2-win-x64.msi" + test_file.write_bytes(b"test content") + + import hashlib + + checksum = hashlib.sha256(b"test content").hexdigest() + release = Release( + tag_name="v0.0.2", + name="WebDropBridge v0.0.2", + version="0.0.2", + body="Release notes", + assets=[ + { + "name": "AgravityBridge-0.0.2-win-x64.msi", + "browser_download_url": "https://example.com/AgravityBridge-0.0.2-win-x64.msi", + }, + { + "name": "AgravityBridge-0.0.2-win-x64.msi.sha256", + "browser_download_url": "https://example.com/AgravityBridge-0.0.2-win-x64.msi.sha256", + }, + { + "name": "release-manifest.json", + "browser_download_url": "https://example.com/release-manifest.json", + }, + ], + published_at="2026-01-29T10:00:00Z", + ) + manifest = { + "version": "0.0.2", + "channel": "stable", + "brands": { + "agravity": { + "windows-x64": { + "installer": "AgravityBridge-0.0.2-win-x64.msi", + "checksum": "AgravityBridge-0.0.2-win-x64.msi.sha256", + } + } + }, + } + + with ( + patch.object(UpdateManager, "_download_json_asset", return_value=manifest), + patch.object(UpdateManager, "_download_checksum", return_value=checksum), + ): + result = await agravity_update_manager.verify_checksum(test_file, release) + + assert result is True + class TestChecksumVerification: """Test checksum verification."""