imetting_backend/test/test_kb_prompt_id_feature.py

167 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
测试知识库提示词模版选择功能
"""
import sys
sys.path.insert(0, 'app')
from app.services.llm_service import LLMService
from app.services.async_knowledge_base_service import AsyncKnowledgeBaseService
from app.core.database import get_db_connection
def test_get_active_knowledge_prompts():
"""测试获取启用的知识库提示词列表"""
print("\n=== 测试1: 获取启用的知识库提示词列表 ===")
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 获取KNOWLEDGE_TASK类型的启用模版
query = """
SELECT id, name, is_default
FROM prompts
WHERE task_type = %s AND is_active = TRUE
ORDER BY is_default DESC, created_at DESC
"""
cursor.execute(query, ('KNOWLEDGE_TASK',))
prompts = cursor.fetchall()
print(f"✓ 找到 {len(prompts)} 个启用的知识库任务模版:")
for p in prompts:
default_flag = " [默认]" if p['is_default'] else ""
print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}")
return prompts
except Exception as e:
print(f"✗ 测试失败: {e}")
import traceback
traceback.print_exc()
return []
def test_get_task_prompt_with_id(prompts):
"""测试通过prompt_id获取知识库提示词内容"""
print("\n=== 测试2: 通过prompt_id获取知识库提示词内容 ===")
if not prompts:
print("⚠ 没有可用的提示词模版,跳过测试")
return
llm_service = LLMService()
# 测试获取第一个提示词
test_prompt = prompts[0]
try:
content = llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id=test_prompt['id'])
print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}")
print(f" 内容长度: {len(content)} 字符")
print(f" 内容预览: {content[:100]}...")
except Exception as e:
print(f"✗ 测试失败: {e}")
import traceback
traceback.print_exc()
# 测试获取默认提示词不指定prompt_id
try:
default_content = llm_service.get_task_prompt('KNOWLEDGE_TASK')
print(f"✓ 成功获取默认提示词")
print(f" 内容长度: {len(default_content)} 字符")
except Exception as e:
print(f"✗ 获取默认提示词失败: {e}")
def test_async_kb_service_signature():
"""测试async_knowledge_base_service的方法签名"""
print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===")
import inspect
async_service = AsyncKnowledgeBaseService()
# 检查start_generation方法签名
sig = inspect.signature(async_service.start_generation)
params = list(sig.parameters.keys())
if 'prompt_id' in params:
print(f"✓ start_generation 方法支持 prompt_id 参数")
print(f" 参数列表: {params}")
else:
print(f"✗ start_generation 方法缺少 prompt_id 参数")
print(f" 参数列表: {params}")
# 检查_build_prompt方法签名
sig2 = inspect.signature(async_service._build_prompt)
params2 = list(sig2.parameters.keys())
if 'prompt_id' in params2:
print(f"✓ _build_prompt 方法支持 prompt_id 参数")
print(f" 参数列表: {params2}")
else:
print(f"✗ _build_prompt 方法缺少 prompt_id 参数")
print(f" 参数列表: {params2}")
def test_database_schema():
"""测试数据库schema是否包含prompt_id列"""
print("\n=== 测试4: 验证数据库schema ===")
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 检查knowledge_base_tasks表是否有prompt_id列
cursor.execute("""
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = 'knowledge_base_tasks'
AND COLUMN_NAME = 'prompt_id'
""")
result = cursor.fetchone()
if result:
print(f"✓ knowledge_base_tasks 表包含 prompt_id 列")
print(f" 类型: {result['DATA_TYPE']}")
print(f" 可空: {result['IS_NULLABLE']}")
print(f" 默认值: {result['COLUMN_DEFAULT']}")
else:
print(f"✗ knowledge_base_tasks 表缺少 prompt_id 列")
except Exception as e:
print(f"✗ 数据库检查失败: {e}")
import traceback
traceback.print_exc()
def test_api_model():
"""测试API模型定义"""
print("\n=== 测试5: 验证API模型定义 ===")
try:
from app.models.models import CreateKnowledgeBaseRequest
import inspect
# 检查CreateKnowledgeBaseRequest模型
fields = CreateKnowledgeBaseRequest.model_fields
if 'prompt_id' in fields:
print(f"✓ CreateKnowledgeBaseRequest 包含 prompt_id 字段")
print(f" 字段列表: {list(fields.keys())}")
else:
print(f"✗ CreateKnowledgeBaseRequest 缺少 prompt_id 字段")
print(f" 字段列表: {list(fields.keys())}")
except Exception as e:
print(f"✗ API模型检查失败: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
print("=" * 60)
print("开始测试知识库提示词模版选择功能")
print("=" * 60)
# 运行所有测试
prompts = test_get_active_knowledge_prompts()
test_get_task_prompt_with_id(prompts)
test_async_kb_service_signature()
test_database_schema()
test_api_model()
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)