148 lines
5.3 KiB
Python
148 lines
5.3 KiB
Python
import json
|
||
import dashscope
|
||
from http import HTTPStatus
|
||
from typing import Optional, Dict, List, Generator, Any
|
||
import app.core.config as config_module
|
||
from app.core.database import get_db_connection
|
||
from app.services.system_config_service import SystemConfigService
|
||
|
||
|
||
class LLMService:
|
||
"""LLM服务 - 专注于大模型API调用和提示词管理"""
|
||
|
||
def __init__(self):
|
||
# 设置dashscope API key
|
||
dashscope.api_key = config_module.QWEN_API_KEY
|
||
|
||
def _get_llm_call_params(self) -> Dict[str, Any]:
|
||
"""
|
||
获取 dashscope.Generation.call() 所需的参数字典
|
||
|
||
Returns:
|
||
Dict: 包含 model、timeout、temperature、top_p 的参数字典
|
||
"""
|
||
return {
|
||
'model': SystemConfigService.get_llm_model_name(),
|
||
'timeout': SystemConfigService.get_llm_timeout(),
|
||
'temperature': SystemConfigService.get_llm_temperature(),
|
||
'top_p': SystemConfigService.get_llm_top_p(),
|
||
}
|
||
|
||
def get_task_prompt(self, task_type: str, cursor=None, prompt_id: Optional[int] = None) -> str:
|
||
"""
|
||
统一的提示词获取方法
|
||
|
||
Args:
|
||
task_type: 任务类型,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等
|
||
cursor: 数据库游标,如果传入则使用,否则创建新连接
|
||
prompt_id: 可选的提示词ID,如果指定则使用该提示词,否则使用默认提示词
|
||
|
||
Returns:
|
||
str: 提示词内容,如果未找到返回默认提示词
|
||
"""
|
||
# 如果指定了 prompt_id,直接获取该提示词
|
||
if prompt_id:
|
||
query = """
|
||
SELECT content
|
||
FROM prompts
|
||
WHERE id = %s AND task_type = %s AND is_active = TRUE
|
||
LIMIT 1
|
||
"""
|
||
params = (prompt_id, task_type)
|
||
else:
|
||
# 否则获取默认提示词
|
||
query = """
|
||
SELECT content
|
||
FROM prompts
|
||
WHERE task_type = %s
|
||
AND is_default = TRUE
|
||
AND is_active = TRUE
|
||
LIMIT 1
|
||
"""
|
||
params = (task_type,)
|
||
|
||
if cursor:
|
||
cursor.execute(query, params)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
return result['content'] if isinstance(result, dict) else result[0]
|
||
else:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
cursor.execute(query, params)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
return result['content']
|
||
|
||
# 返回默认提示词
|
||
return self._get_default_prompt(task_type)
|
||
|
||
def _get_default_prompt(self, task_name: str) -> str:
|
||
"""获取默认提示词"""
|
||
system_prompt = config_module.LLM_CONFIG.get("system_prompt", "请根据提供的内容进行总结和分析。")
|
||
default_prompts = {
|
||
'MEETING_TASK': system_prompt,
|
||
'KNOWLEDGE_TASK': "请根据提供的信息生成知识库文章。",
|
||
}
|
||
return default_prompts.get(task_name, "请根据提供的内容进行总结和分析。")
|
||
|
||
def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]:
|
||
"""流式调用阿里Qwen大模型API"""
|
||
try:
|
||
responses = dashscope.Generation.call(
|
||
**self._get_llm_call_params(),
|
||
prompt=prompt,
|
||
stream=True,
|
||
incremental_output=True
|
||
)
|
||
|
||
for response in responses:
|
||
if response.status_code == HTTPStatus.OK:
|
||
# 增量输出内容
|
||
new_content = response.output.get('text', '')
|
||
if new_content:
|
||
yield new_content
|
||
else:
|
||
error_msg = f"Request failed with status code: {response.status_code}, Error: {response.message}"
|
||
print(error_msg)
|
||
yield f"error: {error_msg}"
|
||
break
|
||
|
||
except Exception as e:
|
||
error_msg = f"流式调用大模型API错误: {e}"
|
||
print(error_msg)
|
||
yield f"error: {error_msg}"
|
||
|
||
def _call_llm_api(self, prompt: str) -> Optional[str]:
|
||
"""调用阿里Qwen大模型API(非流式)"""
|
||
try:
|
||
response = dashscope.Generation.call(
|
||
**self._get_llm_call_params(),
|
||
prompt=prompt
|
||
)
|
||
|
||
if response.status_code == HTTPStatus.OK:
|
||
return response.output.get('text', '')
|
||
else:
|
||
print(f"API调用失败: {response.status_code}, {response.message}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"调用大模型API错误: {e}")
|
||
return None
|
||
|
||
|
||
# 测试代码
|
||
if __name__ == '__main__':
|
||
print("--- 运行LLM服务测试 ---")
|
||
llm_service = LLMService()
|
||
|
||
# 测试获取任务提示词
|
||
meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
|
||
print(f"会议任务提示词: {meeting_prompt[:100]}...")
|
||
|
||
knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||
print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
|
||
|
||
print("--- LLM服务测试完成 ---")
|