276 lines
9.7 KiB
Python
276 lines
9.7 KiB
Python
"""
|
||
测试 upload_audio 接口的 auto_summarize 参数
|
||
"""
|
||
import requests
|
||
import time
|
||
import json
|
||
|
||
# 配置
|
||
BASE_URL = "http://localhost:8000/api"
|
||
# 请替换为你的有效token
|
||
AUTH_TOKEN = "your_auth_token_here"
|
||
# 请替换为你的测试会议ID
|
||
TEST_MEETING_ID = 1
|
||
|
||
# 请求头
|
||
headers = {
|
||
"Authorization": f"Bearer {AUTH_TOKEN}"
|
||
}
|
||
|
||
|
||
def test_upload_audio(auto_summarize=True):
|
||
"""测试音频上传接口"""
|
||
print("=" * 60)
|
||
print(f"测试: upload_audio 接口 (auto_summarize={auto_summarize})")
|
||
print("=" * 60)
|
||
|
||
# 准备测试文件
|
||
audio_file_path = "test_audio.mp3" # 请替换为实际的音频文件路径
|
||
|
||
try:
|
||
with open(audio_file_path, 'rb') as audio_file:
|
||
files = {
|
||
'audio_file': ('test_audio.mp3', audio_file, 'audio/mpeg')
|
||
}
|
||
data = {
|
||
'force_replace': 'false',
|
||
'auto_summarize': 'true' if auto_summarize else 'false'
|
||
}
|
||
|
||
# 发送请求
|
||
url = f"{BASE_URL}/meetings/upload-audio"
|
||
print(f"\n发送请求到: {url}")
|
||
print(f"参数: auto_summarize={data['auto_summarize']}")
|
||
response = requests.post(url, headers=headers, files=files, data=data)
|
||
|
||
print(f"状态码: {response.status_code}")
|
||
print(f"响应内容:")
|
||
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||
|
||
# 如果上传成功,获取任务ID
|
||
if response.status_code == 200:
|
||
response_data = response.json()
|
||
if response_data.get('code') == '200':
|
||
task_id = response_data['data'].get('task_id')
|
||
auto_sum = response_data['data'].get('auto_summarize')
|
||
print(f"\n✓ 上传成功! 转录任务ID: {task_id}")
|
||
print(f" 自动总结: {'开启' if auto_sum else '关闭'}")
|
||
if auto_sum:
|
||
print(f" 提示: 音频已上传,后台正在自动进行转录和总结")
|
||
else:
|
||
print(f" 提示: 音频已上传,正在进行转录(不会自动总结)")
|
||
print(f"\n 可以通过以下接口查询状态:")
|
||
print(f" - 转录状态: GET /meetings/{TEST_MEETING_ID}/transcription/status")
|
||
print(f" - 总结任务: GET /meetings/{TEST_MEETING_ID}/llm-tasks")
|
||
print(f" - 会议详情: GET /meetings/{TEST_MEETING_ID}")
|
||
return True
|
||
elif response_data.get('code') == '300':
|
||
print("\n⚠ 需要确认替换现有文件")
|
||
return False
|
||
else:
|
||
print(f"\n✗ 上传失败")
|
||
return False
|
||
|
||
except FileNotFoundError:
|
||
print(f"\n✗ 错误: 找不到测试音频文件 {audio_file_path}")
|
||
print("请创建一个测试音频文件或修改 audio_file_path 变量")
|
||
return False
|
||
except Exception as e:
|
||
print(f"\n✗ 错误: {e}")
|
||
return False
|
||
|
||
|
||
def test_get_transcription_status():
|
||
"""测试获取转录状态接口"""
|
||
print("\n" + "=" * 60)
|
||
print("测试: 获取转录状态")
|
||
print("=" * 60)
|
||
|
||
url = f"{BASE_URL}/meetings/{TEST_MEETING_ID}/transcription/status"
|
||
print(f"\n发送请求到: {url}")
|
||
|
||
try:
|
||
response = requests.get(url, headers=headers)
|
||
print(f"状态码: {response.status_code}")
|
||
print(f"响应内容:")
|
||
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||
|
||
if response.status_code == 200:
|
||
response_data = response.json()
|
||
if response_data.get('code') == '200':
|
||
data = response_data['data']
|
||
print(f"\n✓ 获取转录状态成功!")
|
||
print(f" - 任务ID: {data.get('task_id')}")
|
||
print(f" - 状态: {data.get('status')}")
|
||
print(f" - 进度: {data.get('progress')}%")
|
||
return data.get('status'), data.get('progress')
|
||
else:
|
||
print(f"\n✗ 获取状态失败")
|
||
return None, None
|
||
|
||
except Exception as e:
|
||
print(f"\n✗ 错误: {e}")
|
||
return None, None
|
||
|
||
|
||
def test_get_llm_tasks():
|
||
"""测试获取LLM任务列表"""
|
||
print("\n" + "=" * 60)
|
||
print("测试: 获取LLM任务列表")
|
||
print("=" * 60)
|
||
|
||
url = f"{BASE_URL}/meetings/{TEST_MEETING_ID}/llm-tasks"
|
||
print(f"\n发送请求到: {url}")
|
||
|
||
try:
|
||
response = requests.get(url, headers=headers)
|
||
print(f"状态码: {response.status_code}")
|
||
print(f"响应内容:")
|
||
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||
|
||
if response.status_code == 200:
|
||
response_data = response.json()
|
||
if response_data.get('code') == '200':
|
||
tasks = response_data['data'].get('tasks', [])
|
||
print(f"\n✓ 获取LLM任务成功! 共 {len(tasks)} 个任务")
|
||
if tasks:
|
||
latest_task = tasks[0]
|
||
print(f" 最新任务:")
|
||
print(f" - 任务ID: {latest_task.get('task_id')}")
|
||
print(f" - 状态: {latest_task.get('status')}")
|
||
print(f" - 进度: {latest_task.get('progress')}%")
|
||
return latest_task.get('status'), latest_task.get('progress')
|
||
return None, None
|
||
else:
|
||
print(f"\n✗ 获取任务失败")
|
||
return None, None
|
||
|
||
except Exception as e:
|
||
print(f"\n✗ 错误: {e}")
|
||
return None, None
|
||
|
||
|
||
def monitor_progress():
|
||
"""持续监控处理进度"""
|
||
print("\n" + "=" * 60)
|
||
print("持续监控处理进度 (每10秒查询一次)")
|
||
print("按 Ctrl+C 停止监控")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
transcription_completed = False
|
||
summary_completed = False
|
||
|
||
while True:
|
||
print(f"\n[{time.strftime('%H:%M:%S')}] 查询状态...")
|
||
|
||
# 查询转录状态
|
||
trans_status, trans_progress = test_get_transcription_status()
|
||
|
||
# 如果转录完成,查询总结状态
|
||
if trans_status == 'completed' and not transcription_completed:
|
||
print(f"\n✓ 转录已完成!")
|
||
transcription_completed = True
|
||
|
||
if transcription_completed:
|
||
summ_status, summ_progress = test_get_llm_tasks()
|
||
if summ_status == 'completed' and not summary_completed:
|
||
print(f"\n✓ 总结已完成!")
|
||
summary_completed = True
|
||
break
|
||
elif summ_status == 'failed':
|
||
print(f"\n✗ 总结失败")
|
||
break
|
||
|
||
# 检查转录是否失败
|
||
if trans_status == 'failed':
|
||
print(f"\n✗ 转录失败")
|
||
break
|
||
|
||
# 如果全部完成,退出
|
||
if transcription_completed and summary_completed:
|
||
print(f"\n✓ 全部完成!")
|
||
break
|
||
|
||
time.sleep(10)
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\n⚠ 用户中断监控")
|
||
except Exception as e:
|
||
print(f"\n✗ 监控出错: {e}")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("\n")
|
||
print("╔" + "═" * 58 + "╗")
|
||
print("║" + " " * 12 + "upload_audio 接口测试" + " " * 23 + "║")
|
||
print("║" + " " * 10 + "(测试 auto_summarize 参数)" + " " * 17 + "║")
|
||
print("╚" + "═" * 58 + "╝")
|
||
|
||
print("\n请确保:")
|
||
print("1. 后端服务正在运行 (http://localhost:8000)")
|
||
print("2. 已修改脚本中的 AUTH_TOKEN 和 TEST_MEETING_ID")
|
||
print("3. 已准备好测试音频文件")
|
||
|
||
input("\n按回车键开始测试...")
|
||
|
||
# 测试1: 查看当前转录状态
|
||
test_get_transcription_status()
|
||
|
||
# 测试2: 查看当前LLM任务
|
||
test_get_llm_tasks()
|
||
|
||
# 询问要测试哪种模式
|
||
print("\n" + "-" * 60)
|
||
print("请选择测试模式:")
|
||
print("1. 仅转录 (auto_summarize=false)")
|
||
print("2. 转录+自动总结 (auto_summarize=true)")
|
||
print("3. 两种模式都测试")
|
||
choice = input("请输入选项 (1/2/3): ")
|
||
|
||
if choice == '1':
|
||
# 测试:仅转录
|
||
if test_upload_audio(auto_summarize=False):
|
||
print("\n⚠ 注意: 此模式下不会自动生成总结")
|
||
print("如需生成总结,请手动调用: POST /meetings/{meeting_id}/generate-summary-async")
|
||
elif choice == '2':
|
||
# 测试:转录+自动总结
|
||
if test_upload_audio(auto_summarize=True):
|
||
print("\n" + "-" * 60)
|
||
choice = input("是否要持续监控处理进度? (y/n): ")
|
||
if choice.lower() == 'y':
|
||
monitor_progress()
|
||
elif choice == '3':
|
||
# 两种模式都测试
|
||
print("\n" + "=" * 60)
|
||
print("测试模式1: 仅转录 (auto_summarize=false)")
|
||
print("=" * 60)
|
||
test_upload_audio(auto_summarize=False)
|
||
|
||
input("\n按回车键继续测试模式2...")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试模式2: 转录+自动总结 (auto_summarize=true)")
|
||
print("=" * 60)
|
||
if test_upload_audio(auto_summarize=True):
|
||
print("\n" + "-" * 60)
|
||
choice = input("是否要持续监控处理进度? (y/n): ")
|
||
if choice.lower() == 'y':
|
||
monitor_progress()
|
||
else:
|
||
print("\n✗ 无效选项")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试完成!")
|
||
print("=" * 60)
|
||
print("\n总结:")
|
||
print("- auto_summarize=false: 只执行转录,不自动生成总结")
|
||
print("- auto_summarize=true: 执行转录后自动生成总结")
|
||
print("- 默认值: true (向前兼容)")
|
||
print("- 现有页面建议设置: auto_summarize=false")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|