整理了会议和知识库的代码结构
parent
5bdab4a405
commit
976ea854b6
|
|
@ -5,7 +5,7 @@ from app.core.config import BASE_DIR, AUDIO_DIR, MARKDOWN_DIR, ALLOWED_EXTENSION
|
|||
import app.core.config as config_module
|
||||
from app.services.llm_service import LLMService
|
||||
from app.services.async_transcription_service import AsyncTranscriptionService
|
||||
from app.services.async_llm_service import async_llm_service
|
||||
from app.services.async_meeting_service import async_meeting_service
|
||||
from app.core.auth import get_current_user
|
||||
from app.core.response import create_api_response
|
||||
from typing import List, Optional
|
||||
|
|
@ -449,8 +449,8 @@ def generate_meeting_summary_async(meeting_id: int, request: GenerateSummaryRequ
|
|||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
return create_api_response(code="404", message="Meeting not found")
|
||||
task_id = async_llm_service.start_summary_generation(meeting_id, request.user_prompt)
|
||||
background_tasks.add_task(async_llm_service._process_task, task_id)
|
||||
task_id = async_meeting_service.start_summary_generation(meeting_id, request.user_prompt)
|
||||
background_tasks.add_task(async_meeting_service._process_task, task_id)
|
||||
return create_api_response(code="200", message="Summary generation task has been accepted.", data={
|
||||
"task_id": task_id, "status": "pending", "meeting_id": meeting_id
|
||||
})
|
||||
|
|
@ -465,7 +465,7 @@ def get_meeting_llm_tasks(meeting_id: int, current_user: dict = Depends(get_curr
|
|||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
return create_api_response(code="404", message="Meeting not found")
|
||||
tasks = async_llm_service.get_meeting_llm_tasks(meeting_id)
|
||||
tasks = async_meeting_service.get_meeting_llm_tasks(meeting_id)
|
||||
return create_api_response(code="200", message="LLM tasks retrieved successfully", data={
|
||||
"tasks": tasks, "total": len(tasks)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.auth import get_current_admin_user
|
||||
from app.core.auth import get_current_user
|
||||
from app.core.database import get_db_connection
|
||||
from app.core.response import create_api_response
|
||||
|
||||
|
|
@ -23,14 +23,14 @@ class PromptListResponse(BaseModel):
|
|||
total: int
|
||||
|
||||
@router.post("/prompts")
|
||||
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)):
|
||||
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
||||
"""Create a new prompt."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
try:
|
||||
cursor.execute(
|
||||
"INSERT INTO prompts (name, tags, content) VALUES (%s, %s, %s)",
|
||||
(prompt.name, prompt.tags, prompt.content)
|
||||
"INSERT INTO prompts (name, tags, content, creator_id) VALUES (%s, %s, %s, %s)",
|
||||
(prompt.name, prompt.tags, prompt.content, current_user["user_id"])
|
||||
)
|
||||
connection.commit()
|
||||
new_id = cursor.lastrowid
|
||||
|
|
@ -41,23 +41,27 @@ def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_adm
|
|||
return create_api_response(code="500", message=f"创建提示词失败: {e}")
|
||||
|
||||
@router.get("/prompts")
|
||||
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_admin_user)):
|
||||
"""Get a paginated list of prompts."""
|
||||
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_user)):
|
||||
"""Get a paginated list of prompts filtered by current user."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute("SELECT COUNT(*) as total FROM prompts")
|
||||
# 只获取当前用户创建的提示词
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) as total FROM prompts WHERE creator_id = %s",
|
||||
(current_user["user_id"],)
|
||||
)
|
||||
total = cursor.fetchone()['total']
|
||||
|
||||
offset = (page - 1) * size
|
||||
cursor.execute(
|
||||
"SELECT id, name, tags, content, created_at FROM prompts ORDER BY created_at DESC LIMIT %s OFFSET %s",
|
||||
(size, offset)
|
||||
"SELECT id, name, tags, content, created_at FROM prompts WHERE creator_id = %s ORDER BY created_at DESC LIMIT %s OFFSET %s",
|
||||
(current_user["user_id"], size, offset)
|
||||
)
|
||||
prompts = cursor.fetchall()
|
||||
return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total})
|
||||
|
||||
@router.get("/prompts/{prompt_id}")
|
||||
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)):
|
||||
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""Get a single prompt by its ID."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
|
@ -68,7 +72,7 @@ def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_us
|
|||
return create_api_response(code="200", message="获取提示词成功", data=prompt)
|
||||
|
||||
@router.put("/prompts/{prompt_id}")
|
||||
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_admin_user)):
|
||||
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
||||
"""Update an existing prompt."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
|
@ -87,12 +91,24 @@ def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends
|
|||
return create_api_response(code="500", message=f"更新提示词失败: {e}")
|
||||
|
||||
@router.delete("/prompts/{prompt_id}")
|
||||
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_admin_user)):
|
||||
"""Delete a prompt."""
|
||||
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""Delete a prompt. Only the creator can delete their own prompts."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
|
||||
if cursor.rowcount == 0:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
# 首先检查提示词是否存在以及是否属于当前用户
|
||||
cursor.execute(
|
||||
"SELECT creator_id FROM prompts WHERE id = %s",
|
||||
(prompt_id,)
|
||||
)
|
||||
prompt = cursor.fetchone()
|
||||
|
||||
if not prompt:
|
||||
return create_api_response(code="404", message="提示词不存在")
|
||||
|
||||
if prompt['creator_id'] != current_user["user_id"]:
|
||||
return create_api_response(code="403", message="无权删除其他用户的提示词")
|
||||
|
||||
# 删除提示词
|
||||
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
|
||||
connection.commit()
|
||||
return create_api_response(code="200", message="提示词删除成功")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from app.core.database import get_db_connection
|
||||
from app.core.response import create_api_response
|
||||
from app.core.auth import get_current_user
|
||||
from app.models.models import Tag
|
||||
from typing import List
|
||||
import mysql.connector
|
||||
|
|
@ -24,16 +25,16 @@ def get_all_tags():
|
|||
return create_api_response(code="500", message="获取标签失败")
|
||||
|
||||
@router.post("/tags/")
|
||||
def create_tag(tag_in: Tag):
|
||||
def create_tag(tag_in: Tag, current_user: dict = Depends(get_current_user)):
|
||||
"""_summary_
|
||||
创建一个新标签
|
||||
创建一个新标签,并记录创建者
|
||||
"""
|
||||
query = "INSERT INTO tags (name, color) VALUES (%s, %s)"
|
||||
query = "INSERT INTO tags (name, color, creator_id) VALUES (%s, %s, %s)"
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
with connection.cursor(dictionary=True) as cursor:
|
||||
try:
|
||||
cursor.execute(query, (tag_in.name, tag_in.color))
|
||||
cursor.execute(query, (tag_in.name, tag_in.color, current_user["user_id"]))
|
||||
connection.commit()
|
||||
tag_id = cursor.lastrowid
|
||||
new_tag = {"id": tag_id, "name": tag_in.name, "color": tag_in.color}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
|
|||
from app.core.auth import get_current_user
|
||||
from app.core.response import create_api_response
|
||||
from app.services.async_transcription_service import AsyncTranscriptionService
|
||||
from app.services.async_llm_service import async_llm_service
|
||||
from app.services.async_meeting_service import async_meeting_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ def get_transcription_task_status(task_id: str, current_user: dict = Depends(get
|
|||
def get_llm_task_status(task_id: str, current_user: dict = Depends(get_current_user)):
|
||||
"""获取LLM总结任务状态(包括进度)"""
|
||||
try:
|
||||
status = async_llm_service.get_task_status(task_id)
|
||||
status = async_meeting_service.get_task_status(task_id)
|
||||
if status.get('status') == 'not_found':
|
||||
return create_api_response(code="404", message="Task not found")
|
||||
return create_api_response(code="200", message="Task status retrieved", data=status)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
"""
|
||||
异步知识库服务 - 处理知识库生成的异步任务
|
||||
采用FastAPI BackgroundTasks模式
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
|
@ -7,6 +11,7 @@ from app.core.database import get_db_connection
|
|||
from app.services.llm_service import LLMService
|
||||
|
||||
class AsyncKnowledgeBaseService:
|
||||
"""异步知识库服务类 - 处理知识库相关的异步任务"""
|
||||
|
||||
def __init__(self):
|
||||
from app.core.config import REDIS_CONFIG
|
||||
|
|
@ -16,12 +21,25 @@ class AsyncKnowledgeBaseService:
|
|||
self.llm_service = LLMService()
|
||||
|
||||
def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str:
|
||||
"""
|
||||
创建异步知识库生成任务
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
kb_id: 知识库ID
|
||||
user_prompt: 用户提示词
|
||||
source_meeting_ids: 源会议ID列表
|
||||
cursor: 数据库游标(可选)
|
||||
|
||||
Returns:
|
||||
str: 任务ID
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
# If a cursor is passed, use it directly to avoid creating a new transaction
|
||||
if cursor:
|
||||
query = """
|
||||
INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at)
|
||||
INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at)
|
||||
VALUES (%s, %s, %s, %s, NOW())
|
||||
"""
|
||||
cursor.execute(query, (task_id, user_id, kb_id, user_prompt))
|
||||
|
|
@ -42,13 +60,17 @@ class AsyncKnowledgeBaseService:
|
|||
}
|
||||
self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data)
|
||||
self.redis_client.expire(f"kb_task:{task_id}", 86400)
|
||||
|
||||
|
||||
print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}")
|
||||
return task_id
|
||||
|
||||
def _process_task(self, task_id: str):
|
||||
"""
|
||||
处理单个异步任务的函数,设计为由BackgroundTasks调用。
|
||||
"""
|
||||
print(f"Background task started for knowledge base task: {task_id}")
|
||||
try:
|
||||
# 从Redis获取任务数据
|
||||
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
|
||||
if not task_data:
|
||||
print(f"Error: Task {task_id} not found in Redis for processing.")
|
||||
|
|
@ -57,99 +79,136 @@ class AsyncKnowledgeBaseService:
|
|||
kb_id = int(task_data['kb_id'])
|
||||
user_prompt = task_data.get('user_prompt', '')
|
||||
|
||||
# 1. 更新状态为processing
|
||||
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
|
||||
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
# 2. 获取关联的会议总结
|
||||
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
|
||||
source_text = self._get_meeting_summaries(kb_id)
|
||||
|
||||
# Get source meeting summaries
|
||||
source_text = ""
|
||||
cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
|
||||
kb_info = cursor.fetchone()
|
||||
if kb_info and kb_info['source_meeting_ids']:
|
||||
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
|
||||
meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()]
|
||||
if meeting_ids:
|
||||
summaries = []
|
||||
for meeting_id in meeting_ids:
|
||||
cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
summary = cursor.fetchone()
|
||||
if summary and summary['summary']:
|
||||
summaries.append(summary['summary'])
|
||||
source_text = "\n\n---\n\n".join(summaries)
|
||||
# 3. 构建提示词
|
||||
self._update_task_status_in_redis(task_id, 'processing', 30, message="准备AI提示词...")
|
||||
full_prompt = self._build_prompt(source_text, user_prompt)
|
||||
|
||||
# Get system prompt
|
||||
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取知识库生成模版...")
|
||||
system_prompt = self._get_knowledge_task_prompt(cursor)
|
||||
|
||||
# Build final prompt
|
||||
final_prompt = f"{system_prompt}\n\n"
|
||||
if source_text:
|
||||
final_prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
|
||||
final_prompt += f"用户要求:{user_prompt}"
|
||||
# 4. 调用LLM API
|
||||
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
|
||||
generated_content = self.llm_service._call_llm_api(full_prompt)
|
||||
if not generated_content:
|
||||
raise Exception("LLM API调用失败或返回空内容")
|
||||
|
||||
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
|
||||
generated_content = self.llm_service._call_llm_api(final_prompt)
|
||||
# 5. 保存结果到数据库
|
||||
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
|
||||
self._save_result_to_db(kb_id, generated_content)
|
||||
|
||||
if not generated_content:
|
||||
raise Exception("LLM API call failed or returned empty content")
|
||||
# 6. 任务完成
|
||||
self._update_task_in_db(task_id, 'completed', 100)
|
||||
self._update_task_status_in_redis(task_id, 'completed', 100)
|
||||
|
||||
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
|
||||
self._save_result_to_db(kb_id, generated_content, cursor)
|
||||
|
||||
self._update_task_in_db(task_id, 'completed', 100, cursor=cursor)
|
||||
self._update_task_status_in_redis(task_id, 'completed', 100)
|
||||
|
||||
connection.commit()
|
||||
print(f"Task {task_id} completed successfully")
|
||||
print(f"Task {task_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"Task {task_id} failed: {error_msg}")
|
||||
# Use a new connection for error logging to avoid issues with a potentially broken transaction
|
||||
with get_db_connection() as err_conn:
|
||||
err_cursor = err_conn.cursor()
|
||||
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg, cursor=err_cursor)
|
||||
err_conn.commit()
|
||||
# 更新失败状态
|
||||
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
|
||||
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
|
||||
|
||||
def _get_knowledge_task_prompt(self, cursor) -> str:
|
||||
query = """
|
||||
SELECT p.content
|
||||
FROM prompt_config pc
|
||||
JOIN prompts p ON pc.prompt_id = p.id
|
||||
WHERE pc.task_name = 'KNOWLEDGE_TASK'
|
||||
# --- 知识库相关方法 ---
|
||||
|
||||
def _get_meeting_summaries(self, kb_id: int) -> str:
|
||||
"""
|
||||
cursor.execute(query)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result['content']
|
||||
else:
|
||||
# Fallback prompt
|
||||
return "Please generate a knowledge base article based on the provided information."
|
||||
从数据库获取知识库关联的会议总结
|
||||
|
||||
Args:
|
||||
kb_id: 知识库ID
|
||||
|
||||
Returns:
|
||||
str: 拼接后的会议总结文本
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
def _save_result_to_db(self, kb_id: int, content: str, cursor):
|
||||
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s"
|
||||
cursor.execute(query, (content, kb_id))
|
||||
# 获取知识库的源会议ID列表
|
||||
cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
|
||||
kb_info = cursor.fetchone()
|
||||
|
||||
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None, cursor=None):
|
||||
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
|
||||
cursor.execute(query, (status, progress, error_message, status, task_id))
|
||||
if not kb_info or not kb_info['source_meeting_ids']:
|
||||
return ""
|
||||
|
||||
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
|
||||
update_data = {
|
||||
'status': status,
|
||||
'progress': str(progress),
|
||||
'updated_at': datetime.now().isoformat()
|
||||
}
|
||||
if message: update_data['message'] = message
|
||||
if error_message: update_data['error_message'] = error_message
|
||||
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
|
||||
# 解析会议ID列表
|
||||
meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()]
|
||||
if not meeting_ids:
|
||||
return ""
|
||||
|
||||
# 获取所有会议的总结
|
||||
summaries = []
|
||||
for meeting_id in meeting_ids:
|
||||
cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
summary = cursor.fetchone()
|
||||
if summary and summary['summary']:
|
||||
summaries.append(summary['summary'])
|
||||
|
||||
# 用分隔符拼接多个会议总结
|
||||
return "\n\n---\n\n".join(summaries)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取会议总结错误: {e}")
|
||||
return ""
|
||||
|
||||
def _build_prompt(self, source_text: str, user_prompt: str) -> str:
|
||||
"""
|
||||
构建完整的提示词
|
||||
使用数据库中配置的KNOWLEDGE_TASK提示词模板
|
||||
|
||||
Args:
|
||||
source_text: 源会议总结文本
|
||||
user_prompt: 用户自定义提示词
|
||||
|
||||
Returns:
|
||||
str: 完整的提示词
|
||||
"""
|
||||
# 从数据库获取知识库任务的提示词模板
|
||||
system_prompt = self.llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||||
|
||||
prompt = f"{system_prompt}\n\n"
|
||||
|
||||
if source_text:
|
||||
prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
|
||||
|
||||
prompt += f"用户要求:{user_prompt}"
|
||||
|
||||
return prompt
|
||||
|
||||
def _save_result_to_db(self, kb_id: int, content: str) -> Optional[int]:
|
||||
"""
|
||||
保存生成结果到数据库
|
||||
|
||||
Args:
|
||||
kb_id: 知识库ID
|
||||
content: 生成的内容
|
||||
|
||||
Returns:
|
||||
Optional[int]: 知识库ID,失败返回None
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s"
|
||||
cursor.execute(query, (content, kb_id))
|
||||
connection.commit()
|
||||
|
||||
print(f"成功保存知识库内容,kb_id: {kb_id}")
|
||||
return kb_id
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存知识库内容错误: {e}")
|
||||
return None
|
||||
|
||||
# --- 状态查询和数据库操作方法 ---
|
||||
|
||||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""获取任务状态 - 与 async_llm_service 保持一致"""
|
||||
"""获取任务状态"""
|
||||
try:
|
||||
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
|
||||
if not task_data:
|
||||
|
|
@ -170,6 +229,20 @@ class AsyncKnowledgeBaseService:
|
|||
print(f"Error getting task status: {e}")
|
||||
return {'task_id': task_id, 'status': 'error', 'error_message': str(e)}
|
||||
|
||||
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
|
||||
"""更新Redis中的任务状态"""
|
||||
try:
|
||||
update_data = {
|
||||
'status': status,
|
||||
'progress': str(progress),
|
||||
'updated_at': datetime.now().isoformat()
|
||||
}
|
||||
if message: update_data['message'] = message
|
||||
if error_message: update_data['error_message'] = error_message
|
||||
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
|
||||
except Exception as e:
|
||||
print(f"Error updating task status in Redis: {e}")
|
||||
|
||||
def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str):
|
||||
"""保存任务到数据库"""
|
||||
try:
|
||||
|
|
@ -182,6 +255,17 @@ class AsyncKnowledgeBaseService:
|
|||
print(f"Error saving task to database: {e}")
|
||||
raise
|
||||
|
||||
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None):
|
||||
"""更新数据库中的任务状态"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
|
||||
cursor.execute(query, (status, progress, error_message, status, task_id))
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
print(f"Error updating task in database: {e}")
|
||||
|
||||
def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]:
|
||||
"""从数据库获取任务信息"""
|
||||
try:
|
||||
|
|
@ -198,4 +282,5 @@ class AsyncKnowledgeBaseService:
|
|||
print(f"Error getting task from database: {e}")
|
||||
return None
|
||||
|
||||
# 创建全局实例
|
||||
async_kb_service = AsyncKnowledgeBaseService()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
异步LLM服务 - 处理会议总结生成的异步任务
|
||||
异步会议服务 - 处理会议总结生成的异步任务
|
||||
采用FastAPI BackgroundTasks模式
|
||||
"""
|
||||
import uuid
|
||||
|
|
@ -12,30 +12,30 @@ from app.core.config import REDIS_CONFIG
|
|||
from app.core.database import get_db_connection
|
||||
from app.services.llm_service import LLMService
|
||||
|
||||
class AsyncLLMService:
|
||||
"""异步LLM服务类 - 采用FastAPI BackgroundTasks模式"""
|
||||
|
||||
class AsyncMeetingService:
|
||||
"""异步会议服务类 - 处理会议相关的异步任务"""
|
||||
|
||||
def __init__(self):
|
||||
# 确保redis客户端自动解码响应,代码更简洁
|
||||
if 'decode_responses' not in REDIS_CONFIG:
|
||||
REDIS_CONFIG['decode_responses'] = True
|
||||
self.redis_client = redis.Redis(**REDIS_CONFIG)
|
||||
self.llm_service = LLMService() # 复用现有的同步LLM服务
|
||||
|
||||
|
||||
def start_summary_generation(self, meeting_id: int, user_prompt: str = "") -> str:
|
||||
"""
|
||||
创建异步总结任务,任务的执行将由外部(如API层的BackgroundTasks)触发。
|
||||
|
||||
|
||||
Args:
|
||||
meeting_id: 会议ID
|
||||
user_prompt: 用户额外提示词
|
||||
|
||||
|
||||
Returns:
|
||||
str: 任务ID
|
||||
"""
|
||||
try:
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
# 在数据库中创建任务记录
|
||||
self._save_task_to_db(task_id, meeting_id, user_prompt)
|
||||
|
||||
|
|
@ -52,10 +52,10 @@ class AsyncLLMService:
|
|||
}
|
||||
self.redis_client.hset(f"llm_task:{task_id}", mapping=task_data)
|
||||
self.redis_client.expire(f"llm_task:{task_id}", 86400)
|
||||
|
||||
print(f"LLM summary task created: {task_id} for meeting: {meeting_id}")
|
||||
|
||||
print(f"Meeting summary task created: {task_id} for meeting: {meeting_id}")
|
||||
return task_id
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error starting summary generation: {e}")
|
||||
raise e
|
||||
|
|
@ -64,7 +64,7 @@ class AsyncLLMService:
|
|||
"""
|
||||
处理单个异步任务的函数,设计为由BackgroundTasks调用。
|
||||
"""
|
||||
print(f"Background task started for LLM task: {task_id}")
|
||||
print(f"Background task started for meeting summary task: {task_id}")
|
||||
try:
|
||||
# 从Redis获取任务数据
|
||||
task_data = self.redis_client.hgetall(f"llm_task:{task_id}")
|
||||
|
|
@ -80,13 +80,13 @@ class AsyncLLMService:
|
|||
|
||||
# 2. 获取会议转录内容
|
||||
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取会议转录内容...")
|
||||
transcript_text = self.llm_service._get_meeting_transcript(meeting_id)
|
||||
transcript_text = self._get_meeting_transcript(meeting_id)
|
||||
if not transcript_text:
|
||||
raise Exception("无法获取会议转录内容")
|
||||
|
||||
# 3. 构建提示词
|
||||
self._update_task_status_in_redis(task_id, 'processing', 40, message="准备AI提示词...")
|
||||
full_prompt = self.llm_service._build_prompt(transcript_text, user_prompt)
|
||||
full_prompt = self._build_prompt(transcript_text, user_prompt)
|
||||
|
||||
# 4. 调用LLM API
|
||||
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在分析会议内容...")
|
||||
|
|
@ -96,7 +96,7 @@ class AsyncLLMService:
|
|||
|
||||
# 5. 保存结果到主表
|
||||
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存总结结果...")
|
||||
self.llm_service._save_summary_to_db(meeting_id, summary_content, user_prompt)
|
||||
self._save_summary_to_db(meeting_id, summary_content, user_prompt)
|
||||
|
||||
# 6. 任务完成
|
||||
self._update_task_in_db(task_id, 'completed', 100, result=summary_content)
|
||||
|
|
@ -110,6 +110,78 @@ class AsyncLLMService:
|
|||
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
|
||||
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
|
||||
|
||||
# --- 会议相关方法 ---
|
||||
|
||||
def _get_meeting_transcript(self, meeting_id: int) -> str:
|
||||
"""从数据库获取会议转录内容"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = """
|
||||
SELECT speaker_tag, start_time_ms, end_time_ms, text_content
|
||||
FROM transcript_segments
|
||||
WHERE meeting_id = %s
|
||||
ORDER BY start_time_ms
|
||||
"""
|
||||
cursor.execute(query, (meeting_id,))
|
||||
segments = cursor.fetchall()
|
||||
|
||||
if not segments:
|
||||
return ""
|
||||
|
||||
# 组装转录文本
|
||||
transcript_lines = []
|
||||
for speaker_tag, start_time, end_time, text in segments:
|
||||
# 将毫秒转换为分:秒格式
|
||||
start_min = start_time // 60000
|
||||
start_sec = (start_time % 60000) // 1000
|
||||
transcript_lines.append(f"[{start_min:02d}:{start_sec:02d}] 说话人{speaker_tag}: {text}")
|
||||
|
||||
return "\n".join(transcript_lines)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取会议转录内容错误: {e}")
|
||||
return ""
|
||||
|
||||
def _build_prompt(self, transcript_text: str, user_prompt: str) -> str:
|
||||
"""
|
||||
构建完整的提示词
|
||||
使用数据库中配置的MEETING_TASK提示词模板
|
||||
"""
|
||||
# 从数据库获取会议任务的提示词模板
|
||||
system_prompt = self.llm_service.get_task_prompt('MEETING_TASK')
|
||||
|
||||
prompt = f"{system_prompt}\n\n"
|
||||
|
||||
if user_prompt:
|
||||
prompt += f"用户额外要求:{user_prompt}\n\n"
|
||||
|
||||
prompt += f"会议转录内容:\n{transcript_text}\n\n请根据以上内容生成会议总结:"
|
||||
|
||||
return prompt
|
||||
|
||||
def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]:
|
||||
"""保存总结到数据库 - 更新meetings表的summary、user_prompt和updated_at字段"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
|
||||
# 更新meetings表的summary、user_prompt和updated_at字段
|
||||
update_query = """
|
||||
UPDATE meetings
|
||||
SET summary = %s, user_prompt = %s, updated_at = NOW()
|
||||
WHERE meeting_id = %s
|
||||
"""
|
||||
cursor.execute(update_query, (summary_content, user_prompt, meeting_id))
|
||||
connection.commit()
|
||||
|
||||
print(f"成功保存会议总结到meetings表,meeting_id: {meeting_id}")
|
||||
return meeting_id
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存总结到数据库错误: {e}")
|
||||
return None
|
||||
|
||||
# --- 状态查询和数据库操作方法 ---
|
||||
|
||||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
|
|
@ -120,7 +192,7 @@ class AsyncLLMService:
|
|||
task_data = self._get_task_from_db(task_id)
|
||||
if not task_data:
|
||||
return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'}
|
||||
|
||||
|
||||
return {
|
||||
'task_id': task_id,
|
||||
'status': task_data.get('status', 'unknown'),
|
||||
|
|
@ -189,7 +261,7 @@ class AsyncLLMService:
|
|||
params.insert(2, result)
|
||||
else:
|
||||
query = "UPDATE llm_tasks SET status = %s, progress = %s, error_message = %s WHERE task_id = %s"
|
||||
|
||||
|
||||
cursor.execute(query, tuple(params))
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
|
|
@ -212,4 +284,4 @@ class AsyncLLMService:
|
|||
return None
|
||||
|
||||
# 创建全局实例
|
||||
async_llm_service = AsyncLLMService()
|
||||
async_meeting_service = AsyncMeetingService()
|
||||
|
|
@ -7,6 +7,8 @@ from app.core.database import get_db_connection
|
|||
|
||||
|
||||
class LLMService:
|
||||
"""LLM服务 - 专注于大模型API调用和提示词管理"""
|
||||
|
||||
def __init__(self):
|
||||
# 设置dashscope API key
|
||||
dashscope.api_key = config_module.QWEN_API_KEY
|
||||
|
|
@ -35,125 +37,49 @@ class LLMService:
|
|||
def top_p(self):
|
||||
"""动态获取top_p"""
|
||||
return config_module.LLM_CONFIG["top_p"]
|
||||
|
||||
def generate_meeting_summary_stream(self, meeting_id: int, user_prompt: str = "") -> Generator[str, None, None]:
|
||||
|
||||
def get_task_prompt(self, task_name: str, cursor=None) -> str:
|
||||
"""
|
||||
流式生成会议总结
|
||||
统一的提示词获取方法
|
||||
|
||||
Args:
|
||||
meeting_id: 会议ID
|
||||
user_prompt: 用户额外提示词
|
||||
|
||||
Yields:
|
||||
str: 流式输出的内容片段
|
||||
"""
|
||||
try:
|
||||
# 获取会议转录内容
|
||||
transcript_text = self._get_meeting_transcript(meeting_id)
|
||||
if not transcript_text:
|
||||
yield "error: 无法获取会议转录内容"
|
||||
return
|
||||
|
||||
# 构建完整提示词
|
||||
full_prompt = self._build_prompt(transcript_text, user_prompt)
|
||||
|
||||
# 调用大模型API进行流式生成
|
||||
full_content = ""
|
||||
for chunk in self._call_llm_api_stream(full_prompt):
|
||||
if chunk.startswith("error:"):
|
||||
yield chunk
|
||||
return
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
|
||||
# 保存完整总结到数据库
|
||||
if full_content:
|
||||
self._save_summary_to_db(meeting_id, full_content, user_prompt)
|
||||
|
||||
except Exception as e:
|
||||
print(f"流式生成会议总结错误: {e}")
|
||||
yield f"error: {str(e)}"
|
||||
|
||||
def generate_meeting_summary(self, meeting_id: int, user_prompt: str = "") -> Optional[Dict]:
|
||||
"""
|
||||
生成会议总结(非流式,保持向后兼容)
|
||||
|
||||
Args:
|
||||
meeting_id: 会议ID
|
||||
user_prompt: 用户额外提示词
|
||||
task_name: 任务名称,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等
|
||||
cursor: 数据库游标,如果传入则使用,否则创建新连接
|
||||
|
||||
Returns:
|
||||
包含总结内容的字典,如果失败返回None
|
||||
str: 提示词内容,如果未找到返回默认提示词
|
||||
"""
|
||||
query = """
|
||||
SELECT p.content
|
||||
FROM prompt_config pc
|
||||
JOIN prompts p ON pc.prompt_id = p.id
|
||||
WHERE pc.task_name = %s
|
||||
"""
|
||||
try:
|
||||
# 获取会议转录内容
|
||||
transcript_text = self._get_meeting_transcript(meeting_id)
|
||||
if not transcript_text:
|
||||
return {"error": "无法获取会议转录内容"}
|
||||
|
||||
# 构建完整提示词
|
||||
full_prompt = self._build_prompt(transcript_text, user_prompt)
|
||||
|
||||
# 调用大模型API
|
||||
response = self._call_llm_api(full_prompt)
|
||||
|
||||
if response:
|
||||
# 保存总结到数据库
|
||||
summary_id = self._save_summary_to_db(meeting_id, response, user_prompt)
|
||||
return {
|
||||
"summary_id": summary_id,
|
||||
"content": response,
|
||||
"meeting_id": meeting_id
|
||||
}
|
||||
else:
|
||||
return {"error": "大模型API调用失败"}
|
||||
|
||||
except Exception as e:
|
||||
print(f"生成会议总结错误: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _get_meeting_transcript(self, meeting_id: int) -> str:
|
||||
"""从数据库获取会议转录内容"""
|
||||
try:
|
||||
if cursor:
|
||||
cursor.execute(query, (task_name,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result['content'] if isinstance(result, dict) else result[0]
|
||||
else:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = """
|
||||
SELECT speaker_tag, start_time_ms, end_time_ms, text_content
|
||||
FROM transcript_segments
|
||||
WHERE meeting_id = %s
|
||||
ORDER BY start_time_ms
|
||||
"""
|
||||
cursor.execute(query, (meeting_id,))
|
||||
segments = cursor.fetchall()
|
||||
|
||||
if not segments:
|
||||
return ""
|
||||
|
||||
# 组装转录文本
|
||||
transcript_lines = []
|
||||
for speaker_tag, start_time, end_time, text in segments:
|
||||
# 将毫秒转换为分:秒格式
|
||||
start_min = start_time // 60000
|
||||
start_sec = (start_time % 60000) // 1000
|
||||
transcript_lines.append(f"[{start_min:02d}:{start_sec:02d}] 说话人{speaker_tag}: {text}")
|
||||
|
||||
return "\n".join(transcript_lines)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取会议转录内容错误: {e}")
|
||||
return ""
|
||||
|
||||
def _build_prompt(self, transcript_text: str, user_prompt: str) -> str:
|
||||
"""构建完整的提示词"""
|
||||
prompt = f"{self.system_prompt}\n\n"
|
||||
|
||||
if user_prompt:
|
||||
prompt += f"用户额外要求:{user_prompt}\n\n"
|
||||
|
||||
prompt += f"会议转录内容:\n{transcript_text}\n\n请根据以上内容生成会议总结:"
|
||||
|
||||
return prompt
|
||||
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute(query, (task_name,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result['content']
|
||||
|
||||
# 返回默认提示词
|
||||
return self._get_default_prompt(task_name)
|
||||
|
||||
def _get_default_prompt(self, task_name: str) -> str:
|
||||
"""获取默认提示词"""
|
||||
default_prompts = {
|
||||
'MEETING_TASK': self.system_prompt, # 使用配置文件中的系统提示词
|
||||
'KNOWLEDGE_TASK': "请根据提供的信息生成知识库文章。",
|
||||
}
|
||||
return default_prompts.get(task_name, "请根据提供的内容进行总结和分析。")
|
||||
|
||||
def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]:
|
||||
"""流式调用阿里Qwen3大模型API"""
|
||||
try:
|
||||
|
|
@ -185,7 +111,7 @@ class LLMService:
|
|||
yield f"error: {error_msg}"
|
||||
|
||||
def _call_llm_api(self, prompt: str) -> Optional[str]:
|
||||
"""调用阿里Qwen3大模型API(非流式,保持向后兼容)"""
|
||||
"""调用阿里Qwen3大模型API(非流式)"""
|
||||
try:
|
||||
response = dashscope.Generation.call(
|
||||
model=self.model_name,
|
||||
|
|
@ -204,96 +130,18 @@ class LLMService:
|
|||
except Exception as e:
|
||||
print(f"调用大模型API错误: {e}")
|
||||
return None
|
||||
|
||||
def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]:
|
||||
"""保存总结到数据库 - 更新meetings表的summary字段"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
|
||||
# 更新meetings表的summary字段
|
||||
update_query = """
|
||||
UPDATE meetings
|
||||
SET summary = %s
|
||||
WHERE meeting_id = %s
|
||||
"""
|
||||
cursor.execute(update_query, (summary_content, meeting_id))
|
||||
connection.commit()
|
||||
|
||||
print(f"成功保存会议总结到meetings表,meeting_id: {meeting_id}")
|
||||
return meeting_id
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存总结到数据库错误: {e}")
|
||||
return None
|
||||
|
||||
def get_meeting_summaries(self, meeting_id: int) -> List[Dict]:
|
||||
"""获取会议的当前总结 - 从meetings表的summary字段获取"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = """
|
||||
SELECT summary
|
||||
FROM meetings
|
||||
WHERE meeting_id = %s
|
||||
"""
|
||||
cursor.execute(query, (meeting_id,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
# 如果有总结内容,返回一个包含当前总结的列表格式(保持API一致性)
|
||||
if result and result[0]:
|
||||
return [{
|
||||
"id": meeting_id,
|
||||
"content": result[0],
|
||||
"user_prompt": "", # meetings表中没有user_prompt字段
|
||||
"created_at": None # meetings表中没有单独的总结创建时间
|
||||
}]
|
||||
else:
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取会议总结错误: {e}")
|
||||
return []
|
||||
|
||||
def get_current_meeting_summary(self, meeting_id: int) -> Optional[str]:
|
||||
"""获取会议当前的总结内容 - 从meetings表的summary字段获取"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = """
|
||||
SELECT summary
|
||||
FROM meetings
|
||||
WHERE meeting_id = %s
|
||||
"""
|
||||
cursor.execute(query, (meeting_id,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
return result[0] if result and result[0] else None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取会议当前总结错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == '__main__':
|
||||
# 测试LLM服务
|
||||
test_meeting_id = 38
|
||||
test_user_prompt = "请重点关注决策事项和待办任务"
|
||||
|
||||
print("--- 运行LLM服务测试 ---")
|
||||
llm_service = LLMService()
|
||||
|
||||
# 生成总结
|
||||
result = llm_service.generate_meeting_summary(test_meeting_id, test_user_prompt)
|
||||
if result.get("error"):
|
||||
print(f"生成总结失败: {result['error']}")
|
||||
else:
|
||||
print(f"总结生成成功,ID: {result.get('summary_id')}")
|
||||
print(f"总结内容: {result.get('content')[:200]}...")
|
||||
|
||||
# 获取历史总结
|
||||
summaries = llm_service.get_meeting_summaries(test_meeting_id)
|
||||
print(f"获取到 {len(summaries)} 个历史总结")
|
||||
|
||||
print("--- LLM服务测试完成 ---")
|
||||
|
||||
# 测试获取任务提示词
|
||||
meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
|
||||
print(f"会议任务提示词: {meeting_prompt[:100]}...")
|
||||
|
||||
knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||||
print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
|
||||
|
||||
print("--- LLM服务测试完成 ---")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
-- 为meetings表添加updated_at和user_prompt字段
|
||||
-- 执行日期: 2025-10-28
|
||||
|
||||
-- 添加updated_at字段
|
||||
ALTER TABLE meetings
|
||||
ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
AFTER created_at;
|
||||
|
||||
-- 添加user_prompt字段
|
||||
ALTER TABLE meetings
|
||||
ADD COLUMN user_prompt TEXT
|
||||
AFTER summary;
|
||||
|
||||
-- 为现有记录设置updated_at为created_at的值
|
||||
UPDATE meetings
|
||||
SET updated_at = created_at
|
||||
WHERE updated_at IS NULL;
|
||||
|
|
@ -6,10 +6,10 @@ import sys
|
|||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.services.async_llm_service import AsyncLLMService
|
||||
from app.services.async_meeting_service import AsyncMeetingService
|
||||
|
||||
# 创建服务实例
|
||||
service = AsyncLLMService()
|
||||
service = AsyncMeetingService()
|
||||
|
||||
# 创建测试任务
|
||||
meeting_id = 38
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ import time
|
|||
import threading
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.services.async_llm_service import AsyncLLMService
|
||||
from app.services.async_meeting_service import AsyncMeetingService
|
||||
|
||||
# 创建服务实例
|
||||
service = AsyncLLMService()
|
||||
service = AsyncMeetingService()
|
||||
|
||||
# 直接调用处理任务方法测试
|
||||
print("测试直接调用_process_tasks方法...")
|
||||
|
|
|
|||
Loading…
Reference in New Issue