177 lines
6.1 KiB
Python
177 lines
6.1 KiB
Python
"""
|
||
测试提示词模版选择功能
|
||
"""
|
||
import sys
|
||
sys.path.insert(0, 'app')
|
||
|
||
from app.services.llm_service import LLMService
|
||
from app.services.async_meeting_service import AsyncMeetingService
|
||
from app.core.database import get_db_connection
|
||
|
||
def test_get_active_prompts():
|
||
"""测试获取启用的提示词列表"""
|
||
print("\n=== 测试1: 获取启用的提示词列表 ===")
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
# 获取MEETING_TASK类型的启用模版
|
||
query = """
|
||
SELECT id, name, is_default
|
||
FROM prompts
|
||
WHERE task_type = %s AND is_active = TRUE
|
||
ORDER BY is_default DESC, created_at DESC
|
||
"""
|
||
cursor.execute(query, ('MEETING_TASK',))
|
||
prompts = cursor.fetchall()
|
||
|
||
print(f"✓ 找到 {len(prompts)} 个启用的会议任务模版:")
|
||
for p in prompts:
|
||
default_flag = " [默认]" if p['is_default'] else ""
|
||
print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}")
|
||
|
||
return prompts
|
||
except Exception as e:
|
||
print(f"✗ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return []
|
||
|
||
def test_get_task_prompt_with_id(prompts):
|
||
"""测试通过prompt_id获取提示词内容"""
|
||
print("\n=== 测试2: 通过prompt_id获取提示词内容 ===")
|
||
|
||
if not prompts:
|
||
print("⚠ 没有可用的提示词模版,跳过测试")
|
||
return
|
||
|
||
llm_service = LLMService()
|
||
|
||
# 测试获取第一个提示词
|
||
test_prompt = prompts[0]
|
||
try:
|
||
content = llm_service.get_task_prompt('MEETING_TASK', prompt_id=test_prompt['id'])
|
||
print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}")
|
||
print(f" 内容长度: {len(content)} 字符")
|
||
print(f" 内容预览: {content[:100]}...")
|
||
except Exception as e:
|
||
print(f"✗ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# 测试获取默认提示词(不指定prompt_id)
|
||
try:
|
||
default_content = llm_service.get_task_prompt('MEETING_TASK')
|
||
print(f"✓ 成功获取默认提示词")
|
||
print(f" 内容长度: {len(default_content)} 字符")
|
||
except Exception as e:
|
||
print(f"✗ 获取默认提示词失败: {e}")
|
||
|
||
def test_async_meeting_service_signature():
|
||
"""测试async_meeting_service的方法签名"""
|
||
print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===")
|
||
|
||
import inspect
|
||
async_service = AsyncMeetingService()
|
||
|
||
# 检查start_summary_generation方法签名
|
||
sig = inspect.signature(async_service.start_summary_generation)
|
||
params = list(sig.parameters.keys())
|
||
|
||
if 'prompt_id' in params:
|
||
print(f"✓ start_summary_generation 方法支持 prompt_id 参数")
|
||
print(f" 参数列表: {params}")
|
||
else:
|
||
print(f"✗ start_summary_generation 方法缺少 prompt_id 参数")
|
||
print(f" 参数列表: {params}")
|
||
|
||
# 检查monitor_and_auto_summarize方法签名
|
||
sig2 = inspect.signature(async_service.monitor_and_auto_summarize)
|
||
params2 = list(sig2.parameters.keys())
|
||
|
||
if 'prompt_id' in params2:
|
||
print(f"✓ monitor_and_auto_summarize 方法支持 prompt_id 参数")
|
||
print(f" 参数列表: {params2}")
|
||
else:
|
||
print(f"✗ monitor_and_auto_summarize 方法缺少 prompt_id 参数")
|
||
print(f" 参数列表: {params2}")
|
||
|
||
def test_database_schema():
|
||
"""测试数据库schema是否包含prompt_id列"""
|
||
print("\n=== 测试4: 验证数据库schema ===")
|
||
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
# 检查llm_tasks表是否有prompt_id列
|
||
cursor.execute("""
|
||
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT
|
||
FROM information_schema.COLUMNS
|
||
WHERE TABLE_SCHEMA = DATABASE()
|
||
AND TABLE_NAME = 'llm_tasks'
|
||
AND COLUMN_NAME = 'prompt_id'
|
||
""")
|
||
result = cursor.fetchone()
|
||
|
||
if result:
|
||
print(f"✓ llm_tasks 表包含 prompt_id 列")
|
||
print(f" 类型: {result['DATA_TYPE']}")
|
||
print(f" 可空: {result['IS_NULLABLE']}")
|
||
print(f" 默认值: {result['COLUMN_DEFAULT']}")
|
||
else:
|
||
print(f"✗ llm_tasks 表缺少 prompt_id 列")
|
||
except Exception as e:
|
||
print(f"✗ 数据库检查失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
def test_api_endpoints():
|
||
"""测试API端点定义"""
|
||
print("\n=== 测试5: 验证API端点定义 ===")
|
||
|
||
try:
|
||
from app.api.endpoints.meetings import GenerateSummaryRequest
|
||
import inspect
|
||
|
||
# 检查GenerateSummaryRequest模型
|
||
fields = GenerateSummaryRequest.__fields__
|
||
|
||
if 'prompt_id' in fields:
|
||
print(f"✓ GenerateSummaryRequest 包含 prompt_id 字段")
|
||
print(f" 字段列表: {list(fields.keys())}")
|
||
else:
|
||
print(f"✗ GenerateSummaryRequest 缺少 prompt_id 字段")
|
||
print(f" 字段列表: {list(fields.keys())}")
|
||
|
||
# 检查audio_service.handle_audio_upload签名
|
||
from app.services.audio_service import handle_audio_upload
|
||
sig = inspect.signature(handle_audio_upload)
|
||
params = list(sig.parameters.keys())
|
||
|
||
if 'prompt_id' in params:
|
||
print(f"✓ handle_audio_upload 方法支持 prompt_id 参数")
|
||
else:
|
||
print(f"✗ handle_audio_upload 方法缺少 prompt_id 参数")
|
||
|
||
except Exception as e:
|
||
print(f"✗ API端点检查失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
if __name__ == '__main__':
|
||
print("=" * 60)
|
||
print("开始测试提示词模版选择功能")
|
||
print("=" * 60)
|
||
|
||
# 运行所有测试
|
||
prompts = test_get_active_prompts()
|
||
test_get_task_prompt_with_id(prompts)
|
||
test_async_meeting_service_signature()
|
||
test_database_schema()
|
||
test_api_endpoints()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试完成")
|
||
print("=" * 60)
|