117 lines
3.4 KiB
Python
117 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, Iterable, Optional
|
|
|
|
|
|
PROMPT_TOKEN_KEYS: tuple[str, ...] = ("prompt_tokens", "input_tokens", "promptTokenCount")
|
|
COMPLETION_TOKEN_KEYS: tuple[str, ...] = (
|
|
"completion_tokens",
|
|
"output_tokens",
|
|
"candidatesTokenCount",
|
|
)
|
|
TOTAL_TOKEN_KEYS: tuple[str, ...] = ("total_tokens", "totalTokenCount")
|
|
USAGE_CONTAINER_KEYS: tuple[str, ...] = ("usage", "usageMetadata", "usage_metadata")
|
|
|
|
|
|
def _normalize_usage_value(value: Any) -> Any:
|
|
if isinstance(value, (int, float)):
|
|
return int(value)
|
|
|
|
if isinstance(value, str):
|
|
stripped = value.strip()
|
|
if not stripped:
|
|
return None
|
|
try:
|
|
numeric = float(stripped)
|
|
except ValueError:
|
|
return None
|
|
return int(numeric)
|
|
|
|
if isinstance(value, dict):
|
|
normalized: Dict[str, Any] = {}
|
|
for key, nested_value in value.items():
|
|
normalized_value = _normalize_usage_value(nested_value)
|
|
if normalized_value is not None:
|
|
normalized[key] = normalized_value
|
|
return normalized or None
|
|
|
|
if isinstance(value, (list, tuple, set)):
|
|
normalized_list = [
|
|
item for item in (_normalize_usage_value(element) for element in value) if item is not None
|
|
]
|
|
return normalized_list or None
|
|
|
|
return None
|
|
|
|
|
|
def _first_numeric(payload: Dict[str, Any], keys: Iterable[str]) -> Optional[int]:
|
|
for key in keys:
|
|
value = payload.get(key)
|
|
if isinstance(value, (int, float)):
|
|
return int(value)
|
|
return None
|
|
|
|
|
|
def _canonicalize_counts(payload: Dict[str, Any]) -> None:
|
|
prompt = _first_numeric(payload, PROMPT_TOKEN_KEYS)
|
|
completion = _first_numeric(payload, COMPLETION_TOKEN_KEYS)
|
|
total = _first_numeric(payload, TOTAL_TOKEN_KEYS)
|
|
|
|
if prompt is not None:
|
|
payload["prompt_tokens"] = prompt
|
|
else:
|
|
payload.pop("prompt_tokens", None)
|
|
|
|
if completion is not None:
|
|
payload["completion_tokens"] = completion
|
|
else:
|
|
payload.pop("completion_tokens", None)
|
|
|
|
if total is not None:
|
|
payload["total_tokens"] = total
|
|
elif prompt is not None and completion is not None:
|
|
payload["total_tokens"] = prompt + completion
|
|
else:
|
|
payload.pop("total_tokens", None)
|
|
|
|
for alias in PROMPT_TOKEN_KEYS[1:]:
|
|
payload.pop(alias, None)
|
|
for alias in COMPLETION_TOKEN_KEYS[1:]:
|
|
payload.pop(alias, None)
|
|
for alias in TOTAL_TOKEN_KEYS[1:]:
|
|
payload.pop(alias, None)
|
|
|
|
|
|
def _extract_usage_container(candidate: Any) -> Optional[Dict[str, Any]]:
|
|
if not isinstance(candidate, dict):
|
|
return None
|
|
for key in USAGE_CONTAINER_KEYS:
|
|
value = candidate.get(key)
|
|
if isinstance(value, dict):
|
|
return value
|
|
return None
|
|
|
|
|
|
def extract_usage(payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
"""Unified helper to parse token usage metadata from diverse provider responses."""
|
|
if not isinstance(payload, dict):
|
|
return None
|
|
|
|
usage_candidate = _extract_usage_container(payload)
|
|
if usage_candidate is None:
|
|
raw_section = payload.get("raw")
|
|
usage_candidate = _extract_usage_container(raw_section)
|
|
|
|
if usage_candidate is None:
|
|
return None
|
|
|
|
normalized = _normalize_usage_value(usage_candidate)
|
|
if not isinstance(normalized, dict):
|
|
return None
|
|
|
|
_canonicalize_counts(normalized)
|
|
return normalized or None
|
|
|
|
|
|
__all__ = ["extract_usage"]
|