nex_basse/backend/app/services/meeting_service.py

728 lines
33 KiB
Python

from sqlalchemy.orm import Session
from app.core.db import SessionLocal
from typing import Callable, Awaitable, Optional, List
from app.models import Meeting, SummarizeTask, PromptTemplate, AIModel, TranscriptSegment, TranscriptTask, MeetingAudio, Hotword
from app.services.llm_service import LLMService
import uuid
import json
import logging
import math
import httpx
import asyncio
from datetime import datetime
from app.core.redis import redis_client
from app.core.config import get_settings
from pathlib import Path
logger = logging.getLogger(__name__)
settings = get_settings()
class MeetingService:
@staticmethod
async def create_transcript_task(
db: Session,
meeting_id: int,
model_id: int,
language: str = "auto"
):
# 1. 验证会议和音频
meeting = db.query(Meeting).filter(Meeting.meeting_id == meeting_id).first()
if not meeting:
raise Exception("Meeting not found")
# 获取最新的音频文件
audio = db.query(MeetingAudio).filter(
MeetingAudio.meeting_id == meeting_id
).order_by(MeetingAudio.upload_time.desc()).first()
if not audio:
raise Exception("No audio file found for this meeting")
# 2. 创建转译任务
task_id = str(uuid.uuid4())
new_task = TranscriptTask(
task_id=task_id,
meeting_id=meeting_id,
model_id=model_id,
language=language,
status="pending",
progress=0,
created_at=datetime.utcnow()
)
db.add(new_task)
# 更新会议状态
meeting.status = "transcribing"
db.commit()
# 3. 进入 Redis 队列 (transcribe queue)
# Note: In a real worker system, we would push to a queue.
# For now, we return the task so the API can trigger background processing.
# await redis_client.lpush("meeting:transcribe:queue", task_id)
return new_task
@staticmethod
async def process_transcript_task(db: Session, task_id: str):
print(f"[DEBUG] Processing transcript task {task_id}")
task = db.query(TranscriptTask).filter(TranscriptTask.task_id == task_id).first()
if not task:
print(f"[ERROR] Task {task_id} not found in DB")
return
try:
task.status = "processing"
task.progress = 10
await asyncio.to_thread(db.commit)
# 1. 获取模型配置
model_config = db.query(AIModel).filter(AIModel.model_id == task.model_id).first()
if not model_config:
# 如果没有指定模型,尝试找一个默认的 ASR 模型
model_config = db.query(AIModel).filter(AIModel.model_type == "asr").first()
if not model_config:
raise Exception("No ASR model configuration found")
# 2. 获取音频文件
audio = db.query(MeetingAudio).filter(
MeetingAudio.meeting_id == task.meeting_id
).order_by(MeetingAudio.upload_time.desc()).first()
if not audio:
raise Exception("Audio file missing")
task.progress = 20
await asyncio.to_thread(db.commit)
# 3. 调用 ASR 服务 (Local Model Only)
# Use model config base_url if available, otherwise fallback to settings
asr_base_url = model_config.base_url if model_config.base_url else settings.asr_api_base_url
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
if not meeting:
raise Exception("Meeting not found")
hotword_filters = [Hotword.scope.in_(["public", "global"])]
if meeting.user_id:
hotword_filters.append(
(Hotword.scope == "personal") & (Hotword.user_id == meeting.user_id)
)
hotwords = db.query(Hotword).filter(
(hotword_filters[0]) if len(hotword_filters) == 1 else (hotword_filters[0] | hotword_filters[1])
).all()
hotword_entries = []
hotword_string_parts = []
for hw in hotwords:
hotword_entries.append({"word": hw.word, "weight": hw.weight})
if hw.weight and hw.weight != 1:
hotword_string_parts.append(f"{hw.word}:{hw.weight}")
else:
hotword_string_parts.append(hw.word)
hotword_string = " ".join([p for p in hotword_string_parts if p])
logger.info(f"Task {task_id}: Starting transcription with ASR Model at {asr_base_url}")
print(f"[DEBUG] Calling ASR at {asr_base_url} for file: {audio.file_path}")
# Define progress callback
async def update_progress(p: int, msg: str = None):
# Map ASR progress (0-100) to Task progress (20-80)
# Ensure we don't go backwards or exceed bounds
new_progress = 20 + int(p * 0.6)
print(f"[DEBUG] update_progress called: p={p}, new_progress={new_progress}, current={task.progress}")
if new_progress > task.progress:
task.progress = new_progress
# Note: We must be careful with db.commit() in async loop if db session is shared
try:
# Offload blocking DB commit to thread to avoid freezing the event loop
await asyncio.to_thread(db.commit)
print(f"[DEBUG] DB Updated: progress={new_progress}")
except Exception as dbe:
logger.error(f"DB Commit failed: {dbe}")
db.rollback()
# Update status message in Redis
if msg:
try:
await redis_client.setex(f"task:status:{task_id}", 3600, msg)
print(f"[DEBUG] Redis Updated: key=task:status:{task_id}, msg={msg}")
except Exception as e:
logger.warning(f"Failed to update task status in Redis: {e}")
# 使用本地模型 API 进行转译
segments_data = await MeetingService._call_local_asr_api(
audio_path=audio.file_path,
base_url=asr_base_url,
language=task.language,
hotwords=hotword_entries if hotword_entries else None,
hotword_string=hotword_string,
progress_callback=update_progress
)
print(f"[DEBUG] Received {len(segments_data)} segments from ASR")
task.progress = 80
await asyncio.to_thread(db.commit)
# 4. 保存转译结果
last_saved_end_ms = 0
def parse_number(v):
if isinstance(v, (int, float)):
return float(v)
if isinstance(v, str):
try:
return float(v.strip())
except Exception:
return None
return None
for seg in segments_data:
# Local API returns: { "text": "...", "timestamp": [[start, end]], "speaker": "..." }
# We need to adapt it to our schema
text = (seg.get("text") or "").strip()
if not text:
continue
start_ms = None
end_ms = None
ts = seg.get("timestamp")
if ts:
if isinstance(ts, list) and len(ts) > 0:
try:
if isinstance(ts[0], list):
starts = []
ends = []
for pair in ts:
if not isinstance(pair, (list, tuple)) or len(pair) < 2:
continue
s = parse_number(pair[0])
e = parse_number(pair[1])
if s is not None and e is not None:
starts.append(s)
ends.append(e)
if starts and ends:
raw_start = min(starts)
raw_end = max(ends)
if raw_end < raw_start:
raise ValueError("timestamp end < start")
if raw_end < 1000:
raw_start *= 1000.0
raw_end *= 1000.0
start_ms = int(raw_start)
end_ms = int(raw_end)
elif len(ts) >= 2:
raw_start = parse_number(ts[0])
raw_end = parse_number(ts[1])
if raw_start is None or raw_end is None:
raise ValueError("timestamp not numeric")
if raw_end < raw_start:
raise ValueError("timestamp end < start")
if raw_end < 1000:
raw_start *= 1000.0
raw_end *= 1000.0
start_ms = int(raw_start)
end_ms = int(raw_end)
except Exception:
start_ms = None
end_ms = None
if start_ms is None or end_ms is None:
bt = parse_number(seg.get("begin_time") or seg.get("start_time"))
et = parse_number(seg.get("end_time"))
if bt is not None and et is not None:
if bt < 1000 and et < 1000:
bt *= 1000.0
et *= 1000.0
start_ms = int(bt)
end_ms = int(et)
if start_ms is None or end_ms is None or end_ms < start_ms:
start_ms = last_saved_end_ms + 1
end_ms = start_ms + 1000
last_saved_end_ms = max(last_saved_end_ms, end_ms)
transcript_segment = TranscriptSegment(
meeting_id=task.meeting_id,
audio_id=audio.audio_id,
speaker_id=0,
speaker_tag=seg.get("speaker", "Unknown"),
start_time_ms=start_ms,
end_time_ms=end_ms,
text_content=text
)
db.add(transcript_segment)
# 5. 完成任务
task.status = "completed"
task.progress = 100
task.completed_at = datetime.utcnow()
# 更新音频状态
audio.processing_status = "completed"
# 更新会议状态 (如果需要自动开始总结,可以在这里触发,或者由用户触发)
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
if meeting:
meeting.status = "transcribed" # distinct status before summarizing
# Auto-trigger summarization for uploaded meetings
if meeting.type == 'upload':
try:
# Find default LLM model
llm_model = None
if meeting.summary_model_id:
llm_model = db.query(AIModel).filter(
AIModel.model_id == meeting.summary_model_id,
AIModel.status == 1
).first()
if not llm_model:
llm_model = db.query(AIModel).filter(
AIModel.model_type == 'llm',
AIModel.is_default == 1,
AIModel.status == 1
).first()
prompt_tmpl = None
if meeting.summary_prompt_id:
prompt_tmpl = db.query(PromptTemplate).filter(
PromptTemplate.id == meeting.summary_prompt_id
).first()
if not prompt_tmpl:
prompt_tmpl = db.query(PromptTemplate).filter(
PromptTemplate.status == 1
).order_by(PromptTemplate.is_system.desc(), PromptTemplate.id.asc()).first()
if llm_model and prompt_tmpl:
logger.info(f"Auto-triggering summary for meeting {meeting.meeting_id}")
# Create summarize task
# We need to call create_summarize_task but it's an async static method
# and we are inside an async static method.
# However, create_summarize_task commits to DB, so we should be careful with session.
# Since we are in the same session `db`, we can just call it.
# But create_summarize_task takes `db` session.
# Note: create_summarize_task commits. We should commit our changes first.
await asyncio.to_thread(db.commit)
new_sum_task = await MeetingService.create_summarize_task(
db,
meeting_id=meeting.meeting_id,
prompt_id=prompt_tmpl.id,
model_id=llm_model.model_id
)
# Trigger background worker for summary
# Since we are in a worker, we can just call process_summarize_task directly or via background task
# But better to use the same mechanism (add to background tasks if possible, or just call it)
# Here we don't have access to FastAPI BackgroundTasks object.
# We can use asyncio.create_task to run it in background
# import asyncio # Already imported at top level
# Use a separate worker method that creates its own session
# to avoid "Session is closed" error since the current session
# will be closed when this function returns.
asyncio.create_task(MeetingService.run_summarize_worker(new_sum_task.task_id))
# Update meeting status to summarizing
meeting.status = "summarizing"
else:
logger.warning(f"Skipping auto-summary: No default LLM or Prompt found (LLM: {llm_model}, Prompt: {prompt_tmpl})")
except Exception as sum_e:
logger.error(f"Failed to auto-trigger summary: {sum_e}")
await asyncio.to_thread(db.commit)
logger.info(f"Task {task_id} transcription completed")
except Exception as e:
logger.error(f"Task {task_id} failed: {str(e)}")
task.status = "failed"
task.error_message = str(e)
# Update audio status
audio = db.query(MeetingAudio).filter(
MeetingAudio.meeting_id == task.meeting_id
).order_by(MeetingAudio.upload_time.desc()).first()
if audio:
audio.processing_status = "failed"
audio.error_message = str(e)
# Update meeting status to failed so frontend knows to stop polling
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
if meeting:
meeting.status = "failed"
await asyncio.to_thread(db.commit)
@staticmethod
async def _call_local_asr_api(
audio_path: str,
base_url: str = "http://localhost:3050",
language: str = "auto",
hotwords: Optional[List[dict]] = None,
hotword_string: Optional[str] = None,
progress_callback: Optional[Callable[[int, Optional[str]], Awaitable[None]]] = None
) -> list:
"""
Call local ASR API for transcription.
Flow: Create Task -> Poll Status -> Get Result
"""
# import asyncio # Already imported at top level
create_url = f"{base_url.rstrip('/')}/api/tasks/recognition"
async with httpx.AsyncClient(timeout=30.0) as client:
# 1. Create Task
try:
normalized_hotwords = {}
if hotwords:
for hw in hotwords:
if isinstance(hw, dict):
word = hw.get("word")
weight = hw.get("weight")
if word:
normalized_hotwords[word] = int(weight) if weight is not None else 20
elif isinstance(hw, str):
normalized_hotwords[hw] = 20
payload = {
"file_path": audio_path,
"language": language,
"use_spk_id": True
}
if normalized_hotwords:
payload["hotwords"] = normalized_hotwords
elif hotword_string:
payload["hotword"] = hotword_string
response = await client.post(create_url, json=payload)
if response.status_code >= 500:
base_payload = {
"file_path": audio_path,
"language": language,
"use_spk_id": True
}
response = await client.post(create_url, json=base_payload)
if response.status_code == 422:
alt_payload = {
"audio_path": audio_path,
"language": language,
"use_spk_id": True
}
if normalized_hotwords:
alt_payload["hotwords"] = normalized_hotwords
elif hotword_string:
alt_payload["hotword"] = hotword_string
response = await client.post(create_url, json=alt_payload)
if response.status_code >= 500:
alt_base_payload = {
"audio_path": audio_path,
"language": language,
"use_spk_id": True
}
response = await client.post(create_url, json=alt_base_payload)
if response.status_code == 422 or response.status_code >= 500:
with open(audio_path, "rb") as f:
files = {
"file": (Path(audio_path).name, f, "application/octet-stream")
}
data_fields = {
"language": language,
"use_spk_id": "true"
}
if normalized_hotwords:
data_fields["hotwords"] = json.dumps(normalized_hotwords, ensure_ascii=False)
elif hotword_string:
data_fields["hotword"] = hotword_string
response = await client.post(create_url, data=data_fields, files=files)
if response.status_code >= 400:
try:
logger.error(f"ASR create response: {response.status_code} {response.text}")
except Exception:
pass
response.raise_for_status()
data = response.json()
# Handle nested data structure {code: 200, data: {task_id: ...}}
if "data" in data and isinstance(data["data"], dict) and "task_id" in data["data"]:
task_id = data["data"]["task_id"]
else:
task_id = data.get("task_id")
if not task_id:
raise Exception("Failed to get task_id from ASR service")
except Exception as e:
logger.error(f"Failed to create ASR task: {e}")
# Fallback for testing/mock if service is not running
if "Connection refused" in str(e) or "ConnectError" in str(e):
logger.warning("ASR Service not reachable, using mock data for testing")
await asyncio.sleep(2)
return [
{"timestamp": [[0, 2500]], "text": "This is a mock transcription (Local ASR unreachable).", "speaker": "System"},
{"timestamp": [[2500, 5000]], "text": "Please ensure the ASR service is running at localhost:3050.", "speaker": "System"}
]
raise
# 2. Poll Status
status_url = f"{base_url.rstrip('/')}/api/tasks/{task_id}"
# Increase timeout for polling
poll_client = httpx.AsyncClient(timeout=10.0)
max_retries = 600 # 20 minutes max wait (assuming 2s sleep)
try:
for _ in range(max_retries):
try:
res = await poll_client.get(status_url)
res.raise_for_status()
status_data = res.json()
progress = 0
# Handle nested data structure
if "data" in status_data and isinstance(status_data["data"], dict):
inner_data = status_data["data"]
status = inner_data.get("status")
result = inner_data.get("result", {})
error_msg = inner_data.get("msg") or status_data.get("msg")
# Try to get progress
if "progress" in inner_data:
raw_progress = inner_data["progress"]
elif "percent" in inner_data:
raw_progress = inner_data["percent"]
elif "percentage" in inner_data:
raw_progress = inner_data["percentage"]
else:
raw_progress = 0
# Update error_msg with message if available
if "message" in inner_data and inner_data["message"]:
error_msg = inner_data["message"]
else:
status = status_data.get("status")
result = status_data.get("result", {})
error_msg = status_data.get("msg")
# Try to get progress
if "progress" in status_data:
raw_progress = status_data["progress"]
elif "percent" in status_data:
raw_progress = status_data["percent"]
elif "percentage" in status_data:
raw_progress = status_data["percentage"]
else:
raw_progress = 0
# Update error_msg with message if available
if "message" in status_data and status_data["message"]:
error_msg = status_data["message"]
# Handle float progress (0.0 - 1.0) or int (0-100)
progress = 0
try:
if isinstance(raw_progress, (int, float)):
if 0.0 <= raw_progress <= 1.0 and isinstance(raw_progress, float):
progress = int(raw_progress * 100)
else:
progress = int(raw_progress)
elif isinstance(raw_progress, str):
# Handle string like "45%" or "0.45"
if raw_progress.endswith("%"):
progress = int(float(raw_progress.strip("%")))
else:
val = float(raw_progress)
if 0.0 <= val <= 1.0:
progress = int(val * 100)
else:
progress = int(val)
except (ValueError, TypeError):
progress = 0
# Log raw status for debugging
print(f"[DEBUG] ASR Polling: status={status}, raw_progress={raw_progress}, progress={progress}, msg={error_msg}")
# Update progress if callback provided
if progress_callback:
try:
status_msg = error_msg if error_msg else f"正在转写中... {progress}%"
print(f"[DEBUG] Invoking callback with progress={progress}, msg={status_msg}")
await progress_callback(progress, status_msg)
except Exception as e:
logger.warning(f"Progress callback error: {e}")
if status == "completed" or status == "success":
# Return the segments list.
if isinstance(result, dict) and "segments" in result:
return result["segments"]
elif isinstance(result, list):
return result
else:
# Fallback or empty
return []
elif status == "failed":
raise Exception(f"ASR Task failed: {error_msg}")
# Still processing
await asyncio.sleep(2)
except httpx.RequestError as e:
logger.warning(f"Error polling ASR task {task_id}: {e}")
await asyncio.sleep(2)
finally:
await poll_client.aclose()
raise Exception("ASR Task timed out")
@staticmethod
async def create_summarize_task(
db: Session,
meeting_id: int,
prompt_id: int,
model_id: int,
extra_prompt: str = ""
):
# 1. 基础数据校验
meeting = db.query(Meeting).filter(Meeting.meeting_id == meeting_id).first()
if not meeting:
raise Exception("Meeting not found")
# 2. 格式化会议转译内容 (作为 user_prompt 素材)
segments = db.query(TranscriptSegment).filter(
TranscriptSegment.meeting_id == meeting_id
).order_by(TranscriptSegment.start_time_ms).all()
formatted_content = []
for s in segments:
secs = int(s.start_time_ms // 1000)
m, sc = divmod(secs, 60)
timestamp = f"[{m:02d}:{sc:02d}]"
speaker = s.speaker_tag or f"发言人{s.speaker_id or '?'}"
formatted_content.append(f"{timestamp} {speaker}: {s.text_content}")
meeting_text = "\n".join(formatted_content)
# 组合最终 user_prompt (素材 + 用户的附加要求)
user_prompt_content = f"### 会议转译内容 ###\n{meeting_text}"
if extra_prompt:
user_prompt_content += f"\n\n### 用户的额外指令 ###\n{extra_prompt}"
# 3. 创建任务记录 (按照数据库实际字段)
task_id = str(uuid.uuid4())
new_task = SummarizeTask(
task_id=task_id,
meeting_id=meeting_id,
prompt_id=prompt_id,
model_id=model_id,
user_prompt=user_prompt_content,
status="pending",
progress=0,
created_at=datetime.utcnow()
)
db.add(new_task)
# 更新会议状态为“总结中”
meeting.status = "summarizing"
await asyncio.to_thread(db.commit)
# 4. 进入 Redis 队列
await redis_client.lpush("meeting:summarize:queue", task_id)
return new_task
@staticmethod
async def run_summarize_worker(task_id: str):
"""
Worker entry point for summarize task that manages its own DB session.
This is safe to call via asyncio.create_task from other contexts.
"""
db = SessionLocal()
try:
await MeetingService.process_summarize_task(db, task_id)
except Exception as e:
logger.error(f"Summarize worker (task {task_id}) failed: {e}")
finally:
db.close()
@staticmethod
async def process_summarize_task(db: Session, task_id: str):
"""
后台 Worker 真实执行逻辑
"""
task = db.query(SummarizeTask).filter(SummarizeTask.task_id == task_id).first()
if not task:
return
try:
task.status = "processing"
task.progress = 15
await asyncio.to_thread(db.commit)
# 1. 获取模型配置
model_config = db.query(AIModel).filter(AIModel.model_id == task.model_id).first()
if not model_config:
raise Exception("AI 模型配置不存在")
# 2. 实时获取提示词模板内容
prompt_tmpl = db.query(PromptTemplate).filter(PromptTemplate.id == task.prompt_id).first()
system_prompt = prompt_tmpl.content if prompt_tmpl else "请根据提供的会议转译内容生成准确的总结。"
task.progress = 30
await asyncio.to_thread(db.commit)
# 3. 构建消息结构
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": task.user_prompt}
]
# 解析模型参数
config_params = model_config.config or {}
temperature = config_params.get("temperature", 0.7)
top_p = config_params.get("top_p", 0.9)
logger.info(f"Task {task_id}: Launching LLM request to {model_config.model_name}")
task.progress = 50
await asyncio.to_thread(db.commit)
# 4. 调用大模型服务
summary_result = await LLMService.chat_completion(
api_key=model_config.api_key,
base_url=model_config.base_url or "https://api.openai.com/v1",
model_name=model_config.model_name,
messages=messages,
api_path=model_config.api_path or "/chat/completions",
temperature=float(temperature),
top_p=float(top_p)
)
# 5. 任务完成,回写结果
task.result = summary_result
task.status = "completed"
task.progress = 100
task.completed_at = datetime.utcnow()
# 同步更新会议主表摘要和状态
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
if meeting:
meeting.summary = summary_result
meeting.status = "completed"
await asyncio.to_thread(db.commit)
logger.info(f"Task {task_id} completed successfully")
except Exception as e:
logger.error(f"Task {task_id} execution error: {str(e)}")
task.status = "failed"
task.error_message = str(e)
# Restore meeting status to transcribed (so user can retry summary) instead of draft
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
if meeting:
meeting.status = "transcribed"
await asyncio.to_thread(db.commit)
@staticmethod
def get_task_status(db: Session, task_id: str):
task = db.query(SummarizeTask).filter(SummarizeTask.task_id == task_id).first()
if task:
return task
return db.query(TranscriptTask).filter(TranscriptTask.task_id == task_id).first()