imetting_backend/test/test_prompt_id_feature.py

177 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
测试提示词模版选择功能
"""
import sys
sys.path.insert(0, 'app')
from app.services.llm_service import LLMService
from app.services.async_meeting_service import AsyncMeetingService
from app.core.database import get_db_connection
def test_get_active_prompts():
"""测试获取启用的提示词列表"""
print("\n=== 测试1: 获取启用的提示词列表 ===")
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 获取MEETING_TASK类型的启用模版
query = """
SELECT id, name, is_default
FROM prompts
WHERE task_type = %s AND is_active = TRUE
ORDER BY is_default DESC, created_at DESC
"""
cursor.execute(query, ('MEETING_TASK',))
prompts = cursor.fetchall()
print(f"✓ 找到 {len(prompts)} 个启用的会议任务模版:")
for p in prompts:
default_flag = " [默认]" if p['is_default'] else ""
print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}")
return prompts
except Exception as e:
print(f"✗ 测试失败: {e}")
import traceback
traceback.print_exc()
return []
def test_get_task_prompt_with_id(prompts):
"""测试通过prompt_id获取提示词内容"""
print("\n=== 测试2: 通过prompt_id获取提示词内容 ===")
if not prompts:
print("⚠ 没有可用的提示词模版,跳过测试")
return
llm_service = LLMService()
# 测试获取第一个提示词
test_prompt = prompts[0]
try:
content = llm_service.get_task_prompt('MEETING_TASK', prompt_id=test_prompt['id'])
print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}")
print(f" 内容长度: {len(content)} 字符")
print(f" 内容预览: {content[:100]}...")
except Exception as e:
print(f"✗ 测试失败: {e}")
import traceback
traceback.print_exc()
# 测试获取默认提示词不指定prompt_id
try:
default_content = llm_service.get_task_prompt('MEETING_TASK')
print(f"✓ 成功获取默认提示词")
print(f" 内容长度: {len(default_content)} 字符")
except Exception as e:
print(f"✗ 获取默认提示词失败: {e}")
def test_async_meeting_service_signature():
"""测试async_meeting_service的方法签名"""
print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===")
import inspect
async_service = AsyncMeetingService()
# 检查start_summary_generation方法签名
sig = inspect.signature(async_service.start_summary_generation)
params = list(sig.parameters.keys())
if 'prompt_id' in params:
print(f"✓ start_summary_generation 方法支持 prompt_id 参数")
print(f" 参数列表: {params}")
else:
print(f"✗ start_summary_generation 方法缺少 prompt_id 参数")
print(f" 参数列表: {params}")
# 检查monitor_and_auto_summarize方法签名
sig2 = inspect.signature(async_service.monitor_and_auto_summarize)
params2 = list(sig2.parameters.keys())
if 'prompt_id' in params2:
print(f"✓ monitor_and_auto_summarize 方法支持 prompt_id 参数")
print(f" 参数列表: {params2}")
else:
print(f"✗ monitor_and_auto_summarize 方法缺少 prompt_id 参数")
print(f" 参数列表: {params2}")
def test_database_schema():
"""测试数据库schema是否包含prompt_id列"""
print("\n=== 测试4: 验证数据库schema ===")
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 检查llm_tasks表是否有prompt_id列
cursor.execute("""
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = 'llm_tasks'
AND COLUMN_NAME = 'prompt_id'
""")
result = cursor.fetchone()
if result:
print(f"✓ llm_tasks 表包含 prompt_id 列")
print(f" 类型: {result['DATA_TYPE']}")
print(f" 可空: {result['IS_NULLABLE']}")
print(f" 默认值: {result['COLUMN_DEFAULT']}")
else:
print(f"✗ llm_tasks 表缺少 prompt_id 列")
except Exception as e:
print(f"✗ 数据库检查失败: {e}")
import traceback
traceback.print_exc()
def test_api_endpoints():
"""测试API端点定义"""
print("\n=== 测试5: 验证API端点定义 ===")
try:
from app.api.endpoints.meetings import GenerateSummaryRequest
import inspect
# 检查GenerateSummaryRequest模型
fields = GenerateSummaryRequest.__fields__
if 'prompt_id' in fields:
print(f"✓ GenerateSummaryRequest 包含 prompt_id 字段")
print(f" 字段列表: {list(fields.keys())}")
else:
print(f"✗ GenerateSummaryRequest 缺少 prompt_id 字段")
print(f" 字段列表: {list(fields.keys())}")
# 检查audio_service.handle_audio_upload签名
from app.services.audio_service import handle_audio_upload
sig = inspect.signature(handle_audio_upload)
params = list(sig.parameters.keys())
if 'prompt_id' in params:
print(f"✓ handle_audio_upload 方法支持 prompt_id 参数")
else:
print(f"✗ handle_audio_upload 方法缺少 prompt_id 参数")
except Exception as e:
print(f"✗ API端点检查失败: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
print("=" * 60)
print("开始测试提示词模版选择功能")
print("=" * 60)
# 运行所有测试
prompts = test_get_active_prompts()
test_get_task_prompt_with_id(prompts)
test_async_meeting_service_signature()
test_database_schema()
test_api_endpoints()
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)