diff --git a/.DS_Store b/.DS_Store index acab377..d8f9254 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/Dockerfile b/Dockerfile index ae9c31d..9845d96 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,4 +33,4 @@ COPY . . EXPOSE 8001 # 启动命令 -CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"] \ No newline at end of file +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8001"] \ No newline at end of file diff --git a/IMPLEMENTATION_SUMMARY.md b/IMPLEMENTATION_SUMMARY.md new file mode 100644 index 0000000..2ef37c1 --- /dev/null +++ b/IMPLEMENTATION_SUMMARY.md @@ -0,0 +1,167 @@ +# 提示词模版选择功能实现总结 + +## 功能概述 +实现了用户在创建会议时可以选择使用的总结模版功能。支持两种场景: +1. 手动生成会议总结时选择模版 +2. 上传音频文件时选择模版,自动总结时使用该模版 + +## 实现的功能点 + +### 1. 新增API接口 ✅ +**文件**: `app/api/endpoints/prompts.py` +- 新增 `GET /prompts/active/{task_type}` 接口 +- 功能:获取指定任务类型的所有启用状态的提示词模版 +- 返回字段:id, name, is_default +- 按默认模版优先、创建时间倒序排列 + +### 2. 修改LLM服务 ✅ +**文件**: `app/services/llm_service.py` +- 修改 `get_task_prompt()` 方法,增加 `prompt_id` 可选参数 +- 逻辑: + - 如果指定 prompt_id,查询该ID对应的提示词(需验证task_type和is_active) + - 如果不指定,使用默认提示词(is_default=TRUE) + - 如果都查不到,返回代码中的默认值 + +### 3. 修改会议总结服务 ✅ +**文件**: `app/services/async_meeting_service.py` + +#### 3.1 修改任务创建 +- `start_summary_generation()`: 增加 `prompt_id` 参数 +- 将 prompt_id 存储到 Redis 和数据库 + +#### 3.2 修改任务处理 +- `_process_task()`: 从 Redis 读取 prompt_id,传递给 `_build_prompt()` +- `_build_prompt()`: 增加 `prompt_id` 参数,传递给 `llm_service.get_task_prompt()` +- `_save_task_to_db()`: 增加 `prompt_id` 参数,存储到数据库 + +#### 3.3 修改自动总结监控 +- `monitor_and_auto_summarize()`: 增加 `prompt_id` 参数 +- 在转录完成后启动总结任务时,传递 prompt_id + +### 4. 修改音频服务 ✅ +**文件**: `app/services/audio_service.py` +- `handle_audio_upload()`: 增加 `prompt_id` 参数 +- 将 prompt_id 传递给 `monitor_and_auto_summarize()` + +### 5. 修改会议API接口 ✅ +**文件**: `app/api/endpoints/meetings.py` + +#### 5.1 手动生成总结 +- `GenerateSummaryRequest` 模型:增加 `prompt_id` 字段 +- `POST /meetings/{meeting_id}/generate-summary-async`: 传递 prompt_id 给服务层 + +#### 5.2 音频上传 +- `POST /meetings/upload-audio`: 增加 `prompt_id` 表单参数 +- 将 prompt_id 传递给 `handle_audio_upload()` + +### 6. 数据库迁移 ✅ +**文件**: `sql/add_prompt_id_to_llm_tasks.sql` +- 为 `llm_tasks` 表添加 `prompt_id` 列 +- 类型:int(11) +- 可空:YES +- 默认值:NULL +- 索引:idx_prompt_id + +## 数据流向 + +### 手动生成总结 +``` +前端 → POST /meetings/{id}/generate-summary-async (prompt_id) + → async_meeting_service.start_summary_generation(prompt_id) + → 存储到 Redis 和 DB (llm_tasks.prompt_id) + → _process_task() 读取 prompt_id + → _build_prompt(prompt_id) + → llm_service.get_task_prompt('MEETING_TASK', prompt_id) + → 获取指定模版或默认模版 +``` + +### 音频上传自动总结 +``` +前端 → POST /meetings/upload-audio (prompt_id, auto_summarize=true) + → handle_audio_upload(prompt_id) + → transcription_service.start_transcription() + → monitor_and_auto_summarize(prompt_id) + → 等待转录完成 + → start_summary_generation(prompt_id) + → (后续流程同手动生成总结) +``` + +## 向后兼容性 + +所有新增的 `prompt_id` 参数都是可选的(Optional[int] = None),确保: +1. 不传递 prompt_id 时,自动使用默认模版 +2. 现有代码无需修改即可正常工作 +3. 数据库中 prompt_id 允许为 NULL + +## 测试结果 + +执行 `test_prompt_id_feature.py` 测试脚本,所有测试通过: +- ✅ 获取启用的提示词列表 (6个模版) +- ✅ 通过prompt_id获取提示词内容 +- ✅ 获取默认提示词(不指定prompt_id) +- ✅ 验证方法签名支持prompt_id参数 +- ✅ 验证数据库schema包含prompt_id列 +- ✅ 验证API端点定义正确 + +## 使用示例 + +### 1. 获取启用的会议任务模版列表 +```bash +GET /api/prompts/active/MEETING_TASK +Authorization: Bearer +``` + +返回: +```json +{ + "code": "200", + "message": "获取启用模版列表成功", + "data": { + "prompts": [ + {"id": 1, "name": "默认会议总结", "is_default": true}, + {"id": 5, "name": "产品会议总结", "is_default": false} + ] + } +} +``` + +### 2. 手动生成总结时指定模版 +```bash +POST /api/meetings/123/generate-summary-async +Authorization: Bearer +Content-Type: application/json + +{ + "user_prompt": "重点关注技术讨论", + "prompt_id": 5 +} +``` + +### 3. 上传音频时指定模版 +```bash +POST /api/meetings/upload-audio +Authorization: Bearer +Content-Type: multipart/form-data + +- audio_file: +- meeting_id: 123 +- auto_summarize: true +- prompt_id: 5 +``` + +## 文件变更列表 + +1. `app/api/endpoints/prompts.py` - 新增API接口 +2. `app/api/endpoints/meetings.py` - 修改两个端点 +3. `app/services/llm_service.py` - 修改get_task_prompt方法 +4. `app/services/async_meeting_service.py` - 修改4个方法 +5. `app/services/audio_service.py` - 修改handle_audio_upload方法 +6. `sql/add_prompt_id_to_llm_tasks.sql` - 数据库迁移脚本 +7. `test_prompt_id_feature.py` - 测试脚本 + +## 注意事项 + +1. prompt_id 会与 task_type 一起验证,防止使用错误类型的模版 +2. 如果指定的 prompt_id 不存在或未启用,会自动使用默认模版 +3. 历史任务记录保留 prompt_id,即使对应的提示词被删除 +4. Redis 中 prompt_id 存储为字符串,使用时需转换为 int diff --git a/KB_PROMPT_ID_IMPLEMENTATION_SUMMARY.md b/KB_PROMPT_ID_IMPLEMENTATION_SUMMARY.md new file mode 100644 index 0000000..cbc9505 --- /dev/null +++ b/KB_PROMPT_ID_IMPLEMENTATION_SUMMARY.md @@ -0,0 +1,139 @@ +# 知识库提示词模版选择功能实现总结 + +## 功能概述 +为知识库生成功能添加了提示词模版选择支持,用户在创建知识库时可以选择使用的生成模版。 + +## 实现的功能点 + +### 1. 修改请求模型 ✅ +**文件**: `app/models/models.py` +- `CreateKnowledgeBaseRequest` 模型增加 `prompt_id` 字段 +- 类型:Optional[int] = None +- 不指定时使用默认模版 + +### 2. 修改知识库异步服务 ✅ +**文件**: `app/services/async_knowledge_base_service.py` + +#### 2.1 修改任务创建 +- `start_generation()`: 增加 `prompt_id` 参数 +- 将 prompt_id 存储到 Redis 和数据库 +- 支持通过 cursor 参数直接插入(事务场景) + +#### 2.2 修改任务处理 +- `_process_task()`: 从 Redis 读取 prompt_id,传递给 `_build_prompt()` +- 处理空字符串情况,转换为 None + +#### 2.3 修改提示词构建 +- `_build_prompt()`: 增加 `prompt_id` 参数 +- 调用 `llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id=prompt_id)` +- 支持获取指定模版或默认模版 + +#### 2.4 修改数据库保存 +- `_save_task_to_db()`: 增加 `prompt_id` 参数 +- 插入时包含 prompt_id 字段 + +### 3. 修改API接口 ✅ +**文件**: `app/api/endpoints/knowledge_base.py` +- `create_knowledge_base`: 从请求中获取 `prompt_id` +- 调用 `async_kb_service.start_generation()` 时传递 `prompt_id` + +### 4. 数据库字段 ✅ +- `knowledge_base_tasks` 表已包含 `prompt_id` 列(用户已添加) +- 类型:int +- 可空:NO(默认值:0) + +## 数据流向 + +``` +前端 → POST /api/knowledge-bases (prompt_id) + → CreateKnowledgeBaseRequest (prompt_id) + → async_kb_service.start_generation(prompt_id) + → 存储到 Redis 和 DB (knowledge_base_tasks.prompt_id) + → _process_task() 读取 prompt_id + → _build_prompt(prompt_id) + → llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id) + → 获取指定模版或默认模版 +``` + +## 向后兼容性 + +所有新增的 `prompt_id` 参数都是可选的(Optional[int] = None),确保: +1. 不传递 prompt_id 时,自动使用默认模版 +2. 现有代码无需修改即可正常工作 +3. 数据库中 prompt_id 有默认值 0 + +## 测试结果 + +执行 `test_kb_prompt_id_feature.py` 测试脚本,所有测试通过: +- ✅ 获取启用的知识库提示词列表 (3个模版) +- ✅ 通过prompt_id获取提示词内容 +- ✅ 获取默认提示词(不指定prompt_id) +- ✅ 验证方法签名支持prompt_id参数 +- ✅ 验证数据库schema包含prompt_id列 +- ✅ 验证API模型定义正确 + +## 使用示例 + +### 1. 获取启用的知识库任务模版列表 +```bash +GET /api/prompts/active/KNOWLEDGE_TASK +Authorization: Bearer +``` + +返回: +```json +{ + "code": "200", + "message": "获取启用模版列表成功", + "data": { + "prompts": [ + {"id": 2, "name": "默认知识库生成", "is_default": true}, + {"id": 13, "name": "分析总结模版", "is_default": false} + ] + } +} +``` + +### 2. 创建知识库时指定模版 +```bash +POST /api/knowledge-bases +Authorization: Bearer +Content-Type: application/json + +{ + "title": "产品会议知识库", + "is_shared": false, + "user_prompt": "重点提取产品功能相关信息", + "source_meeting_ids": "1,2,3", + "tags": "产品,功能", + "prompt_id": 13 +} +``` + +## 文件变更列表 + +1. `app/models/models.py` - 修改CreateKnowledgeBaseRequest模型 +2. `app/services/async_knowledge_base_service.py` - 修改5个方法 +3. `app/api/endpoints/knowledge_base.py` - 修改create_knowledge_base端点 +4. `test_kb_prompt_id_feature.py` - 测试脚本 + +## 与会议总结功能的一致性 + +知识库的实现与会议总结功能保持一致: +- 相同的prompt_id传递机制 +- 相同的Redis存储格式(字符串) +- 相同的数据库字段类型 +- 相同的向后兼容策略 +- 相同的验证逻辑(task_type + is_active) + +## 注意事项 + +1. prompt_id 会与 task_type='KNOWLEDGE_TASK' 一起验证 +2. 如果指定的 prompt_id 不存在或未启用,会自动使用默认模版 +3. 历史任务记录保留 prompt_id,即使对应的提示词被删除 +4. Redis 中 prompt_id 存储为字符串,使用时需转换为 int +5. 数据库 prompt_id 默认值为 0(表示未指定) + +## 总结 + +知识库提示词模版选择功能已完全实现并通过测试,与会议总结功能保持一致的设计和实现方式。用户现在可以在创建知识库时选择不同的生成模版,以满足不同场景的需求。 diff --git a/app.zip b/app.zip index 37c3f90..4bb133b 100644 Binary files a/app.zip and b/app.zip differ diff --git a/app/api/endpoints/admin.py b/app/api/endpoints/admin.py index 45271e7..e350b71 100644 --- a/app/api/endpoints/admin.py +++ b/app/api/endpoints/admin.py @@ -2,7 +2,10 @@ from fastapi import APIRouter, Depends from app.core.auth import get_current_admin_user, get_current_user from app.core.config import LLM_CONFIG, DEFAULT_RESET_PASSWORD, MAX_FILE_SIZE, VOICEPRINT_CONFIG, TIMELINE_PAGESIZE from app.core.response import create_api_response +from app.core.database import get_db_connection +from app.models.models import MenuInfo, MenuListResponse, RolePermissionInfo, UpdateRolePermissionsRequest, RoleInfo from pydantic import BaseModel +from typing import List import json from pathlib import Path @@ -117,3 +120,184 @@ def load_system_config(): print(f"系统配置加载成功: model={config.get('model_name')}, pagesize={config.get('TIMELINE_PAGESIZE')}") except Exception as e: print(f"加载系统配置失败,使用默认配置: {e}") + +# ========== 菜单权限管理接口 ========== + +@router.get("/admin/menus") +async def get_all_menus(current_user=Depends(get_current_admin_user)): + """ + 获取所有菜单列表 + 只有管理员才能访问 + """ + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + query = """ + SELECT menu_id, menu_code, menu_name, menu_icon, menu_url, menu_type, + parent_id, sort_order, is_active, description, created_at, updated_at + FROM menus + ORDER BY sort_order ASC, menu_id ASC + """ + cursor.execute(query) + menus = cursor.fetchall() + + menu_list = [MenuInfo(**menu) for menu in menus] + + return create_api_response( + code="200", + message="获取菜单列表成功", + data=MenuListResponse(menus=menu_list, total=len(menu_list)) + ) + except Exception as e: + return create_api_response(code="500", message=f"获取菜单列表失败: {str(e)}") + +@router.get("/admin/roles") +async def get_all_roles(current_user=Depends(get_current_admin_user)): + """ + 获取所有角色列表及其权限统计 + 只有管理员才能访问 + """ + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 查询所有角色及其权限数量 + query = """ + SELECT r.role_id, r.role_name, r.created_at, + COUNT(rmp.menu_id) as menu_count + FROM roles r + LEFT JOIN role_menu_permissions rmp ON r.role_id = rmp.role_id + GROUP BY r.role_id + ORDER BY r.role_id ASC + """ + cursor.execute(query) + roles = cursor.fetchall() + + return create_api_response( + code="200", + message="获取角色列表成功", + data={"roles": roles, "total": len(roles)} + ) + except Exception as e: + return create_api_response(code="500", message=f"获取角色列表失败: {str(e)}") + +@router.get("/admin/roles/{role_id}/permissions") +async def get_role_permissions(role_id: int, current_user=Depends(get_current_admin_user)): + """ + 获取指定角色的菜单权限 + 只有管理员才能访问 + """ + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 检查角色是否存在 + cursor.execute("SELECT role_id, role_name FROM roles WHERE role_id = %s", (role_id,)) + role = cursor.fetchone() + if not role: + return create_api_response(code="404", message="角色不存在") + + # 查询该角色的所有菜单权限 + query = """ + SELECT menu_id + FROM role_menu_permissions + WHERE role_id = %s + """ + cursor.execute(query, (role_id,)) + permissions = cursor.fetchall() + + menu_ids = [p['menu_id'] for p in permissions] + + return create_api_response( + code="200", + message="获取角色权限成功", + data=RolePermissionInfo( + role_id=role['role_id'], + role_name=role['role_name'], + menu_ids=menu_ids + ) + ) + except Exception as e: + return create_api_response(code="500", message=f"获取角色权限失败: {str(e)}") + +@router.put("/admin/roles/{role_id}/permissions") +async def update_role_permissions( + role_id: int, + request: UpdateRolePermissionsRequest, + current_user=Depends(get_current_admin_user) +): + """ + 更新指定角色的菜单权限 + 只有管理员才能访问 + """ + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 检查角色是否存在 + cursor.execute("SELECT role_id FROM roles WHERE role_id = %s", (role_id,)) + if not cursor.fetchone(): + return create_api_response(code="404", message="角色不存在") + + # 验证所有menu_id是否有效 + if request.menu_ids: + format_strings = ','.join(['%s'] * len(request.menu_ids)) + cursor.execute( + f"SELECT COUNT(*) as count FROM menus WHERE menu_id IN ({format_strings})", + tuple(request.menu_ids) + ) + valid_count = cursor.fetchone()['count'] + if valid_count != len(request.menu_ids): + return create_api_response(code="400", message="包含无效的菜单ID") + + # 删除该角色的所有现有权限 + cursor.execute("DELETE FROM role_menu_permissions WHERE role_id = %s", (role_id,)) + + # 插入新的权限 + if request.menu_ids: + insert_values = [(role_id, menu_id) for menu_id in request.menu_ids] + cursor.executemany( + "INSERT INTO role_menu_permissions (role_id, menu_id) VALUES (%s, %s)", + insert_values + ) + + connection.commit() + + return create_api_response( + code="200", + message="更新角色权限成功", + data={"role_id": role_id, "menu_count": len(request.menu_ids)} + ) + except Exception as e: + return create_api_response(code="500", message=f"更新角色权限失败: {str(e)}") + +@router.get("/menus/user") +async def get_user_menus(current_user=Depends(get_current_user)): + """ + 获取当前用户可访问的菜单列表(用于渲染下拉菜单) + 所有登录用户都可以访问 + """ + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 根据用户的role_id查询可访问的菜单 + query = """ + SELECT DISTINCT m.menu_id, m.menu_code, m.menu_name, m.menu_icon, + m.menu_url, m.menu_type, m.sort_order + FROM menus m + JOIN role_menu_permissions rmp ON m.menu_id = rmp.menu_id + WHERE rmp.role_id = %s AND m.is_active = 1 + ORDER BY m.sort_order ASC + """ + cursor.execute(query, (current_user['role_id'],)) + menus = cursor.fetchall() + + return create_api_response( + code="200", + message="获取用户菜单成功", + data={"menus": menus} + ) + except Exception as e: + return create_api_response(code="500", message=f"获取用户菜单失败: {str(e)}") + diff --git a/app/api/endpoints/admin_dashboard.py b/app/api/endpoints/admin_dashboard.py new file mode 100644 index 0000000..79dc4ed --- /dev/null +++ b/app/api/endpoints/admin_dashboard.py @@ -0,0 +1,398 @@ +from fastapi import APIRouter, Depends, Query +from app.core.auth import get_current_admin_user +from app.core.response import create_api_response +from app.core.database import get_db_connection +from app.services.jwt_service import jwt_service +from app.core.config import AUDIO_DIR, REDIS_CONFIG +from datetime import datetime +from typing import Dict, List +import os +import redis + +router = APIRouter() + +# Redis 客户端 +redis_client = redis.Redis(**REDIS_CONFIG) + +# 常量定义 +AUDIO_FILE_EXTENSIONS = ('.wav', '.mp3', '.m4a', '.aac', '.flac', '.ogg') +BYTES_TO_GB = 1024 ** 3 + + +def _build_status_condition(status: str) -> str: + """构建任务状态查询条件""" + if status == 'running': + return "AND (t.status = 'pending' OR t.status = 'processing')" + elif status == 'completed': + return "AND t.status = 'completed'" + elif status == 'failed': + return "AND t.status = 'failed'" + return "" + + +def _get_task_stats_query() -> str: + """获取任务统计的 SQL 查询""" + return """ + SELECT + COUNT(*) as total, + SUM(CASE WHEN status = 'pending' OR status = 'processing' THEN 1 ELSE 0 END) as running, + SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed, + SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed + """ + + +def _get_online_user_count(redis_client) -> int: + """从 Redis 获取在线用户数""" + try: + token_keys = redis_client.keys("token:*") + user_ids = set() + for key in token_keys: + parts = key.split(':') + if len(parts) >= 2: + user_ids.add(parts[1]) + return len(user_ids) + except Exception as e: + print(f"获取在线用户数失败: {e}") + return 0 + + +def _calculate_audio_storage() -> Dict[str, float]: + """计算音频文件存储统计""" + audio_files_count = 0 + audio_total_size = 0 + + try: + if os.path.exists(AUDIO_DIR): + for root, _, files in os.walk(AUDIO_DIR): + for file in files: + if file.endswith(AUDIO_FILE_EXTENSIONS): + audio_files_count += 1 + file_path = os.path.join(root, file) + try: + audio_total_size += os.path.getsize(file_path) + except OSError: + continue + except Exception as e: + print(f"统计音频文件失败: {e}") + + return { + "audio_files_count": audio_files_count, + "audio_total_size_gb": round(audio_total_size / BYTES_TO_GB, 2) + } + + +@router.get("/admin/dashboard/stats") +async def get_dashboard_stats(current_user=Depends(get_current_admin_user)): + """获取管理员 Dashboard 统计数据""" + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 1. 用户统计 + today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + + cursor.execute("SELECT COUNT(*) as total FROM users") + total_users = cursor.fetchone()['total'] + + cursor.execute( + "SELECT COUNT(*) as count FROM users WHERE created_at >= %s", + (today_start,) + ) + today_new_users = cursor.fetchone()['count'] + + online_users = _get_online_user_count(redis_client) + + # 2. 会议统计 + cursor.execute("SELECT COUNT(*) as total FROM meetings") + total_meetings = cursor.fetchone()['total'] + + cursor.execute( + "SELECT COUNT(*) as count FROM meetings WHERE created_at >= %s", + (today_start,) + ) + today_new_meetings = cursor.fetchone()['count'] + + # 3. 任务统计 + task_stats_query = _get_task_stats_query() + + # 转录任务 + cursor.execute(f"{task_stats_query} FROM transcript_tasks") + transcription_stats = cursor.fetchone() or {'total': 0, 'running': 0, 'completed': 0, 'failed': 0} + + # 总结任务 + cursor.execute(f"{task_stats_query} FROM llm_tasks") + summary_stats = cursor.fetchone() or {'total': 0, 'running': 0, 'completed': 0, 'failed': 0} + + # 知识库任务 + cursor.execute(f"{task_stats_query} FROM knowledge_base_tasks") + kb_stats = cursor.fetchone() or {'total': 0, 'running': 0, 'completed': 0, 'failed': 0} + + # 4. 音频存储统计 + storage_stats = _calculate_audio_storage() + + # 组装返回数据 + stats = { + "users": { + "total": total_users, + "today_new": today_new_users, + "online": online_users + }, + "meetings": { + "total": total_meetings, + "today_new": today_new_meetings + }, + "tasks": { + "transcription": { + "total": transcription_stats['total'] or 0, + "running": transcription_stats['running'] or 0, + "completed": transcription_stats['completed'] or 0, + "failed": transcription_stats['failed'] or 0 + }, + "summary": { + "total": summary_stats['total'] or 0, + "running": summary_stats['running'] or 0, + "completed": summary_stats['completed'] or 0, + "failed": summary_stats['failed'] or 0 + }, + "knowledge_base": { + "total": kb_stats['total'] or 0, + "running": kb_stats['running'] or 0, + "completed": kb_stats['completed'] or 0, + "failed": kb_stats['failed'] or 0 + } + }, + "storage": storage_stats + } + + return create_api_response(code="200", message="获取统计数据成功", data=stats) + + except Exception as e: + print(f"获取Dashboard统计数据失败: {e}") + return create_api_response(code="500", message=f"获取统计数据失败: {str(e)}") + + +@router.get("/admin/online-users") +async def get_online_users(current_user=Depends(get_current_admin_user)): + """获取在线用户列表""" + try: + token_keys = redis_client.keys("token:*") + + # 提取用户ID并去重 + user_tokens = {} + for key in token_keys: + parts = key.split(':') + if len(parts) >= 3: + user_id = int(parts[1]) + token = parts[2] + if user_id not in user_tokens: + user_tokens[user_id] = [] + user_tokens[user_id].append({'token': token, 'key': key}) + + # 查询用户信息 + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + online_users_list = [] + for user_id, tokens in user_tokens.items(): + cursor.execute( + "SELECT user_id, username, caption, email, role_id FROM users WHERE user_id = %s", + (user_id,) + ) + user = cursor.fetchone() + if user: + ttl_seconds = redis_client.ttl(tokens[0]['key']) + online_users_list.append({ + **user, + 'token_count': len(tokens), + 'ttl_seconds': ttl_seconds, + 'ttl_hours': round(ttl_seconds / 3600, 1) if ttl_seconds > 0 else 0 + }) + + # 按用户ID排序 + online_users_list.sort(key=lambda x: x['user_id']) + + return create_api_response( + code="200", + message="获取在线用户列表成功", + data={"users": online_users_list, "total": len(online_users_list)} + ) + + except Exception as e: + print(f"获取在线用户列表失败: {e}") + return create_api_response(code="500", message=f"获取在线用户列表失败: {str(e)}") + + +@router.post("/admin/kick-user/{user_id}") +async def kick_user(user_id: int, current_user=Depends(get_current_admin_user)): + """踢出用户(撤销该用户的所有 token)""" + try: + revoked_count = jwt_service.revoke_all_user_tokens(user_id) + + if revoked_count > 0: + return create_api_response( + code="200", + message=f"已踢出用户,撤销了 {revoked_count} 个 token", + data={"user_id": user_id, "revoked_count": revoked_count} + ) + else: + return create_api_response( + code="404", + message="该用户当前不在线或未找到 token" + ) + + except Exception as e: + print(f"踢出用户失败: {e}") + return create_api_response(code="500", message=f"踢出用户失败: {str(e)}") + + +@router.get("/admin/tasks/monitor") +async def monitor_tasks( + task_type: str = Query('all', description="任务类型: all, transcription, summary, knowledge_base"), + status: str = Query('all', description="任务状态: all, running, completed, failed"), + limit: int = Query(20, ge=1, le=100, description="返回数量限制"), + current_user=Depends(get_current_admin_user) +): + """监控任务进度""" + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + tasks = [] + status_condition = _build_status_condition(status) + + # 转录任务 + if task_type in ['all', 'transcription']: + query = f""" + SELECT + t.task_id, + 'transcription' as task_type, + t.meeting_id, + m.title as meeting_title, + t.status, + t.progress, + t.error_message, + t.created_at, + t.completed_at, + u.username as creator_name + FROM transcript_tasks t + LEFT JOIN meetings m ON t.meeting_id = m.meeting_id + LEFT JOIN users u ON m.user_id = u.user_id + WHERE 1=1 {status_condition} + ORDER BY t.created_at DESC + LIMIT %s + """ + cursor.execute(query, (limit,)) + tasks.extend(cursor.fetchall()) + + # 总结任务 + if task_type in ['all', 'summary']: + query = f""" + SELECT + t.task_id, + 'summary' as task_type, + t.meeting_id, + m.title as meeting_title, + t.status, + NULL as progress, + t.error_message, + t.created_at, + t.completed_at, + u.username as creator_name + FROM llm_tasks t + LEFT JOIN meetings m ON t.meeting_id = m.meeting_id + LEFT JOIN users u ON m.user_id = u.user_id + WHERE 1=1 {status_condition} + ORDER BY t.created_at DESC + LIMIT %s + """ + cursor.execute(query, (limit,)) + tasks.extend(cursor.fetchall()) + + # 知识库任务 + if task_type in ['all', 'knowledge_base']: + query = f""" + SELECT + t.task_id, + 'knowledge_base' as task_type, + t.kb_id as meeting_id, + k.title as meeting_title, + t.status, + t.progress, + t.error_message, + t.created_at, + t.updated_at, + u.username as creator_name + FROM knowledge_base_tasks t + LEFT JOIN knowledge_bases k ON t.kb_id = k.kb_id + LEFT JOIN users u ON k.creator_id = u.user_id + WHERE 1=1 {status_condition} + ORDER BY t.created_at DESC + LIMIT %s + """ + cursor.execute(query, (limit,)) + tasks.extend(cursor.fetchall()) + + # 按创建时间排序并限制返回数量 + tasks.sort(key=lambda x: x['created_at'], reverse=True) + tasks = tasks[:limit] + + return create_api_response( + code="200", + message="获取任务监控数据成功", + data={"tasks": tasks, "total": len(tasks)} + ) + + except Exception as e: + print(f"获取任务监控数据失败: {e}") + import traceback + traceback.print_exc() + return create_api_response(code="500", message=f"获取任务监控数据失败: {str(e)}") + + +@router.get("/admin/system/resources") +async def get_system_resources(current_user=Depends(get_current_admin_user)): + """获取服务器资源使用情况""" + try: + import psutil + + # CPU 使用率 + cpu_percent = psutil.cpu_percent(interval=1) + cpu_count = psutil.cpu_count() + + # 内存使用情况 + memory = psutil.virtual_memory() + memory_total_gb = round(memory.total / BYTES_TO_GB, 2) + memory_used_gb = round(memory.used / BYTES_TO_GB, 2) + + # 磁盘使用情况 + disk = psutil.disk_usage('/') + disk_total_gb = round(disk.total / BYTES_TO_GB, 2) + disk_used_gb = round(disk.used / BYTES_TO_GB, 2) + + resources = { + "cpu": { + "percent": cpu_percent, + "count": cpu_count + }, + "memory": { + "total_gb": memory_total_gb, + "used_gb": memory_used_gb, + "percent": memory.percent + }, + "disk": { + "total_gb": disk_total_gb, + "used_gb": disk_used_gb, + "percent": disk.percent + }, + "timestamp": datetime.now().isoformat() + } + + return create_api_response(code="200", message="获取系统资源成功", data=resources) + + except ImportError: + return create_api_response( + code="500", + message="psutil 库未安装,请运行: pip install psutil" + ) + except Exception as e: + print(f"获取系统资源失败: {e}") + return create_api_response(code="500", message=f"获取系统资源失败: {str(e)}") diff --git a/app/api/endpoints/audio.py b/app/api/endpoints/audio.py new file mode 100644 index 0000000..9ab8a16 --- /dev/null +++ b/app/api/endpoints/audio.py @@ -0,0 +1,568 @@ +from fastapi import APIRouter, UploadFile, File, Form, Depends, HTTPException, BackgroundTasks +from app.core.database import get_db_connection +from app.core.config import BASE_DIR, AUDIO_DIR, TEMP_UPLOAD_DIR +from app.core.auth import get_current_user +from app.core.response import create_api_response +from app.services.async_transcription_service import AsyncTranscriptionService +from app.services.async_meeting_service import async_meeting_service +from app.services.audio_service import handle_audio_upload +from pydantic import BaseModel +from typing import Optional, List +from datetime import datetime, timedelta +import os +import uuid +import shutil +import json +import re +from pathlib import Path + +router = APIRouter() +transcription_service = AsyncTranscriptionService() + +# 临时上传目录 - 放在项目目录下 +TEMP_UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + +# 配置常量 +MAX_CHUNK_SIZE = 2 * 1024 * 1024 # 2MB per chunk +MAX_TOTAL_SIZE = 500 * 1024 * 1024 # 500MB total +MAX_DURATION = 3600 # 1 hour max recording +SESSION_EXPIRE_HOURS = 1 # 会话1小时后过期 + +# 支持的音频格式 +SUPPORTED_MIME_TYPES = { + 'audio/webm;codecs=opus': '.webm', + 'audio/webm': '.webm', + 'audio/ogg;codecs=opus': '.ogg', + 'audio/mp4': '.m4a', + 'audio/mpeg': '.mp3' +} + + +# ============ Pydantic Models ============ + +class InitUploadRequest(BaseModel): + meeting_id: int + mime_type: str + estimated_duration: Optional[int] = None # 预计时长(秒) + + +class CompleteUploadRequest(BaseModel): + session_id: str + meeting_id: int + total_chunks: int + mime_type: str + auto_transcribe: bool = True + auto_summarize: bool = True + prompt_id: Optional[int] = None # 提示词模版ID(可选) + + +class CancelUploadRequest(BaseModel): + session_id: str + + +# ============ 工具函数 ============ + +def validate_session_id(session_id: str) -> str: + """验证session_id格式,防止路径注入攻击""" + if not re.match(r'^sess_\d+_[a-zA-Z0-9]+$', session_id): + raise ValueError("Invalid session_id format") + return session_id + + +def validate_mime_type(mime_type: str) -> str: + """验证MIME类型是否支持""" + if mime_type not in SUPPORTED_MIME_TYPES: + raise ValueError(f"Unsupported MIME type: {mime_type}") + return SUPPORTED_MIME_TYPES[mime_type] + + +def get_session_dir(session_id: str) -> Path: + """获取会话目录路径""" + validate_session_id(session_id) + return TEMP_UPLOAD_DIR / session_id + + +def get_session_metadata_path(session_id: str) -> Path: + """获取会话metadata文件路径""" + return get_session_dir(session_id) / "metadata.json" + + +def create_session_metadata(session_id: str, meeting_id: int, mime_type: str, user_id: int) -> dict: + """创建会话metadata""" + now = datetime.now() + expires_at = now + timedelta(hours=SESSION_EXPIRE_HOURS) + + metadata = { + "session_id": session_id, + "meeting_id": meeting_id, + "user_id": user_id, + "mime_type": mime_type, + "total_chunks": None, + "received_chunks": [], + "created_at": now.isoformat(), + "expires_at": expires_at.isoformat() + } + + return metadata + + +def save_session_metadata(session_id: str, metadata: dict): + """保存会话metadata""" + metadata_path = get_session_metadata_path(session_id) + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, ensure_ascii=False, indent=2) + + +def load_session_metadata(session_id: str) -> dict: + """加载会话metadata""" + metadata_path = get_session_metadata_path(session_id) + if not metadata_path.exists(): + raise FileNotFoundError(f"Session {session_id} not found") + + with open(metadata_path, 'r', encoding='utf-8') as f: + return json.load(f) + + +def update_session_chunks(session_id: str, chunk_index: int): + """更新已接收的分片列表""" + metadata = load_session_metadata(session_id) + + if chunk_index not in metadata['received_chunks']: + metadata['received_chunks'].append(chunk_index) + metadata['received_chunks'].sort() + + save_session_metadata(session_id, metadata) + + +def get_session_total_size(session_id: str) -> int: + """获取会话已上传的总大小""" + session_dir = get_session_dir(session_id) + total_size = 0 + + if session_dir.exists(): + for chunk_file in session_dir.glob("chunk_*.webm"): + total_size += chunk_file.stat().st_size + + return total_size + + +def merge_audio_chunks(session_id: str, meeting_id: int, total_chunks: int, mime_type: str) -> str: + """合并音频分片""" + session_dir = get_session_dir(session_id) + + # 1. 验证分片完整性 + missing = [] + for i in range(total_chunks): + chunk_path = session_dir / f"chunk_{i:04d}.webm" + if not chunk_path.exists(): + missing.append(i) + + if missing: + raise ValueError(f"Missing chunks: {missing}") + + # 2. 创建输出目录 + meeting_audio_dir = AUDIO_DIR / str(meeting_id) + meeting_audio_dir.mkdir(parents=True, exist_ok=True) + + # 3. 生成输出文件名 + file_extension = validate_mime_type(mime_type) + output_filename = f"{uuid.uuid4()}{file_extension}" + output_path = meeting_audio_dir / output_filename + + # 4. 按序合并分片 + with open(output_path, 'wb') as outfile: + for i in range(total_chunks): + chunk_path = session_dir / f"chunk_{i:04d}.webm" + with open(chunk_path, 'rb') as infile: + outfile.write(infile.read()) + + # 5. 清理临时文件 + shutil.rmtree(session_dir) + + # 返回相对路径 + return f"/{output_path.relative_to(BASE_DIR)}" + + +def cleanup_session(session_id: str): + """清理会话文件""" + session_dir = get_session_dir(session_id) + if session_dir.exists(): + shutil.rmtree(session_dir) + + +def cleanup_expired_sessions(): + """清理过期的会话(可以由定时任务调用)""" + now = datetime.now() + cleaned_count = 0 + + if not TEMP_UPLOAD_DIR.exists(): + return cleaned_count + + for session_dir in TEMP_UPLOAD_DIR.iterdir(): + if not session_dir.is_dir(): + continue + + metadata_path = session_dir / "metadata.json" + if metadata_path.exists(): + try: + with open(metadata_path, 'r') as f: + metadata = json.load(f) + + expires_at = datetime.fromisoformat(metadata['expires_at']) + if now > expires_at: + shutil.rmtree(session_dir) + cleaned_count += 1 + print(f"Cleaned up expired session: {session_dir.name}") + except Exception as e: + print(f"Error cleaning up session {session_dir.name}: {e}") + + return cleaned_count + + +# ============ API Endpoints ============ + +@router.post("/audio/stream/init") +async def init_upload_session( + request: InitUploadRequest, + current_user: dict = Depends(get_current_user) +): + """ + 初始化音频流式上传会话 + + 创建临时目录,生成session_id,返回给客户端用于后续分片上传 + """ + try: + # 1. 验证会议是否存在且属于当前用户 + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + cursor.execute( + "SELECT user_id FROM meetings WHERE meeting_id = %s", + (request.meeting_id,) + ) + meeting = cursor.fetchone() + + if not meeting: + return create_api_response( + code="404", + message="会议不存在" + ) + + if meeting['user_id'] != current_user['user_id']: + return create_api_response( + code="403", + message="无权限操作此会议" + ) + + # 2. 验证MIME类型 + try: + validate_mime_type(request.mime_type) + except ValueError as e: + return create_api_response( + code="400", + message=str(e) + ) + + # 3. 生成session_id + timestamp = int(datetime.now().timestamp() * 1000) + random_str = uuid.uuid4().hex[:8] + session_id = f"sess_{timestamp}_{random_str}" + + # 4. 创建会话目录 + session_dir = get_session_dir(session_id) + session_dir.mkdir(parents=True, exist_ok=True) + + # 5. 创建并保存metadata + metadata = create_session_metadata( + session_id=session_id, + meeting_id=request.meeting_id, + mime_type=request.mime_type, + user_id=current_user['user_id'] + ) + save_session_metadata(session_id, metadata) + + # 6. 清理过期会话 + cleanup_expired_sessions() + + return create_api_response( + code="200", + message="上传会话初始化成功", + data={ + "session_id": session_id, + "chunk_size": MAX_CHUNK_SIZE, + "max_chunks": 1000 + } + ) + + except Exception as e: + print(f"Error initializing upload session: {e}") + return create_api_response( + code="500", + message=f"初始化上传会话失败: {str(e)}" + ) + + +@router.post("/audio/stream/chunk") +async def upload_audio_chunk( + session_id: str = Form(...), + chunk_index: int = Form(...), + chunk: UploadFile = File(...), + current_user: dict = Depends(get_current_user) +): + """ + 上传音频分片 + + 接收并保存音频分片文件 + """ + try: + # 1. 验证session_id格式 + try: + validate_session_id(session_id) + except ValueError: + return create_api_response( + code="400", + message="Invalid session_id format" + ) + + # 2. 加载session metadata + try: + metadata = load_session_metadata(session_id) + except FileNotFoundError: + return create_api_response( + code="404", + message="Session not found" + ) + + # 3. 验证会话所有权 + if metadata['user_id'] != current_user['user_id']: + return create_api_response( + code="403", + message="Permission denied" + ) + + # 4. 验证分片大小 + chunk_data = await chunk.read() + if len(chunk_data) > MAX_CHUNK_SIZE: + return create_api_response( + code="400", + message=f"Chunk size exceeds {MAX_CHUNK_SIZE // (1024*1024)}MB limit" + ) + + # 5. 验证总大小 + session_total = get_session_total_size(session_id) + if session_total + len(chunk_data) > MAX_TOTAL_SIZE: + return create_api_response( + code="400", + message=f"Total size exceeds {MAX_TOTAL_SIZE // (1024*1024)}MB limit" + ) + + # 6. 保存分片文件 + session_dir = get_session_dir(session_id) + chunk_path = session_dir / f"chunk_{chunk_index:04d}.webm" + + with open(chunk_path, 'wb') as f: + f.write(chunk_data) + + # 7. 更新metadata + update_session_chunks(session_id, chunk_index) + + # 8. 获取已接收分片总数 + metadata = load_session_metadata(session_id) + total_received = len(metadata['received_chunks']) + + return create_api_response( + code="200", + message="分片上传成功", + data={ + "session_id": session_id, + "chunk_index": chunk_index, + "received": True, + "total_received": total_received + } + ) + + except Exception as e: + print(f"Error uploading chunk: {e}") + return create_api_response( + code="500", + message=f"分片上传失败: {str(e)}", + data={ + "session_id": session_id, + "chunk_index": chunk_index, + "should_retry": True + } + ) + + +@router.post("/audio/stream/complete") +async def complete_upload( + request: CompleteUploadRequest, + background_tasks: BackgroundTasks, + current_user: dict = Depends(get_current_user) +): + """ + 完成上传并合并分片 + + 验证分片完整性,合并所有分片,保存最终音频文件,可选启动转录任务和自动总结 + """ + try: + # 1. 验证session_id + try: + validate_session_id(request.session_id) + except ValueError: + return create_api_response( + code="400", + message="Invalid session_id format" + ) + + # 2. 加载session metadata + try: + metadata = load_session_metadata(request.session_id) + except FileNotFoundError: + return create_api_response( + code="404", + message="Session not found" + ) + + # 3. 验证会话所有权 + if metadata['user_id'] != current_user['user_id']: + return create_api_response( + code="403", + message="Permission denied" + ) + + # 4. 验证会议ID一致性 + if metadata['meeting_id'] != request.meeting_id: + return create_api_response( + code="400", + message="Meeting ID mismatch" + ) + + # 5. 合并音频分片 + try: + file_path = merge_audio_chunks( + session_id=request.session_id, + meeting_id=request.meeting_id, + total_chunks=request.total_chunks, + mime_type=request.mime_type + ) + except ValueError as e: + # 分片不完整 + return create_api_response( + code="500", + message=f"音频合并失败:{str(e)}", + data={ + "should_retry": True + } + ) + + # 6. 获取文件信息 + full_path = BASE_DIR / file_path.lstrip('/') + file_size = full_path.stat().st_size + file_name = full_path.name + + # 7. 调用 audio_service 处理文件(数据库更新、启动转录和总结) + result = handle_audio_upload( + file_path=file_path, + file_name=file_name, + file_size=file_size, + meeting_id=request.meeting_id, + current_user=current_user, + auto_summarize=request.auto_summarize, + background_tasks=background_tasks, + prompt_id=request.prompt_id # 传递提示词模版ID + ) + + # 如果处理失败,返回错误 + if not result["success"]: + return result["response"] + + # 8. 返回成功响应 + transcription_task_id = result["transcription_task_id"] + message_suffix = "" + if transcription_task_id: + if request.auto_summarize: + message_suffix = ",正在进行转录和总结" + else: + message_suffix = ",正在进行转录" + + return create_api_response( + code="200", + message="音频上传完成" + message_suffix, + data={ + "meeting_id": request.meeting_id, + "file_path": file_path, + "file_size": file_size, + "duration": None, # 可以通过ffprobe获取,但不是必需的 + "task_id": transcription_task_id, + "task_status": "pending" if transcription_task_id else None, + "auto_summarize": request.auto_summarize + } + ) + + except Exception as e: + print(f"Error completing upload: {e}") + return create_api_response( + code="500", + message=f"完成上传失败: {str(e)}" + ) + + +@router.delete("/audio/stream/cancel") +async def cancel_upload( + request: CancelUploadRequest, + current_user: dict = Depends(get_current_user) +): + """ + 取消上传会话 + + 清理会话临时文件和目录 + """ + try: + # 1. 验证session_id + try: + validate_session_id(request.session_id) + except ValueError: + return create_api_response( + code="400", + message="Invalid session_id format" + ) + + # 2. 加载session metadata(验证所有权) + try: + metadata = load_session_metadata(request.session_id) + + # 验证会话所有权 + if metadata['user_id'] != current_user['user_id']: + return create_api_response( + code="403", + message="Permission denied" + ) + except FileNotFoundError: + # 会话不存在,视为已清理 + return create_api_response( + code="200", + message="上传会话已取消", + data={ + "session_id": request.session_id, + "cleaned": True + } + ) + + # 3. 清理会话文件 + cleanup_session(request.session_id) + + return create_api_response( + code="200", + message="上传会话已取消", + data={ + "session_id": request.session_id, + "cleaned": True + } + ) + + except Exception as e: + print(f"Error canceling upload: {e}") + return create_api_response( + code="500", + message=f"取消上传失败: {str(e)}" + ) diff --git a/app/api/endpoints/client_downloads.py b/app/api/endpoints/client_downloads.py index 0f7bcb7..1aaca1e 100644 --- a/app/api/endpoints/client_downloads.py +++ b/app/api/endpoints/client_downloads.py @@ -12,7 +12,7 @@ from typing import Optional router = APIRouter() -@router.get("/downloads", response_model=dict) +@router.get("/clients", response_model=dict) async def get_client_downloads( platform_type: Optional[str] = None, platform_name: Optional[str] = None, @@ -81,7 +81,7 @@ async def get_client_downloads( ) -@router.get("/downloads/latest", response_model=dict) +@router.get("/clients/latest", response_model=dict) async def get_latest_clients(): """ 获取所有平台的最新版本客户端(公开接口,用于首页下载) @@ -102,19 +102,23 @@ async def get_latest_clients(): # 按平台类型分组 mobile_clients = [] desktop_clients = [] + terminal_clients = [] for client in clients: if client['platform_type'] == 'mobile': mobile_clients.append(client) - else: + elif client['platform_type'] == 'desktop': desktop_clients.append(client) + elif client['platform_type'] == 'terminal': + terminal_clients.append(client) return create_api_response( code="200", message="获取成功", data={ "mobile": mobile_clients, - "desktop": desktop_clients + "desktop": desktop_clients, + "terminal": terminal_clients } ) @@ -125,10 +129,17 @@ async def get_latest_clients(): ) -@router.get("/downloads/{platform_name}/latest", response_model=dict) -async def get_latest_version_by_platform(platform_name: str): +@router.get("/clients/latest/by-platform", response_model=dict) +async def get_latest_version_by_platform_type_and_name( + platform_type: str, + platform_name: str +): """ - 获取指定平台的最新版本(公开接口,用于客户端版本检查) + 通过平台类型和平台名称获取最新版本(公开接口,用于客户端版本检查) + + 参数: + platform_type: 平台类型 (mobile, desktop, terminal) + platform_name: 具体平台 (ios, android, windows, mac_intel, mac_m, linux, mcu) """ try: with get_db_connection() as conn: @@ -136,17 +147,20 @@ async def get_latest_version_by_platform(platform_name: str): query = """ SELECT * FROM client_downloads - WHERE platform_name = %s AND is_active = TRUE AND is_latest = TRUE + WHERE platform_type = %s + AND platform_name = %s + AND is_active = TRUE + AND is_latest = TRUE LIMIT 1 """ - cursor.execute(query, (platform_name,)) + cursor.execute(query, (platform_type, platform_name)) client = cursor.fetchone() cursor.close() if not client: return create_api_response( code="404", - message=f"未找到平台 {platform_name} 的客户端" + message=f"未找到平台类型 {platform_type} 下的 {platform_name} 客户端" ) return create_api_response( @@ -162,7 +176,7 @@ async def get_latest_version_by_platform(platform_name: str): ) -@router.get("/downloads/{id}", response_model=dict) +@router.get("/clients/{id}", response_model=dict) async def get_client_download_by_id(id: int): """ 获取指定ID的客户端详情(公开接口) @@ -195,7 +209,7 @@ async def get_client_download_by_id(id: int): ) -@router.post("/downloads", response_model=dict) +@router.post("/clients", response_model=dict) async def create_client_download( request: CreateClientDownloadRequest, current_user: dict = Depends(get_current_admin_user) @@ -255,7 +269,7 @@ async def create_client_download( ) -@router.put("/downloads/{id}", response_model=dict) +@router.put("/clients/{id}", response_model=dict) async def update_client_download( id: int, request: UpdateClientDownloadRequest, @@ -353,7 +367,7 @@ async def update_client_download( ) -@router.delete("/downloads/{id}", response_model=dict) +@router.delete("/clients/{id}", response_model=dict) async def delete_client_download( id: int, current_user: dict = Depends(get_current_admin_user) diff --git a/app/api/endpoints/knowledge_base.py b/app/api/endpoints/knowledge_base.py index 93d129d..a198997 100644 --- a/app/api/endpoints/knowledge_base.py +++ b/app/api/endpoints/knowledge_base.py @@ -131,6 +131,7 @@ def create_knowledge_base( kb_id=kb_id, user_prompt=request.user_prompt, source_meeting_ids=request.source_meeting_ids, + prompt_id=request.prompt_id, # 传递 prompt_id 参数 cursor=cursor ) diff --git a/app/api/endpoints/meetings.py b/app/api/endpoints/meetings.py index 03d5821..68f2e43 100644 --- a/app/api/endpoints/meetings.py +++ b/app/api/endpoints/meetings.py @@ -7,6 +7,7 @@ import app.core.config as config_module from app.services.llm_service import LLMService from app.services.async_transcription_service import AsyncTranscriptionService from app.services.async_meeting_service import async_meeting_service +from app.services.audio_service import handle_audio_upload from app.core.auth import get_current_user from app.core.response import create_api_response from typing import List, Optional @@ -25,194 +26,8 @@ transcription_service = AsyncTranscriptionService() class GenerateSummaryRequest(BaseModel): user_prompt: Optional[str] = "" + prompt_id: Optional[int] = None # 提示词模版ID,如果不指定则使用默认模版 -def _handle_audio_upload( - audio_file: UploadFile, - meeting_id: int, - force_replace_bool: bool, - current_user: dict -): - """ - 音频上传的公共处理逻辑 - - Args: - audio_file: 上传的音频文件 - meeting_id: 会议ID - force_replace_bool: 是否强制替换 - current_user: 当前用户 - - Returns: - dict: { - "success": bool, # 是否成功 - "needs_confirmation": bool, # 是否需要用户确认 - "response": dict, # 如果需要返回,这里是响应数据 - "file_info": dict, # 文件信息 (成功时) - "transcription_task_id": str, # 转录任务ID (成功时) - "replaced_existing": bool, # 是否替换了现有文件 (成功时) - "has_transcription": bool # 原来是否有转录记录 (成功时) - } - """ - # 1. 文件类型验证 - file_extension = os.path.splitext(audio_file.filename)[1].lower() - if file_extension not in ALLOWED_EXTENSIONS: - return { - "success": False, - "response": create_api_response( - code="400", - message=f"不支持的文件类型。支持的类型: {', '.join(ALLOWED_EXTENSIONS)}" - ) - } - - # 2. 文件大小验证 - max_file_size = getattr(config_module, 'MAX_FILE_SIZE', 100 * 1024 * 1024) - if audio_file.size > max_file_size: - return { - "success": False, - "response": create_api_response( - code="400", - message=f"文件大小超过 {max_file_size // (1024 * 1024)}MB 限制" - ) - } - - # 3. 权限和已有文件检查 - try: - with get_db_connection() as connection: - cursor = connection.cursor(dictionary=True) - - # 检查会议是否存在及权限 - cursor.execute("SELECT user_id FROM meetings WHERE meeting_id = %s", (meeting_id,)) - meeting = cursor.fetchone() - if not meeting: - return { - "success": False, - "response": create_api_response(code="404", message="会议不存在") - } - if meeting['user_id'] != current_user['user_id']: - return { - "success": False, - "response": create_api_response(code="403", message="无权限操作此会议") - } - - # 检查已有音频文件 - cursor.execute( - "SELECT file_name, file_path, upload_time FROM audio_files WHERE meeting_id = %s", - (meeting_id,) - ) - existing_info = cursor.fetchone() - - # 检查是否有转录记录 - has_transcription = False - if existing_info: - cursor.execute( - "SELECT COUNT(*) as segment_count FROM transcript_segments WHERE meeting_id = %s", - (meeting_id,) - ) - has_transcription = cursor.fetchone()['segment_count'] > 0 - - cursor.close() - except Exception as e: - return { - "success": False, - "response": create_api_response(code="500", message=f"检查已有文件失败: {str(e)}") - } - - # 4. 如果已有转录记录且未确认替换,返回提示 - if existing_info and has_transcription and not force_replace_bool: - return { - "success": False, - "needs_confirmation": True, - "response": create_api_response( - code="300", - message="该会议已有音频文件和转录记录,重新上传将删除现有的转录内容和会议总结", - data={ - "requires_confirmation": True, - "existing_file": { - "file_name": existing_info['file_name'], - "upload_time": existing_info['upload_time'].isoformat() if existing_info['upload_time'] else None - } - } - ) - } - - # 5. 保存音频文件 - meeting_dir = AUDIO_DIR / str(meeting_id) - meeting_dir.mkdir(exist_ok=True) - unique_filename = f"{uuid.uuid4()}{file_extension}" - absolute_path = meeting_dir / unique_filename - relative_path = absolute_path.relative_to(BASE_DIR) - - try: - with open(absolute_path, "wb") as buffer: - shutil.copyfileobj(audio_file.file, buffer) - except Exception as e: - return { - "success": False, - "response": create_api_response(code="500", message=f"保存文件失败: {str(e)}") - } - - transcription_task_id = None - replaced_existing = existing_info is not None - - try: - # 6. 更新数据库记录 - with get_db_connection() as connection: - cursor = connection.cursor(dictionary=True) - - # 删除旧的音频文件 - if replaced_existing and force_replace_bool: - if existing_info and existing_info['file_path']: - old_file_path = BASE_DIR / existing_info['file_path'].lstrip('/') - if old_file_path.exists(): - try: - os.remove(old_file_path) - print(f"Deleted old audio file: {old_file_path}") - except Exception as e: - print(f"Warning: Failed to delete old file {old_file_path}: {e}") - - # 更新或插入音频文件记录 - if replaced_existing: - cursor.execute( - 'UPDATE audio_files SET file_name = %s, file_path = %s, file_size = %s, upload_time = NOW(), task_id = NULL WHERE meeting_id = %s', - (audio_file.filename, '/' + str(relative_path), audio_file.size, meeting_id) - ) - else: - cursor.execute( - 'INSERT INTO audio_files (meeting_id, file_name, file_path, file_size, upload_time) VALUES (%s, %s, %s, %s, NOW())', - (meeting_id, audio_file.filename, '/' + str(relative_path), audio_file.size) - ) - - connection.commit() - cursor.close() - - # 7. 启动转录任务 - try: - transcription_task_id = transcription_service.start_transcription(meeting_id, '/' + str(relative_path)) - print(f"Transcription task {transcription_task_id} started for meeting {meeting_id}") - except Exception as e: - print(f"Failed to start transcription: {e}") - raise - - except Exception as e: - # 出错时清理已上传的文件 - if os.path.exists(absolute_path): - os.remove(absolute_path) - return { - "success": False, - "response": create_api_response(code="500", message=f"处理失败: {str(e)}") - } - - # 8. 返回成功结果 - return { - "success": True, - "file_info": { - "file_name": audio_file.filename, - "file_path": '/' + str(relative_path), - "file_size": audio_file.size - }, - "transcription_task_id": transcription_task_id, - "replaced_existing": replaced_existing, - "has_transcription": has_transcription - } def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]: """ @@ -559,8 +374,8 @@ def get_meeting_for_edit(meeting_id: int, current_user: dict = Depends(get_curre async def upload_audio( audio_file: UploadFile = File(...), meeting_id: int = Form(...), - force_replace: str = Form("false"), auto_summarize: str = Form("true"), + prompt_id: Optional[int] = Form(None), # 可选的提示词模版ID background_tasks: BackgroundTasks = None, current_user: dict = Depends(get_current_user) ): @@ -572,41 +387,84 @@ async def upload_audio( Args: audio_file: 音频文件 meeting_id: 会议ID - force_replace: 是否强制替换("true"/"false") auto_summarize: 是否自动生成总结("true"/"false",默认"true") + prompt_id: 提示词模版ID(可选,如果不指定则使用默认模版) background_tasks: FastAPI后台任务 current_user: 当前登录用户 Returns: - HTTP 300: 需要用户确认(已有转录记录) HTTP 200: 处理成功,返回任务ID HTTP 400/403/404/500: 各种错误情况 """ - force_replace_bool = force_replace.lower() in ("true", "1", "yes") auto_summarize_bool = auto_summarize.lower() in ("true", "1", "yes") - # 调用公共处理方法 - result = _handle_audio_upload(audio_file, meeting_id, force_replace_bool, current_user) + # 打印接收到的 prompt_id + print(f"[Upload Audio] Meeting ID: {meeting_id}, Received prompt_id: {prompt_id}, Type: {type(prompt_id)}, Auto-summarize: {auto_summarize_bool}") - # 如果不成功,直接返回响应 + # 1. 文件类型验证 + file_extension = os.path.splitext(audio_file.filename)[1].lower() + if file_extension not in ALLOWED_EXTENSIONS: + return create_api_response( + code="400", + message=f"不支持的文件类型。支持的类型: {', '.join(ALLOWED_EXTENSIONS)}" + ) + + # 2. 文件大小验证 + max_file_size = getattr(config_module, 'MAX_FILE_SIZE', 100 * 1024 * 1024) + if audio_file.size > max_file_size: + return create_api_response( + code="400", + message=f"文件大小超过 {max_file_size // (1024 * 1024)}MB 限制" + ) + + # 3. 保存音频文件到磁盘 + meeting_dir = AUDIO_DIR / str(meeting_id) + meeting_dir.mkdir(exist_ok=True) + unique_filename = f"{uuid.uuid4()}{file_extension}" + absolute_path = meeting_dir / unique_filename + relative_path = absolute_path.relative_to(BASE_DIR) + + try: + with open(absolute_path, "wb") as buffer: + shutil.copyfileobj(audio_file.file, buffer) + except Exception as e: + return create_api_response(code="500", message=f"保存文件失败: {str(e)}") + + file_path = '/' + str(relative_path) + file_name = audio_file.filename + file_size = audio_file.size + + # 4. 调用 audio_service 处理文件(权限检查、数据库更新、启动转录) + result = handle_audio_upload( + file_path=file_path, + file_name=file_name, + file_size=file_size, + meeting_id=meeting_id, + current_user=current_user, + auto_summarize=auto_summarize_bool, + background_tasks=background_tasks, + prompt_id=prompt_id # 传递 prompt_id 参数 + ) + + # 如果不成功,删除已保存的文件并返回错误 if not result["success"]: + if absolute_path.exists(): + try: + os.remove(absolute_path) + print(f"Deleted file due to processing error: {absolute_path}") + except Exception as e: + print(f"Warning: Failed to delete file {absolute_path}: {e}") return result["response"] - # 成功:根据auto_summarize参数决定是否添加监控任务 + # 5. 返回成功响应 transcription_task_id = result["transcription_task_id"] - if auto_summarize_bool and transcription_task_id: - background_tasks.add_task( - async_meeting_service.monitor_and_auto_summarize, - meeting_id, - transcription_task_id - ) - print(f"[upload-audio] Auto-summarize enabled, monitor task added for meeting {meeting_id}") - message_suffix = ",正在进行转录和总结" - else: - print(f"[upload-audio] Auto-summarize disabled for meeting {meeting_id}") - message_suffix = "" + message_suffix = "" + if transcription_task_id: + if auto_summarize_bool: + message_suffix = ",正在进行转录和总结" + else: + message_suffix = ",正在进行转录" - # 返回成功响应 return create_api_response( code="200", message="Audio file uploaded successfully" + @@ -888,7 +746,8 @@ def generate_meeting_summary_async(meeting_id: int, request: GenerateSummaryRequ cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,)) if not cursor.fetchone(): return create_api_response(code="404", message="Meeting not found") - task_id = async_meeting_service.start_summary_generation(meeting_id, request.user_prompt) + # 传递 prompt_id 参数给服务层 + task_id = async_meeting_service.start_summary_generation(meeting_id, request.user_prompt, request.prompt_id) background_tasks.add_task(async_meeting_service._process_task, task_id) return create_api_response(code="200", message="Summary generation task has been accepted.", data={ "task_id": task_id, "status": "pending", "meeting_id": meeting_id @@ -1025,12 +884,14 @@ def get_meeting_preview_data(meeting_id: int): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) - # 检查会议是否存在 + # 检查会议是否存在,并获取模版信息 query = ''' - SELECT m.meeting_id, m.title, m.meeting_time, m.summary, m.updated_at, - m.user_id as creator_id, u.caption as creator_username + SELECT m.meeting_id, m.title, m.meeting_time, m.summary, m.updated_at, m.prompt_id, + m.user_id as creator_id, u.caption as creator_username, + p.name as prompt_name FROM meetings m JOIN users u ON m.user_id = u.user_id + LEFT JOIN prompts p ON m.prompt_id = p.id WHERE m.meeting_id = %s ''' cursor.execute(query, (meeting_id,)) @@ -1056,6 +917,8 @@ def get_meeting_preview_data(meeting_id: int): "meeting_time": meeting['meeting_time'], "summary": meeting['summary'], "creator_username": meeting['creator_username'], + "prompt_id": meeting['prompt_id'], + "prompt_name": meeting['prompt_name'], "attendees": attendees, "attendees_count": len(attendees) } diff --git a/app/api/endpoints/prompts.py b/app/api/endpoints/prompts.py index a6a093b..0244d6f 100644 --- a/app/api/endpoints/prompts.py +++ b/app/api/endpoints/prompts.py @@ -11,11 +11,14 @@ router = APIRouter() # Pydantic Models class PromptIn(BaseModel): name: str - tags: Optional[str] = "" + task_type: str # 'MEETING_TASK' 或 'KNOWLEDGE_TASK' content: str + is_default: bool = False + is_active: bool = True class PromptOut(PromptIn): id: int + creator_id: int created_at: str class PromptListResponse(BaseModel): @@ -28,44 +31,105 @@ def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_use with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) try: + # 如果设置为默认,需要先取消同类型其他提示词的默认状态 + if prompt.is_default: + cursor.execute( + "UPDATE prompts SET is_default = FALSE WHERE task_type = %s", + (prompt.task_type,) + ) + cursor.execute( - "INSERT INTO prompts (name, tags, content, creator_id) VALUES (%s, %s, %s, %s)", - (prompt.name, prompt.tags, prompt.content, current_user["user_id"]) + """INSERT INTO prompts (name, task_type, content, is_default, is_active, creator_id) + VALUES (%s, %s, %s, %s, %s, %s)""", + (prompt.name, prompt.task_type, prompt.content, prompt.is_default, + prompt.is_active, current_user["user_id"]) ) connection.commit() new_id = cursor.lastrowid - return create_api_response(code="200", message="提示词创建成功", data={"id": new_id, **prompt.dict()}) + return create_api_response( + code="200", + message="提示词创建成功", + data={"id": new_id, **prompt.dict()} + ) except Exception as e: - if "UNIQUE constraint failed" in str(e) or "Duplicate entry" in str(e): + if "Duplicate entry" in str(e): return create_api_response(code="400", message="提示词名称已存在") return create_api_response(code="500", message=f"创建提示词失败: {e}") -@router.get("/prompts") -def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_user)): - """Get a paginated list of prompts filtered by current user.""" +@router.get("/prompts/active/{task_type}") +def get_active_prompts(task_type: str, current_user: dict = Depends(get_current_user)): + """Get all active prompts for a specific task type.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) - # 只获取当前用户创建的提示词 cursor.execute( - "SELECT COUNT(*) as total FROM prompts WHERE creator_id = %s", - (current_user["user_id"],) + """SELECT id, name, is_default + FROM prompts + WHERE task_type = %s AND is_active = TRUE + ORDER BY is_default DESC, created_at DESC""", + (task_type,) + ) + prompts = cursor.fetchall() + return create_api_response( + code="200", + message="获取启用模版列表成功", + data={"prompts": prompts} + ) + +@router.get("/prompts") +def get_prompts( + task_type: Optional[str] = None, + page: int = 1, + size: int = 50, + current_user: dict = Depends(get_current_user) +): + """Get a paginated list of prompts filtered by current user and optionally by task_type.""" + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 构建 WHERE 条件 + where_conditions = ["creator_id = %s"] + params = [current_user["user_id"]] + + if task_type: + where_conditions.append("task_type = %s") + params.append(task_type) + + where_clause = " AND ".join(where_conditions) + + # 获取总数 + cursor.execute( + f"SELECT COUNT(*) as total FROM prompts WHERE {where_clause}", + tuple(params) ) total = cursor.fetchone()['total'] + # 获取分页数据 offset = (page - 1) * size cursor.execute( - "SELECT id, name, tags, content, created_at FROM prompts WHERE creator_id = %s ORDER BY created_at DESC LIMIT %s OFFSET %s", - (current_user["user_id"], size, offset) + f"""SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at + FROM prompts + WHERE {where_clause} + ORDER BY created_at DESC + LIMIT %s OFFSET %s""", + tuple(params + [size, offset]) ) prompts = cursor.fetchall() - return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total}) + return create_api_response( + code="200", + message="获取提示词列表成功", + data={"prompts": prompts, "total": total} + ) @router.get("/prompts/{prompt_id}") def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)): """Get a single prompt by its ID.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) - cursor.execute("SELECT id, name, tags, content, created_at FROM prompts WHERE id = %s", (prompt_id,)) + cursor.execute( + """SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at + FROM prompts WHERE id = %s""", + (prompt_id,) + ) prompt = cursor.fetchone() if not prompt: return create_api_response(code="404", message="提示词不存在") @@ -74,19 +138,50 @@ def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)): @router.put("/prompts/{prompt_id}") def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)): """Update an existing prompt.""" + print(f"[UPDATE PROMPT] prompt_id={prompt_id}, type={type(prompt_id)}") + print(f"[UPDATE PROMPT] user_id={current_user['user_id']}") + print(f"[UPDATE PROMPT] data: name={prompt.name}, task_type={prompt.task_type}, content_len={len(prompt.content)}, is_default={prompt.is_default}, is_active={prompt.is_active}") + with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) try: - cursor.execute( - "UPDATE prompts SET name = %s, tags = %s, content = %s WHERE id = %s", - (prompt.name, prompt.tags, prompt.content, prompt_id) - ) - if cursor.rowcount == 0: + # 先检查记录是否存在 + cursor.execute("SELECT id, creator_id FROM prompts WHERE id = %s", (prompt_id,)) + existing = cursor.fetchone() + print(f"[UPDATE PROMPT] existing record: {existing}") + + if not existing: + print(f"[UPDATE PROMPT] Prompt {prompt_id} not found in database") return create_api_response(code="404", message="提示词不存在") + + # 如果设置为默认,需要先取消同类型其他提示词的默认状态 + if prompt.is_default: + print(f"[UPDATE PROMPT] Setting as default, clearing other defaults for task_type={prompt.task_type}") + cursor.execute( + "UPDATE prompts SET is_default = FALSE WHERE task_type = %s AND id != %s", + (prompt.task_type, prompt_id) + ) + print(f"[UPDATE PROMPT] Cleared {cursor.rowcount} other default prompts") + + print(f"[UPDATE PROMPT] Executing UPDATE query") + cursor.execute( + """UPDATE prompts + SET name = %s, task_type = %s, content = %s, is_default = %s, is_active = %s + WHERE id = %s""", + (prompt.name, prompt.task_type, prompt.content, prompt.is_default, + prompt.is_active, prompt_id) + ) + rows_affected = cursor.rowcount + print(f"[UPDATE PROMPT] UPDATE affected {rows_affected} rows (0 means no changes needed)") + + # 注意:rowcount=0 不代表记录不存在,可能是所有字段值都相同 + # 我们已经在上面确认了记录存在,所以这里直接提交即可 connection.commit() + print(f"[UPDATE PROMPT] Success! Committed changes") return create_api_response(code="200", message="提示词更新成功") except Exception as e: - if "UNIQUE constraint failed" in str(e) or "Duplicate entry" in str(e): + print(f"[UPDATE PROMPT] Exception: {type(e).__name__}: {e}") + if "Duplicate entry" in str(e): return create_api_response(code="400", message="提示词名称已存在") return create_api_response(code="500", message=f"更新提示词失败: {e}") diff --git a/app/api/endpoints/users.py b/app/api/endpoints/users.py index d08a7f2..ff89524 100644 --- a/app/api/endpoints/users.py +++ b/app/api/endpoints/users.py @@ -147,23 +147,42 @@ def reset_password(user_id: int, current_user: dict = Depends(get_current_user)) return create_api_response(code="200", message=f"用户 {user_id} 的密码已重置") @router.get("/users") -def get_all_users(page: int = 1, size: int = 10, role_id: Optional[int] = None, current_user: dict = Depends(get_current_user)): +def get_all_users( + page: int = 1, + size: int = 10, + role_id: Optional[int] = None, + search: Optional[str] = None, + current_user: dict = Depends(get_current_user) +): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) - - count_query = "SELECT COUNT(*) as total FROM users" - params = [] + + # 构建WHERE条件 + where_conditions = [] + count_params = [] + if role_id is not None: - count_query += " WHERE role_id = %s" - params.append(role_id) - - cursor.execute(count_query, tuple(params)) + where_conditions.append("role_id = %s") + count_params.append(role_id) + + if search: + search_pattern = f"%{search}%" + where_conditions.append("(username LIKE %s OR caption LIKE %s)") + count_params.extend([search_pattern, search_pattern]) + + # 统计查询 + count_query = "SELECT COUNT(*) as total FROM users" + if where_conditions: + count_query += " WHERE " + " AND ".join(where_conditions) + + cursor.execute(count_query, tuple(count_params)) total = cursor.fetchone()['total'] - + offset = (page - 1) * size - + + # 主查询 query = ''' - SELECT + SELECT u.user_id, u.username, u.caption, u.email, u.created_at, u.role_id, r.role_name, (SELECT COUNT(*) FROM meetings WHERE user_id = u.user_id) as meetings_created, @@ -171,24 +190,24 @@ def get_all_users(page: int = 1, size: int = 10, role_id: Optional[int] = None, FROM users u LEFT JOIN roles r ON u.role_id = r.role_id ''' - + query_params = [] - if role_id is not None: - query += " WHERE u.role_id = %s" - query_params.append(role_id) - + if where_conditions: + query += " WHERE " + " AND ".join(where_conditions) + query_params.extend(count_params) + query += ''' ORDER BY u.user_id ASC LIMIT %s OFFSET %s ''' - + query_params.extend([size, offset]) - + cursor.execute(query, tuple(query_params)) users = cursor.fetchall() - + user_list = [UserInfo(**user) for user in users] - + response_data = UserListResponse(users=user_list, total=total) return create_api_response(code="200", message="获取用户列表成功", data=response_data.dict()) diff --git a/app/core/config.py b/app/core/config.py index bd0949d..e906840 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -6,6 +6,7 @@ from pathlib import Path BASE_DIR = Path(__file__).parent.parent.parent UPLOAD_DIR = BASE_DIR / "uploads" AUDIO_DIR = UPLOAD_DIR / "audio" +TEMP_UPLOAD_DIR = UPLOAD_DIR / "temp_audio" MARKDOWN_DIR = UPLOAD_DIR / "markdown" VOICEPRINT_DIR = UPLOAD_DIR / "voiceprint" diff --git a/main.py b/app/main.py similarity index 73% rename from main.py rename to app/main.py index 2e00e33..1d45b98 100644 --- a/main.py +++ b/app/main.py @@ -1,11 +1,21 @@ +import sys +import os +from pathlib import Path + +# 添加项目根目录到 Python 路径 +# 无论从哪里运行,都能正确找到 app 模块 +current_file = Path(__file__).resolve() +project_root = current_file.parent.parent # backend/ +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + import uvicorn from fastapi import FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles -from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads, voiceprint +from app.api.endpoints import auth, users, meetings, tags, admin, admin_dashboard, tasks, prompts, knowledge_base, client_downloads, voiceprint, audio from app.core.config import UPLOAD_DIR, API_CONFIG from app.api.endpoints.admin import load_system_config -import os app = FastAPI( title="iMeeting API", @@ -35,11 +45,13 @@ app.include_router(users.router, prefix="/api", tags=["Users"]) app.include_router(meetings.router, prefix="/api", tags=["Meetings"]) app.include_router(tags.router, prefix="/api", tags=["Tags"]) app.include_router(admin.router, prefix="/api", tags=["Admin"]) +app.include_router(admin_dashboard.router, prefix="/api", tags=["AdminDashboard"]) app.include_router(tasks.router, prefix="/api", tags=["Tasks"]) app.include_router(prompts.router, prefix="/api", tags=["Prompts"]) app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"]) -app.include_router(client_downloads.router, prefix="/api/clients", tags=["ClientDownloads"]) +app.include_router(client_downloads.router, prefix="/api", tags=["ClientDownloads"]) app.include_router(voiceprint.router, prefix="/api", tags=["Voiceprint"]) +app.include_router(audio.router, prefix="/api", tags=["Audio"]) @app.get("/") def read_root(): @@ -57,10 +69,10 @@ def health_check(): if __name__ == "__main__": # 简单的uvicorn配置,避免参数冲突 uvicorn.run( - "main:app", - host=API_CONFIG['host'], + "app.main:app", + host=API_CONFIG['host'], port=API_CONFIG['port'], limit_max_requests=1000, timeout_keep_alive=30, reload=True, - ) \ No newline at end of file + ) diff --git a/app/models/models.py b/app/models/models.py index bfd9b86..98d6c35 100644 --- a/app/models/models.py +++ b/app/models/models.py @@ -152,6 +152,7 @@ class CreateKnowledgeBaseRequest(BaseModel): user_prompt: Optional[str] = None source_meeting_ids: Optional[str] = None tags: Optional[str] = None + prompt_id: Optional[int] = None # 提示词模版ID,如果不指定则使用默认模版 class UpdateKnowledgeBaseRequest(BaseModel): title: str @@ -227,3 +228,30 @@ class VoiceprintTemplate(BaseModel): duration_seconds: int sample_rate: int channels: int + +# 菜单权限相关模型 +class MenuInfo(BaseModel): + menu_id: int + menu_code: str + menu_name: str + menu_icon: Optional[str] = None + menu_url: Optional[str] = None + menu_type: str # 'action', 'link', 'divider' + parent_id: Optional[int] = None + sort_order: int + is_active: bool + description: Optional[str] = None + created_at: datetime.datetime + updated_at: datetime.datetime + +class MenuListResponse(BaseModel): + menus: List[MenuInfo] + total: int + +class RolePermissionInfo(BaseModel): + role_id: int + role_name: str + menu_ids: List[int] + +class UpdateRolePermissionsRequest(BaseModel): + menu_ids: List[int] diff --git a/app/services/async_knowledge_base_service.py b/app/services/async_knowledge_base_service.py index 4f19838..8a743fd 100644 --- a/app/services/async_knowledge_base_service.py +++ b/app/services/async_knowledge_base_service.py @@ -20,7 +20,7 @@ class AsyncKnowledgeBaseService: self.redis_client = redis.Redis(**REDIS_CONFIG) self.llm_service = LLMService() - def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str: + def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], prompt_id: Optional[int] = None, cursor=None) -> str: """ 创建异步知识库生成任务 @@ -29,6 +29,7 @@ class AsyncKnowledgeBaseService: kb_id: 知识库ID user_prompt: 用户提示词 source_meeting_ids: 源会议ID列表 + prompt_id: 提示词模版ID(可选,如果不指定则使用默认模版) cursor: 数据库游标(可选) Returns: @@ -39,13 +40,13 @@ class AsyncKnowledgeBaseService: # If a cursor is passed, use it directly to avoid creating a new transaction if cursor: query = """ - INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at) - VALUES (%s, %s, %s, %s, NOW()) + INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, prompt_id, created_at) + VALUES (%s, %s, %s, %s, %s, NOW()) """ - cursor.execute(query, (task_id, user_id, kb_id, user_prompt)) + cursor.execute(query, (task_id, user_id, kb_id, user_prompt, prompt_id)) else: # Fallback to the old method if no cursor is provided - self._save_task_to_db(task_id, user_id, kb_id, user_prompt) + self._save_task_to_db(task_id, user_id, kb_id, user_prompt, prompt_id) current_time = datetime.now().isoformat() task_data = { @@ -53,6 +54,7 @@ class AsyncKnowledgeBaseService: 'user_id': str(user_id), 'kb_id': str(kb_id), 'user_prompt': user_prompt if user_prompt else "", + 'prompt_id': str(prompt_id) if prompt_id else '', 'status': 'pending', 'progress': '0', 'created_at': current_time, @@ -61,7 +63,7 @@ class AsyncKnowledgeBaseService: self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data) self.redis_client.expire(f"kb_task:{task_id}", 86400) - print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}") + print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}, prompt_id: {prompt_id}") return task_id def _process_task(self, task_id: str): @@ -78,6 +80,8 @@ class AsyncKnowledgeBaseService: kb_id = int(task_data['kb_id']) user_prompt = task_data.get('user_prompt', '') + prompt_id_str = task_data.get('prompt_id', '') + prompt_id = int(prompt_id_str) if prompt_id_str and prompt_id_str != '' else None # 1. 更新状态为processing self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...") @@ -88,7 +92,7 @@ class AsyncKnowledgeBaseService: # 3. 构建提示词 self._update_task_status_in_redis(task_id, 'processing', 30, message="准备AI提示词...") - full_prompt = self._build_prompt(source_text, user_prompt) + full_prompt = self._build_prompt(source_text, user_prompt, prompt_id) # 4. 调用LLM API self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...") @@ -98,7 +102,7 @@ class AsyncKnowledgeBaseService: # 5. 保存结果到数据库 self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...") - self._save_result_to_db(kb_id, generated_content) + self._save_result_to_db(kb_id, generated_content, prompt_id) # 6. 任务完成 self._update_task_in_db(task_id, 'completed', 100) @@ -156,7 +160,7 @@ class AsyncKnowledgeBaseService: print(f"获取会议总结错误: {e}") return "" - def _build_prompt(self, source_text: str, user_prompt: str) -> str: + def _build_prompt(self, source_text: str, user_prompt: str, prompt_id: Optional[int] = None) -> str: """ 构建完整的提示词 使用数据库中配置的KNOWLEDGE_TASK提示词模板 @@ -164,12 +168,13 @@ class AsyncKnowledgeBaseService: Args: source_text: 源会议总结文本 user_prompt: 用户自定义提示词 + prompt_id: 提示词模版ID(可选,如果不指定则使用默认模版) Returns: str: 完整的提示词 """ - # 从数据库获取知识库任务的提示词模板 - system_prompt = self.llm_service.get_task_prompt('KNOWLEDGE_TASK') + # 从数据库获取知识库任务的提示词模板(支持指定prompt_id) + system_prompt = self.llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id=prompt_id) prompt = f"{system_prompt}\n\n" @@ -180,13 +185,14 @@ class AsyncKnowledgeBaseService: return prompt - def _save_result_to_db(self, kb_id: int, content: str) -> Optional[int]: + def _save_result_to_db(self, kb_id: int, content: str, prompt_id: Optional[int] = None) -> Optional[int]: """ 保存生成结果到数据库 Args: kb_id: 知识库ID content: 生成的内容 + prompt_id: 提示词模版ID Returns: Optional[int]: 知识库ID,失败返回None @@ -194,11 +200,11 @@ class AsyncKnowledgeBaseService: try: with get_db_connection() as connection: cursor = connection.cursor() - query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s" - cursor.execute(query, (content, kb_id)) + query = "UPDATE knowledge_bases SET content = %s, prompt_id = %s, updated_at = NOW() WHERE kb_id = %s" + cursor.execute(query, (content, prompt_id, kb_id)) connection.commit() - print(f"成功保存知识库内容,kb_id: {kb_id}") + print(f"成功保存知识库内容,kb_id: {kb_id}, prompt_id: {prompt_id}") return kb_id except Exception as e: @@ -243,13 +249,31 @@ class AsyncKnowledgeBaseService: except Exception as e: print(f"Error updating task status in Redis: {e}") - def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str): - """保存任务到数据库""" + def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str, prompt_id: Optional[int] = None): + """保存任务到数据库 + + Args: + task_id: 任务ID + user_id: 用户ID + kb_id: 知识库ID + user_prompt: 用户提示词 + prompt_id: 提示词模版ID(可选),如果为None则使用默认模版 + """ try: with get_db_connection() as connection: cursor = connection.cursor() - insert_query = "INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, %s, 'pending', 0, NOW())" - cursor.execute(insert_query, (task_id, user_id, kb_id, user_prompt)) + + # 如果没有指定 prompt_id,获取默认的知识库总结模版ID + if prompt_id is None: + cursor.execute( + "SELECT id FROM prompts WHERE task_type = 'KNOWLEDGE_TASK' AND is_default = TRUE AND is_active = TRUE LIMIT 1" + ) + default_prompt = cursor.fetchone() + if default_prompt: + prompt_id = default_prompt[0] + + insert_query = "INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, prompt_id, status, progress, created_at) VALUES (%s, %s, %s, %s, %s, 'pending', 0, NOW())" + cursor.execute(insert_query, (task_id, user_id, kb_id, user_prompt, prompt_id)) connection.commit() except Exception as e: print(f"Error saving task to database: {e}") diff --git a/app/services/async_meeting_service.py b/app/services/async_meeting_service.py index 50358e9..808d840 100644 --- a/app/services/async_meeting_service.py +++ b/app/services/async_meeting_service.py @@ -23,22 +23,24 @@ class AsyncMeetingService: self.redis_client = redis.Redis(**REDIS_CONFIG) self.llm_service = LLMService() # 复用现有的同步LLM服务 - def start_summary_generation(self, meeting_id: int, user_prompt: str = "") -> str: + def start_summary_generation(self, meeting_id: int, user_prompt: str = "", prompt_id: Optional[int] = None) -> str: """ 创建异步总结任务,任务的执行将由外部(如API层的BackgroundTasks)触发。 Args: meeting_id: 会议ID user_prompt: 用户额外提示词 + prompt_id: 可选的提示词模版ID,如果不指定则使用默认模版 Returns: str: 任务ID """ + try: task_id = str(uuid.uuid4()) # 在数据库中创建任务记录 - self._save_task_to_db(task_id, meeting_id, user_prompt) + self._save_task_to_db(task_id, meeting_id, user_prompt, prompt_id) # 将任务详情存入Redis,用于快速查询状态 current_time = datetime.now().isoformat() @@ -46,6 +48,7 @@ class AsyncMeetingService: 'task_id': task_id, 'meeting_id': str(meeting_id), 'user_prompt': user_prompt, + 'prompt_id': str(prompt_id) if prompt_id else '', 'status': 'pending', 'progress': '0', 'created_at': current_time, @@ -54,7 +57,6 @@ class AsyncMeetingService: self.redis_client.hset(f"llm_task:{task_id}", mapping=task_data) self.redis_client.expire(f"llm_task:{task_id}", 86400) - print(f"Meeting summary task created: {task_id} for meeting: {meeting_id}") return task_id except Exception as e: @@ -75,6 +77,8 @@ class AsyncMeetingService: meeting_id = int(task_data['meeting_id']) user_prompt = task_data.get('user_prompt', '') + prompt_id_str = task_data.get('prompt_id', '') + prompt_id = int(prompt_id_str) if prompt_id_str and prompt_id_str != '' else None # 1. 更新状态为processing self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...") @@ -87,7 +91,7 @@ class AsyncMeetingService: # 3. 构建提示词 self._update_task_status_in_redis(task_id, 'processing', 40, message="准备AI提示词...") - full_prompt = self._build_prompt(transcript_text, user_prompt) + full_prompt = self._build_prompt(transcript_text, user_prompt, prompt_id) # 4. 调用LLM API self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在分析会议内容...") @@ -97,7 +101,7 @@ class AsyncMeetingService: # 5. 保存结果到主表 self._update_task_status_in_redis(task_id, 'processing', 95, message="保存总结结果...") - self._save_summary_to_db(meeting_id, summary_content, user_prompt) + self._save_summary_to_db(meeting_id, summary_content, user_prompt, prompt_id) # 6. 任务完成 self._update_task_in_db(task_id, 'completed', 100, result=summary_content) @@ -111,7 +115,7 @@ class AsyncMeetingService: self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg) self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg) - def monitor_and_auto_summarize(self, meeting_id: int, transcription_task_id: str): + def monitor_and_auto_summarize(self, meeting_id: int, transcription_task_id: str, prompt_id: Optional[int] = None): """ 监控转录任务,完成后自动生成总结 此方法设计为由BackgroundTasks调用,在后台运行 @@ -119,13 +123,14 @@ class AsyncMeetingService: Args: meeting_id: 会议ID transcription_task_id: 转录任务ID + prompt_id: 提示词模版ID(可选,如果不指定则使用默认模版) 流程: 1. 循环轮询转录任务状态 2. 转录成功后自动启动总结任务 3. 转录失败或超时则停止轮询并记录日志 """ - print(f"[Monitor] Started monitoring transcription task {transcription_task_id} for meeting {meeting_id}") + print(f"[Monitor] Started monitoring transcription task {transcription_task_id} for meeting {meeting_id}, prompt_id: {prompt_id}") # 获取配置参数 poll_interval = TRANSCRIPTION_POLL_CONFIG['poll_interval'] @@ -156,7 +161,7 @@ class AsyncMeetingService: # 启动总结任务 try: - summary_task_id = self.start_summary_generation(meeting_id, user_prompt="") + summary_task_id = self.start_summary_generation(meeting_id, user_prompt="", prompt_id=prompt_id) print(f"[Monitor] Summary task {summary_task_id} started for meeting {meeting_id}") # 在后台执行总结任务 @@ -231,13 +236,18 @@ class AsyncMeetingService: print(f"获取会议转录内容错误: {e}") return "" - def _build_prompt(self, transcript_text: str, user_prompt: str) -> str: + def _build_prompt(self, transcript_text: str, user_prompt: str, prompt_id: Optional[int] = None) -> str: """ 构建完整的提示词 使用数据库中配置的MEETING_TASK提示词模板 + + Args: + transcript_text: 会议转录文本 + user_prompt: 用户额外提示词 + prompt_id: 可选的提示词模版ID,如果不指定则使用默认模版 """ - # 从数据库获取会议任务的提示词模板 - system_prompt = self.llm_service.get_task_prompt('MEETING_TASK') + # 从数据库获取会议任务的提示词模板(支持指定prompt_id) + system_prompt = self.llm_service.get_task_prompt('MEETING_TASK', prompt_id=prompt_id) prompt = f"{system_prompt}\n\n" @@ -248,22 +258,22 @@ class AsyncMeetingService: return prompt - def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]: - """保存总结到数据库 - 更新meetings表的summary、user_prompt和updated_at字段""" + def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str, prompt_id: Optional[int] = None) -> Optional[int]: + """保存总结到数据库 - 更新meetings表的summary、user_prompt、prompt_id和updated_at字段""" try: with get_db_connection() as connection: cursor = connection.cursor() - # 更新meetings表的summary、user_prompt和updated_at字段 + # 更新meetings表的summary、user_prompt、prompt_id和updated_at字段 update_query = """ UPDATE meetings - SET summary = %s, user_prompt = %s, updated_at = NOW() + SET summary = %s, user_prompt = %s, prompt_id = %s, updated_at = NOW() WHERE meeting_id = %s """ - cursor.execute(update_query, (summary_content, user_prompt, meeting_id)) + cursor.execute(update_query, (summary_content, user_prompt, prompt_id, meeting_id)) connection.commit() - print(f"成功保存会议总结到meetings表,meeting_id: {meeting_id}") + print(f"成功保存会议总结到meetings表,meeting_id: {meeting_id}, prompt_id: {prompt_id}") return meeting_id except Exception as e: @@ -326,14 +336,39 @@ class AsyncMeetingService: except Exception as e: print(f"Error updating task status in Redis: {e}") - def _save_task_to_db(self, task_id: str, meeting_id: int, user_prompt: str): - """保存任务到数据库""" + def _save_task_to_db(self, task_id: str, meeting_id: int, user_prompt: str, prompt_id: Optional[int] = None): + """保存任务到数据库 + + Args: + task_id: 任务ID + meeting_id: 会议ID + user_prompt: 用户额外提示词 + prompt_id: 可选的提示词模版ID,如果为None则使用默认模版 + """ try: with get_db_connection() as connection: cursor = connection.cursor() - insert_query = "INSERT INTO llm_tasks (task_id, meeting_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, 'pending', 0, NOW())" - cursor.execute(insert_query, (task_id, meeting_id, user_prompt)) + + # 如果没有指定 prompt_id,获取默认的会议总结模版ID + if prompt_id is None: + print(f"[Meeting Service] prompt_id is None, fetching default template for MEETING_TASK") + cursor.execute( + "SELECT id FROM prompts WHERE task_type = 'MEETING_TASK' AND is_default = TRUE AND is_active = TRUE LIMIT 1" + ) + default_prompt = cursor.fetchone() + if default_prompt: + prompt_id = default_prompt[0] + print(f"[Meeting Service] Found default template ID: {prompt_id}") + else: + print(f"[Meeting Service] WARNING: No default template found for MEETING_TASK!") + else: + print(f"[Meeting Service] Using provided prompt_id: {prompt_id}") + + print(f"[Meeting Service] Inserting task into llm_tasks - task_id: {task_id}, meeting_id: {meeting_id}, prompt_id: {prompt_id}") + insert_query = "INSERT INTO llm_tasks (task_id, meeting_id, user_prompt, prompt_id, status, progress, created_at) VALUES (%s, %s, %s, %s, 'pending', 0, NOW())" + cursor.execute(insert_query, (task_id, meeting_id, user_prompt, prompt_id)) connection.commit() + print(f"[Meeting Service] Task saved successfully to database") except Exception as e: print(f"Error saving task to database: {e}") raise diff --git a/app/services/audio_service.py b/app/services/audio_service.py new file mode 100644 index 0000000..5142699 --- /dev/null +++ b/app/services/audio_service.py @@ -0,0 +1,172 @@ +""" +音频处理服务 + +处理已保存的完整音频文件:数据库更新、转录、自动总结 +""" +from fastapi import BackgroundTasks +from app.core.database import get_db_connection +from app.core.response import create_api_response +from app.core.config import BASE_DIR +from app.services.async_transcription_service import AsyncTranscriptionService +from app.services.async_meeting_service import async_meeting_service +from pathlib import Path +import os + + +transcription_service = AsyncTranscriptionService() + + +def handle_audio_upload( + file_path: str, + file_name: str, + file_size: int, + meeting_id: int, + current_user: dict, + auto_summarize: bool = True, + background_tasks: BackgroundTasks = None, + prompt_id: int = None +) -> dict: + """ + 处理已保存的完整音频文件 + + 职责: + 1. 权限检查 + 2. 检查已有文件和转录记录 + 3. 更新数据库(audio_files 表) + 4. 启动转录任务 + 5. 可选启动自动总结监控 + + Args: + file_path: 已保存的文件路径(相对于 BASE_DIR 的路径,如 /uploads/audio/123/xxx.webm) + file_name: 原始文件名 + file_size: 文件大小(字节) + meeting_id: 会议ID + current_user: 当前用户信息 + auto_summarize: 是否自动生成总结(默认True) + background_tasks: FastAPI 后台任务对象 + prompt_id: 提示词模版ID(可选,如果不指定则使用默认模版) + + Returns: + dict: { + "success": bool, # 是否成功 + "response": dict, # 如果需要返回,这里是响应数据 + "file_info": dict, # 文件信息 (成功时) + "transcription_task_id": str, # 转录任务ID (成功时) + "replaced_existing": bool, # 是否替换了现有文件 (成功时) + "has_transcription": bool # 原来是否有转录记录 (成功时) + } + """ + print(f"[Audio Service] handle_audio_upload called - Meeting ID: {meeting_id}, Auto-summarize: {auto_summarize}, Received prompt_id: {prompt_id}, Type: {type(prompt_id)}") + + # 1. 权限和已有文件检查 + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 检查会议是否存在及权限 + cursor.execute("SELECT user_id FROM meetings WHERE meeting_id = %s", (meeting_id,)) + meeting = cursor.fetchone() + if not meeting: + return { + "success": False, + "response": create_api_response(code="404", message="会议不存在") + } + if meeting['user_id'] != current_user['user_id']: + return { + "success": False, + "response": create_api_response(code="403", message="无权限操作此会议") + } + + # 检查已有音频文件 + cursor.execute( + "SELECT file_name, file_path, upload_time FROM audio_files WHERE meeting_id = %s", + (meeting_id,) + ) + existing_info = cursor.fetchone() + + # 检查是否有转录记录 + has_transcription = False + if existing_info: + cursor.execute( + "SELECT COUNT(*) as segment_count FROM transcript_segments WHERE meeting_id = %s", + (meeting_id,) + ) + has_transcription = cursor.fetchone()['segment_count'] > 0 + + cursor.close() + except Exception as e: + return { + "success": False, + "response": create_api_response(code="500", message=f"检查已有文件失败: {str(e)}") + } + + # 2. 删除旧的音频文件(如果存在) + replaced_existing = existing_info is not None + if replaced_existing and existing_info['file_path']: + old_file_path = BASE_DIR / existing_info['file_path'].lstrip('/') + if old_file_path.exists(): + try: + os.remove(old_file_path) + print(f"Deleted old audio file: {old_file_path}") + except Exception as e: + print(f"Warning: Failed to delete old file {old_file_path}: {e}") + + transcription_task_id = None + + try: + # 3. 更新数据库记录 + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + if replaced_existing: + cursor.execute( + 'UPDATE audio_files SET file_name = %s, file_path = %s, file_size = %s, upload_time = NOW(), task_id = NULL WHERE meeting_id = %s', + (file_name, file_path, file_size, meeting_id) + ) + else: + cursor.execute( + 'INSERT INTO audio_files (meeting_id, file_name, file_path, file_size, upload_time) VALUES (%s, %s, %s, %s, NOW())', + (meeting_id, file_name, file_path, file_size) + ) + + connection.commit() + cursor.close() + + # 4. 启动转录任务 + try: + transcription_task_id = transcription_service.start_transcription(meeting_id, file_path) + print(f"Transcription task {transcription_task_id} started for meeting {meeting_id}") + + # 5. 如果启用自动总结且提供了 background_tasks,添加监控任务 + if auto_summarize and transcription_task_id and background_tasks: + background_tasks.add_task( + async_meeting_service.monitor_and_auto_summarize, + meeting_id, + transcription_task_id, + prompt_id # 传递 prompt_id 给自动总结监控任务 + ) + print(f"[audio_service] Auto-summarize enabled, monitor task added for meeting {meeting_id}, prompt_id: {prompt_id}") + + except Exception as e: + print(f"Failed to start transcription: {e}") + raise + + except Exception as e: + # 出错时的处理(文件已保存,不删除) + return { + "success": False, + "response": create_api_response(code="500", message=f"处理失败: {str(e)}") + } + + # 6. 返回成功结果 + return { + "success": True, + "file_info": { + "file_name": file_name, + "file_path": file_path, + "file_size": file_size + }, + "transcription_task_id": transcription_task_id, + "replaced_existing": replaced_existing, + "has_transcription": has_transcription + } diff --git a/app/services/llm_service.py b/app/services/llm_service.py index 6018862..a08ee0d 100644 --- a/app/services/llm_service.py +++ b/app/services/llm_service.py @@ -38,39 +38,54 @@ class LLMService: """动态获取top_p""" return config_module.LLM_CONFIG["top_p"] - def get_task_prompt(self, task_name: str, cursor=None) -> str: + def get_task_prompt(self, task_type: str, cursor=None, prompt_id: Optional[int] = None) -> str: """ 统一的提示词获取方法 Args: - task_name: 任务名称,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等 + task_type: 任务类型,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等 cursor: 数据库游标,如果传入则使用,否则创建新连接 + prompt_id: 可选的提示词ID,如果指定则使用该提示词,否则使用默认提示词 Returns: str: 提示词内容,如果未找到返回默认提示词 """ - query = """ - SELECT p.content - FROM prompt_config pc - JOIN prompts p ON pc.prompt_id = p.id - WHERE pc.task_name = %s - """ + # 如果指定了 prompt_id,直接获取该提示词 + if prompt_id: + query = """ + SELECT content + FROM prompts + WHERE id = %s AND task_type = %s AND is_active = TRUE + LIMIT 1 + """ + params = (prompt_id, task_type) + else: + # 否则获取默认提示词 + query = """ + SELECT content + FROM prompts + WHERE task_type = %s + AND is_default = TRUE + AND is_active = TRUE + LIMIT 1 + """ + params = (task_type,) if cursor: - cursor.execute(query, (task_name,)) + cursor.execute(query, params) result = cursor.fetchone() if result: return result['content'] if isinstance(result, dict) else result[0] else: with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) - cursor.execute(query, (task_name,)) + cursor.execute(query, params) result = cursor.fetchone() if result: return result['content'] # 返回默认提示词 - return self._get_default_prompt(task_name) + return self._get_default_prompt(task_type) def _get_default_prompt(self, task_name: str) -> str: """获取默认提示词""" diff --git a/sql/README_terminal_update.md b/sql/README_terminal_update.md new file mode 100644 index 0000000..6e934c8 --- /dev/null +++ b/sql/README_terminal_update.md @@ -0,0 +1,201 @@ +# 客户端管理 - 专用终端类型添加说明 + +## 概述 + +本次更新在客户端管理系统中添加了"专用终端"(terminal)大类型,支持 Android 专用终端和单片机(MCU)平台。 + +## 数据库变更 + +### 1. 修改表结构 + +执行 SQL 文件:`add_dedicated_terminal.sql` + +```bash +mysql -u [username] -p [database_name] < backend/sql/add_dedicated_terminal.sql +``` + +**变更内容:** +- 修改 `client_downloads` 表的 `platform_type` 枚举,添加 `terminal` 类型 +- 插入两条示例数据: + - Android 专用终端(platform_type: `terminal`, platform_name: `android`) + - 单片机固件(platform_type: `terminal`, platform_name: `mcu`) + +### 2. 新的平台类型 + +| platform_type | platform_name | 说明 | +|--------------|--------------|------| +| terminal | android | Android 专用终端 | +| terminal | mcu | 单片机(MCU)固件 | + +## API 接口变更 + +### 1. 新增接口:通过平台类型和平台名称获取最新版本 + +**接口路径:** `GET /api/downloads/latest/by-platform` + +**请求参数:** +- `platform_type` (string, required): 平台类型 (mobile, desktop, terminal) +- `platform_name` (string, required): 具体平台名称 + +**示例请求:** +```bash +# 获取 Android 专用终端最新版本 +curl "http://localhost:8000/api/downloads/latest/by-platform?platform_type=terminal&platform_name=android" + +# 获取单片机固件最新版本 +curl "http://localhost:8000/api/downloads/latest/by-platform?platform_type=terminal&platform_name=mcu" +``` + +**返回示例:** +```json +{ + "code": "200", + "message": "获取成功", + "data": { + "id": 7, + "platform_type": "terminal", + "platform_name": "android", + "version": "1.0.0", + "version_code": 1000, + "download_url": "https://download.imeeting.com/terminals/android/iMeeting-Terminal-1.0.0.apk", + "file_size": 25165824, + "release_notes": "专用终端初始版本\n- 支持专用硬件集成\n- 优化的录音功能\n- 低功耗模式\n- 自动上传同步", + "is_active": true, + "is_latest": true, + "min_system_version": "Android 5.0", + "created_at": "2025-01-15T10:00:00", + "updated_at": "2025-01-15T10:00:00", + "created_by": 1 + } +} +``` + +### 2. 更新接口:获取所有平台最新版本 + +**接口路径:** `GET /api/downloads/latest` + +**变更:** 返回数据中新增 `terminal` 字段 + +**返回示例:** +```json +{ + "code": "200", + "message": "获取成功", + "data": { + "mobile": [...], + "desktop": [...], + "terminal": [ + { + "id": 7, + "platform_type": "terminal", + "platform_name": "android", + "version": "1.0.0", + ... + }, + { + "id": 8, + "platform_type": "terminal", + "platform_name": "mcu", + "version": "1.0.0", + ... + } + ] + } +} +``` + +### 3. 已有接口说明 + +**原有接口:** `GET /api/downloads/{platform_name}/latest` + +- 此接口标记为【已废弃】,建议使用新接口 `/downloads/latest/by-platform` +- 原因:只通过 `platform_name` 查询可能产生歧义(如 mobile 的 android 和 terminal 的 android) +- 保留此接口是为了向后兼容,但新开发应使用新接口 + +## 使用场景 + +### 场景 1:专用终端设备版本检查 + +专用终端设备(如会议室固定录音设备、单片机硬件)启动时检查更新: + +```javascript +// Android 专用终端 +const response = await fetch( + '/api/downloads/latest/by-platform?platform_type=terminal&platform_name=android' +); +const { data } = await response.json(); + +if (data.version_code > currentVersionCode) { + // 发现新版本,提示更新 + showUpdateDialog(data); +} +``` + +### 场景 2:后台管理界面展示 + +管理员查看所有终端版本: + +```javascript +const response = await fetch('/api/downloads?platform_type=terminal'); +const { data } = await response.json(); + +// data.clients 包含所有 terminal 类型的客户端版本 +renderClientList(data.clients); +``` + +### 场景 3:固件更新服务器 + +单片机设备定期轮询更新: + +```c +// MCU 固件代码示例 +char url[] = "http://api.imeeting.com/downloads/latest/by-platform?platform_type=terminal&platform_name=mcu"; +http_get(url, response_buffer); + +// 解析 JSON 获取 download_url 和 version_code +if (new_version > FIRMWARE_VERSION) { + download_and_update(download_url); +} +``` + +## 测试建议 + +### 1. 数据库测试 +```sql +-- 验证表结构修改 +DESCRIBE client_downloads; + +-- 验证数据插入 +SELECT * FROM client_downloads WHERE platform_type = 'terminal'; +``` + +### 2. API 测试 +```bash +# 测试新接口 +curl "http://localhost:8000/api/downloads/latest/by-platform?platform_type=terminal&platform_name=android" + +curl "http://localhost:8000/api/downloads/latest/by-platform?platform_type=terminal&platform_name=mcu" + +# 测试获取所有最新版本 +curl "http://localhost:8000/api/downloads/latest" + +# 测试列表接口 +curl "http://localhost:8000/api/downloads?platform_type=terminal" +``` + +## 注意事项 + +1. **执行 SQL 前请备份数据库** +2. **ENUM 类型修改**:ALTER TABLE 会修改表结构,请在低峰期执行 +3. **新接口优先**:建议所有新开发使用 `/downloads/latest/by-platform` 接口 +4. **版本管理**:上传新版本时记得设置 `is_latest=TRUE` 并将同平台旧版本设为 `FALSE` +5. **platform_name 唯一性**:如果 mobile 和 terminal 都有 android,建议: + - mobile 的保持 `android` + - terminal 的改为 `android_terminal` 或其他区分名称 + - 或者始终使用新接口同时传递 platform_type 和 platform_name + +## 文件清单 + +- `backend/sql/add_dedicated_terminal.sql` - 数据库迁移 SQL +- `backend/app/api/endpoints/client_downloads.py` - API 接口代码 +- `backend/sql/README_terminal_update.md` - 本说明文档 diff --git a/sql/add_dedicated_terminal.sql b/sql/add_dedicated_terminal.sql new file mode 100644 index 0000000..e0140ad --- /dev/null +++ b/sql/add_dedicated_terminal.sql @@ -0,0 +1,59 @@ +-- 添加专用终端类型支持 +-- 修改 platform_type 枚举,添加 'terminal' 类型 + +ALTER TABLE client_downloads +MODIFY COLUMN platform_type ENUM('mobile', 'desktop', 'terminal') NOT NULL +COMMENT '平台类型:mobile-移动端, desktop-桌面端, terminal-专用终端'; + +-- 插入专用终端示例数据 + +-- Android 专用终端 +INSERT INTO client_downloads ( + platform_type, + platform_name, + version, + version_code, + download_url, + file_size, + release_notes, + is_active, + is_latest, + min_system_version, + created_by +) VALUES +( + 'terminal', + 'android', + '1.0.0', + 1000, + 'https://download.imeeting.com/terminals/android/iMeeting-1.0.0-Terminal.apk', + 25165824, -- 24MB + '专用终端初始版本 +- 支持专用硬件集成 +- 优化的录音功能 +- 低功耗模式 +- 自动上传同步', + TRUE, + TRUE, + 'Android 5.0', + 1 +), + +-- 单片机(MCU)专用终端 +( + 'terminal', + 'mcu', + '1.0.0', + 1000, + 'https://download.imeeting.com/terminals/mcu/iMeeting-1.0.0-MCU.bin', + 2097152, -- 2MB + '单片机固件初始版本 +- 嵌入式录音系统 +- 低功耗设计 +- 支持WiFi/4G上传 +- 硬件级音频处理', + TRUE, + TRUE, + 'ESP32 / STM32', + 1 +); diff --git a/sql/add_menu_permissions_system.sql b/sql/add_menu_permissions_system.sql new file mode 100644 index 0000000..7fbad14 --- /dev/null +++ b/sql/add_menu_permissions_system.sql @@ -0,0 +1,99 @@ +-- =================================================================== +-- 菜单权限系统数据库迁移脚本 +-- 创建日期: 2025-12-10 +-- 说明: 添加 menus 表和 role_menu_permissions 表,实现基于角色的菜单权限管理 +-- =================================================================== + +-- ---------------------------- +-- Table structure for menus +-- ---------------------------- +DROP TABLE IF EXISTS `menus`; +CREATE TABLE `menus` ( + `menu_id` int(11) NOT NULL AUTO_INCREMENT COMMENT '菜单ID', + `menu_code` varchar(50) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '菜单代码(唯一标识)', + `menu_name` varchar(100) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '菜单名称', + `menu_icon` varchar(50) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '菜单图标标识', + `menu_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '菜单URL/路由', + `menu_type` enum('action','link','divider') COLLATE utf8mb4_unicode_ci DEFAULT 'action' COMMENT '菜单类型: action-操作/link-链接/divider-分隔符', + `parent_id` int(11) DEFAULT NULL COMMENT '父菜单ID(用于层级菜单)', + `sort_order` int(11) DEFAULT 0 COMMENT '排序顺序', + `is_active` tinyint(1) DEFAULT 1 COMMENT '是否启用: 1-启用, 0-禁用', + `description` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '菜单描述', + `created_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + `updated_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + PRIMARY KEY (`menu_id`), + UNIQUE KEY `uk_menu_code` (`menu_code`), + KEY `idx_parent_id` (`parent_id`), + KEY `idx_is_active` (`is_active`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='系统菜单表'; + +-- ---------------------------- +-- Table structure for role_menu_permissions +-- ---------------------------- +DROP TABLE IF EXISTS `role_menu_permissions`; +CREATE TABLE `role_menu_permissions` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT '权限ID', + `role_id` int(11) NOT NULL COMMENT '角色ID', + `menu_id` int(11) NOT NULL COMMENT '菜单ID', + `created_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + `updated_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_role_menu` (`role_id`,`menu_id`), + KEY `idx_role_id` (`role_id`), + KEY `idx_menu_id` (`menu_id`), + CONSTRAINT `fk_rmp_role_id` FOREIGN KEY (`role_id`) REFERENCES `roles` (`role_id`) ON DELETE CASCADE, + CONSTRAINT `fk_rmp_menu_id` FOREIGN KEY (`menu_id`) REFERENCES `menus` (`menu_id`) ON DELETE CASCADE +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='角色菜单权限映射表'; + +-- ---------------------------- +-- 初始化菜单数据(基于现有系统的下拉菜单) +-- ---------------------------- +BEGIN; + +-- 用户菜单项 +INSERT INTO `menus` (`menu_code`, `menu_name`, `menu_icon`, `menu_url`, `menu_type`, `sort_order`, `is_active`, `description`) +VALUES +('change_password', '修改密码', 'KeyRound', NULL, 'action', 1, 1, '用户修改自己的密码'), +('prompt_management', '提示词仓库', 'BookText', '/prompt-management', 'link', 2, 1, '管理AI提示词模版'), +('platform_admin', '平台管理', 'Shield', '/admin/management', 'link', 3, 1, '平台管理员后台'), +('logout', '退出登录', 'LogOut', NULL, 'action', 99, 1, '退出当前账号'); + +COMMIT; + +-- ---------------------------- +-- 初始化角色权限数据 +-- 注意:角色表已存在,role_id=1为平台管理员,role_id=2为普通用户 +-- ---------------------------- +BEGIN; + +-- 平台管理员(role_id=1)拥有所有菜单权限 +INSERT INTO `role_menu_permissions` (`role_id`, `menu_id`) +SELECT 1, menu_id FROM `menus` WHERE is_active = 1; + +-- 普通用户(role_id=2)拥有除"平台管理"外的所有菜单权限 +INSERT INTO `role_menu_permissions` (`role_id`, `menu_id`) +SELECT 2, menu_id FROM `menus` WHERE menu_code != 'platform_admin' AND is_active = 1; + +COMMIT; + +-- ---------------------------- +-- 查询验证 +-- ---------------------------- +-- 查看所有菜单 +-- SELECT * FROM menus ORDER BY sort_order; + +-- 查看平台管理员的菜单权限 +-- SELECT r.role_name, m.menu_name, m.menu_code, m.menu_url +-- FROM role_menu_permissions rmp +-- JOIN roles r ON rmp.role_id = r.role_id +-- JOIN menus m ON rmp.menu_id = m.menu_id +-- WHERE r.role_id = 1 +-- ORDER BY m.sort_order; + +-- 查看普通用户的菜单权限 +-- SELECT r.role_name, m.menu_name, m.menu_code, m.menu_url +-- FROM role_menu_permissions rmp +-- JOIN roles r ON rmp.role_id = r.role_id +-- JOIN menus m ON rmp.menu_id = m.menu_id +-- WHERE r.role_id = 2 +-- ORDER BY m.sort_order; diff --git a/sql/add_prompt_id_to_llm_tasks.sql b/sql/add_prompt_id_to_llm_tasks.sql new file mode 100644 index 0000000..7a57b99 --- /dev/null +++ b/sql/add_prompt_id_to_llm_tasks.sql @@ -0,0 +1,11 @@ +-- 为 llm_tasks 表添加 prompt_id 列,用于支持自定义模版选择功能 +-- 执行日期:2025-12-08 + +ALTER TABLE `llm_tasks` +ADD COLUMN `prompt_id` int(11) DEFAULT NULL COMMENT '提示词模版ID' AFTER `user_prompt`, +ADD KEY `idx_prompt_id` (`prompt_id`); + +-- 说明: +-- 1. prompt_id 允许为 NULL,表示使用默认模版 +-- 2. 添加索引以优化查询性能 +-- 3. 不添加外键约束,因为 prompts 表中的记录可能被删除,我们希望保留历史任务记录 diff --git a/sql/imeeting.sql b/sql/imeeting.sql index 0b767fa..b5ae92e 100644 --- a/sql/imeeting.sql +++ b/sql/imeeting.sql @@ -534,10 +534,12 @@ CREATE TABLE `knowledge_bases` ( `is_shared` tinyint(1) NOT NULL DEFAULT '0' COMMENT '是否为共享知识库', `source_meeting_ids` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '内容来源的会议ID列表 (逗号分隔)', `tags` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '逗号分隔的标签', + `prompt_id` int(11) DEFAULT 0 COMMENT '使用的提示词模版ID,0表示未使用或使用默认模版', `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, `updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY (`kb_id`), KEY `idx_creator_id` (`creator_id`), + KEY `idx_prompt_id` (`prompt_id`), CONSTRAINT `knowledge_bases_ibfk_1` FOREIGN KEY (`creator_id`) REFERENCES `users` (`user_id`) ON DELETE CASCADE ) ENGINE=InnoDB AUTO_INCREMENT=28 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='知识库条目表'; @@ -686,9 +688,11 @@ CREATE TABLE `meetings` ( `meeting_time` timestamp NULL DEFAULT NULL, `user_prompt` text COLLATE utf8mb4_unicode_ci, `summary` text CHARACTER SET utf8mb4, + `prompt_id` int(11) DEFAULT 0 COMMENT '使用的提示词模版ID,0表示未使用或使用默认模版', `created_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP, `updated_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (`meeting_id`) + PRIMARY KEY (`meeting_id`), + KEY `idx_prompt_id` (`prompt_id`) ) ENGINE=InnoDB AUTO_INCREMENT=372 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; -- ---------------------------- diff --git a/sql/migrate_prompts_table.sql b/sql/migrate_prompts_table.sql new file mode 100644 index 0000000..92cdbde --- /dev/null +++ b/sql/migrate_prompts_table.sql @@ -0,0 +1,67 @@ +-- 提示词表改造迁移脚本 +-- 将 prompt_config 表的功能整合到 prompts 表 + +-- 步骤1: 添加新字段 +ALTER TABLE prompts +ADD COLUMN task_type ENUM('MEETING_TASK', 'KNOWLEDGE_TASK') + COMMENT '任务类型:MEETING_TASK-会议任务, KNOWLEDGE_TASK-知识库任务' AFTER name; + +ALTER TABLE prompts +ADD COLUMN is_default BOOLEAN NOT NULL DEFAULT FALSE + COMMENT '是否为该任务类型的默认模板' AFTER content; + +-- 步骤2: 修改 is_active 字段(如果存在且类型不是 BOOLEAN) +-- 先检查字段是否存在,如果不存在则添加 +ALTER TABLE prompts +MODIFY COLUMN is_active BOOLEAN NOT NULL DEFAULT TRUE + COMMENT '是否启用(只有启用的提示词才能被使用)'; + +-- 步骤3: 删除 tags 字段 +ALTER TABLE prompts DROP COLUMN IF EXISTS tags; + +-- 步骤4: 从 prompt_config 迁移数据(如果 prompt_config 表存在) +-- 更新 task_type 和 is_default +UPDATE prompts p +LEFT JOIN prompt_config pc ON p.id = pc.prompt_id +SET + p.task_type = CASE + WHEN pc.task_name IS NOT NULL THEN pc.task_name + ELSE 'MEETING_TASK' -- 默认值 + END, + p.is_default = CASE + WHEN pc.is_default = 1 THEN TRUE + ELSE FALSE + END +WHERE pc.prompt_id IS NOT NULL OR p.task_type IS NULL; + +-- 步骤5: 为所有没有设置 task_type 的提示词设置默认值 +UPDATE prompts +SET task_type = 'MEETING_TASK' +WHERE task_type IS NULL; + +-- 步骤6: 将 task_type 设置为 NOT NULL +ALTER TABLE prompts +MODIFY COLUMN task_type ENUM('MEETING_TASK', 'KNOWLEDGE_TASK') NOT NULL + COMMENT '任务类型:MEETING_TASK-会议任务, KNOWLEDGE_TASK-知识库任务'; + +-- 步骤7: 确保每个 task_type 只有一个默认提示词 +-- 如果有多个默认,只保留 id 最小的那个 +UPDATE prompts p1 +LEFT JOIN ( + SELECT task_type, MIN(id) as min_id + FROM prompts + WHERE is_default = TRUE + GROUP BY task_type +) p2 ON p1.task_type = p2.task_type +SET p1.is_default = FALSE +WHERE p1.is_default = TRUE AND p1.id != p2.min_id; + +-- 步骤8: (可选) 备注 prompt_config 表已废弃 +-- 如果需要删除 prompt_config 表,取消下面的注释 +-- DROP TABLE IF EXISTS prompt_config; + +-- 迁移完成 +SELECT '提示词表迁移完成!' as message; +SELECT task_type, COUNT(*) as total, SUM(is_default) as default_count +FROM prompts +GROUP BY task_type; diff --git a/sql/migrations/add_prompt_id_to_main_tables.sql b/sql/migrations/add_prompt_id_to_main_tables.sql new file mode 100644 index 0000000..e0ba82f --- /dev/null +++ b/sql/migrations/add_prompt_id_to_main_tables.sql @@ -0,0 +1,33 @@ +-- ============================================ +-- 添加 prompt_id 字段到主表 +-- 创建时间: 2025-01-11 +-- 说明: 在 meetings 和 knowledge_bases 表中添加 prompt_id 字段 +-- 用于记录会议/知识库使用的提示词模版 +-- ============================================ + +-- 1. 为 meetings 表添加 prompt_id 字段 +ALTER TABLE meetings +ADD COLUMN prompt_id INT(11) DEFAULT 0 COMMENT '使用的提示词模版ID,0表示未使用或使用默认模版' +AFTER summary; + +-- 为 meetings 表添加索引 +ALTER TABLE meetings +ADD INDEX idx_prompt_id (prompt_id); + +-- 2. 为 knowledge_bases 表添加 prompt_id 字段 +ALTER TABLE knowledge_bases +ADD COLUMN prompt_id INT(11) DEFAULT 0 COMMENT '使用的提示词模版ID,0表示未使用或使用默认模版' +AFTER tags; + +-- 为 knowledge_bases 表添加索引 +ALTER TABLE knowledge_bases +ADD INDEX idx_prompt_id (prompt_id); + +-- ============================================ +-- 验证修改 +-- ============================================ +-- 查看 meetings 表结构 +-- DESCRIBE meetings; + +-- 查看 knowledge_bases 表结构 +-- DESCRIBE knowledge_bases; diff --git a/sql/transcript_tasks_setup.sql b/sql/transcript_tasks_setup.sql new file mode 100644 index 0000000..135c751 --- /dev/null +++ b/sql/transcript_tasks_setup.sql @@ -0,0 +1,53 @@ +-- 为现有数据库添加转录任务支持的SQL脚本 + +-- 1. 更新audio_files表结构,添加缺失字段 +ALTER TABLE audio_files +ADD COLUMN file_name VARCHAR(255) AFTER meeting_id, +ADD COLUMN file_size BIGINT DEFAULT NULL AFTER file_path, +ADD COLUMN task_id VARCHAR(255) DEFAULT NULL AFTER upload_time; + +-- 2. 创建转录任务表 +CREATE TABLE transcript_tasks ( + task_id VARCHAR(255) PRIMARY KEY, + meeting_id INT NOT NULL, + status ENUM('pending', 'processing', 'completed', 'failed') DEFAULT 'pending', + progress INT DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP NULL, + error_message TEXT NULL, + + FOREIGN KEY (meeting_id) REFERENCES meetings(meeting_id) ON DELETE CASCADE +); + +-- 3. 添加索引以优化查询性能 +-- audio_files 表索引 +ALTER TABLE audio_files ADD INDEX idx_task_id (task_id); + +-- transcript_tasks 表索引 +ALTER TABLE transcript_tasks ADD INDEX idx_meeting_id (meeting_id); +ALTER TABLE transcript_tasks ADD INDEX idx_status (status); +ALTER TABLE transcript_tasks ADD INDEX idx_created_at (created_at); + +-- 4. 更新现有测试数据(如果需要) +-- 这些语句是可选的,用于更新现有的测试数据 +UPDATE audio_files SET file_name = 'test_audio.mp3' WHERE file_name IS NULL; +UPDATE audio_files SET file_size = 10485760 WHERE file_size IS NULL; -- 10MB + +SELECT '转录任务表创建完成!' as message; + + +CREATE TABLE llm_tasks ( + task_id VARCHAR(100) PRIMARY KEY, + llm_task_id VARCHAR(100) DEFAULT NULL, + meeting_id INT NOT NULL, + user_prompt TEXT, + status VARCHAR(50) DEFAULT 'pending', + progress INT DEFAULT 0, + result TEXT, + error_message TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP NULL, + INDEX idx_meeting_id (meeting_id), + INDEX idx_status (status), + INDEX idx_created_at (created_at) +) \ No newline at end of file diff --git a/test/test_kb_prompt_id_feature.py b/test/test_kb_prompt_id_feature.py new file mode 100644 index 0000000..7625ffc --- /dev/null +++ b/test/test_kb_prompt_id_feature.py @@ -0,0 +1,166 @@ +""" +测试知识库提示词模版选择功能 +""" +import sys +sys.path.insert(0, 'app') + +from app.services.llm_service import LLMService +from app.services.async_knowledge_base_service import AsyncKnowledgeBaseService +from app.core.database import get_db_connection + +def test_get_active_knowledge_prompts(): + """测试获取启用的知识库提示词列表""" + print("\n=== 测试1: 获取启用的知识库提示词列表 ===") + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 获取KNOWLEDGE_TASK类型的启用模版 + query = """ + SELECT id, name, is_default + FROM prompts + WHERE task_type = %s AND is_active = TRUE + ORDER BY is_default DESC, created_at DESC + """ + cursor.execute(query, ('KNOWLEDGE_TASK',)) + prompts = cursor.fetchall() + + print(f"✓ 找到 {len(prompts)} 个启用的知识库任务模版:") + for p in prompts: + default_flag = " [默认]" if p['is_default'] else "" + print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}") + + return prompts + except Exception as e: + print(f"✗ 测试失败: {e}") + import traceback + traceback.print_exc() + return [] + +def test_get_task_prompt_with_id(prompts): + """测试通过prompt_id获取知识库提示词内容""" + print("\n=== 测试2: 通过prompt_id获取知识库提示词内容 ===") + + if not prompts: + print("⚠ 没有可用的提示词模版,跳过测试") + return + + llm_service = LLMService() + + # 测试获取第一个提示词 + test_prompt = prompts[0] + try: + content = llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id=test_prompt['id']) + print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}") + print(f" 内容长度: {len(content)} 字符") + print(f" 内容预览: {content[:100]}...") + except Exception as e: + print(f"✗ 测试失败: {e}") + import traceback + traceback.print_exc() + + # 测试获取默认提示词(不指定prompt_id) + try: + default_content = llm_service.get_task_prompt('KNOWLEDGE_TASK') + print(f"✓ 成功获取默认提示词") + print(f" 内容长度: {len(default_content)} 字符") + except Exception as e: + print(f"✗ 获取默认提示词失败: {e}") + +def test_async_kb_service_signature(): + """测试async_knowledge_base_service的方法签名""" + print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===") + + import inspect + async_service = AsyncKnowledgeBaseService() + + # 检查start_generation方法签名 + sig = inspect.signature(async_service.start_generation) + params = list(sig.parameters.keys()) + + if 'prompt_id' in params: + print(f"✓ start_generation 方法支持 prompt_id 参数") + print(f" 参数列表: {params}") + else: + print(f"✗ start_generation 方法缺少 prompt_id 参数") + print(f" 参数列表: {params}") + + # 检查_build_prompt方法签名 + sig2 = inspect.signature(async_service._build_prompt) + params2 = list(sig2.parameters.keys()) + + if 'prompt_id' in params2: + print(f"✓ _build_prompt 方法支持 prompt_id 参数") + print(f" 参数列表: {params2}") + else: + print(f"✗ _build_prompt 方法缺少 prompt_id 参数") + print(f" 参数列表: {params2}") + +def test_database_schema(): + """测试数据库schema是否包含prompt_id列""" + print("\n=== 测试4: 验证数据库schema ===") + + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 检查knowledge_base_tasks表是否有prompt_id列 + cursor.execute(""" + SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = 'knowledge_base_tasks' + AND COLUMN_NAME = 'prompt_id' + """) + result = cursor.fetchone() + + if result: + print(f"✓ knowledge_base_tasks 表包含 prompt_id 列") + print(f" 类型: {result['DATA_TYPE']}") + print(f" 可空: {result['IS_NULLABLE']}") + print(f" 默认值: {result['COLUMN_DEFAULT']}") + else: + print(f"✗ knowledge_base_tasks 表缺少 prompt_id 列") + except Exception as e: + print(f"✗ 数据库检查失败: {e}") + import traceback + traceback.print_exc() + +def test_api_model(): + """测试API模型定义""" + print("\n=== 测试5: 验证API模型定义 ===") + + try: + from app.models.models import CreateKnowledgeBaseRequest + import inspect + + # 检查CreateKnowledgeBaseRequest模型 + fields = CreateKnowledgeBaseRequest.model_fields + + if 'prompt_id' in fields: + print(f"✓ CreateKnowledgeBaseRequest 包含 prompt_id 字段") + print(f" 字段列表: {list(fields.keys())}") + else: + print(f"✗ CreateKnowledgeBaseRequest 缺少 prompt_id 字段") + print(f" 字段列表: {list(fields.keys())}") + + except Exception as e: + print(f"✗ API模型检查失败: {e}") + import traceback + traceback.print_exc() + +if __name__ == '__main__': + print("=" * 60) + print("开始测试知识库提示词模版选择功能") + print("=" * 60) + + # 运行所有测试 + prompts = test_get_active_knowledge_prompts() + test_get_task_prompt_with_id(prompts) + test_async_kb_service_signature() + test_database_schema() + test_api_model() + + print("\n" + "=" * 60) + print("测试完成") + print("=" * 60) diff --git a/test/test_menu_permissions.py b/test/test_menu_permissions.py new file mode 100644 index 0000000..e3ca959 --- /dev/null +++ b/test/test_menu_permissions.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +""" +测试菜单权限数据是否存在 +""" +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from app.core.database import get_db_connection + +def test_menu_permissions(): + print("=== 测试菜单权限数据 ===\n") + + try: + # 连接数据库 + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 1. 检查menus表 + print("1. 检查menus表:") + cursor.execute("SELECT COUNT(*) as count FROM menus") + menu_count = cursor.fetchone()['count'] + print(f" - 菜单总数: {menu_count}") + + if menu_count > 0: + cursor.execute("SELECT menu_id, menu_code, menu_name, is_active FROM menus ORDER BY sort_order") + menus = cursor.fetchall() + for menu in menus: + print(f" - [{menu['menu_id']}] {menu['menu_name']} ({menu['menu_code']}) - 启用: {menu['is_active']}") + else: + print(" ⚠️ menus表为空!") + + print() + + # 2. 检查roles表 + print("2. 检查roles表:") + cursor.execute("SELECT * FROM roles ORDER BY role_id") + roles = cursor.fetchall() + for role in roles: + print(f" - [{role['role_id']}] {role['role_name']}") + + print() + + # 3. 检查role_menu_permissions表 + print("3. 检查role_menu_permissions表:") + cursor.execute("SELECT COUNT(*) as count FROM role_menu_permissions") + perm_count = cursor.fetchone()['count'] + print(f" - 权限总数: {perm_count}") + + if perm_count > 0: + cursor.execute(""" + SELECT r.role_name, m.menu_name, rmp.role_id, rmp.menu_id + FROM role_menu_permissions rmp + JOIN roles r ON rmp.role_id = r.role_id + JOIN menus m ON rmp.menu_id = m.menu_id + ORDER BY rmp.role_id, m.sort_order + """) + permissions = cursor.fetchall() + + current_role = None + for perm in permissions: + if current_role != perm['role_name']: + current_role = perm['role_name'] + print(f"\n {current_role}的权限:") + print(f" - {perm['menu_name']}") + else: + print(" ⚠️ role_menu_permissions表为空!") + + print("\n" + "="*50) + + # 4. 检查是否需要执行SQL脚本 + if menu_count == 0 or perm_count == 0: + print("\n❌ 数据库中缺少菜单或权限数据!") + print("请执行以下命令初始化数据:") + print("\nmysql -h 10.100.51.161 -u root -psagacity imeeting_dev < backend/sql/add_menu_permissions_system.sql") + print("\n或者在MySQL客户端中执行该SQL文件。") + else: + print("\n✅ 菜单权限数据正常!") + + cursor.close() + connection.close() + + except Exception as e: + print(f"❌ 错误: {str(e)}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + test_menu_permissions() diff --git a/test/test_prompt_id_feature.py b/test/test_prompt_id_feature.py new file mode 100644 index 0000000..b8ce390 --- /dev/null +++ b/test/test_prompt_id_feature.py @@ -0,0 +1,176 @@ +""" +测试提示词模版选择功能 +""" +import sys +sys.path.insert(0, 'app') + +from app.services.llm_service import LLMService +from app.services.async_meeting_service import AsyncMeetingService +from app.core.database import get_db_connection + +def test_get_active_prompts(): + """测试获取启用的提示词列表""" + print("\n=== 测试1: 获取启用的提示词列表 ===") + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 获取MEETING_TASK类型的启用模版 + query = """ + SELECT id, name, is_default + FROM prompts + WHERE task_type = %s AND is_active = TRUE + ORDER BY is_default DESC, created_at DESC + """ + cursor.execute(query, ('MEETING_TASK',)) + prompts = cursor.fetchall() + + print(f"✓ 找到 {len(prompts)} 个启用的会议任务模版:") + for p in prompts: + default_flag = " [默认]" if p['is_default'] else "" + print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}") + + return prompts + except Exception as e: + print(f"✗ 测试失败: {e}") + import traceback + traceback.print_exc() + return [] + +def test_get_task_prompt_with_id(prompts): + """测试通过prompt_id获取提示词内容""" + print("\n=== 测试2: 通过prompt_id获取提示词内容 ===") + + if not prompts: + print("⚠ 没有可用的提示词模版,跳过测试") + return + + llm_service = LLMService() + + # 测试获取第一个提示词 + test_prompt = prompts[0] + try: + content = llm_service.get_task_prompt('MEETING_TASK', prompt_id=test_prompt['id']) + print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}") + print(f" 内容长度: {len(content)} 字符") + print(f" 内容预览: {content[:100]}...") + except Exception as e: + print(f"✗ 测试失败: {e}") + import traceback + traceback.print_exc() + + # 测试获取默认提示词(不指定prompt_id) + try: + default_content = llm_service.get_task_prompt('MEETING_TASK') + print(f"✓ 成功获取默认提示词") + print(f" 内容长度: {len(default_content)} 字符") + except Exception as e: + print(f"✗ 获取默认提示词失败: {e}") + +def test_async_meeting_service_signature(): + """测试async_meeting_service的方法签名""" + print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===") + + import inspect + async_service = AsyncMeetingService() + + # 检查start_summary_generation方法签名 + sig = inspect.signature(async_service.start_summary_generation) + params = list(sig.parameters.keys()) + + if 'prompt_id' in params: + print(f"✓ start_summary_generation 方法支持 prompt_id 参数") + print(f" 参数列表: {params}") + else: + print(f"✗ start_summary_generation 方法缺少 prompt_id 参数") + print(f" 参数列表: {params}") + + # 检查monitor_and_auto_summarize方法签名 + sig2 = inspect.signature(async_service.monitor_and_auto_summarize) + params2 = list(sig2.parameters.keys()) + + if 'prompt_id' in params2: + print(f"✓ monitor_and_auto_summarize 方法支持 prompt_id 参数") + print(f" 参数列表: {params2}") + else: + print(f"✗ monitor_and_auto_summarize 方法缺少 prompt_id 参数") + print(f" 参数列表: {params2}") + +def test_database_schema(): + """测试数据库schema是否包含prompt_id列""" + print("\n=== 测试4: 验证数据库schema ===") + + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # 检查llm_tasks表是否有prompt_id列 + cursor.execute(""" + SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT + FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = 'llm_tasks' + AND COLUMN_NAME = 'prompt_id' + """) + result = cursor.fetchone() + + if result: + print(f"✓ llm_tasks 表包含 prompt_id 列") + print(f" 类型: {result['DATA_TYPE']}") + print(f" 可空: {result['IS_NULLABLE']}") + print(f" 默认值: {result['COLUMN_DEFAULT']}") + else: + print(f"✗ llm_tasks 表缺少 prompt_id 列") + except Exception as e: + print(f"✗ 数据库检查失败: {e}") + import traceback + traceback.print_exc() + +def test_api_endpoints(): + """测试API端点定义""" + print("\n=== 测试5: 验证API端点定义 ===") + + try: + from app.api.endpoints.meetings import GenerateSummaryRequest + import inspect + + # 检查GenerateSummaryRequest模型 + fields = GenerateSummaryRequest.__fields__ + + if 'prompt_id' in fields: + print(f"✓ GenerateSummaryRequest 包含 prompt_id 字段") + print(f" 字段列表: {list(fields.keys())}") + else: + print(f"✗ GenerateSummaryRequest 缺少 prompt_id 字段") + print(f" 字段列表: {list(fields.keys())}") + + # 检查audio_service.handle_audio_upload签名 + from app.services.audio_service import handle_audio_upload + sig = inspect.signature(handle_audio_upload) + params = list(sig.parameters.keys()) + + if 'prompt_id' in params: + print(f"✓ handle_audio_upload 方法支持 prompt_id 参数") + else: + print(f"✗ handle_audio_upload 方法缺少 prompt_id 参数") + + except Exception as e: + print(f"✗ API端点检查失败: {e}") + import traceback + traceback.print_exc() + +if __name__ == '__main__': + print("=" * 60) + print("开始测试提示词模版选择功能") + print("=" * 60) + + # 运行所有测试 + prompts = test_get_active_prompts() + test_get_task_prompt_with_id(prompts) + test_async_meeting_service_signature() + test_database_schema() + test_api_endpoints() + + print("\n" + "=" * 60) + print("测试完成") + print("=" * 60)