imetting_backend/app/api/endpoints/audio.py

569 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from fastapi import APIRouter, UploadFile, File, Form, Depends, HTTPException, BackgroundTasks
from app.core.database import get_db_connection
from app.core.config import BASE_DIR, AUDIO_DIR, TEMP_UPLOAD_DIR
from app.core.auth import get_current_user
from app.core.response import create_api_response
from app.services.async_transcription_service import AsyncTranscriptionService
from app.services.async_meeting_service import async_meeting_service
from app.services.audio_service import handle_audio_upload
from pydantic import BaseModel
from typing import Optional, List
from datetime import datetime, timedelta
import os
import uuid
import shutil
import json
import re
from pathlib import Path
router = APIRouter()
transcription_service = AsyncTranscriptionService()
# 临时上传目录 - 放在项目目录下
TEMP_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
# 配置常量
MAX_CHUNK_SIZE = 2 * 1024 * 1024 # 2MB per chunk
MAX_TOTAL_SIZE = 500 * 1024 * 1024 # 500MB total
MAX_DURATION = 3600 # 1 hour max recording
SESSION_EXPIRE_HOURS = 1 # 会话1小时后过期
# 支持的音频格式
SUPPORTED_MIME_TYPES = {
'audio/webm;codecs=opus': '.webm',
'audio/webm': '.webm',
'audio/ogg;codecs=opus': '.ogg',
'audio/mp4': '.m4a',
'audio/mpeg': '.mp3'
}
# ============ Pydantic Models ============
class InitUploadRequest(BaseModel):
meeting_id: int
mime_type: str
estimated_duration: Optional[int] = None # 预计时长(秒)
class CompleteUploadRequest(BaseModel):
session_id: str
meeting_id: int
total_chunks: int
mime_type: str
auto_transcribe: bool = True
auto_summarize: bool = True
prompt_id: Optional[int] = None # 提示词模版ID可选
class CancelUploadRequest(BaseModel):
session_id: str
# ============ 工具函数 ============
def validate_session_id(session_id: str) -> str:
"""验证session_id格式防止路径注入攻击"""
if not re.match(r'^sess_\d+_[a-zA-Z0-9]+$', session_id):
raise ValueError("Invalid session_id format")
return session_id
def validate_mime_type(mime_type: str) -> str:
"""验证MIME类型是否支持"""
if mime_type not in SUPPORTED_MIME_TYPES:
raise ValueError(f"Unsupported MIME type: {mime_type}")
return SUPPORTED_MIME_TYPES[mime_type]
def get_session_dir(session_id: str) -> Path:
"""获取会话目录路径"""
validate_session_id(session_id)
return TEMP_UPLOAD_DIR / session_id
def get_session_metadata_path(session_id: str) -> Path:
"""获取会话metadata文件路径"""
return get_session_dir(session_id) / "metadata.json"
def create_session_metadata(session_id: str, meeting_id: int, mime_type: str, user_id: int) -> dict:
"""创建会话metadata"""
now = datetime.now()
expires_at = now + timedelta(hours=SESSION_EXPIRE_HOURS)
metadata = {
"session_id": session_id,
"meeting_id": meeting_id,
"user_id": user_id,
"mime_type": mime_type,
"total_chunks": None,
"received_chunks": [],
"created_at": now.isoformat(),
"expires_at": expires_at.isoformat()
}
return metadata
def save_session_metadata(session_id: str, metadata: dict):
"""保存会话metadata"""
metadata_path = get_session_metadata_path(session_id)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
def load_session_metadata(session_id: str) -> dict:
"""加载会话metadata"""
metadata_path = get_session_metadata_path(session_id)
if not metadata_path.exists():
raise FileNotFoundError(f"Session {session_id} not found")
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
def update_session_chunks(session_id: str, chunk_index: int):
"""更新已接收的分片列表"""
metadata = load_session_metadata(session_id)
if chunk_index not in metadata['received_chunks']:
metadata['received_chunks'].append(chunk_index)
metadata['received_chunks'].sort()
save_session_metadata(session_id, metadata)
def get_session_total_size(session_id: str) -> int:
"""获取会话已上传的总大小"""
session_dir = get_session_dir(session_id)
total_size = 0
if session_dir.exists():
for chunk_file in session_dir.glob("chunk_*.webm"):
total_size += chunk_file.stat().st_size
return total_size
def merge_audio_chunks(session_id: str, meeting_id: int, total_chunks: int, mime_type: str) -> str:
"""合并音频分片"""
session_dir = get_session_dir(session_id)
# 1. 验证分片完整性
missing = []
for i in range(total_chunks):
chunk_path = session_dir / f"chunk_{i:04d}.webm"
if not chunk_path.exists():
missing.append(i)
if missing:
raise ValueError(f"Missing chunks: {missing}")
# 2. 创建输出目录
meeting_audio_dir = AUDIO_DIR / str(meeting_id)
meeting_audio_dir.mkdir(parents=True, exist_ok=True)
# 3. 生成输出文件名
file_extension = validate_mime_type(mime_type)
output_filename = f"{uuid.uuid4()}{file_extension}"
output_path = meeting_audio_dir / output_filename
# 4. 按序合并分片
with open(output_path, 'wb') as outfile:
for i in range(total_chunks):
chunk_path = session_dir / f"chunk_{i:04d}.webm"
with open(chunk_path, 'rb') as infile:
outfile.write(infile.read())
# 5. 清理临时文件
shutil.rmtree(session_dir)
# 返回相对路径
return f"/{output_path.relative_to(BASE_DIR)}"
def cleanup_session(session_id: str):
"""清理会话文件"""
session_dir = get_session_dir(session_id)
if session_dir.exists():
shutil.rmtree(session_dir)
def cleanup_expired_sessions():
"""清理过期的会话(可以由定时任务调用)"""
now = datetime.now()
cleaned_count = 0
if not TEMP_UPLOAD_DIR.exists():
return cleaned_count
for session_dir in TEMP_UPLOAD_DIR.iterdir():
if not session_dir.is_dir():
continue
metadata_path = session_dir / "metadata.json"
if metadata_path.exists():
try:
with open(metadata_path, 'r') as f:
metadata = json.load(f)
expires_at = datetime.fromisoformat(metadata['expires_at'])
if now > expires_at:
shutil.rmtree(session_dir)
cleaned_count += 1
print(f"Cleaned up expired session: {session_dir.name}")
except Exception as e:
print(f"Error cleaning up session {session_dir.name}: {e}")
return cleaned_count
# ============ API Endpoints ============
@router.post("/audio/stream/init")
async def init_upload_session(
request: InitUploadRequest,
current_user: dict = Depends(get_current_user)
):
"""
初始化音频流式上传会话
创建临时目录生成session_id返回给客户端用于后续分片上传
"""
try:
# 1. 验证会议是否存在且属于当前用户
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute(
"SELECT user_id FROM meetings WHERE meeting_id = %s",
(request.meeting_id,)
)
meeting = cursor.fetchone()
if not meeting:
return create_api_response(
code="404",
message="会议不存在"
)
if meeting['user_id'] != current_user['user_id']:
return create_api_response(
code="403",
message="无权限操作此会议"
)
# 2. 验证MIME类型
try:
validate_mime_type(request.mime_type)
except ValueError as e:
return create_api_response(
code="400",
message=str(e)
)
# 3. 生成session_id
timestamp = int(datetime.now().timestamp() * 1000)
random_str = uuid.uuid4().hex[:8]
session_id = f"sess_{timestamp}_{random_str}"
# 4. 创建会话目录
session_dir = get_session_dir(session_id)
session_dir.mkdir(parents=True, exist_ok=True)
# 5. 创建并保存metadata
metadata = create_session_metadata(
session_id=session_id,
meeting_id=request.meeting_id,
mime_type=request.mime_type,
user_id=current_user['user_id']
)
save_session_metadata(session_id, metadata)
# 6. 清理过期会话
cleanup_expired_sessions()
return create_api_response(
code="200",
message="上传会话初始化成功",
data={
"session_id": session_id,
"chunk_size": MAX_CHUNK_SIZE,
"max_chunks": 1000
}
)
except Exception as e:
print(f"Error initializing upload session: {e}")
return create_api_response(
code="500",
message=f"初始化上传会话失败: {str(e)}"
)
@router.post("/audio/stream/chunk")
async def upload_audio_chunk(
session_id: str = Form(...),
chunk_index: int = Form(...),
chunk: UploadFile = File(...),
current_user: dict = Depends(get_current_user)
):
"""
上传音频分片
接收并保存音频分片文件
"""
try:
# 1. 验证session_id格式
try:
validate_session_id(session_id)
except ValueError:
return create_api_response(
code="400",
message="Invalid session_id format"
)
# 2. 加载session metadata
try:
metadata = load_session_metadata(session_id)
except FileNotFoundError:
return create_api_response(
code="404",
message="Session not found"
)
# 3. 验证会话所有权
if metadata['user_id'] != current_user['user_id']:
return create_api_response(
code="403",
message="Permission denied"
)
# 4. 验证分片大小
chunk_data = await chunk.read()
if len(chunk_data) > MAX_CHUNK_SIZE:
return create_api_response(
code="400",
message=f"Chunk size exceeds {MAX_CHUNK_SIZE // (1024*1024)}MB limit"
)
# 5. 验证总大小
session_total = get_session_total_size(session_id)
if session_total + len(chunk_data) > MAX_TOTAL_SIZE:
return create_api_response(
code="400",
message=f"Total size exceeds {MAX_TOTAL_SIZE // (1024*1024)}MB limit"
)
# 6. 保存分片文件
session_dir = get_session_dir(session_id)
chunk_path = session_dir / f"chunk_{chunk_index:04d}.webm"
with open(chunk_path, 'wb') as f:
f.write(chunk_data)
# 7. 更新metadata
update_session_chunks(session_id, chunk_index)
# 8. 获取已接收分片总数
metadata = load_session_metadata(session_id)
total_received = len(metadata['received_chunks'])
return create_api_response(
code="200",
message="分片上传成功",
data={
"session_id": session_id,
"chunk_index": chunk_index,
"received": True,
"total_received": total_received
}
)
except Exception as e:
print(f"Error uploading chunk: {e}")
return create_api_response(
code="500",
message=f"分片上传失败: {str(e)}",
data={
"session_id": session_id,
"chunk_index": chunk_index,
"should_retry": True
}
)
@router.post("/audio/stream/complete")
async def complete_upload(
request: CompleteUploadRequest,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
"""
完成上传并合并分片
验证分片完整性,合并所有分片,保存最终音频文件,可选启动转录任务和自动总结
"""
try:
# 1. 验证session_id
try:
validate_session_id(request.session_id)
except ValueError:
return create_api_response(
code="400",
message="Invalid session_id format"
)
# 2. 加载session metadata
try:
metadata = load_session_metadata(request.session_id)
except FileNotFoundError:
return create_api_response(
code="404",
message="Session not found"
)
# 3. 验证会话所有权
if metadata['user_id'] != current_user['user_id']:
return create_api_response(
code="403",
message="Permission denied"
)
# 4. 验证会议ID一致性
if metadata['meeting_id'] != request.meeting_id:
return create_api_response(
code="400",
message="Meeting ID mismatch"
)
# 5. 合并音频分片
try:
file_path = merge_audio_chunks(
session_id=request.session_id,
meeting_id=request.meeting_id,
total_chunks=request.total_chunks,
mime_type=request.mime_type
)
except ValueError as e:
# 分片不完整
return create_api_response(
code="500",
message=f"音频合并失败:{str(e)}",
data={
"should_retry": True
}
)
# 6. 获取文件信息
full_path = BASE_DIR / file_path.lstrip('/')
file_size = full_path.stat().st_size
file_name = full_path.name
# 7. 调用 audio_service 处理文件(数据库更新、启动转录和总结)
result = handle_audio_upload(
file_path=file_path,
file_name=file_name,
file_size=file_size,
meeting_id=request.meeting_id,
current_user=current_user,
auto_summarize=request.auto_summarize,
background_tasks=background_tasks,
prompt_id=request.prompt_id # 传递提示词模版ID
)
# 如果处理失败,返回错误
if not result["success"]:
return result["response"]
# 8. 返回成功响应
transcription_task_id = result["transcription_task_id"]
message_suffix = ""
if transcription_task_id:
if request.auto_summarize:
message_suffix = ",正在进行转录和总结"
else:
message_suffix = ",正在进行转录"
return create_api_response(
code="200",
message="音频上传完成" + message_suffix,
data={
"meeting_id": request.meeting_id,
"file_path": file_path,
"file_size": file_size,
"duration": None, # 可以通过ffprobe获取但不是必需的
"task_id": transcription_task_id,
"task_status": "pending" if transcription_task_id else None,
"auto_summarize": request.auto_summarize
}
)
except Exception as e:
print(f"Error completing upload: {e}")
return create_api_response(
code="500",
message=f"完成上传失败: {str(e)}"
)
@router.delete("/audio/stream/cancel")
async def cancel_upload(
request: CancelUploadRequest,
current_user: dict = Depends(get_current_user)
):
"""
取消上传会话
清理会话临时文件和目录
"""
try:
# 1. 验证session_id
try:
validate_session_id(request.session_id)
except ValueError:
return create_api_response(
code="400",
message="Invalid session_id format"
)
# 2. 加载session metadata验证所有权
try:
metadata = load_session_metadata(request.session_id)
# 验证会话所有权
if metadata['user_id'] != current_user['user_id']:
return create_api_response(
code="403",
message="Permission denied"
)
except FileNotFoundError:
# 会话不存在,视为已清理
return create_api_response(
code="200",
message="上传会话已取消",
data={
"session_id": request.session_id,
"cleaned": True
}
)
# 3. 清理会话文件
cleanup_session(request.session_id)
return create_api_response(
code="200",
message="上传会话已取消",
data={
"session_id": request.session_id,
"cleaned": True
}
)
except Exception as e:
print(f"Error canceling upload: {e}")
return create_api_response(
code="500",
message=f"取消上传失败: {str(e)}"
)