167 lines
5.8 KiB
Python
167 lines
5.8 KiB
Python
"""
|
||
测试知识库提示词模版选择功能
|
||
"""
|
||
import sys
|
||
sys.path.insert(0, 'app')
|
||
|
||
from app.services.llm_service import LLMService
|
||
from app.services.async_knowledge_base_service import AsyncKnowledgeBaseService
|
||
from app.core.database import get_db_connection
|
||
|
||
def test_get_active_knowledge_prompts():
|
||
"""测试获取启用的知识库提示词列表"""
|
||
print("\n=== 测试1: 获取启用的知识库提示词列表 ===")
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
# 获取KNOWLEDGE_TASK类型的启用模版
|
||
query = """
|
||
SELECT id, name, is_default
|
||
FROM prompts
|
||
WHERE task_type = %s AND is_active = TRUE
|
||
ORDER BY is_default DESC, created_at DESC
|
||
"""
|
||
cursor.execute(query, ('KNOWLEDGE_TASK',))
|
||
prompts = cursor.fetchall()
|
||
|
||
print(f"✓ 找到 {len(prompts)} 个启用的知识库任务模版:")
|
||
for p in prompts:
|
||
default_flag = " [默认]" if p['is_default'] else ""
|
||
print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}")
|
||
|
||
return prompts
|
||
except Exception as e:
|
||
print(f"✗ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return []
|
||
|
||
def test_get_task_prompt_with_id(prompts):
|
||
"""测试通过prompt_id获取知识库提示词内容"""
|
||
print("\n=== 测试2: 通过prompt_id获取知识库提示词内容 ===")
|
||
|
||
if not prompts:
|
||
print("⚠ 没有可用的提示词模版,跳过测试")
|
||
return
|
||
|
||
llm_service = LLMService()
|
||
|
||
# 测试获取第一个提示词
|
||
test_prompt = prompts[0]
|
||
try:
|
||
content = llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id=test_prompt['id'])
|
||
print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}")
|
||
print(f" 内容长度: {len(content)} 字符")
|
||
print(f" 内容预览: {content[:100]}...")
|
||
except Exception as e:
|
||
print(f"✗ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# 测试获取默认提示词(不指定prompt_id)
|
||
try:
|
||
default_content = llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||
print(f"✓ 成功获取默认提示词")
|
||
print(f" 内容长度: {len(default_content)} 字符")
|
||
except Exception as e:
|
||
print(f"✗ 获取默认提示词失败: {e}")
|
||
|
||
def test_async_kb_service_signature():
|
||
"""测试async_knowledge_base_service的方法签名"""
|
||
print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===")
|
||
|
||
import inspect
|
||
async_service = AsyncKnowledgeBaseService()
|
||
|
||
# 检查start_generation方法签名
|
||
sig = inspect.signature(async_service.start_generation)
|
||
params = list(sig.parameters.keys())
|
||
|
||
if 'prompt_id' in params:
|
||
print(f"✓ start_generation 方法支持 prompt_id 参数")
|
||
print(f" 参数列表: {params}")
|
||
else:
|
||
print(f"✗ start_generation 方法缺少 prompt_id 参数")
|
||
print(f" 参数列表: {params}")
|
||
|
||
# 检查_build_prompt方法签名
|
||
sig2 = inspect.signature(async_service._build_prompt)
|
||
params2 = list(sig2.parameters.keys())
|
||
|
||
if 'prompt_id' in params2:
|
||
print(f"✓ _build_prompt 方法支持 prompt_id 参数")
|
||
print(f" 参数列表: {params2}")
|
||
else:
|
||
print(f"✗ _build_prompt 方法缺少 prompt_id 参数")
|
||
print(f" 参数列表: {params2}")
|
||
|
||
def test_database_schema():
|
||
"""测试数据库schema是否包含prompt_id列"""
|
||
print("\n=== 测试4: 验证数据库schema ===")
|
||
|
||
try:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
# 检查knowledge_base_tasks表是否有prompt_id列
|
||
cursor.execute("""
|
||
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT
|
||
FROM information_schema.COLUMNS
|
||
WHERE TABLE_SCHEMA = DATABASE()
|
||
AND TABLE_NAME = 'knowledge_base_tasks'
|
||
AND COLUMN_NAME = 'prompt_id'
|
||
""")
|
||
result = cursor.fetchone()
|
||
|
||
if result:
|
||
print(f"✓ knowledge_base_tasks 表包含 prompt_id 列")
|
||
print(f" 类型: {result['DATA_TYPE']}")
|
||
print(f" 可空: {result['IS_NULLABLE']}")
|
||
print(f" 默认值: {result['COLUMN_DEFAULT']}")
|
||
else:
|
||
print(f"✗ knowledge_base_tasks 表缺少 prompt_id 列")
|
||
except Exception as e:
|
||
print(f"✗ 数据库检查失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
def test_api_model():
|
||
"""测试API模型定义"""
|
||
print("\n=== 测试5: 验证API模型定义 ===")
|
||
|
||
try:
|
||
from app.models.models import CreateKnowledgeBaseRequest
|
||
import inspect
|
||
|
||
# 检查CreateKnowledgeBaseRequest模型
|
||
fields = CreateKnowledgeBaseRequest.model_fields
|
||
|
||
if 'prompt_id' in fields:
|
||
print(f"✓ CreateKnowledgeBaseRequest 包含 prompt_id 字段")
|
||
print(f" 字段列表: {list(fields.keys())}")
|
||
else:
|
||
print(f"✗ CreateKnowledgeBaseRequest 缺少 prompt_id 字段")
|
||
print(f" 字段列表: {list(fields.keys())}")
|
||
|
||
except Exception as e:
|
||
print(f"✗ API模型检查失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
if __name__ == '__main__':
|
||
print("=" * 60)
|
||
print("开始测试知识库提示词模版选择功能")
|
||
print("=" * 60)
|
||
|
||
# 运行所有测试
|
||
prompts = test_get_active_knowledge_prompts()
|
||
test_get_task_prompt_with_id(prompts)
|
||
test_async_kb_service_signature()
|
||
test_database_schema()
|
||
test_api_model()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试完成")
|
||
print("=" * 60)
|