UnisKB/apps/models_provider/impl/siliconCloud_model_provider/model/reranker.py

75 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# coding=utf-8
"""
@project: MaxKB
@Author
@file siliconcloud_reranker.py
@date2024/9/10 9:45
@desc: SiliconCloud 文档重排封装
"""
from typing import Sequence, Optional, Any, Dict
import requests
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from models_provider.base_model_provider import MaxKBBaseModel
from django.utils.translation import gettext as _
class SiliconCloudReranker(MaxKBBaseModel, BaseDocumentCompressor):
api_base: Optional[str]
"""SiliconCloud API URL"""
model: Optional[str]
"""SiliconCloud 重排模型 ID"""
api_key: Optional[str]
"""API Key"""
top_n: Optional[int] = 3 # 取前 N 个最相关的结果
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return SiliconCloudReranker(
api_base=model_credential.get('api_base'),
model=model_name,
api_key=model_credential.get('api_key'),
top_n=model_kwargs.get('top_n', 3)
)
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if not documents:
return []
# 预处理文本
texts = [doc.page_content for doc in documents]
# 发送请求到 SiliconCloud API
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"query": query,
"documents": texts,
"top_n": self.top_n,
"return_documents": True,
}
response = requests.post(f"{self.api_base}/rerank", json=payload, headers=headers)
if response.status_code != 200:
raise RuntimeError(f"SiliconCloud API 请求失败: {response.text}")
res = response.json()
# 解析返回结果
return [
Document(
page_content=item.get('document', {}).get('text', ''),
metadata={'relevance_score': item.get('relevance_score')}
)
for item in res.get('results', [])
]