148 lines
4.9 KiB
Python
148 lines
4.9 KiB
Python
import json
|
||
import dashscope
|
||
from http import HTTPStatus
|
||
from typing import Optional, Dict, List, Generator
|
||
import app.core.config as config_module
|
||
from app.core.database import get_db_connection
|
||
|
||
|
||
class LLMService:
|
||
"""LLM服务 - 专注于大模型API调用和提示词管理"""
|
||
|
||
def __init__(self):
|
||
# 设置dashscope API key
|
||
dashscope.api_key = config_module.QWEN_API_KEY
|
||
|
||
@property
|
||
def model_name(self):
|
||
"""动态获取模型名称"""
|
||
return config_module.LLM_CONFIG["model_name"]
|
||
|
||
@property
|
||
def system_prompt(self):
|
||
"""动态获取系统提示词"""
|
||
return config_module.LLM_CONFIG["system_prompt"]
|
||
|
||
@property
|
||
def time_out(self):
|
||
"""动态获取超时时间"""
|
||
return config_module.LLM_CONFIG["time_out"]
|
||
|
||
@property
|
||
def temperature(self):
|
||
"""动态获取temperature"""
|
||
return config_module.LLM_CONFIG["temperature"]
|
||
|
||
@property
|
||
def top_p(self):
|
||
"""动态获取top_p"""
|
||
return config_module.LLM_CONFIG["top_p"]
|
||
|
||
def get_task_prompt(self, task_name: str, cursor=None) -> str:
|
||
"""
|
||
统一的提示词获取方法
|
||
|
||
Args:
|
||
task_name: 任务名称,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等
|
||
cursor: 数据库游标,如果传入则使用,否则创建新连接
|
||
|
||
Returns:
|
||
str: 提示词内容,如果未找到返回默认提示词
|
||
"""
|
||
query = """
|
||
SELECT p.content
|
||
FROM prompt_config pc
|
||
JOIN prompts p ON pc.prompt_id = p.id
|
||
WHERE pc.task_name = %s
|
||
"""
|
||
|
||
if cursor:
|
||
cursor.execute(query, (task_name,))
|
||
result = cursor.fetchone()
|
||
if result:
|
||
return result['content'] if isinstance(result, dict) else result[0]
|
||
else:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
cursor.execute(query, (task_name,))
|
||
result = cursor.fetchone()
|
||
if result:
|
||
return result['content']
|
||
|
||
# 返回默认提示词
|
||
return self._get_default_prompt(task_name)
|
||
|
||
def _get_default_prompt(self, task_name: str) -> str:
|
||
"""获取默认提示词"""
|
||
default_prompts = {
|
||
'MEETING_TASK': self.system_prompt, # 使用配置文件中的系统提示词
|
||
'KNOWLEDGE_TASK': "请根据提供的信息生成知识库文章。",
|
||
}
|
||
return default_prompts.get(task_name, "请根据提供的内容进行总结和分析。")
|
||
|
||
def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]:
|
||
"""流式调用阿里Qwen3大模型API"""
|
||
try:
|
||
responses = dashscope.Generation.call(
|
||
model=self.model_name,
|
||
prompt=prompt,
|
||
stream=True,
|
||
timeout=self.time_out,
|
||
temperature=self.temperature,
|
||
top_p=self.top_p,
|
||
incremental_output=True # 开启增量输出模式
|
||
)
|
||
|
||
for response in responses:
|
||
if response.status_code == HTTPStatus.OK:
|
||
# 增量输出内容
|
||
new_content = response.output.get('text', '')
|
||
if new_content:
|
||
yield new_content
|
||
else:
|
||
error_msg = f"Request failed with status code: {response.status_code}, Error: {response.message}"
|
||
print(error_msg)
|
||
yield f"error: {error_msg}"
|
||
break
|
||
|
||
except Exception as e:
|
||
error_msg = f"流式调用大模型API错误: {e}"
|
||
print(error_msg)
|
||
yield f"error: {error_msg}"
|
||
|
||
def _call_llm_api(self, prompt: str) -> Optional[str]:
|
||
"""调用阿里Qwen3大模型API(非流式)"""
|
||
try:
|
||
response = dashscope.Generation.call(
|
||
model=self.model_name,
|
||
prompt=prompt,
|
||
timeout=self.time_out,
|
||
temperature=self.temperature,
|
||
top_p=self.top_p
|
||
)
|
||
|
||
if response.status_code == HTTPStatus.OK:
|
||
return response.output.get('text', '')
|
||
else:
|
||
print(f"API调用失败: {response.status_code}, {response.message}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"调用大模型API错误: {e}")
|
||
return None
|
||
|
||
|
||
# 测试代码
|
||
if __name__ == '__main__':
|
||
print("--- 运行LLM服务测试 ---")
|
||
llm_service = LLMService()
|
||
|
||
# 测试获取任务提示词
|
||
meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
|
||
print(f"会议任务提示词: {meeting_prompt[:100]}...")
|
||
|
||
knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||
print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
|
||
|
||
print("--- LLM服务测试完成 ---")
|