569 lines
17 KiB
Python
569 lines
17 KiB
Python
from fastapi import APIRouter, UploadFile, File, Form, Depends, HTTPException, BackgroundTasks
|
||
from app.core.database import get_db_connection
|
||
from app.core.config import BASE_DIR, AUDIO_DIR, TEMP_UPLOAD_DIR
|
||
from app.core.auth import get_current_user
|
||
from app.core.response import create_api_response
|
||
from app.services.async_transcription_service import AsyncTranscriptionService
|
||
from app.services.async_meeting_service import async_meeting_service
|
||
from app.services.audio_service import handle_audio_upload
|
||
from pydantic import BaseModel
|
||
from typing import Optional, List
|
||
from datetime import datetime, timedelta
|
||
import os
|
||
import uuid
|
||
import shutil
|
||
import json
|
||
import re
|
||
from pathlib import Path
|
||
|
||
router = APIRouter()
|
||
transcription_service = AsyncTranscriptionService()
|
||
|
||
# 临时上传目录 - 放在项目目录下
|
||
TEMP_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 配置常量
|
||
MAX_CHUNK_SIZE = 2 * 1024 * 1024 # 2MB per chunk
|
||
MAX_TOTAL_SIZE = 500 * 1024 * 1024 # 500MB total
|
||
MAX_DURATION = 3600 # 1 hour max recording
|
||
SESSION_EXPIRE_HOURS = 1 # 会话1小时后过期
|
||
|
||
# 支持的音频格式
|
||
SUPPORTED_MIME_TYPES = {
|
||
'audio/webm;codecs=opus': '.webm',
|
||
'audio/webm': '.webm',
|
||
'audio/ogg;codecs=opus': '.ogg',
|
||
'audio/mp4': '.m4a',
|
||
'audio/mpeg': '.mp3'
|
||
}
|
||
|
||
|
||
# ============ Pydantic Models ============
|
||
|
||
class InitUploadRequest(BaseModel):
|
||
meeting_id: int
|
||
mime_type: str
|
||
estimated_duration: Optional[int] = None # 预计时长(秒)
|
||
|
||
|
||
class CompleteUploadRequest(BaseModel):
|
||
session_id: str
|
||
meeting_id: int
|
||
total_chunks: int
|
||
mime_type: str
|
||
auto_transcribe: bool = True
|
||
auto_summarize: bool = True
|
||
prompt_id: Optional[int] = None # 提示词模版ID(可选)
|
||
|
||
|
||
class CancelUploadRequest(BaseModel):
|
||
session_id: str
|
||
|
||
|
||
# ============ 工具函数 ============
|
||
|
||
def validate_session_id(session_id: str) -> str:
|
||
"""验证session_id格式,防止路径注入攻击"""
|
||
if not re.match(r'^sess_\d+_[a-zA-Z0-9]+$', session_id):
|
||
raise ValueError("Invalid session_id format")
|
||
return session_id
|
||
|
||
|
||
def validate_mime_type(mime_type: str) -> str:
|
||
"""验证MIME类型是否支持"""
|
||
if mime_type not in SUPPORTED_MIME_TYPES:
|
||
raise ValueError(f"Unsupported MIME type: {mime_type}")
|
||
return SUPPORTED_MIME_TYPES[mime_type]
|
||
|
||
|
||
def get_session_dir(session_id: str) -> Path:
|
||
"""获取会话目录路径"""
|
||
validate_session_id(session_id)
|
||
return TEMP_UPLOAD_DIR / session_id
|
||
|
||
|
||
def get_session_metadata_path(session_id: str) -> Path:
|
||
"""获取会话metadata文件路径"""
|
||
return get_session_dir(session_id) / "metadata.json"
|
||
|
||
|
||
def create_session_metadata(session_id: str, meeting_id: int, mime_type: str, user_id: int) -> dict:
|
||
"""创建会话metadata"""
|
||
now = datetime.now()
|
||
expires_at = now + timedelta(hours=SESSION_EXPIRE_HOURS)
|
||
|
||
metadata = {
|
||
"session_id": session_id,
|
||
"meeting_id": meeting_id,
|
||
"user_id": user_id,
|
||
"mime_type": mime_type,
|
||
"total_chunks": None,
|
||
"received_chunks": [],
|
||
"created_at": now.isoformat(),
|
||
"expires_at": expires_at.isoformat()
|
||
}
|
||
|
||
return metadata
|
||
|
||
|
||
def save_session_metadata(session_id: str, metadata: dict):
|
||
"""保存会话metadata"""
|
||
metadata_path = get_session_metadata_path(session_id)
|
||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
def load_session_metadata(session_id: str) -> dict:
|
||
"""加载会话metadata"""
|
||
metadata_path = get_session_metadata_path(session_id)
|
||
if not metadata_path.exists():
|
||
raise FileNotFoundError(f"Session {session_id} not found")
|
||
|
||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
|
||
|
||
def update_session_chunks(session_id: str, chunk_index: int):
|
||
"""更新已接收的分片列表"""
|
||
metadata = load_session_metadata(session_id)
|
||
|
||
if chunk_index not in metadata['received_chunks']:
|
||
metadata['received_chunks'].append(chunk_index)
|
||
metadata['received_chunks'].sort()
|
||
|
||
save_session_metadata(session_id, metadata)
|
||
|
||
|
||
def get_session_total_size(session_id: str) -> int:
|
||
"""获取会话已上传的总大小"""
|
||
session_dir = get_session_dir(session_id)
|
||
total_size = 0
|
||
|
||
if session_dir.exists():
|
||
for chunk_file in session_dir.glob("chunk_*.webm"):
|
||
total_size += chunk_file.stat().st_size
|
||
|
||
return total_size
|
||
|
||
|
||
def merge_audio_chunks(session_id: str, meeting_id: int, total_chunks: int, mime_type: str) -> str:
|
||
"""合并音频分片"""
|
||
session_dir = get_session_dir(session_id)
|
||
|
||
# 1. 验证分片完整性
|
||
missing = []
|
||
for i in range(total_chunks):
|
||
chunk_path = session_dir / f"chunk_{i:04d}.webm"
|
||
if not chunk_path.exists():
|
||
missing.append(i)
|
||
|
||
if missing:
|
||
raise ValueError(f"Missing chunks: {missing}")
|
||
|
||
# 2. 创建输出目录
|
||
meeting_audio_dir = AUDIO_DIR / str(meeting_id)
|
||
meeting_audio_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 3. 生成输出文件名
|
||
file_extension = validate_mime_type(mime_type)
|
||
output_filename = f"{uuid.uuid4()}{file_extension}"
|
||
output_path = meeting_audio_dir / output_filename
|
||
|
||
# 4. 按序合并分片
|
||
with open(output_path, 'wb') as outfile:
|
||
for i in range(total_chunks):
|
||
chunk_path = session_dir / f"chunk_{i:04d}.webm"
|
||
with open(chunk_path, 'rb') as infile:
|
||
outfile.write(infile.read())
|
||
|
||
# 5. 清理临时文件
|
||
shutil.rmtree(session_dir)
|
||
|
||
# 返回相对路径
|
||
return f"/{output_path.relative_to(BASE_DIR)}"
|
||
|
||
|
||
def cleanup_session(session_id: str):
|
||
"""清理会话文件"""
|
||
session_dir = get_session_dir(session_id)
|
||
if session_dir.exists():
|
||
shutil.rmtree(session_dir)
|
||
|
||
|
||
def cleanup_expired_sessions():
|
||
"""清理过期的会话(可以由定时任务调用)"""
|
||
now = datetime.now()
|
||
cleaned_count = 0
|
||
|
||
if not TEMP_UPLOAD_DIR.exists():
|
||
return cleaned_count
|
||
|
||
for session_dir in TEMP_UPLOAD_DIR.iterdir():
|
||
if not session_dir.is_dir():
|
||
continue
|
||
|
||
metadata_path = session_dir / "metadata.json"
|
||
if metadata_path.exists():
|
||
try:
|
||
with open(metadata_path, 'r') as f:
|
||
metadata = json.load(f)
|
||
|
||
expires_at = datetime.fromisoformat(metadata['expires_at'])
|
||
if now > expires_at:
|
||
shutil.rmtree(session_dir)
|
||
cleaned_count += 1
|
||
print(f"Cleaned up expired session: {session_dir.name}")
|
||
except Exception as e:
|
||
print(f"Error cleaning up session {session_dir.name}: {e}")
|
||
|
||
return cleaned_count
|
||
|
||
|
||
# ============ API Endpoints ============
|
||
|
||
@router.post("/audio/stream/init")
|
||
async def init_upload_session(
|
||
request: InitUploadRequest,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""
|
||
初始化音频流式上传会话
|
||
|
||
创建临时目录,生成session_id,返回给客户端用于后续分片上传
|
||
"""
|
||
try:
|
||
# 1. 验证会议是否存在且属于当前用户
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
cursor.execute(
|
||
"SELECT user_id FROM meetings WHERE meeting_id = %s",
|
||
(request.meeting_id,)
|
||
)
|
||
meeting = cursor.fetchone()
|
||
|
||
if not meeting:
|
||
return create_api_response(
|
||
code="404",
|
||
message="会议不存在"
|
||
)
|
||
|
||
if meeting['user_id'] != current_user['user_id']:
|
||
return create_api_response(
|
||
code="403",
|
||
message="无权限操作此会议"
|
||
)
|
||
|
||
# 2. 验证MIME类型
|
||
try:
|
||
validate_mime_type(request.mime_type)
|
||
except ValueError as e:
|
||
return create_api_response(
|
||
code="400",
|
||
message=str(e)
|
||
)
|
||
|
||
# 3. 生成session_id
|
||
timestamp = int(datetime.now().timestamp() * 1000)
|
||
random_str = uuid.uuid4().hex[:8]
|
||
session_id = f"sess_{timestamp}_{random_str}"
|
||
|
||
# 4. 创建会话目录
|
||
session_dir = get_session_dir(session_id)
|
||
session_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 5. 创建并保存metadata
|
||
metadata = create_session_metadata(
|
||
session_id=session_id,
|
||
meeting_id=request.meeting_id,
|
||
mime_type=request.mime_type,
|
||
user_id=current_user['user_id']
|
||
)
|
||
save_session_metadata(session_id, metadata)
|
||
|
||
# 6. 清理过期会话
|
||
cleanup_expired_sessions()
|
||
|
||
return create_api_response(
|
||
code="200",
|
||
message="上传会话初始化成功",
|
||
data={
|
||
"session_id": session_id,
|
||
"chunk_size": MAX_CHUNK_SIZE,
|
||
"max_chunks": 1000
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"Error initializing upload session: {e}")
|
||
return create_api_response(
|
||
code="500",
|
||
message=f"初始化上传会话失败: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.post("/audio/stream/chunk")
|
||
async def upload_audio_chunk(
|
||
session_id: str = Form(...),
|
||
chunk_index: int = Form(...),
|
||
chunk: UploadFile = File(...),
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""
|
||
上传音频分片
|
||
|
||
接收并保存音频分片文件
|
||
"""
|
||
try:
|
||
# 1. 验证session_id格式
|
||
try:
|
||
validate_session_id(session_id)
|
||
except ValueError:
|
||
return create_api_response(
|
||
code="400",
|
||
message="Invalid session_id format"
|
||
)
|
||
|
||
# 2. 加载session metadata
|
||
try:
|
||
metadata = load_session_metadata(session_id)
|
||
except FileNotFoundError:
|
||
return create_api_response(
|
||
code="404",
|
||
message="Session not found"
|
||
)
|
||
|
||
# 3. 验证会话所有权
|
||
if metadata['user_id'] != current_user['user_id']:
|
||
return create_api_response(
|
||
code="403",
|
||
message="Permission denied"
|
||
)
|
||
|
||
# 4. 验证分片大小
|
||
chunk_data = await chunk.read()
|
||
if len(chunk_data) > MAX_CHUNK_SIZE:
|
||
return create_api_response(
|
||
code="400",
|
||
message=f"Chunk size exceeds {MAX_CHUNK_SIZE // (1024*1024)}MB limit"
|
||
)
|
||
|
||
# 5. 验证总大小
|
||
session_total = get_session_total_size(session_id)
|
||
if session_total + len(chunk_data) > MAX_TOTAL_SIZE:
|
||
return create_api_response(
|
||
code="400",
|
||
message=f"Total size exceeds {MAX_TOTAL_SIZE // (1024*1024)}MB limit"
|
||
)
|
||
|
||
# 6. 保存分片文件
|
||
session_dir = get_session_dir(session_id)
|
||
chunk_path = session_dir / f"chunk_{chunk_index:04d}.webm"
|
||
|
||
with open(chunk_path, 'wb') as f:
|
||
f.write(chunk_data)
|
||
|
||
# 7. 更新metadata
|
||
update_session_chunks(session_id, chunk_index)
|
||
|
||
# 8. 获取已接收分片总数
|
||
metadata = load_session_metadata(session_id)
|
||
total_received = len(metadata['received_chunks'])
|
||
|
||
return create_api_response(
|
||
code="200",
|
||
message="分片上传成功",
|
||
data={
|
||
"session_id": session_id,
|
||
"chunk_index": chunk_index,
|
||
"received": True,
|
||
"total_received": total_received
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"Error uploading chunk: {e}")
|
||
return create_api_response(
|
||
code="500",
|
||
message=f"分片上传失败: {str(e)}",
|
||
data={
|
||
"session_id": session_id,
|
||
"chunk_index": chunk_index,
|
||
"should_retry": True
|
||
}
|
||
)
|
||
|
||
|
||
@router.post("/audio/stream/complete")
|
||
async def complete_upload(
|
||
request: CompleteUploadRequest,
|
||
background_tasks: BackgroundTasks,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""
|
||
完成上传并合并分片
|
||
|
||
验证分片完整性,合并所有分片,保存最终音频文件,可选启动转录任务和自动总结
|
||
"""
|
||
try:
|
||
# 1. 验证session_id
|
||
try:
|
||
validate_session_id(request.session_id)
|
||
except ValueError:
|
||
return create_api_response(
|
||
code="400",
|
||
message="Invalid session_id format"
|
||
)
|
||
|
||
# 2. 加载session metadata
|
||
try:
|
||
metadata = load_session_metadata(request.session_id)
|
||
except FileNotFoundError:
|
||
return create_api_response(
|
||
code="404",
|
||
message="Session not found"
|
||
)
|
||
|
||
# 3. 验证会话所有权
|
||
if metadata['user_id'] != current_user['user_id']:
|
||
return create_api_response(
|
||
code="403",
|
||
message="Permission denied"
|
||
)
|
||
|
||
# 4. 验证会议ID一致性
|
||
if metadata['meeting_id'] != request.meeting_id:
|
||
return create_api_response(
|
||
code="400",
|
||
message="Meeting ID mismatch"
|
||
)
|
||
|
||
# 5. 合并音频分片
|
||
try:
|
||
file_path = merge_audio_chunks(
|
||
session_id=request.session_id,
|
||
meeting_id=request.meeting_id,
|
||
total_chunks=request.total_chunks,
|
||
mime_type=request.mime_type
|
||
)
|
||
except ValueError as e:
|
||
# 分片不完整
|
||
return create_api_response(
|
||
code="500",
|
||
message=f"音频合并失败:{str(e)}",
|
||
data={
|
||
"should_retry": True
|
||
}
|
||
)
|
||
|
||
# 6. 获取文件信息
|
||
full_path = BASE_DIR / file_path.lstrip('/')
|
||
file_size = full_path.stat().st_size
|
||
file_name = full_path.name
|
||
|
||
# 7. 调用 audio_service 处理文件(数据库更新、启动转录和总结)
|
||
result = handle_audio_upload(
|
||
file_path=file_path,
|
||
file_name=file_name,
|
||
file_size=file_size,
|
||
meeting_id=request.meeting_id,
|
||
current_user=current_user,
|
||
auto_summarize=request.auto_summarize,
|
||
background_tasks=background_tasks,
|
||
prompt_id=request.prompt_id # 传递提示词模版ID
|
||
)
|
||
|
||
# 如果处理失败,返回错误
|
||
if not result["success"]:
|
||
return result["response"]
|
||
|
||
# 8. 返回成功响应
|
||
transcription_task_id = result["transcription_task_id"]
|
||
message_suffix = ""
|
||
if transcription_task_id:
|
||
if request.auto_summarize:
|
||
message_suffix = ",正在进行转录和总结"
|
||
else:
|
||
message_suffix = ",正在进行转录"
|
||
|
||
return create_api_response(
|
||
code="200",
|
||
message="音频上传完成" + message_suffix,
|
||
data={
|
||
"meeting_id": request.meeting_id,
|
||
"file_path": file_path,
|
||
"file_size": file_size,
|
||
"duration": None, # 可以通过ffprobe获取,但不是必需的
|
||
"task_id": transcription_task_id,
|
||
"task_status": "pending" if transcription_task_id else None,
|
||
"auto_summarize": request.auto_summarize
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"Error completing upload: {e}")
|
||
return create_api_response(
|
||
code="500",
|
||
message=f"完成上传失败: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.delete("/audio/stream/cancel")
|
||
async def cancel_upload(
|
||
request: CancelUploadRequest,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""
|
||
取消上传会话
|
||
|
||
清理会话临时文件和目录
|
||
"""
|
||
try:
|
||
# 1. 验证session_id
|
||
try:
|
||
validate_session_id(request.session_id)
|
||
except ValueError:
|
||
return create_api_response(
|
||
code="400",
|
||
message="Invalid session_id format"
|
||
)
|
||
|
||
# 2. 加载session metadata(验证所有权)
|
||
try:
|
||
metadata = load_session_metadata(request.session_id)
|
||
|
||
# 验证会话所有权
|
||
if metadata['user_id'] != current_user['user_id']:
|
||
return create_api_response(
|
||
code="403",
|
||
message="Permission denied"
|
||
)
|
||
except FileNotFoundError:
|
||
# 会话不存在,视为已清理
|
||
return create_api_response(
|
||
code="200",
|
||
message="上传会话已取消",
|
||
data={
|
||
"session_id": request.session_id,
|
||
"cleaned": True
|
||
}
|
||
)
|
||
|
||
# 3. 清理会话文件
|
||
cleanup_session(request.session_id)
|
||
|
||
return create_api_response(
|
||
code="200",
|
||
message="上传会话已取消",
|
||
data={
|
||
"session_id": request.session_id,
|
||
"cleaned": True
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"Error canceling upload: {e}")
|
||
return create_api_response(
|
||
code="500",
|
||
message=f"取消上传失败: {str(e)}"
|
||
)
|