diff --git a/.DS_Store b/.DS_Store index 38b7f27..ef921b4 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/app.zip b/app.zip index e60e17b..47f81f0 100644 Binary files a/app.zip and b/app.zip differ diff --git a/app/api/endpoints/meetings.py b/app/api/endpoints/meetings.py index 3f80b93..9ef76b2 100644 --- a/app/api/endpoints/meetings.py +++ b/app/api/endpoints/meetings.py @@ -5,7 +5,7 @@ from app.core.config import BASE_DIR, AUDIO_DIR, MARKDOWN_DIR, ALLOWED_EXTENSION import app.core.config as config_module from app.services.llm_service import LLMService from app.services.async_transcription_service import AsyncTranscriptionService -from app.services.async_llm_service import async_llm_service +from app.services.async_meeting_service import async_meeting_service from app.core.auth import get_current_user from app.core.response import create_api_response from typing import List, Optional @@ -449,8 +449,8 @@ def generate_meeting_summary_async(meeting_id: int, request: GenerateSummaryRequ cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,)) if not cursor.fetchone(): return create_api_response(code="404", message="Meeting not found") - task_id = async_llm_service.start_summary_generation(meeting_id, request.user_prompt) - background_tasks.add_task(async_llm_service._process_task, task_id) + task_id = async_meeting_service.start_summary_generation(meeting_id, request.user_prompt) + background_tasks.add_task(async_meeting_service._process_task, task_id) return create_api_response(code="200", message="Summary generation task has been accepted.", data={ "task_id": task_id, "status": "pending", "meeting_id": meeting_id }) @@ -465,7 +465,7 @@ def get_meeting_llm_tasks(meeting_id: int, current_user: dict = Depends(get_curr cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,)) if not cursor.fetchone(): return create_api_response(code="404", message="Meeting not found") - tasks = async_llm_service.get_meeting_llm_tasks(meeting_id) + tasks = async_meeting_service.get_meeting_llm_tasks(meeting_id) return create_api_response(code="200", message="LLM tasks retrieved successfully", data={ "tasks": tasks, "total": len(tasks) }) diff --git a/app/api/endpoints/prompts.py b/app/api/endpoints/prompts.py index b541228..a6a093b 100644 --- a/app/api/endpoints/prompts.py +++ b/app/api/endpoints/prompts.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends from pydantic import BaseModel from typing import List, Optional -from app.core.auth import get_current_admin_user +from app.core.auth import get_current_user from app.core.database import get_db_connection from app.core.response import create_api_response @@ -23,14 +23,14 @@ class PromptListResponse(BaseModel): total: int @router.post("/prompts") -def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)): +def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)): """Create a new prompt.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) try: cursor.execute( - "INSERT INTO prompts (name, tags, content) VALUES (%s, %s, %s)", - (prompt.name, prompt.tags, prompt.content) + "INSERT INTO prompts (name, tags, content, creator_id) VALUES (%s, %s, %s, %s)", + (prompt.name, prompt.tags, prompt.content, current_user["user_id"]) ) connection.commit() new_id = cursor.lastrowid @@ -41,23 +41,27 @@ def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_adm return create_api_response(code="500", message=f"创建提示词失败: {e}") @router.get("/prompts") -def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_admin_user)): - """Get a paginated list of prompts.""" +def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_user)): + """Get a paginated list of prompts filtered by current user.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) - cursor.execute("SELECT COUNT(*) as total FROM prompts") + # 只获取当前用户创建的提示词 + cursor.execute( + "SELECT COUNT(*) as total FROM prompts WHERE creator_id = %s", + (current_user["user_id"],) + ) total = cursor.fetchone()['total'] offset = (page - 1) * size cursor.execute( - "SELECT id, name, tags, content, created_at FROM prompts ORDER BY created_at DESC LIMIT %s OFFSET %s", - (size, offset) + "SELECT id, name, tags, content, created_at FROM prompts WHERE creator_id = %s ORDER BY created_at DESC LIMIT %s OFFSET %s", + (current_user["user_id"], size, offset) ) prompts = cursor.fetchall() return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total}) @router.get("/prompts/{prompt_id}") -def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)): +def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)): """Get a single prompt by its ID.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) @@ -68,7 +72,7 @@ def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_us return create_api_response(code="200", message="获取提示词成功", data=prompt) @router.put("/prompts/{prompt_id}") -def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)): +def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)): """Update an existing prompt.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) @@ -87,12 +91,24 @@ def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends return create_api_response(code="500", message=f"更新提示词失败: {e}") @router.delete("/prompts/{prompt_id}") -def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)): - """Delete a prompt.""" +def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)): + """Delete a prompt. Only the creator can delete their own prompts.""" with get_db_connection() as connection: - cursor = connection.cursor() - cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,)) - if cursor.rowcount == 0: + cursor = connection.cursor(dictionary=True) + # 首先检查提示词是否存在以及是否属于当前用户 + cursor.execute( + "SELECT creator_id FROM prompts WHERE id = %s", + (prompt_id,) + ) + prompt = cursor.fetchone() + + if not prompt: return create_api_response(code="404", message="提示词不存在") + + if prompt['creator_id'] != current_user["user_id"]: + return create_api_response(code="403", message="无权删除其他用户的提示词") + + # 删除提示词 + cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,)) connection.commit() return create_api_response(code="200", message="提示词删除成功") diff --git a/app/api/endpoints/tags.py b/app/api/endpoints/tags.py index 0964963..b4c802f 100644 --- a/app/api/endpoints/tags.py +++ b/app/api/endpoints/tags.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends from app.core.database import get_db_connection from app.core.response import create_api_response +from app.core.auth import get_current_user from app.models.models import Tag from typing import List import mysql.connector @@ -24,16 +25,16 @@ def get_all_tags(): return create_api_response(code="500", message="获取标签失败") @router.post("/tags/") -def create_tag(tag_in: Tag): +def create_tag(tag_in: Tag, current_user: dict = Depends(get_current_user)): """_summary_ - 创建一个新标签 + 创建一个新标签,并记录创建者 """ - query = "INSERT INTO tags (name, color) VALUES (%s, %s)" + query = "INSERT INTO tags (name, color, creator_id) VALUES (%s, %s, %s)" try: with get_db_connection() as connection: with connection.cursor(dictionary=True) as cursor: try: - cursor.execute(query, (tag_in.name, tag_in.color)) + cursor.execute(query, (tag_in.name, tag_in.color, current_user["user_id"])) connection.commit() tag_id = cursor.lastrowid new_tag = {"id": tag_id, "name": tag_in.name, "color": tag_in.color} diff --git a/app/api/endpoints/tasks.py b/app/api/endpoints/tasks.py index dd64973..384c3b8 100644 --- a/app/api/endpoints/tasks.py +++ b/app/api/endpoints/tasks.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends from app.core.auth import get_current_user from app.core.response import create_api_response from app.services.async_transcription_service import AsyncTranscriptionService -from app.services.async_llm_service import async_llm_service +from app.services.async_meeting_service import async_meeting_service router = APIRouter() @@ -23,7 +23,7 @@ def get_transcription_task_status(task_id: str, current_user: dict = Depends(get def get_llm_task_status(task_id: str, current_user: dict = Depends(get_current_user)): """获取LLM总结任务状态(包括进度)""" try: - status = async_llm_service.get_task_status(task_id) + status = async_meeting_service.get_task_status(task_id) if status.get('status') == 'not_found': return create_api_response(code="404", message="Task not found") return create_api_response(code="200", message="Task status retrieved", data=status) diff --git a/app/services/async_knowledge_base_service.py b/app/services/async_knowledge_base_service.py index 49cecb5..4f19838 100644 --- a/app/services/async_knowledge_base_service.py +++ b/app/services/async_knowledge_base_service.py @@ -1,3 +1,7 @@ +""" +异步知识库服务 - 处理知识库生成的异步任务 +采用FastAPI BackgroundTasks模式 +""" import uuid from datetime import datetime from typing import Optional, Dict, Any, List @@ -7,6 +11,7 @@ from app.core.database import get_db_connection from app.services.llm_service import LLMService class AsyncKnowledgeBaseService: + """异步知识库服务类 - 处理知识库相关的异步任务""" def __init__(self): from app.core.config import REDIS_CONFIG @@ -16,12 +21,25 @@ class AsyncKnowledgeBaseService: self.llm_service = LLMService() def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str: + """ + 创建异步知识库生成任务 + + Args: + user_id: 用户ID + kb_id: 知识库ID + user_prompt: 用户提示词 + source_meeting_ids: 源会议ID列表 + cursor: 数据库游标(可选) + + Returns: + str: 任务ID + """ task_id = str(uuid.uuid4()) - + # If a cursor is passed, use it directly to avoid creating a new transaction if cursor: query = """ - INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at) + INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at) VALUES (%s, %s, %s, %s, NOW()) """ cursor.execute(query, (task_id, user_id, kb_id, user_prompt)) @@ -42,13 +60,17 @@ class AsyncKnowledgeBaseService: } self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data) self.redis_client.expire(f"kb_task:{task_id}", 86400) - + print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}") return task_id def _process_task(self, task_id: str): + """ + 处理单个异步任务的函数,设计为由BackgroundTasks调用。 + """ print(f"Background task started for knowledge base task: {task_id}") try: + # 从Redis获取任务数据 task_data = self.redis_client.hgetall(f"kb_task:{task_id}") if not task_data: print(f"Error: Task {task_id} not found in Redis for processing.") @@ -57,99 +79,136 @@ class AsyncKnowledgeBaseService: kb_id = int(task_data['kb_id']) user_prompt = task_data.get('user_prompt', '') + # 1. 更新状态为processing self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...") - with get_db_connection() as connection: - cursor = connection.cursor(dictionary=True) + # 2. 获取关联的会议总结 + self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...") + source_text = self._get_meeting_summaries(kb_id) - # Get source meeting summaries - source_text = "" - cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,)) - kb_info = cursor.fetchone() - if kb_info and kb_info['source_meeting_ids']: - self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...") - meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()] - if meeting_ids: - summaries = [] - for meeting_id in meeting_ids: - cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,)) - summary = cursor.fetchone() - if summary and summary['summary']: - summaries.append(summary['summary']) - source_text = "\n\n---\n\n".join(summaries) + # 3. 构建提示词 + self._update_task_status_in_redis(task_id, 'processing', 30, message="准备AI提示词...") + full_prompt = self._build_prompt(source_text, user_prompt) - # Get system prompt - self._update_task_status_in_redis(task_id, 'processing', 30, message="获取知识库生成模版...") - system_prompt = self._get_knowledge_task_prompt(cursor) - - # Build final prompt - final_prompt = f"{system_prompt}\n\n" - if source_text: - final_prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n" - final_prompt += f"用户要求:{user_prompt}" + # 4. 调用LLM API + self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...") + generated_content = self.llm_service._call_llm_api(full_prompt) + if not generated_content: + raise Exception("LLM API调用失败或返回空内容") - self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...") - generated_content = self.llm_service._call_llm_api(final_prompt) + # 5. 保存结果到数据库 + self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...") + self._save_result_to_db(kb_id, generated_content) - if not generated_content: - raise Exception("LLM API call failed or returned empty content") + # 6. 任务完成 + self._update_task_in_db(task_id, 'completed', 100) + self._update_task_status_in_redis(task_id, 'completed', 100) - self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...") - self._save_result_to_db(kb_id, generated_content, cursor) - - self._update_task_in_db(task_id, 'completed', 100, cursor=cursor) - self._update_task_status_in_redis(task_id, 'completed', 100) - - connection.commit() - print(f"Task {task_id} completed successfully") + print(f"Task {task_id} completed successfully") except Exception as e: error_msg = str(e) print(f"Task {task_id} failed: {error_msg}") - # Use a new connection for error logging to avoid issues with a potentially broken transaction - with get_db_connection() as err_conn: - err_cursor = err_conn.cursor() - self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg, cursor=err_cursor) - err_conn.commit() + # 更新失败状态 + self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg) self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg) - def _get_knowledge_task_prompt(self, cursor) -> str: - query = """ - SELECT p.content - FROM prompt_config pc - JOIN prompts p ON pc.prompt_id = p.id - WHERE pc.task_name = 'KNOWLEDGE_TASK' + # --- 知识库相关方法 --- + + def _get_meeting_summaries(self, kb_id: int) -> str: """ - cursor.execute(query) - result = cursor.fetchone() - if result: - return result['content'] - else: - # Fallback prompt - return "Please generate a knowledge base article based on the provided information." + 从数据库获取知识库关联的会议总结 + Args: + kb_id: 知识库ID + Returns: + str: 拼接后的会议总结文本 + """ + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) - def _save_result_to_db(self, kb_id: int, content: str, cursor): - query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s" - cursor.execute(query, (content, kb_id)) + # 获取知识库的源会议ID列表 + cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,)) + kb_info = cursor.fetchone() - def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None, cursor=None): - query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s" - cursor.execute(query, (status, progress, error_message, status, task_id)) + if not kb_info or not kb_info['source_meeting_ids']: + return "" - def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None): - update_data = { - 'status': status, - 'progress': str(progress), - 'updated_at': datetime.now().isoformat() - } - if message: update_data['message'] = message - if error_message: update_data['error_message'] = error_message - self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data) + # 解析会议ID列表 + meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()] + if not meeting_ids: + return "" + + # 获取所有会议的总结 + summaries = [] + for meeting_id in meeting_ids: + cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,)) + summary = cursor.fetchone() + if summary and summary['summary']: + summaries.append(summary['summary']) + + # 用分隔符拼接多个会议总结 + return "\n\n---\n\n".join(summaries) + + except Exception as e: + print(f"获取会议总结错误: {e}") + return "" + + def _build_prompt(self, source_text: str, user_prompt: str) -> str: + """ + 构建完整的提示词 + 使用数据库中配置的KNOWLEDGE_TASK提示词模板 + + Args: + source_text: 源会议总结文本 + user_prompt: 用户自定义提示词 + + Returns: + str: 完整的提示词 + """ + # 从数据库获取知识库任务的提示词模板 + system_prompt = self.llm_service.get_task_prompt('KNOWLEDGE_TASK') + + prompt = f"{system_prompt}\n\n" + + if source_text: + prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n" + + prompt += f"用户要求:{user_prompt}" + + return prompt + + def _save_result_to_db(self, kb_id: int, content: str) -> Optional[int]: + """ + 保存生成结果到数据库 + + Args: + kb_id: 知识库ID + content: 生成的内容 + + Returns: + Optional[int]: 知识库ID,失败返回None + """ + try: + with get_db_connection() as connection: + cursor = connection.cursor() + query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s" + cursor.execute(query, (content, kb_id)) + connection.commit() + + print(f"成功保存知识库内容,kb_id: {kb_id}") + return kb_id + + except Exception as e: + print(f"保存知识库内容错误: {e}") + return None + + # --- 状态查询和数据库操作方法 --- def get_task_status(self, task_id: str) -> Dict[str, Any]: - """获取任务状态 - 与 async_llm_service 保持一致""" + """获取任务状态""" try: task_data = self.redis_client.hgetall(f"kb_task:{task_id}") if not task_data: @@ -170,6 +229,20 @@ class AsyncKnowledgeBaseService: print(f"Error getting task status: {e}") return {'task_id': task_id, 'status': 'error', 'error_message': str(e)} + def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None): + """更新Redis中的任务状态""" + try: + update_data = { + 'status': status, + 'progress': str(progress), + 'updated_at': datetime.now().isoformat() + } + if message: update_data['message'] = message + if error_message: update_data['error_message'] = error_message + self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data) + except Exception as e: + print(f"Error updating task status in Redis: {e}") + def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str): """保存任务到数据库""" try: @@ -182,6 +255,17 @@ class AsyncKnowledgeBaseService: print(f"Error saving task to database: {e}") raise + def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None): + """更新数据库中的任务状态""" + try: + with get_db_connection() as connection: + cursor = connection.cursor() + query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s" + cursor.execute(query, (status, progress, error_message, status, task_id)) + connection.commit() + except Exception as e: + print(f"Error updating task in database: {e}") + def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]: """从数据库获取任务信息""" try: @@ -198,4 +282,5 @@ class AsyncKnowledgeBaseService: print(f"Error getting task from database: {e}") return None +# 创建全局实例 async_kb_service = AsyncKnowledgeBaseService() diff --git a/app/services/async_llm_service.py b/app/services/async_meeting_service.py similarity index 72% rename from app/services/async_llm_service.py rename to app/services/async_meeting_service.py index f26964e..7fd9b4b 100644 --- a/app/services/async_llm_service.py +++ b/app/services/async_meeting_service.py @@ -1,5 +1,5 @@ """ -异步LLM服务 - 处理会议总结生成的异步任务 +异步会议服务 - 处理会议总结生成的异步任务 采用FastAPI BackgroundTasks模式 """ import uuid @@ -12,30 +12,30 @@ from app.core.config import REDIS_CONFIG from app.core.database import get_db_connection from app.services.llm_service import LLMService -class AsyncLLMService: - """异步LLM服务类 - 采用FastAPI BackgroundTasks模式""" - +class AsyncMeetingService: + """异步会议服务类 - 处理会议相关的异步任务""" + def __init__(self): # 确保redis客户端自动解码响应,代码更简洁 if 'decode_responses' not in REDIS_CONFIG: REDIS_CONFIG['decode_responses'] = True self.redis_client = redis.Redis(**REDIS_CONFIG) self.llm_service = LLMService() # 复用现有的同步LLM服务 - + def start_summary_generation(self, meeting_id: int, user_prompt: str = "") -> str: """ 创建异步总结任务,任务的执行将由外部(如API层的BackgroundTasks)触发。 - + Args: meeting_id: 会议ID user_prompt: 用户额外提示词 - + Returns: str: 任务ID """ try: task_id = str(uuid.uuid4()) - + # 在数据库中创建任务记录 self._save_task_to_db(task_id, meeting_id, user_prompt) @@ -52,10 +52,10 @@ class AsyncLLMService: } self.redis_client.hset(f"llm_task:{task_id}", mapping=task_data) self.redis_client.expire(f"llm_task:{task_id}", 86400) - - print(f"LLM summary task created: {task_id} for meeting: {meeting_id}") + + print(f"Meeting summary task created: {task_id} for meeting: {meeting_id}") return task_id - + except Exception as e: print(f"Error starting summary generation: {e}") raise e @@ -64,7 +64,7 @@ class AsyncLLMService: """ 处理单个异步任务的函数,设计为由BackgroundTasks调用。 """ - print(f"Background task started for LLM task: {task_id}") + print(f"Background task started for meeting summary task: {task_id}") try: # 从Redis获取任务数据 task_data = self.redis_client.hgetall(f"llm_task:{task_id}") @@ -80,13 +80,13 @@ class AsyncLLMService: # 2. 获取会议转录内容 self._update_task_status_in_redis(task_id, 'processing', 30, message="获取会议转录内容...") - transcript_text = self.llm_service._get_meeting_transcript(meeting_id) + transcript_text = self._get_meeting_transcript(meeting_id) if not transcript_text: raise Exception("无法获取会议转录内容") # 3. 构建提示词 self._update_task_status_in_redis(task_id, 'processing', 40, message="准备AI提示词...") - full_prompt = self.llm_service._build_prompt(transcript_text, user_prompt) + full_prompt = self._build_prompt(transcript_text, user_prompt) # 4. 调用LLM API self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在分析会议内容...") @@ -96,7 +96,7 @@ class AsyncLLMService: # 5. 保存结果到主表 self._update_task_status_in_redis(task_id, 'processing', 95, message="保存总结结果...") - self.llm_service._save_summary_to_db(meeting_id, summary_content, user_prompt) + self._save_summary_to_db(meeting_id, summary_content, user_prompt) # 6. 任务完成 self._update_task_in_db(task_id, 'completed', 100, result=summary_content) @@ -110,6 +110,78 @@ class AsyncLLMService: self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg) self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg) + # --- 会议相关方法 --- + + def _get_meeting_transcript(self, meeting_id: int) -> str: + """从数据库获取会议转录内容""" + try: + with get_db_connection() as connection: + cursor = connection.cursor() + query = """ + SELECT speaker_tag, start_time_ms, end_time_ms, text_content + FROM transcript_segments + WHERE meeting_id = %s + ORDER BY start_time_ms + """ + cursor.execute(query, (meeting_id,)) + segments = cursor.fetchall() + + if not segments: + return "" + + # 组装转录文本 + transcript_lines = [] + for speaker_tag, start_time, end_time, text in segments: + # 将毫秒转换为分:秒格式 + start_min = start_time // 60000 + start_sec = (start_time % 60000) // 1000 + transcript_lines.append(f"[{start_min:02d}:{start_sec:02d}] 说话人{speaker_tag}: {text}") + + return "\n".join(transcript_lines) + + except Exception as e: + print(f"获取会议转录内容错误: {e}") + return "" + + def _build_prompt(self, transcript_text: str, user_prompt: str) -> str: + """ + 构建完整的提示词 + 使用数据库中配置的MEETING_TASK提示词模板 + """ + # 从数据库获取会议任务的提示词模板 + system_prompt = self.llm_service.get_task_prompt('MEETING_TASK') + + prompt = f"{system_prompt}\n\n" + + if user_prompt: + prompt += f"用户额外要求:{user_prompt}\n\n" + + prompt += f"会议转录内容:\n{transcript_text}\n\n请根据以上内容生成会议总结:" + + return prompt + + def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]: + """保存总结到数据库 - 更新meetings表的summary、user_prompt和updated_at字段""" + try: + with get_db_connection() as connection: + cursor = connection.cursor() + + # 更新meetings表的summary、user_prompt和updated_at字段 + update_query = """ + UPDATE meetings + SET summary = %s, user_prompt = %s, updated_at = NOW() + WHERE meeting_id = %s + """ + cursor.execute(update_query, (summary_content, user_prompt, meeting_id)) + connection.commit() + + print(f"成功保存会议总结到meetings表,meeting_id: {meeting_id}") + return meeting_id + + except Exception as e: + print(f"保存总结到数据库错误: {e}") + return None + # --- 状态查询和数据库操作方法 --- def get_task_status(self, task_id: str) -> Dict[str, Any]: @@ -120,7 +192,7 @@ class AsyncLLMService: task_data = self._get_task_from_db(task_id) if not task_data: return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'} - + return { 'task_id': task_id, 'status': task_data.get('status', 'unknown'), @@ -189,7 +261,7 @@ class AsyncLLMService: params.insert(2, result) else: query = "UPDATE llm_tasks SET status = %s, progress = %s, error_message = %s WHERE task_id = %s" - + cursor.execute(query, tuple(params)) connection.commit() except Exception as e: @@ -212,4 +284,4 @@ class AsyncLLMService: return None # 创建全局实例 -async_llm_service = AsyncLLMService() +async_meeting_service = AsyncMeetingService() diff --git a/app/services/llm_service.py b/app/services/llm_service.py index e1c2e59..6018862 100644 --- a/app/services/llm_service.py +++ b/app/services/llm_service.py @@ -7,6 +7,8 @@ from app.core.database import get_db_connection class LLMService: + """LLM服务 - 专注于大模型API调用和提示词管理""" + def __init__(self): # 设置dashscope API key dashscope.api_key = config_module.QWEN_API_KEY @@ -35,125 +37,49 @@ class LLMService: def top_p(self): """动态获取top_p""" return config_module.LLM_CONFIG["top_p"] - - def generate_meeting_summary_stream(self, meeting_id: int, user_prompt: str = "") -> Generator[str, None, None]: + + def get_task_prompt(self, task_name: str, cursor=None) -> str: """ - 流式生成会议总结 + 统一的提示词获取方法 Args: - meeting_id: 会议ID - user_prompt: 用户额外提示词 - - Yields: - str: 流式输出的内容片段 - """ - try: - # 获取会议转录内容 - transcript_text = self._get_meeting_transcript(meeting_id) - if not transcript_text: - yield "error: 无法获取会议转录内容" - return - - # 构建完整提示词 - full_prompt = self._build_prompt(transcript_text, user_prompt) - - # 调用大模型API进行流式生成 - full_content = "" - for chunk in self._call_llm_api_stream(full_prompt): - if chunk.startswith("error:"): - yield chunk - return - full_content += chunk - yield chunk - - # 保存完整总结到数据库 - if full_content: - self._save_summary_to_db(meeting_id, full_content, user_prompt) - - except Exception as e: - print(f"流式生成会议总结错误: {e}") - yield f"error: {str(e)}" - - def generate_meeting_summary(self, meeting_id: int, user_prompt: str = "") -> Optional[Dict]: - """ - 生成会议总结(非流式,保持向后兼容) - - Args: - meeting_id: 会议ID - user_prompt: 用户额外提示词 + task_name: 任务名称,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等 + cursor: 数据库游标,如果传入则使用,否则创建新连接 Returns: - 包含总结内容的字典,如果失败返回None + str: 提示词内容,如果未找到返回默认提示词 + """ + query = """ + SELECT p.content + FROM prompt_config pc + JOIN prompts p ON pc.prompt_id = p.id + WHERE pc.task_name = %s """ - try: - # 获取会议转录内容 - transcript_text = self._get_meeting_transcript(meeting_id) - if not transcript_text: - return {"error": "无法获取会议转录内容"} - # 构建完整提示词 - full_prompt = self._build_prompt(transcript_text, user_prompt) - - # 调用大模型API - response = self._call_llm_api(full_prompt) - - if response: - # 保存总结到数据库 - summary_id = self._save_summary_to_db(meeting_id, response, user_prompt) - return { - "summary_id": summary_id, - "content": response, - "meeting_id": meeting_id - } - else: - return {"error": "大模型API调用失败"} - - except Exception as e: - print(f"生成会议总结错误: {e}") - return {"error": str(e)} - - def _get_meeting_transcript(self, meeting_id: int) -> str: - """从数据库获取会议转录内容""" - try: + if cursor: + cursor.execute(query, (task_name,)) + result = cursor.fetchone() + if result: + return result['content'] if isinstance(result, dict) else result[0] + else: with get_db_connection() as connection: - cursor = connection.cursor() - query = """ - SELECT speaker_tag, start_time_ms, end_time_ms, text_content - FROM transcript_segments - WHERE meeting_id = %s - ORDER BY start_time_ms - """ - cursor.execute(query, (meeting_id,)) - segments = cursor.fetchall() - - if not segments: - return "" - - # 组装转录文本 - transcript_lines = [] - for speaker_tag, start_time, end_time, text in segments: - # 将毫秒转换为分:秒格式 - start_min = start_time // 60000 - start_sec = (start_time % 60000) // 1000 - transcript_lines.append(f"[{start_min:02d}:{start_sec:02d}] 说话人{speaker_tag}: {text}") - - return "\n".join(transcript_lines) - - except Exception as e: - print(f"获取会议转录内容错误: {e}") - return "" - - def _build_prompt(self, transcript_text: str, user_prompt: str) -> str: - """构建完整的提示词""" - prompt = f"{self.system_prompt}\n\n" - - if user_prompt: - prompt += f"用户额外要求:{user_prompt}\n\n" - - prompt += f"会议转录内容:\n{transcript_text}\n\n请根据以上内容生成会议总结:" - - return prompt - + cursor = connection.cursor(dictionary=True) + cursor.execute(query, (task_name,)) + result = cursor.fetchone() + if result: + return result['content'] + + # 返回默认提示词 + return self._get_default_prompt(task_name) + + def _get_default_prompt(self, task_name: str) -> str: + """获取默认提示词""" + default_prompts = { + 'MEETING_TASK': self.system_prompt, # 使用配置文件中的系统提示词 + 'KNOWLEDGE_TASK': "请根据提供的信息生成知识库文章。", + } + return default_prompts.get(task_name, "请根据提供的内容进行总结和分析。") + def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]: """流式调用阿里Qwen3大模型API""" try: @@ -185,7 +111,7 @@ class LLMService: yield f"error: {error_msg}" def _call_llm_api(self, prompt: str) -> Optional[str]: - """调用阿里Qwen3大模型API(非流式,保持向后兼容)""" + """调用阿里Qwen3大模型API(非流式)""" try: response = dashscope.Generation.call( model=self.model_name, @@ -204,96 +130,18 @@ class LLMService: except Exception as e: print(f"调用大模型API错误: {e}") return None - - def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]: - """保存总结到数据库 - 更新meetings表的summary字段""" - try: - with get_db_connection() as connection: - cursor = connection.cursor() - - # 更新meetings表的summary字段 - update_query = """ - UPDATE meetings - SET summary = %s - WHERE meeting_id = %s - """ - cursor.execute(update_query, (summary_content, meeting_id)) - connection.commit() - - print(f"成功保存会议总结到meetings表,meeting_id: {meeting_id}") - return meeting_id - - except Exception as e: - print(f"保存总结到数据库错误: {e}") - return None - - def get_meeting_summaries(self, meeting_id: int) -> List[Dict]: - """获取会议的当前总结 - 从meetings表的summary字段获取""" - try: - with get_db_connection() as connection: - cursor = connection.cursor() - query = """ - SELECT summary - FROM meetings - WHERE meeting_id = %s - """ - cursor.execute(query, (meeting_id,)) - result = cursor.fetchone() - - # 如果有总结内容,返回一个包含当前总结的列表格式(保持API一致性) - if result and result[0]: - return [{ - "id": meeting_id, - "content": result[0], - "user_prompt": "", # meetings表中没有user_prompt字段 - "created_at": None # meetings表中没有单独的总结创建时间 - }] - else: - return [] - - except Exception as e: - print(f"获取会议总结错误: {e}") - return [] - - def get_current_meeting_summary(self, meeting_id: int) -> Optional[str]: - """获取会议当前的总结内容 - 从meetings表的summary字段获取""" - try: - with get_db_connection() as connection: - cursor = connection.cursor() - query = """ - SELECT summary - FROM meetings - WHERE meeting_id = %s - """ - cursor.execute(query, (meeting_id,)) - result = cursor.fetchone() - - return result[0] if result and result[0] else None - - except Exception as e: - print(f"获取会议当前总结错误: {e}") - return None # 测试代码 if __name__ == '__main__': - # 测试LLM服务 - test_meeting_id = 38 - test_user_prompt = "请重点关注决策事项和待办任务" - print("--- 运行LLM服务测试 ---") llm_service = LLMService() - - # 生成总结 - result = llm_service.generate_meeting_summary(test_meeting_id, test_user_prompt) - if result.get("error"): - print(f"生成总结失败: {result['error']}") - else: - print(f"总结生成成功,ID: {result.get('summary_id')}") - print(f"总结内容: {result.get('content')[:200]}...") - - # 获取历史总结 - summaries = llm_service.get_meeting_summaries(test_meeting_id) - print(f"获取到 {len(summaries)} 个历史总结") - - print("--- LLM服务测试完成 ---") \ No newline at end of file + + # 测试获取任务提示词 + meeting_prompt = llm_service.get_task_prompt('MEETING_TASK') + print(f"会议任务提示词: {meeting_prompt[:100]}...") + + knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK') + print(f"知识库任务提示词: {knowledge_prompt[:100]}...") + + print("--- LLM服务测试完成 ---") diff --git a/migrations/add_meetings_fields.sql b/migrations/add_meetings_fields.sql new file mode 100644 index 0000000..af532f7 --- /dev/null +++ b/migrations/add_meetings_fields.sql @@ -0,0 +1,17 @@ +-- 为meetings表添加updated_at和user_prompt字段 +-- 执行日期: 2025-10-28 + +-- 添加updated_at字段 +ALTER TABLE meetings +ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP +AFTER created_at; + +-- 添加user_prompt字段 +ALTER TABLE meetings +ADD COLUMN user_prompt TEXT +AFTER summary; + +-- 为现有记录设置updated_at为created_at的值 +UPDATE meetings +SET updated_at = created_at +WHERE updated_at IS NULL; diff --git a/test/test_create_task.py b/test/test_create_task.py index 87228de..b029bb9 100644 --- a/test/test_create_task.py +++ b/test/test_create_task.py @@ -6,10 +6,10 @@ import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -from app.services.async_llm_service import AsyncLLMService +from app.services.async_meeting_service import AsyncMeetingService # 创建服务实例 -service = AsyncLLMService() +service = AsyncMeetingService() # 创建测试任务 meeting_id = 38 diff --git a/test/test_worker_thread.py b/test/test_worker_thread.py index 3c0c8c2..0e3b094 100644 --- a/test/test_worker_thread.py +++ b/test/test_worker_thread.py @@ -8,10 +8,10 @@ import time import threading sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -from app.services.async_llm_service import AsyncLLMService +from app.services.async_meeting_service import AsyncMeetingService # 创建服务实例 -service = AsyncLLMService() +service = AsyncMeetingService() # 直接调用处理任务方法测试 print("测试直接调用_process_tasks方法...")