Enhance OAuth client to support scope options and update related tests

This commit is contained in:
claudi 2026-04-07 09:38:15 +02:00
parent 1a9f924764
commit bebf99d826
5 changed files with 80 additions and 15 deletions

View file

@ -1,7 +1,7 @@
from __future__ import annotations
import base64
from typing import Iterable
from typing import Iterable, Sequence
from urllib.parse import urlencode
import httpx
@ -44,10 +44,17 @@ class EbayOAuthClient:
query["prompt"] = prompt
return f"{self.config.auth_base_url}?{urlencode(query)}"
def get_valid_token(self, *, scopes: Iterable[str] | None = None) -> OAuthToken:
def get_valid_token(
self,
*,
scopes: Iterable[str] | None = None,
scope_options: Sequence[Iterable[str]] | None = None,
) -> OAuthToken:
token = self.token_store.get_token()
if token is None or token.is_expired() or not self._has_required_scopes(token, scopes):
token = self.fetch_client_credentials_token(scopes=scopes)
if token is None or token.is_expired() or not self._has_required_scopes(token, scopes=scopes, scope_options=scope_options):
token = self.fetch_client_credentials_token(
scopes=self._choose_requested_scopes(scopes=scopes, scope_options=scope_options)
)
return token
def fetch_client_credentials_token(self, *, scopes: Iterable[str] | None = None) -> OAuthToken:
@ -111,8 +118,34 @@ class EbayOAuthClient:
return base64.b64encode(raw).decode("ascii")
@staticmethod
def _has_required_scopes(token: OAuthToken, scopes: Iterable[str] | None) -> bool:
requested = {scope for scope in (scopes or []) if scope}
if not requested:
def _choose_requested_scopes(
*,
scopes: Iterable[str] | None = None,
scope_options: Sequence[Iterable[str]] | None = None,
) -> list[str] | None:
if scopes is not None:
requested = [scope for scope in scopes if scope]
return requested or None
if scope_options:
for option in scope_options:
requested = [scope for scope in option if scope]
if requested:
return requested
return None
@staticmethod
def _has_required_scopes(
token: OAuthToken,
*,
scopes: Iterable[str] | None = None,
scope_options: Sequence[Iterable[str]] | None = None,
) -> bool:
requested_sets: list[set[str]] = []
if scopes is not None:
requested_sets.append({scope for scope in scopes if scope})
if scope_options:
requested_sets.extend({scope for scope in option if scope} for option in scope_options)
if not requested_sets:
return True
return requested.issubset(token.scopes())
token_scopes = token.scopes()
return any(requested.issubset(token_scopes) for requested in requested_sets if requested)

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Mapping, TypeVar
from typing import Any, Mapping, Sequence, TypeVar
import httpx
from pydantic import BaseModel
@ -31,12 +31,13 @@ class ApiTransport:
path: str,
*,
scopes: list[str] | None = None,
scope_options: Sequence[Sequence[str]] | None = None,
params: Mapping[str, Any] | None = None,
json_body: Any | None = None,
headers: Mapping[str, str] | None = None,
content: bytes | None = None,
) -> httpx.Response:
token = self.oauth_client.get_valid_token(scopes=scopes)
token = self.oauth_client.get_valid_token(scopes=scopes, scope_options=scope_options)
request_headers = dict(self.default_headers)
request_headers.update(headers or {})
request_headers["Authorization"] = f"Bearer {token.access_token}"

View file

@ -20,6 +20,10 @@ from ebay_client.generated.notification.models import (
NOTIFICATION_SCOPE = "https://api.ebay.com/oauth/api_scope"
NOTIFICATION_SUBSCRIPTION_SCOPE = "https://api.ebay.com/oauth/api_scope/commerce.notification.subscription"
NOTIFICATION_SUBSCRIPTION_READ_SCOPE = "https://api.ebay.com/oauth/api_scope/commerce.notification.subscription.readonly"
NOTIFICATION_SUBSCRIPTION_READ_SCOPE_OPTIONS = [
[NOTIFICATION_SCOPE, NOTIFICATION_SUBSCRIPTION_READ_SCOPE],
[NOTIFICATION_SCOPE, NOTIFICATION_SUBSCRIPTION_SCOPE],
]
class NotificationClient:
@ -107,7 +111,7 @@ class NotificationClient:
SubscriptionSearchResponse,
"GET",
"/commerce/notification/v1/subscription",
scopes=[NOTIFICATION_SCOPE, NOTIFICATION_SUBSCRIPTION_READ_SCOPE],
scope_options=NOTIFICATION_SUBSCRIPTION_READ_SCOPE_OPTIONS,
params={"limit": limit, "continuation_token": continuation_token},
)
@ -125,7 +129,7 @@ class NotificationClient:
Subscription,
"GET",
f"/commerce/notification/v1/subscription/{subscription_id}",
scopes=[NOTIFICATION_SCOPE, NOTIFICATION_SUBSCRIPTION_READ_SCOPE],
scope_options=NOTIFICATION_SUBSCRIPTION_READ_SCOPE_OPTIONS,
)
def update_subscription(self, subscription_id: str, payload: UpdateSubscriptionRequest) -> None:
@ -162,7 +166,7 @@ class NotificationClient:
SubscriptionFilter,
"GET",
f"/commerce/notification/v1/subscription/{subscription_id}/filter/{filter_id}",
scopes=[NOTIFICATION_SCOPE, NOTIFICATION_SUBSCRIPTION_READ_SCOPE],
scope_options=NOTIFICATION_SUBSCRIPTION_READ_SCOPE_OPTIONS,
)
def delete_subscription_filter(self, subscription_id: str, filter_id: str) -> None:

View file

@ -33,3 +33,19 @@ def test_get_valid_token_reuses_unexpired_token() -> None:
token = client.get_valid_token(scopes=["scope.a"])
assert token.access_token == "cached-token"
def test_get_valid_token_reuses_token_when_any_scope_option_matches() -> None:
config = EbayOAuthConfig(client_id="client-id", client_secret="client-secret")
store = InMemoryTokenStore()
store.set_token(OAuthToken(access_token="cached-token", scope="scope.base scope.write"))
client = EbayOAuthClient(config, token_store=store)
token = client.get_valid_token(
scope_options=[
["scope.base", "scope.read"],
["scope.base", "scope.write"],
]
)
assert token.access_token == "cached-token"

View file

@ -30,8 +30,19 @@ from ebay_client.notification.client import NotificationClient
class DummyOAuthClient:
def get_valid_token(self, *, scopes: list[str] | None = None) -> OAuthToken:
return OAuthToken(access_token="test-token", scope=" ".join(scopes or []))
def get_valid_token(
self,
*,
scopes: list[str] | None = None,
scope_options: list[list[str]] | None = None,
) -> OAuthToken:
if scopes:
resolved_scopes = scopes
elif scope_options:
resolved_scopes = scope_options[0]
else:
resolved_scopes = []
return OAuthToken(access_token="test-token", scope=" ".join(resolved_scopes))
def build_transport() -> ApiTransport: