imetting_backend/app/services/llm_service.py

148 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import json
import dashscope
from http import HTTPStatus
from typing import Optional, Dict, List, Generator
import app.core.config as config_module
from app.core.database import get_db_connection
class LLMService:
"""LLM服务 - 专注于大模型API调用和提示词管理"""
def __init__(self):
# 设置dashscope API key
dashscope.api_key = config_module.QWEN_API_KEY
@property
def model_name(self):
"""动态获取模型名称"""
return config_module.LLM_CONFIG["model_name"]
@property
def system_prompt(self):
"""动态获取系统提示词"""
return config_module.LLM_CONFIG["system_prompt"]
@property
def time_out(self):
"""动态获取超时时间"""
return config_module.LLM_CONFIG["time_out"]
@property
def temperature(self):
"""动态获取temperature"""
return config_module.LLM_CONFIG["temperature"]
@property
def top_p(self):
"""动态获取top_p"""
return config_module.LLM_CONFIG["top_p"]
def get_task_prompt(self, task_name: str, cursor=None) -> str:
"""
统一的提示词获取方法
Args:
task_name: 任务名称,如 'MEETING_TASK', 'KNOWLEDGE_TASK'
cursor: 数据库游标,如果传入则使用,否则创建新连接
Returns:
str: 提示词内容,如果未找到返回默认提示词
"""
query = """
SELECT p.content
FROM prompt_config pc
JOIN prompts p ON pc.prompt_id = p.id
WHERE pc.task_name = %s
"""
if cursor:
cursor.execute(query, (task_name,))
result = cursor.fetchone()
if result:
return result['content'] if isinstance(result, dict) else result[0]
else:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute(query, (task_name,))
result = cursor.fetchone()
if result:
return result['content']
# 返回默认提示词
return self._get_default_prompt(task_name)
def _get_default_prompt(self, task_name: str) -> str:
"""获取默认提示词"""
default_prompts = {
'MEETING_TASK': self.system_prompt, # 使用配置文件中的系统提示词
'KNOWLEDGE_TASK': "请根据提供的信息生成知识库文章。",
}
return default_prompts.get(task_name, "请根据提供的内容进行总结和分析。")
def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]:
"""流式调用阿里Qwen3大模型API"""
try:
responses = dashscope.Generation.call(
model=self.model_name,
prompt=prompt,
stream=True,
timeout=self.time_out,
temperature=self.temperature,
top_p=self.top_p,
incremental_output=True # 开启增量输出模式
)
for response in responses:
if response.status_code == HTTPStatus.OK:
# 增量输出内容
new_content = response.output.get('text', '')
if new_content:
yield new_content
else:
error_msg = f"Request failed with status code: {response.status_code}, Error: {response.message}"
print(error_msg)
yield f"error: {error_msg}"
break
except Exception as e:
error_msg = f"流式调用大模型API错误: {e}"
print(error_msg)
yield f"error: {error_msg}"
def _call_llm_api(self, prompt: str) -> Optional[str]:
"""调用阿里Qwen3大模型API非流式"""
try:
response = dashscope.Generation.call(
model=self.model_name,
prompt=prompt,
timeout=self.time_out,
temperature=self.temperature,
top_p=self.top_p
)
if response.status_code == HTTPStatus.OK:
return response.output.get('text', '')
else:
print(f"API调用失败: {response.status_code}, {response.message}")
return None
except Exception as e:
print(f"调用大模型API错误: {e}")
return None
# 测试代码
if __name__ == '__main__':
print("--- 运行LLM服务测试 ---")
llm_service = LLMService()
# 测试获取任务提示词
meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
print(f"会议任务提示词: {meeting_prompt[:100]}...")
knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
print("--- LLM服务测试完成 ---")