54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import Dict, Type
|
|
|
|
import httpx
|
|
|
|
from app.exceptions import ProviderConfigurationError
|
|
from app.models import LLMProvider, LLMRequest, LLMResponse
|
|
from app.providers import (
|
|
AnthropicProvider,
|
|
DeepSeekProvider,
|
|
GeminiProvider,
|
|
LLMProviderClient,
|
|
OpenAIProvider,
|
|
OpenRouterProvider,
|
|
QwenProvider,
|
|
)
|
|
|
|
|
|
class LLMGateway:
|
|
"""Simple registry that dispatches chat requests to provider clients."""
|
|
|
|
def __init__(self) -> None:
|
|
self._providers: Dict[LLMProvider, LLMProviderClient] = {}
|
|
self._factory: Dict[LLMProvider, Type[LLMProviderClient]] = {
|
|
LLMProvider.OPENAI: OpenAIProvider,
|
|
LLMProvider.ANTHROPIC: AnthropicProvider,
|
|
LLMProvider.OPENROUTER: OpenRouterProvider,
|
|
LLMProvider.GEMINI: GeminiProvider,
|
|
LLMProvider.QWEN: QwenProvider,
|
|
LLMProvider.DEEPSEEK: DeepSeekProvider,
|
|
}
|
|
|
|
def get_provider(self, provider: LLMProvider) -> LLMProviderClient:
|
|
if provider not in self._factory:
|
|
raise ProviderConfigurationError(f"Unsupported provider '{provider.value}'.")
|
|
|
|
if provider not in self._providers:
|
|
self._providers[provider] = self._build_provider(provider)
|
|
return self._providers[provider]
|
|
|
|
def _build_provider(self, provider: LLMProvider) -> LLMProviderClient:
|
|
provider_cls = self._factory[provider]
|
|
api_key_env = getattr(provider_cls, "api_key_env", None)
|
|
api_key = os.getenv(api_key_env) if api_key_env else None
|
|
return provider_cls(api_key)
|
|
|
|
async def chat(
|
|
self, request: LLMRequest, client: httpx.AsyncClient
|
|
) -> LLMResponse:
|
|
provider_client = self.get_provider(request.provider)
|
|
return await provider_client.chat(request, client)
|