""" 测试知识库提示词模版选择功能 """ import sys sys.path.insert(0, 'app') from app.services.llm_service import LLMService from app.services.async_knowledge_base_service import AsyncKnowledgeBaseService from app.core.database import get_db_connection def test_get_active_knowledge_prompts(): """测试获取启用的知识库提示词列表""" print("\n=== 测试1: 获取启用的知识库提示词列表 ===") try: with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) # 获取KNOWLEDGE_TASK类型的启用模版 query = """ SELECT id, name, is_default FROM prompts WHERE task_type = %s AND is_active = TRUE ORDER BY is_default DESC, created_at DESC """ cursor.execute(query, ('KNOWLEDGE_TASK',)) prompts = cursor.fetchall() print(f"✓ 找到 {len(prompts)} 个启用的知识库任务模版:") for p in prompts: default_flag = " [默认]" if p['is_default'] else "" print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}") return prompts except Exception as e: print(f"✗ 测试失败: {e}") import traceback traceback.print_exc() return [] def test_get_task_prompt_with_id(prompts): """测试通过prompt_id获取知识库提示词内容""" print("\n=== 测试2: 通过prompt_id获取知识库提示词内容 ===") if not prompts: print("⚠ 没有可用的提示词模版,跳过测试") return llm_service = LLMService() # 测试获取第一个提示词 test_prompt = prompts[0] try: content = llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id=test_prompt['id']) print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}") print(f" 内容长度: {len(content)} 字符") print(f" 内容预览: {content[:100]}...") except Exception as e: print(f"✗ 测试失败: {e}") import traceback traceback.print_exc() # 测试获取默认提示词(不指定prompt_id) try: default_content = llm_service.get_task_prompt('KNOWLEDGE_TASK') print(f"✓ 成功获取默认提示词") print(f" 内容长度: {len(default_content)} 字符") except Exception as e: print(f"✗ 获取默认提示词失败: {e}") def test_async_kb_service_signature(): """测试async_knowledge_base_service的方法签名""" print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===") import inspect async_service = AsyncKnowledgeBaseService() # 检查start_generation方法签名 sig = inspect.signature(async_service.start_generation) params = list(sig.parameters.keys()) if 'prompt_id' in params: print(f"✓ start_generation 方法支持 prompt_id 参数") print(f" 参数列表: {params}") else: print(f"✗ start_generation 方法缺少 prompt_id 参数") print(f" 参数列表: {params}") # 检查_build_prompt方法签名 sig2 = inspect.signature(async_service._build_prompt) params2 = list(sig2.parameters.keys()) if 'prompt_id' in params2: print(f"✓ _build_prompt 方法支持 prompt_id 参数") print(f" 参数列表: {params2}") else: print(f"✗ _build_prompt 方法缺少 prompt_id 参数") print(f" 参数列表: {params2}") def test_database_schema(): """测试数据库schema是否包含prompt_id列""" print("\n=== 测试4: 验证数据库schema ===") try: with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) # 检查knowledge_base_tasks表是否有prompt_id列 cursor.execute(""" SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'knowledge_base_tasks' AND COLUMN_NAME = 'prompt_id' """) result = cursor.fetchone() if result: print(f"✓ knowledge_base_tasks 表包含 prompt_id 列") print(f" 类型: {result['DATA_TYPE']}") print(f" 可空: {result['IS_NULLABLE']}") print(f" 默认值: {result['COLUMN_DEFAULT']}") else: print(f"✗ knowledge_base_tasks 表缺少 prompt_id 列") except Exception as e: print(f"✗ 数据库检查失败: {e}") import traceback traceback.print_exc() def test_api_model(): """测试API模型定义""" print("\n=== 测试5: 验证API模型定义 ===") try: from app.models.models import CreateKnowledgeBaseRequest import inspect # 检查CreateKnowledgeBaseRequest模型 fields = CreateKnowledgeBaseRequest.model_fields if 'prompt_id' in fields: print(f"✓ CreateKnowledgeBaseRequest 包含 prompt_id 字段") print(f" 字段列表: {list(fields.keys())}") else: print(f"✗ CreateKnowledgeBaseRequest 缺少 prompt_id 字段") print(f" 字段列表: {list(fields.keys())}") except Exception as e: print(f"✗ API模型检查失败: {e}") import traceback traceback.print_exc() if __name__ == '__main__': print("=" * 60) print("开始测试知识库提示词模版选择功能") print("=" * 60) # 运行所有测试 prompts = test_get_active_knowledge_prompts() test_get_task_prompt_with_id(prompts) test_async_kb_service_signature() test_database_schema() test_api_model() print("\n" + "=" * 60) print("测试完成") print("=" * 60)