from fastapi import APIRouter, UploadFile, File, Form, Depends, HTTPException, BackgroundTasks from app.core.database import get_db_connection from app.core.config import BASE_DIR, AUDIO_DIR, TEMP_UPLOAD_DIR from app.core.auth import get_current_user from app.core.response import create_api_response from app.services.async_transcription_service import AsyncTranscriptionService from app.services.async_meeting_service import async_meeting_service from app.services.audio_service import handle_audio_upload from pydantic import BaseModel from typing import Optional, List from datetime import datetime, timedelta import os import uuid import shutil import json import re from pathlib import Path router = APIRouter() transcription_service = AsyncTranscriptionService() # 临时上传目录 - 放在项目目录下 TEMP_UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # 配置常量 MAX_CHUNK_SIZE = 2 * 1024 * 1024 # 2MB per chunk MAX_TOTAL_SIZE = 500 * 1024 * 1024 # 500MB total MAX_DURATION = 3600 # 1 hour max recording SESSION_EXPIRE_HOURS = 1 # 会话1小时后过期 # 支持的音频格式 SUPPORTED_MIME_TYPES = { 'audio/webm;codecs=opus': '.webm', 'audio/webm': '.webm', 'audio/ogg;codecs=opus': '.ogg', 'audio/mp4': '.m4a', 'audio/mpeg': '.mp3' } # ============ Pydantic Models ============ class InitUploadRequest(BaseModel): meeting_id: int mime_type: str estimated_duration: Optional[int] = None # 预计时长(秒) class CompleteUploadRequest(BaseModel): session_id: str meeting_id: int total_chunks: int mime_type: str auto_transcribe: bool = True auto_summarize: bool = True prompt_id: Optional[int] = None # 提示词模版ID(可选) class CancelUploadRequest(BaseModel): session_id: str # ============ 工具函数 ============ def validate_session_id(session_id: str) -> str: """验证session_id格式,防止路径注入攻击""" if not re.match(r'^sess_\d+_[a-zA-Z0-9]+$', session_id): raise ValueError("Invalid session_id format") return session_id def validate_mime_type(mime_type: str) -> str: """验证MIME类型是否支持""" if mime_type not in SUPPORTED_MIME_TYPES: raise ValueError(f"Unsupported MIME type: {mime_type}") return SUPPORTED_MIME_TYPES[mime_type] def get_session_dir(session_id: str) -> Path: """获取会话目录路径""" validate_session_id(session_id) return TEMP_UPLOAD_DIR / session_id def get_session_metadata_path(session_id: str) -> Path: """获取会话metadata文件路径""" return get_session_dir(session_id) / "metadata.json" def create_session_metadata(session_id: str, meeting_id: int, mime_type: str, user_id: int) -> dict: """创建会话metadata""" now = datetime.now() expires_at = now + timedelta(hours=SESSION_EXPIRE_HOURS) metadata = { "session_id": session_id, "meeting_id": meeting_id, "user_id": user_id, "mime_type": mime_type, "total_chunks": None, "received_chunks": [], "created_at": now.isoformat(), "expires_at": expires_at.isoformat() } return metadata def save_session_metadata(session_id: str, metadata: dict): """保存会话metadata""" metadata_path = get_session_metadata_path(session_id) with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(metadata, f, ensure_ascii=False, indent=2) def load_session_metadata(session_id: str) -> dict: """加载会话metadata""" metadata_path = get_session_metadata_path(session_id) if not metadata_path.exists(): raise FileNotFoundError(f"Session {session_id} not found") with open(metadata_path, 'r', encoding='utf-8') as f: return json.load(f) def update_session_chunks(session_id: str, chunk_index: int): """更新已接收的分片列表""" metadata = load_session_metadata(session_id) if chunk_index not in metadata['received_chunks']: metadata['received_chunks'].append(chunk_index) metadata['received_chunks'].sort() save_session_metadata(session_id, metadata) def get_session_total_size(session_id: str) -> int: """获取会话已上传的总大小""" session_dir = get_session_dir(session_id) total_size = 0 if session_dir.exists(): for chunk_file in session_dir.glob("chunk_*.webm"): total_size += chunk_file.stat().st_size return total_size def merge_audio_chunks(session_id: str, meeting_id: int, total_chunks: int, mime_type: str) -> str: """合并音频分片""" session_dir = get_session_dir(session_id) # 1. 验证分片完整性 missing = [] for i in range(total_chunks): chunk_path = session_dir / f"chunk_{i:04d}.webm" if not chunk_path.exists(): missing.append(i) if missing: raise ValueError(f"Missing chunks: {missing}") # 2. 创建输出目录 meeting_audio_dir = AUDIO_DIR / str(meeting_id) meeting_audio_dir.mkdir(parents=True, exist_ok=True) # 3. 生成输出文件名 file_extension = validate_mime_type(mime_type) output_filename = f"{uuid.uuid4()}{file_extension}" output_path = meeting_audio_dir / output_filename # 4. 按序合并分片 with open(output_path, 'wb') as outfile: for i in range(total_chunks): chunk_path = session_dir / f"chunk_{i:04d}.webm" with open(chunk_path, 'rb') as infile: outfile.write(infile.read()) # 5. 清理临时文件 shutil.rmtree(session_dir) # 返回相对路径 return f"/{output_path.relative_to(BASE_DIR)}" def cleanup_session(session_id: str): """清理会话文件""" session_dir = get_session_dir(session_id) if session_dir.exists(): shutil.rmtree(session_dir) def cleanup_expired_sessions(): """清理过期的会话(可以由定时任务调用)""" now = datetime.now() cleaned_count = 0 if not TEMP_UPLOAD_DIR.exists(): return cleaned_count for session_dir in TEMP_UPLOAD_DIR.iterdir(): if not session_dir.is_dir(): continue metadata_path = session_dir / "metadata.json" if metadata_path.exists(): try: with open(metadata_path, 'r') as f: metadata = json.load(f) expires_at = datetime.fromisoformat(metadata['expires_at']) if now > expires_at: shutil.rmtree(session_dir) cleaned_count += 1 print(f"Cleaned up expired session: {session_dir.name}") except Exception as e: print(f"Error cleaning up session {session_dir.name}: {e}") return cleaned_count # ============ API Endpoints ============ @router.post("/audio/stream/init") async def init_upload_session( request: InitUploadRequest, current_user: dict = Depends(get_current_user) ): """ 初始化音频流式上传会话 创建临时目录,生成session_id,返回给客户端用于后续分片上传 """ try: # 1. 验证会议是否存在且属于当前用户 with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) cursor.execute( "SELECT user_id FROM meetings WHERE meeting_id = %s", (request.meeting_id,) ) meeting = cursor.fetchone() if not meeting: return create_api_response( code="404", message="会议不存在" ) if meeting['user_id'] != current_user['user_id']: return create_api_response( code="403", message="无权限操作此会议" ) # 2. 验证MIME类型 try: validate_mime_type(request.mime_type) except ValueError as e: return create_api_response( code="400", message=str(e) ) # 3. 生成session_id timestamp = int(datetime.now().timestamp() * 1000) random_str = uuid.uuid4().hex[:8] session_id = f"sess_{timestamp}_{random_str}" # 4. 创建会话目录 session_dir = get_session_dir(session_id) session_dir.mkdir(parents=True, exist_ok=True) # 5. 创建并保存metadata metadata = create_session_metadata( session_id=session_id, meeting_id=request.meeting_id, mime_type=request.mime_type, user_id=current_user['user_id'] ) save_session_metadata(session_id, metadata) # 6. 清理过期会话 cleanup_expired_sessions() return create_api_response( code="200", message="上传会话初始化成功", data={ "session_id": session_id, "chunk_size": MAX_CHUNK_SIZE, "max_chunks": 1000 } ) except Exception as e: print(f"Error initializing upload session: {e}") return create_api_response( code="500", message=f"初始化上传会话失败: {str(e)}" ) @router.post("/audio/stream/chunk") async def upload_audio_chunk( session_id: str = Form(...), chunk_index: int = Form(...), chunk: UploadFile = File(...), current_user: dict = Depends(get_current_user) ): """ 上传音频分片 接收并保存音频分片文件 """ try: # 1. 验证session_id格式 try: validate_session_id(session_id) except ValueError: return create_api_response( code="400", message="Invalid session_id format" ) # 2. 加载session metadata try: metadata = load_session_metadata(session_id) except FileNotFoundError: return create_api_response( code="404", message="Session not found" ) # 3. 验证会话所有权 if metadata['user_id'] != current_user['user_id']: return create_api_response( code="403", message="Permission denied" ) # 4. 验证分片大小 chunk_data = await chunk.read() if len(chunk_data) > MAX_CHUNK_SIZE: return create_api_response( code="400", message=f"Chunk size exceeds {MAX_CHUNK_SIZE // (1024*1024)}MB limit" ) # 5. 验证总大小 session_total = get_session_total_size(session_id) if session_total + len(chunk_data) > MAX_TOTAL_SIZE: return create_api_response( code="400", message=f"Total size exceeds {MAX_TOTAL_SIZE // (1024*1024)}MB limit" ) # 6. 保存分片文件 session_dir = get_session_dir(session_id) chunk_path = session_dir / f"chunk_{chunk_index:04d}.webm" with open(chunk_path, 'wb') as f: f.write(chunk_data) # 7. 更新metadata update_session_chunks(session_id, chunk_index) # 8. 获取已接收分片总数 metadata = load_session_metadata(session_id) total_received = len(metadata['received_chunks']) return create_api_response( code="200", message="分片上传成功", data={ "session_id": session_id, "chunk_index": chunk_index, "received": True, "total_received": total_received } ) except Exception as e: print(f"Error uploading chunk: {e}") return create_api_response( code="500", message=f"分片上传失败: {str(e)}", data={ "session_id": session_id, "chunk_index": chunk_index, "should_retry": True } ) @router.post("/audio/stream/complete") async def complete_upload( request: CompleteUploadRequest, background_tasks: BackgroundTasks, current_user: dict = Depends(get_current_user) ): """ 完成上传并合并分片 验证分片完整性,合并所有分片,保存最终音频文件,可选启动转录任务和自动总结 """ try: # 1. 验证session_id try: validate_session_id(request.session_id) except ValueError: return create_api_response( code="400", message="Invalid session_id format" ) # 2. 加载session metadata try: metadata = load_session_metadata(request.session_id) except FileNotFoundError: return create_api_response( code="404", message="Session not found" ) # 3. 验证会话所有权 if metadata['user_id'] != current_user['user_id']: return create_api_response( code="403", message="Permission denied" ) # 4. 验证会议ID一致性 if metadata['meeting_id'] != request.meeting_id: return create_api_response( code="400", message="Meeting ID mismatch" ) # 5. 合并音频分片 try: file_path = merge_audio_chunks( session_id=request.session_id, meeting_id=request.meeting_id, total_chunks=request.total_chunks, mime_type=request.mime_type ) except ValueError as e: # 分片不完整 return create_api_response( code="500", message=f"音频合并失败:{str(e)}", data={ "should_retry": True } ) # 6. 获取文件信息 full_path = BASE_DIR / file_path.lstrip('/') file_size = full_path.stat().st_size file_name = full_path.name # 7. 调用 audio_service 处理文件(数据库更新、启动转录和总结) result = handle_audio_upload( file_path=file_path, file_name=file_name, file_size=file_size, meeting_id=request.meeting_id, current_user=current_user, auto_summarize=request.auto_summarize, background_tasks=background_tasks, prompt_id=request.prompt_id # 传递提示词模版ID ) # 如果处理失败,返回错误 if not result["success"]: return result["response"] # 8. 返回成功响应 transcription_task_id = result["transcription_task_id"] message_suffix = "" if transcription_task_id: if request.auto_summarize: message_suffix = ",正在进行转录和总结" else: message_suffix = ",正在进行转录" return create_api_response( code="200", message="音频上传完成" + message_suffix, data={ "meeting_id": request.meeting_id, "file_path": file_path, "file_size": file_size, "duration": None, # 可以通过ffprobe获取,但不是必需的 "task_id": transcription_task_id, "task_status": "pending" if transcription_task_id else None, "auto_summarize": request.auto_summarize } ) except Exception as e: print(f"Error completing upload: {e}") return create_api_response( code="500", message=f"完成上传失败: {str(e)}" ) @router.delete("/audio/stream/cancel") async def cancel_upload( request: CancelUploadRequest, current_user: dict = Depends(get_current_user) ): """ 取消上传会话 清理会话临时文件和目录 """ try: # 1. 验证session_id try: validate_session_id(request.session_id) except ValueError: return create_api_response( code="400", message="Invalid session_id format" ) # 2. 加载session metadata(验证所有权) try: metadata = load_session_metadata(request.session_id) # 验证会话所有权 if metadata['user_id'] != current_user['user_id']: return create_api_response( code="403", message="Permission denied" ) except FileNotFoundError: # 会话不存在,视为已清理 return create_api_response( code="200", message="上传会话已取消", data={ "session_id": request.session_id, "cleaned": True } ) # 3. 清理会话文件 cleanup_session(request.session_id) return create_api_response( code="200", message="上传会话已取消", data={ "session_id": request.session_id, "cleaned": True } ) except Exception as e: print(f"Error canceling upload: {e}") return create_api_response( code="500", message=f"取消上传失败: {str(e)}" )