152 lines
5.4 KiB
Python
152 lines
5.4 KiB
Python
from sqlalchemy.orm import Session
|
|
from app.models import Meeting, SummarizeTask, PromptTemplate, AIModel, TranscriptSegment
|
|
from app.services.llm_service import LLMService
|
|
import uuid
|
|
import json
|
|
import logging
|
|
import math
|
|
from datetime import datetime
|
|
from app.core.redis import redis_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MeetingService:
|
|
@staticmethod
|
|
async def create_summarize_task(
|
|
db: Session,
|
|
meeting_id: int,
|
|
prompt_id: int,
|
|
model_id: int,
|
|
extra_prompt: str = ""
|
|
):
|
|
# 1. 基础数据校验
|
|
meeting = db.query(Meeting).filter(Meeting.meeting_id == meeting_id).first()
|
|
if not meeting:
|
|
raise Exception("Meeting not found")
|
|
|
|
# 2. 格式化会议转译内容 (作为 user_prompt 素材)
|
|
segments = db.query(TranscriptSegment).filter(
|
|
TranscriptSegment.meeting_id == meeting_id
|
|
).order_by(TranscriptSegment.start_time_ms).all()
|
|
|
|
formatted_content = []
|
|
for s in segments:
|
|
secs = int(s.start_time_ms // 1000)
|
|
m, sc = divmod(secs, 60)
|
|
timestamp = f"[{m:02d}:{sc:02d}]"
|
|
speaker = s.speaker_tag or f"发言人{s.speaker_id or '?'}"
|
|
formatted_content.append(f"{timestamp} {speaker}: {s.text_content}")
|
|
|
|
meeting_text = "\n".join(formatted_content)
|
|
|
|
# 组合最终 user_prompt (素材 + 用户的附加要求)
|
|
user_prompt_content = f"### 会议转译内容 ###\n{meeting_text}"
|
|
if extra_prompt:
|
|
user_prompt_content += f"\n\n### 用户的额外指令 ###\n{extra_prompt}"
|
|
|
|
# 3. 创建任务记录 (按照数据库实际字段)
|
|
task_id = str(uuid.uuid4())
|
|
new_task = SummarizeTask(
|
|
task_id=task_id,
|
|
meeting_id=meeting_id,
|
|
prompt_id=prompt_id,
|
|
model_id=model_id,
|
|
user_prompt=user_prompt_content,
|
|
status="pending",
|
|
progress=0,
|
|
created_at=datetime.utcnow()
|
|
)
|
|
db.add(new_task)
|
|
|
|
# 更新会议状态为“总结中”
|
|
meeting.status = "summarizing"
|
|
db.commit()
|
|
|
|
# 4. 进入 Redis 队列
|
|
await redis_client.lpush("meeting:summarize:queue", task_id)
|
|
|
|
return new_task
|
|
|
|
@staticmethod
|
|
async def process_summarize_task(db: Session, task_id: str):
|
|
"""
|
|
后台 Worker 真实执行逻辑
|
|
"""
|
|
task = db.query(SummarizeTask).filter(SummarizeTask.task_id == task_id).first()
|
|
if not task:
|
|
return
|
|
|
|
try:
|
|
task.status = "processing"
|
|
task.progress = 15
|
|
db.commit()
|
|
|
|
# 1. 获取模型配置
|
|
model_config = db.query(AIModel).filter(AIModel.model_id == task.model_id).first()
|
|
if not model_config:
|
|
raise Exception("AI 模型配置不存在")
|
|
|
|
# 2. 实时获取提示词模板内容
|
|
prompt_tmpl = db.query(PromptTemplate).filter(PromptTemplate.id == task.prompt_id).first()
|
|
system_prompt = prompt_tmpl.content if prompt_tmpl else "请根据提供的会议转译内容生成准确的总结。"
|
|
|
|
task.progress = 30
|
|
db.commit()
|
|
|
|
# 3. 构建消息结构
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": task.user_prompt}
|
|
]
|
|
|
|
# 解析模型参数
|
|
config_params = model_config.config or {}
|
|
temperature = config_params.get("temperature", 0.7)
|
|
top_p = config_params.get("top_p", 0.9)
|
|
|
|
logger.info(f"Task {task_id}: Launching LLM request to {model_config.model_name}")
|
|
task.progress = 50
|
|
db.commit()
|
|
|
|
# 4. 调用大模型服务
|
|
summary_result = await LLMService.chat_completion(
|
|
api_key=model_config.api_key,
|
|
base_url=model_config.base_url or "https://api.openai.com/v1",
|
|
model_name=model_config.model_name,
|
|
messages=messages,
|
|
api_path=model_config.api_path or "/chat/completions",
|
|
temperature=float(temperature),
|
|
top_p=float(top_p)
|
|
)
|
|
|
|
# 5. 任务完成,回写结果
|
|
task.result = summary_result
|
|
task.status = "completed"
|
|
task.progress = 100
|
|
task.completed_at = datetime.utcnow()
|
|
|
|
# 同步更新会议主表摘要和状态
|
|
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
|
|
if meeting:
|
|
meeting.summary = summary_result
|
|
meeting.status = "completed"
|
|
|
|
db.commit()
|
|
logger.info(f"Task {task_id} completed successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Task {task_id} execution error: {str(e)}")
|
|
task.status = "failed"
|
|
task.error_message = str(e)
|
|
|
|
# 还原会议状态
|
|
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
|
|
if meeting:
|
|
meeting.status = "draft"
|
|
|
|
db.commit()
|
|
|
|
@staticmethod
|
|
def get_task_status(db: Session, task_id: str):
|
|
return db.query(SummarizeTask).filter(SummarizeTask.task_id == task_id).first()
|