61 lines
2.1 KiB
Python
61 lines
2.1 KiB
Python
import httpx
|
|
import json
|
|
import logging
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class LLMService:
|
|
@staticmethod
|
|
async def chat_completion(
|
|
api_key: str,
|
|
base_url: str,
|
|
model_name: str,
|
|
messages: List[Dict[str, str]],
|
|
api_path: str = "/chat/completions",
|
|
temperature: float = 0.7,
|
|
top_p: float = 0.9,
|
|
max_tokens: int = 4000
|
|
) -> str:
|
|
"""
|
|
Generic OpenAI-compatible chat completion caller.
|
|
"""
|
|
# Ensure base_url doesn't end with trailing slash if path starts with one
|
|
url = f"{base_url.rstrip('/')}{api_path}"
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
payload = {
|
|
"model": model_name,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_tokens": max_tokens
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
try:
|
|
logger.info(f"Sending request to LLM: {url} (Model: {model_name})")
|
|
response = await client.post(url, headers=headers, json=payload)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
content = result["choices"][0]["message"]["content"]
|
|
return content
|
|
except httpx.HTTPStatusError as e:
|
|
error_body = e.response.text
|
|
logger.error(f"LLM API Error ({e.response.status_code}): {error_body}")
|
|
# Try to parse error message from body
|
|
try:
|
|
err_json = json.loads(error_body)
|
|
detail = err_json.get("error", {}).get("message", error_body)
|
|
except:
|
|
detail = error_body
|
|
raise Exception(f"AI 服务调用失败: {detail}")
|
|
except Exception as e:
|
|
logger.error(f"Error calling LLM: {str(e)}")
|
|
raise Exception(f"连接 AI 服务发生错误: {str(e)}")
|