import uuid from datetime import datetime from typing import Optional, Dict, Any, List import redis from app.core.database import get_db_connection from app.services.llm_service import LLMService class AsyncKnowledgeBaseService: def __init__(self): from app.core.config import REDIS_CONFIG if 'decode_responses' not in REDIS_CONFIG: REDIS_CONFIG['decode_responses'] = True self.redis_client = redis.Redis(**REDIS_CONFIG) self.llm_service = LLMService() def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str: task_id = str(uuid.uuid4()) # If a cursor is passed, use it directly to avoid creating a new transaction if cursor: query = """ INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at) VALUES (%s, %s, %s, %s, NOW()) """ cursor.execute(query, (task_id, user_id, kb_id, user_prompt)) else: # Fallback to the old method if no cursor is provided self._save_task_to_db(task_id, user_id, kb_id, user_prompt) current_time = datetime.now().isoformat() task_data = { 'task_id': task_id, 'user_id': str(user_id), 'kb_id': str(kb_id), 'user_prompt': user_prompt if user_prompt else "", 'status': 'pending', 'progress': '0', 'created_at': current_time, 'updated_at': current_time } self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data) self.redis_client.expire(f"kb_task:{task_id}", 86400) print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}") return task_id def _process_task(self, task_id: str): print(f"Background task started for knowledge base task: {task_id}") try: task_data = self.redis_client.hgetall(f"kb_task:{task_id}") if not task_data: print(f"Error: Task {task_id} not found in Redis for processing.") return kb_id = int(task_data['kb_id']) user_prompt = task_data.get('user_prompt', '') self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...") with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) # Get source meeting summaries source_text = "" cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,)) kb_info = cursor.fetchone() if kb_info and kb_info['source_meeting_ids']: self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...") meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()] if meeting_ids: summaries = [] for meeting_id in meeting_ids: cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,)) summary = cursor.fetchone() if summary and summary['summary']: summaries.append(summary['summary']) source_text = "\n\n---\n\n".join(summaries) # Get system prompt self._update_task_status_in_redis(task_id, 'processing', 30, message="获取知识库生成模版...") system_prompt = self._get_knowledge_task_prompt(cursor) # Build final prompt final_prompt = f"{system_prompt}\n\n" if source_text: final_prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n" final_prompt += f"用户要求:{user_prompt}" self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...") generated_content = self.llm_service._call_llm_api(final_prompt) if not generated_content: raise Exception("LLM API call failed or returned empty content") self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...") self._save_result_to_db(kb_id, generated_content, cursor) self._update_task_in_db(task_id, 'completed', 100, cursor=cursor) self._update_task_status_in_redis(task_id, 'completed', 100) connection.commit() print(f"Task {task_id} completed successfully") except Exception as e: error_msg = str(e) print(f"Task {task_id} failed: {error_msg}") # Use a new connection for error logging to avoid issues with a potentially broken transaction with get_db_connection() as err_conn: err_cursor = err_conn.cursor() self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg, cursor=err_cursor) err_conn.commit() self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg) def _get_knowledge_task_prompt(self, cursor) -> str: query = """ SELECT p.content FROM prompt_config pc JOIN prompts p ON pc.prompt_id = p.id WHERE pc.task_name = 'KNOWLEDGE_TASK' """ cursor.execute(query) result = cursor.fetchone() if result: return result['content'] else: # Fallback prompt return "Please generate a knowledge base article based on the provided information." def _save_result_to_db(self, kb_id: int, content: str, cursor): query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s" cursor.execute(query, (content, kb_id)) def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None, cursor=None): query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s" cursor.execute(query, (status, progress, error_message, status, task_id)) def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None): update_data = { 'status': status, 'progress': str(progress), 'updated_at': datetime.now().isoformat() } if message: update_data['message'] = message if error_message: update_data['error_message'] = error_message self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data) def get_task_status(self, task_id: str) -> Dict[str, Any]: """获取任务状态 - 与 async_llm_service 保持一致""" try: task_data = self.redis_client.hgetall(f"kb_task:{task_id}") if not task_data: task_data = self._get_task_from_db(task_id) if not task_data: return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'} return { 'task_id': task_id, 'status': task_data.get('status', 'unknown'), 'progress': int(task_data.get('progress', 0)), 'kb_id': int(task_data.get('kb_id', 0)), 'created_at': task_data.get('created_at'), 'updated_at': task_data.get('updated_at'), 'error_message': task_data.get('error_message') } except Exception as e: print(f"Error getting task status: {e}") return {'task_id': task_id, 'status': 'error', 'error_message': str(e)} def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str): """保存任务到数据库""" try: with get_db_connection() as connection: cursor = connection.cursor() insert_query = "INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, %s, 'pending', 0, NOW())" cursor.execute(insert_query, (task_id, user_id, kb_id, user_prompt)) connection.commit() except Exception as e: print(f"Error saving task to database: {e}") raise def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]: """从数据库获取任务信息""" try: with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) query = "SELECT * FROM knowledge_base_tasks WHERE task_id = %s" cursor.execute(query, (task_id,)) task = cursor.fetchone() if task: # 确保所有字段都是字符串,以匹配Redis的行为 return {k: v.isoformat() if isinstance(v, datetime) else str(v) for k, v in task.items()} return None except Exception as e: print(f"Error getting task from database: {e}") return None async_kb_service = AsyncKnowledgeBaseService()