imetting_backend/app/services/async_knowledge_base_servic...

202 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import uuid
from datetime import datetime
from typing import Optional, Dict, Any, List
import redis
from app.core.database import get_db_connection
from app.services.llm_service import LLMService
class AsyncKnowledgeBaseService:
def __init__(self):
from app.core.config import REDIS_CONFIG
if 'decode_responses' not in REDIS_CONFIG:
REDIS_CONFIG['decode_responses'] = True
self.redis_client = redis.Redis(**REDIS_CONFIG)
self.llm_service = LLMService()
def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str:
task_id = str(uuid.uuid4())
# If a cursor is passed, use it directly to avoid creating a new transaction
if cursor:
query = """
INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at)
VALUES (%s, %s, %s, %s, NOW())
"""
cursor.execute(query, (task_id, user_id, kb_id, user_prompt))
else:
# Fallback to the old method if no cursor is provided
self._save_task_to_db(task_id, user_id, kb_id, user_prompt)
current_time = datetime.now().isoformat()
task_data = {
'task_id': task_id,
'user_id': str(user_id),
'kb_id': str(kb_id),
'user_prompt': user_prompt if user_prompt else "",
'status': 'pending',
'progress': '0',
'created_at': current_time,
'updated_at': current_time
}
self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data)
self.redis_client.expire(f"kb_task:{task_id}", 86400)
print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}")
return task_id
def _process_task(self, task_id: str):
print(f"Background task started for knowledge base task: {task_id}")
try:
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
if not task_data:
print(f"Error: Task {task_id} not found in Redis for processing.")
return
kb_id = int(task_data['kb_id'])
user_prompt = task_data.get('user_prompt', '')
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# Get source meeting summaries
source_text = ""
cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
kb_info = cursor.fetchone()
if kb_info and kb_info['source_meeting_ids']:
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()]
if meeting_ids:
summaries = []
for meeting_id in meeting_ids:
cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,))
summary = cursor.fetchone()
if summary and summary['summary']:
summaries.append(summary['summary'])
source_text = "\n\n---\n\n".join(summaries)
# Get system prompt
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取知识库生成模版...")
system_prompt = self._get_knowledge_task_prompt(cursor)
# Build final prompt
final_prompt = f"{system_prompt}\n\n"
if source_text:
final_prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
final_prompt += f"用户要求:{user_prompt}"
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
generated_content = self.llm_service._call_llm_api(final_prompt)
if not generated_content:
raise Exception("LLM API call failed or returned empty content")
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
self._save_result_to_db(kb_id, generated_content, cursor)
self._update_task_in_db(task_id, 'completed', 100, cursor=cursor)
self._update_task_status_in_redis(task_id, 'completed', 100)
connection.commit()
print(f"Task {task_id} completed successfully")
except Exception as e:
error_msg = str(e)
print(f"Task {task_id} failed: {error_msg}")
# Use a new connection for error logging to avoid issues with a potentially broken transaction
with get_db_connection() as err_conn:
err_cursor = err_conn.cursor()
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg, cursor=err_cursor)
err_conn.commit()
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
def _get_knowledge_task_prompt(self, cursor) -> str:
query = """
SELECT p.content
FROM prompt_config pc
JOIN prompts p ON pc.prompt_id = p.id
WHERE pc.task_name = 'KNOWLEDGE_TASK'
"""
cursor.execute(query)
result = cursor.fetchone()
if result:
return result['content']
else:
# Fallback prompt
return "Please generate a knowledge base article based on the provided information."
def _save_result_to_db(self, kb_id: int, content: str, cursor):
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s"
cursor.execute(query, (content, kb_id))
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None, cursor=None):
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
cursor.execute(query, (status, progress, error_message, status, task_id))
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
update_data = {
'status': status,
'progress': str(progress),
'updated_at': datetime.now().isoformat()
}
if message: update_data['message'] = message
if error_message: update_data['error_message'] = error_message
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""获取任务状态 - 与 async_llm_service 保持一致"""
try:
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
if not task_data:
task_data = self._get_task_from_db(task_id)
if not task_data:
return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'}
return {
'task_id': task_id,
'status': task_data.get('status', 'unknown'),
'progress': int(task_data.get('progress', 0)),
'kb_id': int(task_data.get('kb_id', 0)),
'created_at': task_data.get('created_at'),
'updated_at': task_data.get('updated_at'),
'error_message': task_data.get('error_message')
}
except Exception as e:
print(f"Error getting task status: {e}")
return {'task_id': task_id, 'status': 'error', 'error_message': str(e)}
def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str):
"""保存任务到数据库"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
insert_query = "INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, %s, 'pending', 0, NOW())"
cursor.execute(insert_query, (task_id, user_id, kb_id, user_prompt))
connection.commit()
except Exception as e:
print(f"Error saving task to database: {e}")
raise
def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]:
"""从数据库获取任务信息"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = "SELECT * FROM knowledge_base_tasks WHERE task_id = %s"
cursor.execute(query, (task_id,))
task = cursor.fetchone()
if task:
# 确保所有字段都是字符串以匹配Redis的行为
return {k: v.isoformat() if isinstance(v, datetime) else str(v) for k, v in task.items()}
return None
except Exception as e:
print(f"Error getting task from database: {e}")
return None
async_kb_service = AsyncKnowledgeBaseService()