728 lines
33 KiB
Python
728 lines
33 KiB
Python
from sqlalchemy.orm import Session
|
|
from app.core.db import SessionLocal
|
|
from typing import Callable, Awaitable, Optional, List
|
|
from app.models import Meeting, SummarizeTask, PromptTemplate, AIModel, TranscriptSegment, TranscriptTask, MeetingAudio, Hotword
|
|
from app.services.llm_service import LLMService
|
|
import uuid
|
|
import json
|
|
import logging
|
|
import math
|
|
import httpx
|
|
import asyncio
|
|
from datetime import datetime
|
|
from app.core.redis import redis_client
|
|
from app.core.config import get_settings
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
settings = get_settings()
|
|
|
|
class MeetingService:
|
|
@staticmethod
|
|
async def create_transcript_task(
|
|
db: Session,
|
|
meeting_id: int,
|
|
model_id: int,
|
|
language: str = "auto"
|
|
):
|
|
# 1. 验证会议和音频
|
|
meeting = db.query(Meeting).filter(Meeting.meeting_id == meeting_id).first()
|
|
if not meeting:
|
|
raise Exception("Meeting not found")
|
|
|
|
# 获取最新的音频文件
|
|
audio = db.query(MeetingAudio).filter(
|
|
MeetingAudio.meeting_id == meeting_id
|
|
).order_by(MeetingAudio.upload_time.desc()).first()
|
|
|
|
if not audio:
|
|
raise Exception("No audio file found for this meeting")
|
|
|
|
# 2. 创建转译任务
|
|
task_id = str(uuid.uuid4())
|
|
new_task = TranscriptTask(
|
|
task_id=task_id,
|
|
meeting_id=meeting_id,
|
|
model_id=model_id,
|
|
language=language,
|
|
status="pending",
|
|
progress=0,
|
|
created_at=datetime.utcnow()
|
|
)
|
|
db.add(new_task)
|
|
|
|
# 更新会议状态
|
|
meeting.status = "transcribing"
|
|
db.commit()
|
|
|
|
# 3. 进入 Redis 队列 (transcribe queue)
|
|
# Note: In a real worker system, we would push to a queue.
|
|
# For now, we return the task so the API can trigger background processing.
|
|
# await redis_client.lpush("meeting:transcribe:queue", task_id)
|
|
|
|
return new_task
|
|
|
|
@staticmethod
|
|
async def process_transcript_task(db: Session, task_id: str):
|
|
print(f"[DEBUG] Processing transcript task {task_id}")
|
|
task = db.query(TranscriptTask).filter(TranscriptTask.task_id == task_id).first()
|
|
if not task:
|
|
print(f"[ERROR] Task {task_id} not found in DB")
|
|
return
|
|
|
|
try:
|
|
task.status = "processing"
|
|
task.progress = 10
|
|
await asyncio.to_thread(db.commit)
|
|
|
|
# 1. 获取模型配置
|
|
model_config = db.query(AIModel).filter(AIModel.model_id == task.model_id).first()
|
|
if not model_config:
|
|
# 如果没有指定模型,尝试找一个默认的 ASR 模型
|
|
model_config = db.query(AIModel).filter(AIModel.model_type == "asr").first()
|
|
|
|
if not model_config:
|
|
raise Exception("No ASR model configuration found")
|
|
|
|
# 2. 获取音频文件
|
|
audio = db.query(MeetingAudio).filter(
|
|
MeetingAudio.meeting_id == task.meeting_id
|
|
).order_by(MeetingAudio.upload_time.desc()).first()
|
|
|
|
if not audio:
|
|
raise Exception("Audio file missing")
|
|
|
|
task.progress = 20
|
|
await asyncio.to_thread(db.commit)
|
|
|
|
# 3. 调用 ASR 服务 (Local Model Only)
|
|
# Use model config base_url if available, otherwise fallback to settings
|
|
asr_base_url = model_config.base_url if model_config.base_url else settings.asr_api_base_url
|
|
|
|
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
|
|
if not meeting:
|
|
raise Exception("Meeting not found")
|
|
|
|
hotword_filters = [Hotword.scope.in_(["public", "global"])]
|
|
if meeting.user_id:
|
|
hotword_filters.append(
|
|
(Hotword.scope == "personal") & (Hotword.user_id == meeting.user_id)
|
|
)
|
|
hotwords = db.query(Hotword).filter(
|
|
(hotword_filters[0]) if len(hotword_filters) == 1 else (hotword_filters[0] | hotword_filters[1])
|
|
).all()
|
|
hotword_entries = []
|
|
hotword_string_parts = []
|
|
for hw in hotwords:
|
|
hotword_entries.append({"word": hw.word, "weight": hw.weight})
|
|
if hw.weight and hw.weight != 1:
|
|
hotword_string_parts.append(f"{hw.word}:{hw.weight}")
|
|
else:
|
|
hotword_string_parts.append(hw.word)
|
|
hotword_string = " ".join([p for p in hotword_string_parts if p])
|
|
|
|
logger.info(f"Task {task_id}: Starting transcription with ASR Model at {asr_base_url}")
|
|
print(f"[DEBUG] Calling ASR at {asr_base_url} for file: {audio.file_path}")
|
|
|
|
# Define progress callback
|
|
async def update_progress(p: int, msg: str = None):
|
|
# Map ASR progress (0-100) to Task progress (20-80)
|
|
# Ensure we don't go backwards or exceed bounds
|
|
new_progress = 20 + int(p * 0.6)
|
|
print(f"[DEBUG] update_progress called: p={p}, new_progress={new_progress}, current={task.progress}")
|
|
|
|
if new_progress > task.progress:
|
|
task.progress = new_progress
|
|
# Note: We must be careful with db.commit() in async loop if db session is shared
|
|
try:
|
|
# Offload blocking DB commit to thread to avoid freezing the event loop
|
|
await asyncio.to_thread(db.commit)
|
|
print(f"[DEBUG] DB Updated: progress={new_progress}")
|
|
except Exception as dbe:
|
|
logger.error(f"DB Commit failed: {dbe}")
|
|
db.rollback()
|
|
|
|
# Update status message in Redis
|
|
if msg:
|
|
try:
|
|
await redis_client.setex(f"task:status:{task_id}", 3600, msg)
|
|
print(f"[DEBUG] Redis Updated: key=task:status:{task_id}, msg={msg}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to update task status in Redis: {e}")
|
|
|
|
# 使用本地模型 API 进行转译
|
|
segments_data = await MeetingService._call_local_asr_api(
|
|
audio_path=audio.file_path,
|
|
base_url=asr_base_url,
|
|
language=task.language,
|
|
hotwords=hotword_entries if hotword_entries else None,
|
|
hotword_string=hotword_string,
|
|
progress_callback=update_progress
|
|
)
|
|
print(f"[DEBUG] Received {len(segments_data)} segments from ASR")
|
|
|
|
task.progress = 80
|
|
await asyncio.to_thread(db.commit)
|
|
|
|
# 4. 保存转译结果
|
|
last_saved_end_ms = 0
|
|
|
|
def parse_number(v):
|
|
if isinstance(v, (int, float)):
|
|
return float(v)
|
|
if isinstance(v, str):
|
|
try:
|
|
return float(v.strip())
|
|
except Exception:
|
|
return None
|
|
return None
|
|
|
|
for seg in segments_data:
|
|
# Local API returns: { "text": "...", "timestamp": [[start, end]], "speaker": "..." }
|
|
# We need to adapt it to our schema
|
|
|
|
text = (seg.get("text") or "").strip()
|
|
if not text:
|
|
continue
|
|
|
|
start_ms = None
|
|
end_ms = None
|
|
ts = seg.get("timestamp")
|
|
if ts:
|
|
if isinstance(ts, list) and len(ts) > 0:
|
|
try:
|
|
if isinstance(ts[0], list):
|
|
starts = []
|
|
ends = []
|
|
for pair in ts:
|
|
if not isinstance(pair, (list, tuple)) or len(pair) < 2:
|
|
continue
|
|
s = parse_number(pair[0])
|
|
e = parse_number(pair[1])
|
|
if s is not None and e is not None:
|
|
starts.append(s)
|
|
ends.append(e)
|
|
if starts and ends:
|
|
raw_start = min(starts)
|
|
raw_end = max(ends)
|
|
if raw_end < raw_start:
|
|
raise ValueError("timestamp end < start")
|
|
if raw_end < 1000:
|
|
raw_start *= 1000.0
|
|
raw_end *= 1000.0
|
|
start_ms = int(raw_start)
|
|
end_ms = int(raw_end)
|
|
elif len(ts) >= 2:
|
|
raw_start = parse_number(ts[0])
|
|
raw_end = parse_number(ts[1])
|
|
if raw_start is None or raw_end is None:
|
|
raise ValueError("timestamp not numeric")
|
|
if raw_end < raw_start:
|
|
raise ValueError("timestamp end < start")
|
|
if raw_end < 1000:
|
|
raw_start *= 1000.0
|
|
raw_end *= 1000.0
|
|
start_ms = int(raw_start)
|
|
end_ms = int(raw_end)
|
|
except Exception:
|
|
start_ms = None
|
|
end_ms = None
|
|
|
|
if start_ms is None or end_ms is None:
|
|
bt = parse_number(seg.get("begin_time") or seg.get("start_time"))
|
|
et = parse_number(seg.get("end_time"))
|
|
if bt is not None and et is not None:
|
|
if bt < 1000 and et < 1000:
|
|
bt *= 1000.0
|
|
et *= 1000.0
|
|
start_ms = int(bt)
|
|
end_ms = int(et)
|
|
|
|
if start_ms is None or end_ms is None or end_ms < start_ms:
|
|
start_ms = last_saved_end_ms + 1
|
|
end_ms = start_ms + 1000
|
|
|
|
last_saved_end_ms = max(last_saved_end_ms, end_ms)
|
|
|
|
transcript_segment = TranscriptSegment(
|
|
meeting_id=task.meeting_id,
|
|
audio_id=audio.audio_id,
|
|
speaker_id=0,
|
|
speaker_tag=seg.get("speaker", "Unknown"),
|
|
start_time_ms=start_ms,
|
|
end_time_ms=end_ms,
|
|
text_content=text
|
|
)
|
|
db.add(transcript_segment)
|
|
|
|
# 5. 完成任务
|
|
task.status = "completed"
|
|
task.progress = 100
|
|
task.completed_at = datetime.utcnow()
|
|
|
|
# 更新音频状态
|
|
audio.processing_status = "completed"
|
|
|
|
# 更新会议状态 (如果需要自动开始总结,可以在这里触发,或者由用户触发)
|
|
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
|
|
if meeting:
|
|
meeting.status = "transcribed" # distinct status before summarizing
|
|
|
|
# Auto-trigger summarization for uploaded meetings
|
|
if meeting.type == 'upload':
|
|
try:
|
|
# Find default LLM model
|
|
llm_model = None
|
|
if meeting.summary_model_id:
|
|
llm_model = db.query(AIModel).filter(
|
|
AIModel.model_id == meeting.summary_model_id,
|
|
AIModel.status == 1
|
|
).first()
|
|
if not llm_model:
|
|
llm_model = db.query(AIModel).filter(
|
|
AIModel.model_type == 'llm',
|
|
AIModel.is_default == 1,
|
|
AIModel.status == 1
|
|
).first()
|
|
|
|
prompt_tmpl = None
|
|
if meeting.summary_prompt_id:
|
|
prompt_tmpl = db.query(PromptTemplate).filter(
|
|
PromptTemplate.id == meeting.summary_prompt_id
|
|
).first()
|
|
if not prompt_tmpl:
|
|
prompt_tmpl = db.query(PromptTemplate).filter(
|
|
PromptTemplate.status == 1
|
|
).order_by(PromptTemplate.is_system.desc(), PromptTemplate.id.asc()).first()
|
|
|
|
if llm_model and prompt_tmpl:
|
|
logger.info(f"Auto-triggering summary for meeting {meeting.meeting_id}")
|
|
# Create summarize task
|
|
# We need to call create_summarize_task but it's an async static method
|
|
# and we are inside an async static method.
|
|
# However, create_summarize_task commits to DB, so we should be careful with session.
|
|
# Since we are in the same session `db`, we can just call it.
|
|
# But create_summarize_task takes `db` session.
|
|
|
|
# Note: create_summarize_task commits. We should commit our changes first.
|
|
await asyncio.to_thread(db.commit)
|
|
|
|
new_sum_task = await MeetingService.create_summarize_task(
|
|
db,
|
|
meeting_id=meeting.meeting_id,
|
|
prompt_id=prompt_tmpl.id,
|
|
model_id=llm_model.model_id
|
|
)
|
|
# Trigger background worker for summary
|
|
# Since we are in a worker, we can just call process_summarize_task directly or via background task
|
|
# But better to use the same mechanism (add to background tasks if possible, or just call it)
|
|
# Here we don't have access to FastAPI BackgroundTasks object.
|
|
# We can use asyncio.create_task to run it in background
|
|
# import asyncio # Already imported at top level
|
|
# Use a separate worker method that creates its own session
|
|
# to avoid "Session is closed" error since the current session
|
|
# will be closed when this function returns.
|
|
asyncio.create_task(MeetingService.run_summarize_worker(new_sum_task.task_id))
|
|
|
|
# Update meeting status to summarizing
|
|
meeting.status = "summarizing"
|
|
else:
|
|
logger.warning(f"Skipping auto-summary: No default LLM or Prompt found (LLM: {llm_model}, Prompt: {prompt_tmpl})")
|
|
except Exception as sum_e:
|
|
logger.error(f"Failed to auto-trigger summary: {sum_e}")
|
|
|
|
await asyncio.to_thread(db.commit)
|
|
logger.info(f"Task {task_id} transcription completed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Task {task_id} failed: {str(e)}")
|
|
task.status = "failed"
|
|
task.error_message = str(e)
|
|
|
|
# Update audio status
|
|
audio = db.query(MeetingAudio).filter(
|
|
MeetingAudio.meeting_id == task.meeting_id
|
|
).order_by(MeetingAudio.upload_time.desc()).first()
|
|
if audio:
|
|
audio.processing_status = "failed"
|
|
audio.error_message = str(e)
|
|
|
|
# Update meeting status to failed so frontend knows to stop polling
|
|
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
|
|
if meeting:
|
|
meeting.status = "failed"
|
|
|
|
await asyncio.to_thread(db.commit)
|
|
|
|
@staticmethod
|
|
async def _call_local_asr_api(
|
|
audio_path: str,
|
|
base_url: str = "http://localhost:3050",
|
|
language: str = "auto",
|
|
hotwords: Optional[List[dict]] = None,
|
|
hotword_string: Optional[str] = None,
|
|
progress_callback: Optional[Callable[[int, Optional[str]], Awaitable[None]]] = None
|
|
) -> list:
|
|
"""
|
|
Call local ASR API for transcription.
|
|
Flow: Create Task -> Poll Status -> Get Result
|
|
"""
|
|
# import asyncio # Already imported at top level
|
|
|
|
create_url = f"{base_url.rstrip('/')}/api/tasks/recognition"
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
# 1. Create Task
|
|
try:
|
|
normalized_hotwords = {}
|
|
if hotwords:
|
|
for hw in hotwords:
|
|
if isinstance(hw, dict):
|
|
word = hw.get("word")
|
|
weight = hw.get("weight")
|
|
if word:
|
|
normalized_hotwords[word] = int(weight) if weight is not None else 20
|
|
elif isinstance(hw, str):
|
|
normalized_hotwords[hw] = 20
|
|
|
|
payload = {
|
|
"file_path": audio_path,
|
|
"language": language,
|
|
"use_spk_id": True
|
|
}
|
|
if normalized_hotwords:
|
|
payload["hotwords"] = normalized_hotwords
|
|
elif hotword_string:
|
|
payload["hotword"] = hotword_string
|
|
response = await client.post(create_url, json=payload)
|
|
if response.status_code >= 500:
|
|
base_payload = {
|
|
"file_path": audio_path,
|
|
"language": language,
|
|
"use_spk_id": True
|
|
}
|
|
response = await client.post(create_url, json=base_payload)
|
|
if response.status_code == 422:
|
|
alt_payload = {
|
|
"audio_path": audio_path,
|
|
"language": language,
|
|
"use_spk_id": True
|
|
}
|
|
if normalized_hotwords:
|
|
alt_payload["hotwords"] = normalized_hotwords
|
|
elif hotword_string:
|
|
alt_payload["hotword"] = hotword_string
|
|
response = await client.post(create_url, json=alt_payload)
|
|
if response.status_code >= 500:
|
|
alt_base_payload = {
|
|
"audio_path": audio_path,
|
|
"language": language,
|
|
"use_spk_id": True
|
|
}
|
|
response = await client.post(create_url, json=alt_base_payload)
|
|
if response.status_code == 422 or response.status_code >= 500:
|
|
with open(audio_path, "rb") as f:
|
|
files = {
|
|
"file": (Path(audio_path).name, f, "application/octet-stream")
|
|
}
|
|
data_fields = {
|
|
"language": language,
|
|
"use_spk_id": "true"
|
|
}
|
|
if normalized_hotwords:
|
|
data_fields["hotwords"] = json.dumps(normalized_hotwords, ensure_ascii=False)
|
|
elif hotword_string:
|
|
data_fields["hotword"] = hotword_string
|
|
response = await client.post(create_url, data=data_fields, files=files)
|
|
if response.status_code >= 400:
|
|
try:
|
|
logger.error(f"ASR create response: {response.status_code} {response.text}")
|
|
except Exception:
|
|
pass
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
# Handle nested data structure {code: 200, data: {task_id: ...}}
|
|
if "data" in data and isinstance(data["data"], dict) and "task_id" in data["data"]:
|
|
task_id = data["data"]["task_id"]
|
|
else:
|
|
task_id = data.get("task_id")
|
|
|
|
if not task_id:
|
|
raise Exception("Failed to get task_id from ASR service")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create ASR task: {e}")
|
|
# Fallback for testing/mock if service is not running
|
|
if "Connection refused" in str(e) or "ConnectError" in str(e):
|
|
logger.warning("ASR Service not reachable, using mock data for testing")
|
|
await asyncio.sleep(2)
|
|
return [
|
|
{"timestamp": [[0, 2500]], "text": "This is a mock transcription (Local ASR unreachable).", "speaker": "System"},
|
|
{"timestamp": [[2500, 5000]], "text": "Please ensure the ASR service is running at localhost:3050.", "speaker": "System"}
|
|
]
|
|
raise
|
|
|
|
# 2. Poll Status
|
|
status_url = f"{base_url.rstrip('/')}/api/tasks/{task_id}"
|
|
# Increase timeout for polling
|
|
poll_client = httpx.AsyncClient(timeout=10.0)
|
|
|
|
max_retries = 600 # 20 minutes max wait (assuming 2s sleep)
|
|
|
|
try:
|
|
for _ in range(max_retries):
|
|
try:
|
|
res = await poll_client.get(status_url)
|
|
res.raise_for_status()
|
|
status_data = res.json()
|
|
|
|
progress = 0
|
|
|
|
# Handle nested data structure
|
|
if "data" in status_data and isinstance(status_data["data"], dict):
|
|
inner_data = status_data["data"]
|
|
status = inner_data.get("status")
|
|
result = inner_data.get("result", {})
|
|
error_msg = inner_data.get("msg") or status_data.get("msg")
|
|
# Try to get progress
|
|
if "progress" in inner_data:
|
|
raw_progress = inner_data["progress"]
|
|
elif "percent" in inner_data:
|
|
raw_progress = inner_data["percent"]
|
|
elif "percentage" in inner_data:
|
|
raw_progress = inner_data["percentage"]
|
|
else:
|
|
raw_progress = 0
|
|
|
|
# Update error_msg with message if available
|
|
if "message" in inner_data and inner_data["message"]:
|
|
error_msg = inner_data["message"]
|
|
else:
|
|
status = status_data.get("status")
|
|
result = status_data.get("result", {})
|
|
error_msg = status_data.get("msg")
|
|
# Try to get progress
|
|
if "progress" in status_data:
|
|
raw_progress = status_data["progress"]
|
|
elif "percent" in status_data:
|
|
raw_progress = status_data["percent"]
|
|
elif "percentage" in status_data:
|
|
raw_progress = status_data["percentage"]
|
|
else:
|
|
raw_progress = 0
|
|
|
|
# Update error_msg with message if available
|
|
if "message" in status_data and status_data["message"]:
|
|
error_msg = status_data["message"]
|
|
|
|
# Handle float progress (0.0 - 1.0) or int (0-100)
|
|
progress = 0
|
|
try:
|
|
if isinstance(raw_progress, (int, float)):
|
|
if 0.0 <= raw_progress <= 1.0 and isinstance(raw_progress, float):
|
|
progress = int(raw_progress * 100)
|
|
else:
|
|
progress = int(raw_progress)
|
|
elif isinstance(raw_progress, str):
|
|
# Handle string like "45%" or "0.45"
|
|
if raw_progress.endswith("%"):
|
|
progress = int(float(raw_progress.strip("%")))
|
|
else:
|
|
val = float(raw_progress)
|
|
if 0.0 <= val <= 1.0:
|
|
progress = int(val * 100)
|
|
else:
|
|
progress = int(val)
|
|
except (ValueError, TypeError):
|
|
progress = 0
|
|
|
|
# Log raw status for debugging
|
|
print(f"[DEBUG] ASR Polling: status={status}, raw_progress={raw_progress}, progress={progress}, msg={error_msg}")
|
|
|
|
# Update progress if callback provided
|
|
if progress_callback:
|
|
try:
|
|
status_msg = error_msg if error_msg else f"正在转写中... {progress}%"
|
|
print(f"[DEBUG] Invoking callback with progress={progress}, msg={status_msg}")
|
|
await progress_callback(progress, status_msg)
|
|
except Exception as e:
|
|
logger.warning(f"Progress callback error: {e}")
|
|
|
|
if status == "completed" or status == "success":
|
|
# Return the segments list.
|
|
if isinstance(result, dict) and "segments" in result:
|
|
return result["segments"]
|
|
elif isinstance(result, list):
|
|
return result
|
|
else:
|
|
# Fallback or empty
|
|
return []
|
|
elif status == "failed":
|
|
raise Exception(f"ASR Task failed: {error_msg}")
|
|
|
|
# Still processing
|
|
await asyncio.sleep(2)
|
|
except httpx.RequestError as e:
|
|
logger.warning(f"Error polling ASR task {task_id}: {e}")
|
|
await asyncio.sleep(2)
|
|
finally:
|
|
await poll_client.aclose()
|
|
|
|
raise Exception("ASR Task timed out")
|
|
|
|
@staticmethod
|
|
async def create_summarize_task(
|
|
db: Session,
|
|
meeting_id: int,
|
|
prompt_id: int,
|
|
model_id: int,
|
|
extra_prompt: str = ""
|
|
):
|
|
# 1. 基础数据校验
|
|
meeting = db.query(Meeting).filter(Meeting.meeting_id == meeting_id).first()
|
|
if not meeting:
|
|
raise Exception("Meeting not found")
|
|
|
|
# 2. 格式化会议转译内容 (作为 user_prompt 素材)
|
|
segments = db.query(TranscriptSegment).filter(
|
|
TranscriptSegment.meeting_id == meeting_id
|
|
).order_by(TranscriptSegment.start_time_ms).all()
|
|
|
|
formatted_content = []
|
|
for s in segments:
|
|
secs = int(s.start_time_ms // 1000)
|
|
m, sc = divmod(secs, 60)
|
|
timestamp = f"[{m:02d}:{sc:02d}]"
|
|
speaker = s.speaker_tag or f"发言人{s.speaker_id or '?'}"
|
|
formatted_content.append(f"{timestamp} {speaker}: {s.text_content}")
|
|
|
|
meeting_text = "\n".join(formatted_content)
|
|
|
|
# 组合最终 user_prompt (素材 + 用户的附加要求)
|
|
user_prompt_content = f"### 会议转译内容 ###\n{meeting_text}"
|
|
if extra_prompt:
|
|
user_prompt_content += f"\n\n### 用户的额外指令 ###\n{extra_prompt}"
|
|
|
|
# 3. 创建任务记录 (按照数据库实际字段)
|
|
task_id = str(uuid.uuid4())
|
|
new_task = SummarizeTask(
|
|
task_id=task_id,
|
|
meeting_id=meeting_id,
|
|
prompt_id=prompt_id,
|
|
model_id=model_id,
|
|
user_prompt=user_prompt_content,
|
|
status="pending",
|
|
progress=0,
|
|
created_at=datetime.utcnow()
|
|
)
|
|
db.add(new_task)
|
|
|
|
# 更新会议状态为“总结中”
|
|
meeting.status = "summarizing"
|
|
await asyncio.to_thread(db.commit)
|
|
|
|
# 4. 进入 Redis 队列
|
|
await redis_client.lpush("meeting:summarize:queue", task_id)
|
|
|
|
return new_task
|
|
|
|
@staticmethod
|
|
async def run_summarize_worker(task_id: str):
|
|
"""
|
|
Worker entry point for summarize task that manages its own DB session.
|
|
This is safe to call via asyncio.create_task from other contexts.
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
await MeetingService.process_summarize_task(db, task_id)
|
|
except Exception as e:
|
|
logger.error(f"Summarize worker (task {task_id}) failed: {e}")
|
|
finally:
|
|
db.close()
|
|
|
|
@staticmethod
|
|
async def process_summarize_task(db: Session, task_id: str):
|
|
"""
|
|
后台 Worker 真实执行逻辑
|
|
"""
|
|
task = db.query(SummarizeTask).filter(SummarizeTask.task_id == task_id).first()
|
|
if not task:
|
|
return
|
|
|
|
try:
|
|
task.status = "processing"
|
|
task.progress = 15
|
|
await asyncio.to_thread(db.commit)
|
|
|
|
# 1. 获取模型配置
|
|
model_config = db.query(AIModel).filter(AIModel.model_id == task.model_id).first()
|
|
if not model_config:
|
|
raise Exception("AI 模型配置不存在")
|
|
|
|
# 2. 实时获取提示词模板内容
|
|
prompt_tmpl = db.query(PromptTemplate).filter(PromptTemplate.id == task.prompt_id).first()
|
|
system_prompt = prompt_tmpl.content if prompt_tmpl else "请根据提供的会议转译内容生成准确的总结。"
|
|
|
|
task.progress = 30
|
|
await asyncio.to_thread(db.commit)
|
|
|
|
# 3. 构建消息结构
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": task.user_prompt}
|
|
]
|
|
|
|
# 解析模型参数
|
|
config_params = model_config.config or {}
|
|
temperature = config_params.get("temperature", 0.7)
|
|
top_p = config_params.get("top_p", 0.9)
|
|
|
|
logger.info(f"Task {task_id}: Launching LLM request to {model_config.model_name}")
|
|
task.progress = 50
|
|
await asyncio.to_thread(db.commit)
|
|
|
|
# 4. 调用大模型服务
|
|
summary_result = await LLMService.chat_completion(
|
|
api_key=model_config.api_key,
|
|
base_url=model_config.base_url or "https://api.openai.com/v1",
|
|
model_name=model_config.model_name,
|
|
messages=messages,
|
|
api_path=model_config.api_path or "/chat/completions",
|
|
temperature=float(temperature),
|
|
top_p=float(top_p)
|
|
)
|
|
|
|
# 5. 任务完成,回写结果
|
|
task.result = summary_result
|
|
task.status = "completed"
|
|
task.progress = 100
|
|
task.completed_at = datetime.utcnow()
|
|
|
|
# 同步更新会议主表摘要和状态
|
|
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
|
|
if meeting:
|
|
meeting.summary = summary_result
|
|
meeting.status = "completed"
|
|
|
|
await asyncio.to_thread(db.commit)
|
|
logger.info(f"Task {task_id} completed successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Task {task_id} execution error: {str(e)}")
|
|
task.status = "failed"
|
|
task.error_message = str(e)
|
|
|
|
# Restore meeting status to transcribed (so user can retry summary) instead of draft
|
|
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
|
|
if meeting:
|
|
meeting.status = "transcribed"
|
|
|
|
await asyncio.to_thread(db.commit)
|
|
|
|
@staticmethod
|
|
def get_task_status(db: Session, task_id: str):
|
|
task = db.query(SummarizeTask).filter(SummarizeTask.task_id == task_id).first()
|
|
if task:
|
|
return task
|
|
return db.query(TranscriptTask).filter(TranscriptTask.task_id == task_id).first()
|