from sqlalchemy.orm import Session from app.core.db import SessionLocal from typing import Callable, Awaitable, Optional, List from app.models import Meeting, SummarizeTask, PromptTemplate, AIModel, TranscriptSegment, TranscriptTask, MeetingAudio, Hotword from app.services.llm_service import LLMService import uuid import json import logging import math import httpx import asyncio from datetime import datetime from app.core.redis import redis_client from app.core.config import get_settings from pathlib import Path logger = logging.getLogger(__name__) settings = get_settings() class MeetingService: @staticmethod async def create_transcript_task( db: Session, meeting_id: int, model_id: int, language: str = "auto" ): # 1. 验证会议和音频 meeting = db.query(Meeting).filter(Meeting.meeting_id == meeting_id).first() if not meeting: raise Exception("Meeting not found") # 获取最新的音频文件 audio = db.query(MeetingAudio).filter( MeetingAudio.meeting_id == meeting_id ).order_by(MeetingAudio.upload_time.desc()).first() if not audio: raise Exception("No audio file found for this meeting") # 2. 创建转译任务 task_id = str(uuid.uuid4()) new_task = TranscriptTask( task_id=task_id, meeting_id=meeting_id, model_id=model_id, language=language, status="pending", progress=0, created_at=datetime.utcnow() ) db.add(new_task) # 更新会议状态 meeting.status = "transcribing" db.commit() # 3. 进入 Redis 队列 (transcribe queue) # Note: In a real worker system, we would push to a queue. # For now, we return the task so the API can trigger background processing. # await redis_client.lpush("meeting:transcribe:queue", task_id) return new_task @staticmethod async def process_transcript_task(db: Session, task_id: str): print(f"[DEBUG] Processing transcript task {task_id}") task = db.query(TranscriptTask).filter(TranscriptTask.task_id == task_id).first() if not task: print(f"[ERROR] Task {task_id} not found in DB") return try: task.status = "processing" task.progress = 10 await asyncio.to_thread(db.commit) # 1. 获取模型配置 model_config = db.query(AIModel).filter(AIModel.model_id == task.model_id).first() if not model_config: # 如果没有指定模型,尝试找一个默认的 ASR 模型 model_config = db.query(AIModel).filter(AIModel.model_type == "asr").first() if not model_config: raise Exception("No ASR model configuration found") # 2. 获取音频文件 audio = db.query(MeetingAudio).filter( MeetingAudio.meeting_id == task.meeting_id ).order_by(MeetingAudio.upload_time.desc()).first() if not audio: raise Exception("Audio file missing") task.progress = 20 await asyncio.to_thread(db.commit) # 3. 调用 ASR 服务 (Local Model Only) # Use model config base_url if available, otherwise fallback to settings asr_base_url = model_config.base_url if model_config.base_url else settings.asr_api_base_url meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first() if not meeting: raise Exception("Meeting not found") hotword_filters = [Hotword.scope.in_(["public", "global"])] if meeting.user_id: hotword_filters.append( (Hotword.scope == "personal") & (Hotword.user_id == meeting.user_id) ) hotwords = db.query(Hotword).filter( (hotword_filters[0]) if len(hotword_filters) == 1 else (hotword_filters[0] | hotword_filters[1]) ).all() hotword_entries = [] hotword_string_parts = [] for hw in hotwords: hotword_entries.append({"word": hw.word, "weight": hw.weight}) if hw.weight and hw.weight != 1: hotword_string_parts.append(f"{hw.word}:{hw.weight}") else: hotword_string_parts.append(hw.word) hotword_string = " ".join([p for p in hotword_string_parts if p]) logger.info(f"Task {task_id}: Starting transcription with ASR Model at {asr_base_url}") print(f"[DEBUG] Calling ASR at {asr_base_url} for file: {audio.file_path}") # Define progress callback async def update_progress(p: int, msg: str = None): # Map ASR progress (0-100) to Task progress (20-80) # Ensure we don't go backwards or exceed bounds new_progress = 20 + int(p * 0.6) print(f"[DEBUG] update_progress called: p={p}, new_progress={new_progress}, current={task.progress}") if new_progress > task.progress: task.progress = new_progress # Note: We must be careful with db.commit() in async loop if db session is shared try: # Offload blocking DB commit to thread to avoid freezing the event loop await asyncio.to_thread(db.commit) print(f"[DEBUG] DB Updated: progress={new_progress}") except Exception as dbe: logger.error(f"DB Commit failed: {dbe}") db.rollback() # Update status message in Redis if msg: try: await redis_client.setex(f"task:status:{task_id}", 3600, msg) print(f"[DEBUG] Redis Updated: key=task:status:{task_id}, msg={msg}") except Exception as e: logger.warning(f"Failed to update task status in Redis: {e}") # 使用本地模型 API 进行转译 segments_data = await MeetingService._call_local_asr_api( audio_path=audio.file_path, base_url=asr_base_url, language=task.language, hotwords=hotword_entries if hotword_entries else None, hotword_string=hotword_string, progress_callback=update_progress ) print(f"[DEBUG] Received {len(segments_data)} segments from ASR") task.progress = 80 await asyncio.to_thread(db.commit) # 4. 保存转译结果 last_saved_end_ms = 0 def parse_number(v): if isinstance(v, (int, float)): return float(v) if isinstance(v, str): try: return float(v.strip()) except Exception: return None return None for seg in segments_data: # Local API returns: { "text": "...", "timestamp": [[start, end]], "speaker": "..." } # We need to adapt it to our schema text = (seg.get("text") or "").strip() if not text: continue start_ms = None end_ms = None ts = seg.get("timestamp") if ts: if isinstance(ts, list) and len(ts) > 0: try: if isinstance(ts[0], list): starts = [] ends = [] for pair in ts: if not isinstance(pair, (list, tuple)) or len(pair) < 2: continue s = parse_number(pair[0]) e = parse_number(pair[1]) if s is not None and e is not None: starts.append(s) ends.append(e) if starts and ends: raw_start = min(starts) raw_end = max(ends) if raw_end < raw_start: raise ValueError("timestamp end < start") if raw_end < 1000: raw_start *= 1000.0 raw_end *= 1000.0 start_ms = int(raw_start) end_ms = int(raw_end) elif len(ts) >= 2: raw_start = parse_number(ts[0]) raw_end = parse_number(ts[1]) if raw_start is None or raw_end is None: raise ValueError("timestamp not numeric") if raw_end < raw_start: raise ValueError("timestamp end < start") if raw_end < 1000: raw_start *= 1000.0 raw_end *= 1000.0 start_ms = int(raw_start) end_ms = int(raw_end) except Exception: start_ms = None end_ms = None if start_ms is None or end_ms is None: bt = parse_number(seg.get("begin_time") or seg.get("start_time")) et = parse_number(seg.get("end_time")) if bt is not None and et is not None: if bt < 1000 and et < 1000: bt *= 1000.0 et *= 1000.0 start_ms = int(bt) end_ms = int(et) if start_ms is None or end_ms is None or end_ms < start_ms: start_ms = last_saved_end_ms + 1 end_ms = start_ms + 1000 last_saved_end_ms = max(last_saved_end_ms, end_ms) transcript_segment = TranscriptSegment( meeting_id=task.meeting_id, audio_id=audio.audio_id, speaker_id=0, speaker_tag=seg.get("speaker", "Unknown"), start_time_ms=start_ms, end_time_ms=end_ms, text_content=text ) db.add(transcript_segment) # 5. 完成任务 task.status = "completed" task.progress = 100 task.completed_at = datetime.utcnow() # 更新音频状态 audio.processing_status = "completed" # 更新会议状态 (如果需要自动开始总结,可以在这里触发,或者由用户触发) meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first() if meeting: meeting.status = "transcribed" # distinct status before summarizing # Auto-trigger summarization for uploaded meetings if meeting.type == 'upload': try: # Find default LLM model llm_model = None if meeting.summary_model_id: llm_model = db.query(AIModel).filter( AIModel.model_id == meeting.summary_model_id, AIModel.status == 1 ).first() if not llm_model: llm_model = db.query(AIModel).filter( AIModel.model_type == 'llm', AIModel.is_default == 1, AIModel.status == 1 ).first() prompt_tmpl = None if meeting.summary_prompt_id: prompt_tmpl = db.query(PromptTemplate).filter( PromptTemplate.id == meeting.summary_prompt_id ).first() if not prompt_tmpl: prompt_tmpl = db.query(PromptTemplate).filter( PromptTemplate.status == 1 ).order_by(PromptTemplate.is_system.desc(), PromptTemplate.id.asc()).first() if llm_model and prompt_tmpl: logger.info(f"Auto-triggering summary for meeting {meeting.meeting_id}") # Create summarize task # We need to call create_summarize_task but it's an async static method # and we are inside an async static method. # However, create_summarize_task commits to DB, so we should be careful with session. # Since we are in the same session `db`, we can just call it. # But create_summarize_task takes `db` session. # Note: create_summarize_task commits. We should commit our changes first. await asyncio.to_thread(db.commit) new_sum_task = await MeetingService.create_summarize_task( db, meeting_id=meeting.meeting_id, prompt_id=prompt_tmpl.id, model_id=llm_model.model_id ) # Trigger background worker for summary # Since we are in a worker, we can just call process_summarize_task directly or via background task # But better to use the same mechanism (add to background tasks if possible, or just call it) # Here we don't have access to FastAPI BackgroundTasks object. # We can use asyncio.create_task to run it in background # import asyncio # Already imported at top level # Use a separate worker method that creates its own session # to avoid "Session is closed" error since the current session # will be closed when this function returns. asyncio.create_task(MeetingService.run_summarize_worker(new_sum_task.task_id)) # Update meeting status to summarizing meeting.status = "summarizing" else: logger.warning(f"Skipping auto-summary: No default LLM or Prompt found (LLM: {llm_model}, Prompt: {prompt_tmpl})") except Exception as sum_e: logger.error(f"Failed to auto-trigger summary: {sum_e}") await asyncio.to_thread(db.commit) logger.info(f"Task {task_id} transcription completed") except Exception as e: logger.error(f"Task {task_id} failed: {str(e)}") task.status = "failed" task.error_message = str(e) # Update audio status audio = db.query(MeetingAudio).filter( MeetingAudio.meeting_id == task.meeting_id ).order_by(MeetingAudio.upload_time.desc()).first() if audio: audio.processing_status = "failed" audio.error_message = str(e) # Update meeting status to failed so frontend knows to stop polling meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first() if meeting: meeting.status = "failed" await asyncio.to_thread(db.commit) @staticmethod async def _call_local_asr_api( audio_path: str, base_url: str = "http://localhost:3050", language: str = "auto", hotwords: Optional[List[dict]] = None, hotword_string: Optional[str] = None, progress_callback: Optional[Callable[[int, Optional[str]], Awaitable[None]]] = None ) -> list: """ Call local ASR API for transcription. Flow: Create Task -> Poll Status -> Get Result """ # import asyncio # Already imported at top level create_url = f"{base_url.rstrip('/')}/api/tasks/recognition" async with httpx.AsyncClient(timeout=30.0) as client: # 1. Create Task try: normalized_hotwords = {} if hotwords: for hw in hotwords: if isinstance(hw, dict): word = hw.get("word") weight = hw.get("weight") if word: normalized_hotwords[word] = int(weight) if weight is not None else 20 elif isinstance(hw, str): normalized_hotwords[hw] = 20 payload = { "file_path": audio_path, "language": language, "use_spk_id": True } if normalized_hotwords: payload["hotwords"] = normalized_hotwords elif hotword_string: payload["hotword"] = hotword_string response = await client.post(create_url, json=payload) if response.status_code >= 500: base_payload = { "file_path": audio_path, "language": language, "use_spk_id": True } response = await client.post(create_url, json=base_payload) if response.status_code == 422: alt_payload = { "audio_path": audio_path, "language": language, "use_spk_id": True } if normalized_hotwords: alt_payload["hotwords"] = normalized_hotwords elif hotword_string: alt_payload["hotword"] = hotword_string response = await client.post(create_url, json=alt_payload) if response.status_code >= 500: alt_base_payload = { "audio_path": audio_path, "language": language, "use_spk_id": True } response = await client.post(create_url, json=alt_base_payload) if response.status_code == 422 or response.status_code >= 500: with open(audio_path, "rb") as f: files = { "file": (Path(audio_path).name, f, "application/octet-stream") } data_fields = { "language": language, "use_spk_id": "true" } if normalized_hotwords: data_fields["hotwords"] = json.dumps(normalized_hotwords, ensure_ascii=False) elif hotword_string: data_fields["hotword"] = hotword_string response = await client.post(create_url, data=data_fields, files=files) if response.status_code >= 400: try: logger.error(f"ASR create response: {response.status_code} {response.text}") except Exception: pass response.raise_for_status() data = response.json() # Handle nested data structure {code: 200, data: {task_id: ...}} if "data" in data and isinstance(data["data"], dict) and "task_id" in data["data"]: task_id = data["data"]["task_id"] else: task_id = data.get("task_id") if not task_id: raise Exception("Failed to get task_id from ASR service") except Exception as e: logger.error(f"Failed to create ASR task: {e}") # Fallback for testing/mock if service is not running if "Connection refused" in str(e) or "ConnectError" in str(e): logger.warning("ASR Service not reachable, using mock data for testing") await asyncio.sleep(2) return [ {"timestamp": [[0, 2500]], "text": "This is a mock transcription (Local ASR unreachable).", "speaker": "System"}, {"timestamp": [[2500, 5000]], "text": "Please ensure the ASR service is running at localhost:3050.", "speaker": "System"} ] raise # 2. Poll Status status_url = f"{base_url.rstrip('/')}/api/tasks/{task_id}" # Increase timeout for polling poll_client = httpx.AsyncClient(timeout=10.0) max_retries = 600 # 20 minutes max wait (assuming 2s sleep) try: for _ in range(max_retries): try: res = await poll_client.get(status_url) res.raise_for_status() status_data = res.json() progress = 0 # Handle nested data structure if "data" in status_data and isinstance(status_data["data"], dict): inner_data = status_data["data"] status = inner_data.get("status") result = inner_data.get("result", {}) error_msg = inner_data.get("msg") or status_data.get("msg") # Try to get progress if "progress" in inner_data: raw_progress = inner_data["progress"] elif "percent" in inner_data: raw_progress = inner_data["percent"] elif "percentage" in inner_data: raw_progress = inner_data["percentage"] else: raw_progress = 0 # Update error_msg with message if available if "message" in inner_data and inner_data["message"]: error_msg = inner_data["message"] else: status = status_data.get("status") result = status_data.get("result", {}) error_msg = status_data.get("msg") # Try to get progress if "progress" in status_data: raw_progress = status_data["progress"] elif "percent" in status_data: raw_progress = status_data["percent"] elif "percentage" in status_data: raw_progress = status_data["percentage"] else: raw_progress = 0 # Update error_msg with message if available if "message" in status_data and status_data["message"]: error_msg = status_data["message"] # Handle float progress (0.0 - 1.0) or int (0-100) progress = 0 try: if isinstance(raw_progress, (int, float)): if 0.0 <= raw_progress <= 1.0 and isinstance(raw_progress, float): progress = int(raw_progress * 100) else: progress = int(raw_progress) elif isinstance(raw_progress, str): # Handle string like "45%" or "0.45" if raw_progress.endswith("%"): progress = int(float(raw_progress.strip("%"))) else: val = float(raw_progress) if 0.0 <= val <= 1.0: progress = int(val * 100) else: progress = int(val) except (ValueError, TypeError): progress = 0 # Log raw status for debugging print(f"[DEBUG] ASR Polling: status={status}, raw_progress={raw_progress}, progress={progress}, msg={error_msg}") # Update progress if callback provided if progress_callback: try: status_msg = error_msg if error_msg else f"正在转写中... {progress}%" print(f"[DEBUG] Invoking callback with progress={progress}, msg={status_msg}") await progress_callback(progress, status_msg) except Exception as e: logger.warning(f"Progress callback error: {e}") if status == "completed" or status == "success": # Return the segments list. if isinstance(result, dict) and "segments" in result: return result["segments"] elif isinstance(result, list): return result else: # Fallback or empty return [] elif status == "failed": raise Exception(f"ASR Task failed: {error_msg}") # Still processing await asyncio.sleep(2) except httpx.RequestError as e: logger.warning(f"Error polling ASR task {task_id}: {e}") await asyncio.sleep(2) finally: await poll_client.aclose() raise Exception("ASR Task timed out") @staticmethod async def create_summarize_task( db: Session, meeting_id: int, prompt_id: int, model_id: int, extra_prompt: str = "" ): # 1. 基础数据校验 meeting = db.query(Meeting).filter(Meeting.meeting_id == meeting_id).first() if not meeting: raise Exception("Meeting not found") # 2. 格式化会议转译内容 (作为 user_prompt 素材) segments = db.query(TranscriptSegment).filter( TranscriptSegment.meeting_id == meeting_id ).order_by(TranscriptSegment.start_time_ms).all() formatted_content = [] for s in segments: secs = int(s.start_time_ms // 1000) m, sc = divmod(secs, 60) timestamp = f"[{m:02d}:{sc:02d}]" speaker = s.speaker_tag or f"发言人{s.speaker_id or '?'}" formatted_content.append(f"{timestamp} {speaker}: {s.text_content}") meeting_text = "\n".join(formatted_content) # 组合最终 user_prompt (素材 + 用户的附加要求) user_prompt_content = f"### 会议转译内容 ###\n{meeting_text}" if extra_prompt: user_prompt_content += f"\n\n### 用户的额外指令 ###\n{extra_prompt}" # 3. 创建任务记录 (按照数据库实际字段) task_id = str(uuid.uuid4()) new_task = SummarizeTask( task_id=task_id, meeting_id=meeting_id, prompt_id=prompt_id, model_id=model_id, user_prompt=user_prompt_content, status="pending", progress=0, created_at=datetime.utcnow() ) db.add(new_task) # 更新会议状态为“总结中” meeting.status = "summarizing" await asyncio.to_thread(db.commit) # 4. 进入 Redis 队列 await redis_client.lpush("meeting:summarize:queue", task_id) return new_task @staticmethod async def run_summarize_worker(task_id: str): """ Worker entry point for summarize task that manages its own DB session. This is safe to call via asyncio.create_task from other contexts. """ db = SessionLocal() try: await MeetingService.process_summarize_task(db, task_id) except Exception as e: logger.error(f"Summarize worker (task {task_id}) failed: {e}") finally: db.close() @staticmethod async def process_summarize_task(db: Session, task_id: str): """ 后台 Worker 真实执行逻辑 """ task = db.query(SummarizeTask).filter(SummarizeTask.task_id == task_id).first() if not task: return try: task.status = "processing" task.progress = 15 await asyncio.to_thread(db.commit) # 1. 获取模型配置 model_config = db.query(AIModel).filter(AIModel.model_id == task.model_id).first() if not model_config: raise Exception("AI 模型配置不存在") # 2. 实时获取提示词模板内容 prompt_tmpl = db.query(PromptTemplate).filter(PromptTemplate.id == task.prompt_id).first() system_prompt = prompt_tmpl.content if prompt_tmpl else "请根据提供的会议转译内容生成准确的总结。" task.progress = 30 await asyncio.to_thread(db.commit) # 3. 构建消息结构 messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": task.user_prompt} ] # 解析模型参数 config_params = model_config.config or {} temperature = config_params.get("temperature", 0.7) top_p = config_params.get("top_p", 0.9) logger.info(f"Task {task_id}: Launching LLM request to {model_config.model_name}") task.progress = 50 await asyncio.to_thread(db.commit) # 4. 调用大模型服务 summary_result = await LLMService.chat_completion( api_key=model_config.api_key, base_url=model_config.base_url or "https://api.openai.com/v1", model_name=model_config.model_name, messages=messages, api_path=model_config.api_path or "/chat/completions", temperature=float(temperature), top_p=float(top_p) ) # 5. 任务完成,回写结果 task.result = summary_result task.status = "completed" task.progress = 100 task.completed_at = datetime.utcnow() # 同步更新会议主表摘要和状态 meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first() if meeting: meeting.summary = summary_result meeting.status = "completed" await asyncio.to_thread(db.commit) logger.info(f"Task {task_id} completed successfully") except Exception as e: logger.error(f"Task {task_id} execution error: {str(e)}") task.status = "failed" task.error_message = str(e) # Restore meeting status to transcribed (so user can retry summary) instead of draft meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first() if meeting: meeting.status = "transcribed" await asyncio.to_thread(db.commit) @staticmethod def get_task_status(db: Session, task_id: str): task = db.query(SummarizeTask).filter(SummarizeTask.task_id == task_id).first() if task: return task return db.query(TranscriptTask).filter(TranscriptTask.task_id == task_id).first()