imetting_backend/app/services/async_knowledge_base_servic...

287 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
异步知识库服务 - 处理知识库生成的异步任务
采用FastAPI BackgroundTasks模式
"""
import uuid
from datetime import datetime
from typing import Optional, Dict, Any, List
import redis
from app.core.database import get_db_connection
from app.services.llm_service import LLMService
class AsyncKnowledgeBaseService:
"""异步知识库服务类 - 处理知识库相关的异步任务"""
def __init__(self):
from app.core.config import REDIS_CONFIG
if 'decode_responses' not in REDIS_CONFIG:
REDIS_CONFIG['decode_responses'] = True
self.redis_client = redis.Redis(**REDIS_CONFIG)
self.llm_service = LLMService()
def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str:
"""
创建异步知识库生成任务
Args:
user_id: 用户ID
kb_id: 知识库ID
user_prompt: 用户提示词
source_meeting_ids: 源会议ID列表
cursor: 数据库游标(可选)
Returns:
str: 任务ID
"""
task_id = str(uuid.uuid4())
# If a cursor is passed, use it directly to avoid creating a new transaction
if cursor:
query = """
INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at)
VALUES (%s, %s, %s, %s, NOW())
"""
cursor.execute(query, (task_id, user_id, kb_id, user_prompt))
else:
# Fallback to the old method if no cursor is provided
self._save_task_to_db(task_id, user_id, kb_id, user_prompt)
current_time = datetime.now().isoformat()
task_data = {
'task_id': task_id,
'user_id': str(user_id),
'kb_id': str(kb_id),
'user_prompt': user_prompt if user_prompt else "",
'status': 'pending',
'progress': '0',
'created_at': current_time,
'updated_at': current_time
}
self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data)
self.redis_client.expire(f"kb_task:{task_id}", 86400)
print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}")
return task_id
def _process_task(self, task_id: str):
"""
处理单个异步任务的函数设计为由BackgroundTasks调用。
"""
print(f"Background task started for knowledge base task: {task_id}")
try:
# 从Redis获取任务数据
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
if not task_data:
print(f"Error: Task {task_id} not found in Redis for processing.")
return
kb_id = int(task_data['kb_id'])
user_prompt = task_data.get('user_prompt', '')
# 1. 更新状态为processing
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
# 2. 获取关联的会议总结
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
source_text = self._get_meeting_summaries(kb_id)
# 3. 构建提示词
self._update_task_status_in_redis(task_id, 'processing', 30, message="准备AI提示词...")
full_prompt = self._build_prompt(source_text, user_prompt)
# 4. 调用LLM API
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
generated_content = self.llm_service._call_llm_api(full_prompt)
if not generated_content:
raise Exception("LLM API调用失败或返回空内容")
# 5. 保存结果到数据库
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
self._save_result_to_db(kb_id, generated_content)
# 6. 任务完成
self._update_task_in_db(task_id, 'completed', 100)
self._update_task_status_in_redis(task_id, 'completed', 100)
print(f"Task {task_id} completed successfully")
except Exception as e:
error_msg = str(e)
print(f"Task {task_id} failed: {error_msg}")
# 更新失败状态
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
# --- 知识库相关方法 ---
def _get_meeting_summaries(self, kb_id: int) -> str:
"""
从数据库获取知识库关联的会议总结
Args:
kb_id: 知识库ID
Returns:
str: 拼接后的会议总结文本
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 获取知识库的源会议ID列表
cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
kb_info = cursor.fetchone()
if not kb_info or not kb_info['source_meeting_ids']:
return ""
# 解析会议ID列表
meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()]
if not meeting_ids:
return ""
# 获取所有会议的总结
summaries = []
for meeting_id in meeting_ids:
cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,))
summary = cursor.fetchone()
if summary and summary['summary']:
summaries.append(summary['summary'])
# 用分隔符拼接多个会议总结
return "\n\n---\n\n".join(summaries)
except Exception as e:
print(f"获取会议总结错误: {e}")
return ""
def _build_prompt(self, source_text: str, user_prompt: str) -> str:
"""
构建完整的提示词
使用数据库中配置的KNOWLEDGE_TASK提示词模板
Args:
source_text: 源会议总结文本
user_prompt: 用户自定义提示词
Returns:
str: 完整的提示词
"""
# 从数据库获取知识库任务的提示词模板
system_prompt = self.llm_service.get_task_prompt('KNOWLEDGE_TASK')
prompt = f"{system_prompt}\n\n"
if source_text:
prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
prompt += f"用户要求:{user_prompt}"
return prompt
def _save_result_to_db(self, kb_id: int, content: str) -> Optional[int]:
"""
保存生成结果到数据库
Args:
kb_id: 知识库ID
content: 生成的内容
Returns:
Optional[int]: 知识库ID失败返回None
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s"
cursor.execute(query, (content, kb_id))
connection.commit()
print(f"成功保存知识库内容kb_id: {kb_id}")
return kb_id
except Exception as e:
print(f"保存知识库内容错误: {e}")
return None
# --- 状态查询和数据库操作方法 ---
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""获取任务状态"""
try:
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
if not task_data:
task_data = self._get_task_from_db(task_id)
if not task_data:
return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'}
return {
'task_id': task_id,
'status': task_data.get('status', 'unknown'),
'progress': int(task_data.get('progress', 0)),
'kb_id': int(task_data.get('kb_id', 0)),
'created_at': task_data.get('created_at'),
'updated_at': task_data.get('updated_at'),
'error_message': task_data.get('error_message')
}
except Exception as e:
print(f"Error getting task status: {e}")
return {'task_id': task_id, 'status': 'error', 'error_message': str(e)}
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
"""更新Redis中的任务状态"""
try:
update_data = {
'status': status,
'progress': str(progress),
'updated_at': datetime.now().isoformat()
}
if message: update_data['message'] = message
if error_message: update_data['error_message'] = error_message
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
except Exception as e:
print(f"Error updating task status in Redis: {e}")
def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str):
"""保存任务到数据库"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
insert_query = "INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, %s, 'pending', 0, NOW())"
cursor.execute(insert_query, (task_id, user_id, kb_id, user_prompt))
connection.commit()
except Exception as e:
print(f"Error saving task to database: {e}")
raise
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None):
"""更新数据库中的任务状态"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
cursor.execute(query, (status, progress, error_message, status, task_id))
connection.commit()
except Exception as e:
print(f"Error updating task in database: {e}")
def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]:
"""从数据库获取任务信息"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = "SELECT * FROM knowledge_base_tasks WHERE task_id = %s"
cursor.execute(query, (task_id,))
task = cursor.fetchone()
if task:
# 确保所有字段都是字符串以匹配Redis的行为
return {k: v.isoformat() if isinstance(v, datetime) else str(v) for k, v in task.items()}
return None
except Exception as e:
print(f"Error getting task from database: {e}")
return None
# 创建全局实例
async_kb_service = AsyncKnowledgeBaseService()