增加了权限系统
parent
3260b99c6b
commit
de289add81
|
|
@ -33,4 +33,4 @@ COPY . .
|
|||
EXPOSE 8001
|
||||
|
||||
# 启动命令
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8001"]
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
# 提示词模版选择功能实现总结
|
||||
|
||||
## 功能概述
|
||||
实现了用户在创建会议时可以选择使用的总结模版功能。支持两种场景:
|
||||
1. 手动生成会议总结时选择模版
|
||||
2. 上传音频文件时选择模版,自动总结时使用该模版
|
||||
|
||||
## 实现的功能点
|
||||
|
||||
### 1. 新增API接口 ✅
|
||||
**文件**: `app/api/endpoints/prompts.py`
|
||||
- 新增 `GET /prompts/active/{task_type}` 接口
|
||||
- 功能:获取指定任务类型的所有启用状态的提示词模版
|
||||
- 返回字段:id, name, is_default
|
||||
- 按默认模版优先、创建时间倒序排列
|
||||
|
||||
### 2. 修改LLM服务 ✅
|
||||
**文件**: `app/services/llm_service.py`
|
||||
- 修改 `get_task_prompt()` 方法,增加 `prompt_id` 可选参数
|
||||
- 逻辑:
|
||||
- 如果指定 prompt_id,查询该ID对应的提示词(需验证task_type和is_active)
|
||||
- 如果不指定,使用默认提示词(is_default=TRUE)
|
||||
- 如果都查不到,返回代码中的默认值
|
||||
|
||||
### 3. 修改会议总结服务 ✅
|
||||
**文件**: `app/services/async_meeting_service.py`
|
||||
|
||||
#### 3.1 修改任务创建
|
||||
- `start_summary_generation()`: 增加 `prompt_id` 参数
|
||||
- 将 prompt_id 存储到 Redis 和数据库
|
||||
|
||||
#### 3.2 修改任务处理
|
||||
- `_process_task()`: 从 Redis 读取 prompt_id,传递给 `_build_prompt()`
|
||||
- `_build_prompt()`: 增加 `prompt_id` 参数,传递给 `llm_service.get_task_prompt()`
|
||||
- `_save_task_to_db()`: 增加 `prompt_id` 参数,存储到数据库
|
||||
|
||||
#### 3.3 修改自动总结监控
|
||||
- `monitor_and_auto_summarize()`: 增加 `prompt_id` 参数
|
||||
- 在转录完成后启动总结任务时,传递 prompt_id
|
||||
|
||||
### 4. 修改音频服务 ✅
|
||||
**文件**: `app/services/audio_service.py`
|
||||
- `handle_audio_upload()`: 增加 `prompt_id` 参数
|
||||
- 将 prompt_id 传递给 `monitor_and_auto_summarize()`
|
||||
|
||||
### 5. 修改会议API接口 ✅
|
||||
**文件**: `app/api/endpoints/meetings.py`
|
||||
|
||||
#### 5.1 手动生成总结
|
||||
- `GenerateSummaryRequest` 模型:增加 `prompt_id` 字段
|
||||
- `POST /meetings/{meeting_id}/generate-summary-async`: 传递 prompt_id 给服务层
|
||||
|
||||
#### 5.2 音频上传
|
||||
- `POST /meetings/upload-audio`: 增加 `prompt_id` 表单参数
|
||||
- 将 prompt_id 传递给 `handle_audio_upload()`
|
||||
|
||||
### 6. 数据库迁移 ✅
|
||||
**文件**: `sql/add_prompt_id_to_llm_tasks.sql`
|
||||
- 为 `llm_tasks` 表添加 `prompt_id` 列
|
||||
- 类型:int(11)
|
||||
- 可空:YES
|
||||
- 默认值:NULL
|
||||
- 索引:idx_prompt_id
|
||||
|
||||
## 数据流向
|
||||
|
||||
### 手动生成总结
|
||||
```
|
||||
前端 → POST /meetings/{id}/generate-summary-async (prompt_id)
|
||||
→ async_meeting_service.start_summary_generation(prompt_id)
|
||||
→ 存储到 Redis 和 DB (llm_tasks.prompt_id)
|
||||
→ _process_task() 读取 prompt_id
|
||||
→ _build_prompt(prompt_id)
|
||||
→ llm_service.get_task_prompt('MEETING_TASK', prompt_id)
|
||||
→ 获取指定模版或默认模版
|
||||
```
|
||||
|
||||
### 音频上传自动总结
|
||||
```
|
||||
前端 → POST /meetings/upload-audio (prompt_id, auto_summarize=true)
|
||||
→ handle_audio_upload(prompt_id)
|
||||
→ transcription_service.start_transcription()
|
||||
→ monitor_and_auto_summarize(prompt_id)
|
||||
→ 等待转录完成
|
||||
→ start_summary_generation(prompt_id)
|
||||
→ (后续流程同手动生成总结)
|
||||
```
|
||||
|
||||
## 向后兼容性
|
||||
|
||||
所有新增的 `prompt_id` 参数都是可选的(Optional[int] = None),确保:
|
||||
1. 不传递 prompt_id 时,自动使用默认模版
|
||||
2. 现有代码无需修改即可正常工作
|
||||
3. 数据库中 prompt_id 允许为 NULL
|
||||
|
||||
## 测试结果
|
||||
|
||||
执行 `test_prompt_id_feature.py` 测试脚本,所有测试通过:
|
||||
- ✅ 获取启用的提示词列表 (6个模版)
|
||||
- ✅ 通过prompt_id获取提示词内容
|
||||
- ✅ 获取默认提示词(不指定prompt_id)
|
||||
- ✅ 验证方法签名支持prompt_id参数
|
||||
- ✅ 验证数据库schema包含prompt_id列
|
||||
- ✅ 验证API端点定义正确
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 1. 获取启用的会议任务模版列表
|
||||
```bash
|
||||
GET /api/prompts/active/MEETING_TASK
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
返回:
|
||||
```json
|
||||
{
|
||||
"code": "200",
|
||||
"message": "获取启用模版列表成功",
|
||||
"data": {
|
||||
"prompts": [
|
||||
{"id": 1, "name": "默认会议总结", "is_default": true},
|
||||
{"id": 5, "name": "产品会议总结", "is_default": false}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 手动生成总结时指定模版
|
||||
```bash
|
||||
POST /api/meetings/123/generate-summary-async
|
||||
Authorization: Bearer <token>
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"user_prompt": "重点关注技术讨论",
|
||||
"prompt_id": 5
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 上传音频时指定模版
|
||||
```bash
|
||||
POST /api/meetings/upload-audio
|
||||
Authorization: Bearer <token>
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
- audio_file: <file>
|
||||
- meeting_id: 123
|
||||
- auto_summarize: true
|
||||
- prompt_id: 5
|
||||
```
|
||||
|
||||
## 文件变更列表
|
||||
|
||||
1. `app/api/endpoints/prompts.py` - 新增API接口
|
||||
2. `app/api/endpoints/meetings.py` - 修改两个端点
|
||||
3. `app/services/llm_service.py` - 修改get_task_prompt方法
|
||||
4. `app/services/async_meeting_service.py` - 修改4个方法
|
||||
5. `app/services/audio_service.py` - 修改handle_audio_upload方法
|
||||
6. `sql/add_prompt_id_to_llm_tasks.sql` - 数据库迁移脚本
|
||||
7. `test_prompt_id_feature.py` - 测试脚本
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. prompt_id 会与 task_type 一起验证,防止使用错误类型的模版
|
||||
2. 如果指定的 prompt_id 不存在或未启用,会自动使用默认模版
|
||||
3. 历史任务记录保留 prompt_id,即使对应的提示词被删除
|
||||
4. Redis 中 prompt_id 存储为字符串,使用时需转换为 int
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
# 知识库提示词模版选择功能实现总结
|
||||
|
||||
## 功能概述
|
||||
为知识库生成功能添加了提示词模版选择支持,用户在创建知识库时可以选择使用的生成模版。
|
||||
|
||||
## 实现的功能点
|
||||
|
||||
### 1. 修改请求模型 ✅
|
||||
**文件**: `app/models/models.py`
|
||||
- `CreateKnowledgeBaseRequest` 模型增加 `prompt_id` 字段
|
||||
- 类型:Optional[int] = None
|
||||
- 不指定时使用默认模版
|
||||
|
||||
### 2. 修改知识库异步服务 ✅
|
||||
**文件**: `app/services/async_knowledge_base_service.py`
|
||||
|
||||
#### 2.1 修改任务创建
|
||||
- `start_generation()`: 增加 `prompt_id` 参数
|
||||
- 将 prompt_id 存储到 Redis 和数据库
|
||||
- 支持通过 cursor 参数直接插入(事务场景)
|
||||
|
||||
#### 2.2 修改任务处理
|
||||
- `_process_task()`: 从 Redis 读取 prompt_id,传递给 `_build_prompt()`
|
||||
- 处理空字符串情况,转换为 None
|
||||
|
||||
#### 2.3 修改提示词构建
|
||||
- `_build_prompt()`: 增加 `prompt_id` 参数
|
||||
- 调用 `llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id=prompt_id)`
|
||||
- 支持获取指定模版或默认模版
|
||||
|
||||
#### 2.4 修改数据库保存
|
||||
- `_save_task_to_db()`: 增加 `prompt_id` 参数
|
||||
- 插入时包含 prompt_id 字段
|
||||
|
||||
### 3. 修改API接口 ✅
|
||||
**文件**: `app/api/endpoints/knowledge_base.py`
|
||||
- `create_knowledge_base`: 从请求中获取 `prompt_id`
|
||||
- 调用 `async_kb_service.start_generation()` 时传递 `prompt_id`
|
||||
|
||||
### 4. 数据库字段 ✅
|
||||
- `knowledge_base_tasks` 表已包含 `prompt_id` 列(用户已添加)
|
||||
- 类型:int
|
||||
- 可空:NO(默认值:0)
|
||||
|
||||
## 数据流向
|
||||
|
||||
```
|
||||
前端 → POST /api/knowledge-bases (prompt_id)
|
||||
→ CreateKnowledgeBaseRequest (prompt_id)
|
||||
→ async_kb_service.start_generation(prompt_id)
|
||||
→ 存储到 Redis 和 DB (knowledge_base_tasks.prompt_id)
|
||||
→ _process_task() 读取 prompt_id
|
||||
→ _build_prompt(prompt_id)
|
||||
→ llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id)
|
||||
→ 获取指定模版或默认模版
|
||||
```
|
||||
|
||||
## 向后兼容性
|
||||
|
||||
所有新增的 `prompt_id` 参数都是可选的(Optional[int] = None),确保:
|
||||
1. 不传递 prompt_id 时,自动使用默认模版
|
||||
2. 现有代码无需修改即可正常工作
|
||||
3. 数据库中 prompt_id 有默认值 0
|
||||
|
||||
## 测试结果
|
||||
|
||||
执行 `test_kb_prompt_id_feature.py` 测试脚本,所有测试通过:
|
||||
- ✅ 获取启用的知识库提示词列表 (3个模版)
|
||||
- ✅ 通过prompt_id获取提示词内容
|
||||
- ✅ 获取默认提示词(不指定prompt_id)
|
||||
- ✅ 验证方法签名支持prompt_id参数
|
||||
- ✅ 验证数据库schema包含prompt_id列
|
||||
- ✅ 验证API模型定义正确
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 1. 获取启用的知识库任务模版列表
|
||||
```bash
|
||||
GET /api/prompts/active/KNOWLEDGE_TASK
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
返回:
|
||||
```json
|
||||
{
|
||||
"code": "200",
|
||||
"message": "获取启用模版列表成功",
|
||||
"data": {
|
||||
"prompts": [
|
||||
{"id": 2, "name": "默认知识库生成", "is_default": true},
|
||||
{"id": 13, "name": "分析总结模版", "is_default": false}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 创建知识库时指定模版
|
||||
```bash
|
||||
POST /api/knowledge-bases
|
||||
Authorization: Bearer <token>
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"title": "产品会议知识库",
|
||||
"is_shared": false,
|
||||
"user_prompt": "重点提取产品功能相关信息",
|
||||
"source_meeting_ids": "1,2,3",
|
||||
"tags": "产品,功能",
|
||||
"prompt_id": 13
|
||||
}
|
||||
```
|
||||
|
||||
## 文件变更列表
|
||||
|
||||
1. `app/models/models.py` - 修改CreateKnowledgeBaseRequest模型
|
||||
2. `app/services/async_knowledge_base_service.py` - 修改5个方法
|
||||
3. `app/api/endpoints/knowledge_base.py` - 修改create_knowledge_base端点
|
||||
4. `test_kb_prompt_id_feature.py` - 测试脚本
|
||||
|
||||
## 与会议总结功能的一致性
|
||||
|
||||
知识库的实现与会议总结功能保持一致:
|
||||
- 相同的prompt_id传递机制
|
||||
- 相同的Redis存储格式(字符串)
|
||||
- 相同的数据库字段类型
|
||||
- 相同的向后兼容策略
|
||||
- 相同的验证逻辑(task_type + is_active)
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. prompt_id 会与 task_type='KNOWLEDGE_TASK' 一起验证
|
||||
2. 如果指定的 prompt_id 不存在或未启用,会自动使用默认模版
|
||||
3. 历史任务记录保留 prompt_id,即使对应的提示词被删除
|
||||
4. Redis 中 prompt_id 存储为字符串,使用时需转换为 int
|
||||
5. 数据库 prompt_id 默认值为 0(表示未指定)
|
||||
|
||||
## 总结
|
||||
|
||||
知识库提示词模版选择功能已完全实现并通过测试,与会议总结功能保持一致的设计和实现方式。用户现在可以在创建知识库时选择不同的生成模版,以满足不同场景的需求。
|
||||
|
|
@ -2,7 +2,10 @@ from fastapi import APIRouter, Depends
|
|||
from app.core.auth import get_current_admin_user, get_current_user
|
||||
from app.core.config import LLM_CONFIG, DEFAULT_RESET_PASSWORD, MAX_FILE_SIZE, VOICEPRINT_CONFIG, TIMELINE_PAGESIZE
|
||||
from app.core.response import create_api_response
|
||||
from app.core.database import get_db_connection
|
||||
from app.models.models import MenuInfo, MenuListResponse, RolePermissionInfo, UpdateRolePermissionsRequest, RoleInfo
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -117,3 +120,184 @@ def load_system_config():
|
|||
print(f"系统配置加载成功: model={config.get('model_name')}, pagesize={config.get('TIMELINE_PAGESIZE')}")
|
||||
except Exception as e:
|
||||
print(f"加载系统配置失败,使用默认配置: {e}")
|
||||
|
||||
# ========== 菜单权限管理接口 ==========
|
||||
|
||||
@router.get("/admin/menus")
|
||||
async def get_all_menus(current_user=Depends(get_current_admin_user)):
|
||||
"""
|
||||
获取所有菜单列表
|
||||
只有管理员才能访问
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
query = """
|
||||
SELECT menu_id, menu_code, menu_name, menu_icon, menu_url, menu_type,
|
||||
parent_id, sort_order, is_active, description, created_at, updated_at
|
||||
FROM menus
|
||||
ORDER BY sort_order ASC, menu_id ASC
|
||||
"""
|
||||
cursor.execute(query)
|
||||
menus = cursor.fetchall()
|
||||
|
||||
menu_list = [MenuInfo(**menu) for menu in menus]
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="获取菜单列表成功",
|
||||
data=MenuListResponse(menus=menu_list, total=len(menu_list))
|
||||
)
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"获取菜单列表失败: {str(e)}")
|
||||
|
||||
@router.get("/admin/roles")
|
||||
async def get_all_roles(current_user=Depends(get_current_admin_user)):
|
||||
"""
|
||||
获取所有角色列表及其权限统计
|
||||
只有管理员才能访问
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 查询所有角色及其权限数量
|
||||
query = """
|
||||
SELECT r.role_id, r.role_name, r.created_at,
|
||||
COUNT(rmp.menu_id) as menu_count
|
||||
FROM roles r
|
||||
LEFT JOIN role_menu_permissions rmp ON r.role_id = rmp.role_id
|
||||
GROUP BY r.role_id
|
||||
ORDER BY r.role_id ASC
|
||||
"""
|
||||
cursor.execute(query)
|
||||
roles = cursor.fetchall()
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="获取角色列表成功",
|
||||
data={"roles": roles, "total": len(roles)}
|
||||
)
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"获取角色列表失败: {str(e)}")
|
||||
|
||||
@router.get("/admin/roles/{role_id}/permissions")
|
||||
async def get_role_permissions(role_id: int, current_user=Depends(get_current_admin_user)):
|
||||
"""
|
||||
获取指定角色的菜单权限
|
||||
只有管理员才能访问
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 检查角色是否存在
|
||||
cursor.execute("SELECT role_id, role_name FROM roles WHERE role_id = %s", (role_id,))
|
||||
role = cursor.fetchone()
|
||||
if not role:
|
||||
return create_api_response(code="404", message="角色不存在")
|
||||
|
||||
# 查询该角色的所有菜单权限
|
||||
query = """
|
||||
SELECT menu_id
|
||||
FROM role_menu_permissions
|
||||
WHERE role_id = %s
|
||||
"""
|
||||
cursor.execute(query, (role_id,))
|
||||
permissions = cursor.fetchall()
|
||||
|
||||
menu_ids = [p['menu_id'] for p in permissions]
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="获取角色权限成功",
|
||||
data=RolePermissionInfo(
|
||||
role_id=role['role_id'],
|
||||
role_name=role['role_name'],
|
||||
menu_ids=menu_ids
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"获取角色权限失败: {str(e)}")
|
||||
|
||||
@router.put("/admin/roles/{role_id}/permissions")
|
||||
async def update_role_permissions(
|
||||
role_id: int,
|
||||
request: UpdateRolePermissionsRequest,
|
||||
current_user=Depends(get_current_admin_user)
|
||||
):
|
||||
"""
|
||||
更新指定角色的菜单权限
|
||||
只有管理员才能访问
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 检查角色是否存在
|
||||
cursor.execute("SELECT role_id FROM roles WHERE role_id = %s", (role_id,))
|
||||
if not cursor.fetchone():
|
||||
return create_api_response(code="404", message="角色不存在")
|
||||
|
||||
# 验证所有menu_id是否有效
|
||||
if request.menu_ids:
|
||||
format_strings = ','.join(['%s'] * len(request.menu_ids))
|
||||
cursor.execute(
|
||||
f"SELECT COUNT(*) as count FROM menus WHERE menu_id IN ({format_strings})",
|
||||
tuple(request.menu_ids)
|
||||
)
|
||||
valid_count = cursor.fetchone()['count']
|
||||
if valid_count != len(request.menu_ids):
|
||||
return create_api_response(code="400", message="包含无效的菜单ID")
|
||||
|
||||
# 删除该角色的所有现有权限
|
||||
cursor.execute("DELETE FROM role_menu_permissions WHERE role_id = %s", (role_id,))
|
||||
|
||||
# 插入新的权限
|
||||
if request.menu_ids:
|
||||
insert_values = [(role_id, menu_id) for menu_id in request.menu_ids]
|
||||
cursor.executemany(
|
||||
"INSERT INTO role_menu_permissions (role_id, menu_id) VALUES (%s, %s)",
|
||||
insert_values
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="更新角色权限成功",
|
||||
data={"role_id": role_id, "menu_count": len(request.menu_ids)}
|
||||
)
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"更新角色权限失败: {str(e)}")
|
||||
|
||||
@router.get("/menus/user")
|
||||
async def get_user_menus(current_user=Depends(get_current_user)):
|
||||
"""
|
||||
获取当前用户可访问的菜单列表(用于渲染下拉菜单)
|
||||
所有登录用户都可以访问
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 根据用户的role_id查询可访问的菜单
|
||||
query = """
|
||||
SELECT DISTINCT m.menu_id, m.menu_code, m.menu_name, m.menu_icon,
|
||||
m.menu_url, m.menu_type, m.sort_order
|
||||
FROM menus m
|
||||
JOIN role_menu_permissions rmp ON m.menu_id = rmp.menu_id
|
||||
WHERE rmp.role_id = %s AND m.is_active = 1
|
||||
ORDER BY m.sort_order ASC
|
||||
"""
|
||||
cursor.execute(query, (current_user['role_id'],))
|
||||
menus = cursor.fetchall()
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="获取用户菜单成功",
|
||||
data={"menus": menus}
|
||||
)
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"获取用户菜单失败: {str(e)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,398 @@
|
|||
from fastapi import APIRouter, Depends, Query
|
||||
from app.core.auth import get_current_admin_user
|
||||
from app.core.response import create_api_response
|
||||
from app.core.database import get_db_connection
|
||||
from app.services.jwt_service import jwt_service
|
||||
from app.core.config import AUDIO_DIR, REDIS_CONFIG
|
||||
from datetime import datetime
|
||||
from typing import Dict, List
|
||||
import os
|
||||
import redis
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Redis 客户端
|
||||
redis_client = redis.Redis(**REDIS_CONFIG)
|
||||
|
||||
# 常量定义
|
||||
AUDIO_FILE_EXTENSIONS = ('.wav', '.mp3', '.m4a', '.aac', '.flac', '.ogg')
|
||||
BYTES_TO_GB = 1024 ** 3
|
||||
|
||||
|
||||
def _build_status_condition(status: str) -> str:
|
||||
"""构建任务状态查询条件"""
|
||||
if status == 'running':
|
||||
return "AND (t.status = 'pending' OR t.status = 'processing')"
|
||||
elif status == 'completed':
|
||||
return "AND t.status = 'completed'"
|
||||
elif status == 'failed':
|
||||
return "AND t.status = 'failed'"
|
||||
return ""
|
||||
|
||||
|
||||
def _get_task_stats_query() -> str:
|
||||
"""获取任务统计的 SQL 查询"""
|
||||
return """
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN status = 'pending' OR status = 'processing' THEN 1 ELSE 0 END) as running,
|
||||
SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed
|
||||
"""
|
||||
|
||||
|
||||
def _get_online_user_count(redis_client) -> int:
|
||||
"""从 Redis 获取在线用户数"""
|
||||
try:
|
||||
token_keys = redis_client.keys("token:*")
|
||||
user_ids = set()
|
||||
for key in token_keys:
|
||||
parts = key.split(':')
|
||||
if len(parts) >= 2:
|
||||
user_ids.add(parts[1])
|
||||
return len(user_ids)
|
||||
except Exception as e:
|
||||
print(f"获取在线用户数失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def _calculate_audio_storage() -> Dict[str, float]:
|
||||
"""计算音频文件存储统计"""
|
||||
audio_files_count = 0
|
||||
audio_total_size = 0
|
||||
|
||||
try:
|
||||
if os.path.exists(AUDIO_DIR):
|
||||
for root, _, files in os.walk(AUDIO_DIR):
|
||||
for file in files:
|
||||
if file.endswith(AUDIO_FILE_EXTENSIONS):
|
||||
audio_files_count += 1
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
audio_total_size += os.path.getsize(file_path)
|
||||
except OSError:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"统计音频文件失败: {e}")
|
||||
|
||||
return {
|
||||
"audio_files_count": audio_files_count,
|
||||
"audio_total_size_gb": round(audio_total_size / BYTES_TO_GB, 2)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/admin/dashboard/stats")
|
||||
async def get_dashboard_stats(current_user=Depends(get_current_admin_user)):
|
||||
"""获取管理员 Dashboard 统计数据"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 1. 用户统计
|
||||
today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
cursor.execute("SELECT COUNT(*) as total FROM users")
|
||||
total_users = cursor.fetchone()['total']
|
||||
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) as count FROM users WHERE created_at >= %s",
|
||||
(today_start,)
|
||||
)
|
||||
today_new_users = cursor.fetchone()['count']
|
||||
|
||||
online_users = _get_online_user_count(redis_client)
|
||||
|
||||
# 2. 会议统计
|
||||
cursor.execute("SELECT COUNT(*) as total FROM meetings")
|
||||
total_meetings = cursor.fetchone()['total']
|
||||
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) as count FROM meetings WHERE created_at >= %s",
|
||||
(today_start,)
|
||||
)
|
||||
today_new_meetings = cursor.fetchone()['count']
|
||||
|
||||
# 3. 任务统计
|
||||
task_stats_query = _get_task_stats_query()
|
||||
|
||||
# 转录任务
|
||||
cursor.execute(f"{task_stats_query} FROM transcript_tasks")
|
||||
transcription_stats = cursor.fetchone() or {'total': 0, 'running': 0, 'completed': 0, 'failed': 0}
|
||||
|
||||
# 总结任务
|
||||
cursor.execute(f"{task_stats_query} FROM llm_tasks")
|
||||
summary_stats = cursor.fetchone() or {'total': 0, 'running': 0, 'completed': 0, 'failed': 0}
|
||||
|
||||
# 知识库任务
|
||||
cursor.execute(f"{task_stats_query} FROM knowledge_base_tasks")
|
||||
kb_stats = cursor.fetchone() or {'total': 0, 'running': 0, 'completed': 0, 'failed': 0}
|
||||
|
||||
# 4. 音频存储统计
|
||||
storage_stats = _calculate_audio_storage()
|
||||
|
||||
# 组装返回数据
|
||||
stats = {
|
||||
"users": {
|
||||
"total": total_users,
|
||||
"today_new": today_new_users,
|
||||
"online": online_users
|
||||
},
|
||||
"meetings": {
|
||||
"total": total_meetings,
|
||||
"today_new": today_new_meetings
|
||||
},
|
||||
"tasks": {
|
||||
"transcription": {
|
||||
"total": transcription_stats['total'] or 0,
|
||||
"running": transcription_stats['running'] or 0,
|
||||
"completed": transcription_stats['completed'] or 0,
|
||||
"failed": transcription_stats['failed'] or 0
|
||||
},
|
||||
"summary": {
|
||||
"total": summary_stats['total'] or 0,
|
||||
"running": summary_stats['running'] or 0,
|
||||
"completed": summary_stats['completed'] or 0,
|
||||
"failed": summary_stats['failed'] or 0
|
||||
},
|
||||
"knowledge_base": {
|
||||
"total": kb_stats['total'] or 0,
|
||||
"running": kb_stats['running'] or 0,
|
||||
"completed": kb_stats['completed'] or 0,
|
||||
"failed": kb_stats['failed'] or 0
|
||||
}
|
||||
},
|
||||
"storage": storage_stats
|
||||
}
|
||||
|
||||
return create_api_response(code="200", message="获取统计数据成功", data=stats)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取Dashboard统计数据失败: {e}")
|
||||
return create_api_response(code="500", message=f"获取统计数据失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/admin/online-users")
|
||||
async def get_online_users(current_user=Depends(get_current_admin_user)):
|
||||
"""获取在线用户列表"""
|
||||
try:
|
||||
token_keys = redis_client.keys("token:*")
|
||||
|
||||
# 提取用户ID并去重
|
||||
user_tokens = {}
|
||||
for key in token_keys:
|
||||
parts = key.split(':')
|
||||
if len(parts) >= 3:
|
||||
user_id = int(parts[1])
|
||||
token = parts[2]
|
||||
if user_id not in user_tokens:
|
||||
user_tokens[user_id] = []
|
||||
user_tokens[user_id].append({'token': token, 'key': key})
|
||||
|
||||
# 查询用户信息
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
online_users_list = []
|
||||
for user_id, tokens in user_tokens.items():
|
||||
cursor.execute(
|
||||
"SELECT user_id, username, caption, email, role_id FROM users WHERE user_id = %s",
|
||||
(user_id,)
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
if user:
|
||||
ttl_seconds = redis_client.ttl(tokens[0]['key'])
|
||||
online_users_list.append({
|
||||
**user,
|
||||
'token_count': len(tokens),
|
||||
'ttl_seconds': ttl_seconds,
|
||||
'ttl_hours': round(ttl_seconds / 3600, 1) if ttl_seconds > 0 else 0
|
||||
})
|
||||
|
||||
# 按用户ID排序
|
||||
online_users_list.sort(key=lambda x: x['user_id'])
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="获取在线用户列表成功",
|
||||
data={"users": online_users_list, "total": len(online_users_list)}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取在线用户列表失败: {e}")
|
||||
return create_api_response(code="500", message=f"获取在线用户列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/admin/kick-user/{user_id}")
|
||||
async def kick_user(user_id: int, current_user=Depends(get_current_admin_user)):
|
||||
"""踢出用户(撤销该用户的所有 token)"""
|
||||
try:
|
||||
revoked_count = jwt_service.revoke_all_user_tokens(user_id)
|
||||
|
||||
if revoked_count > 0:
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message=f"已踢出用户,撤销了 {revoked_count} 个 token",
|
||||
data={"user_id": user_id, "revoked_count": revoked_count}
|
||||
)
|
||||
else:
|
||||
return create_api_response(
|
||||
code="404",
|
||||
message="该用户当前不在线或未找到 token"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"踢出用户失败: {e}")
|
||||
return create_api_response(code="500", message=f"踢出用户失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/admin/tasks/monitor")
|
||||
async def monitor_tasks(
|
||||
task_type: str = Query('all', description="任务类型: all, transcription, summary, knowledge_base"),
|
||||
status: str = Query('all', description="任务状态: all, running, completed, failed"),
|
||||
limit: int = Query(20, ge=1, le=100, description="返回数量限制"),
|
||||
current_user=Depends(get_current_admin_user)
|
||||
):
|
||||
"""监控任务进度"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
tasks = []
|
||||
status_condition = _build_status_condition(status)
|
||||
|
||||
# 转录任务
|
||||
if task_type in ['all', 'transcription']:
|
||||
query = f"""
|
||||
SELECT
|
||||
t.task_id,
|
||||
'transcription' as task_type,
|
||||
t.meeting_id,
|
||||
m.title as meeting_title,
|
||||
t.status,
|
||||
t.progress,
|
||||
t.error_message,
|
||||
t.created_at,
|
||||
t.completed_at,
|
||||
u.username as creator_name
|
||||
FROM transcript_tasks t
|
||||
LEFT JOIN meetings m ON t.meeting_id = m.meeting_id
|
||||
LEFT JOIN users u ON m.user_id = u.user_id
|
||||
WHERE 1=1 {status_condition}
|
||||
ORDER BY t.created_at DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
cursor.execute(query, (limit,))
|
||||
tasks.extend(cursor.fetchall())
|
||||
|
||||
# 总结任务
|
||||
if task_type in ['all', 'summary']:
|
||||
query = f"""
|
||||
SELECT
|
||||
t.task_id,
|
||||
'summary' as task_type,
|
||||
t.meeting_id,
|
||||
m.title as meeting_title,
|
||||
t.status,
|
||||
NULL as progress,
|
||||
t.error_message,
|
||||
t.created_at,
|
||||
t.completed_at,
|
||||
u.username as creator_name
|
||||
FROM llm_tasks t
|
||||
LEFT JOIN meetings m ON t.meeting_id = m.meeting_id
|
||||
LEFT JOIN users u ON m.user_id = u.user_id
|
||||
WHERE 1=1 {status_condition}
|
||||
ORDER BY t.created_at DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
cursor.execute(query, (limit,))
|
||||
tasks.extend(cursor.fetchall())
|
||||
|
||||
# 知识库任务
|
||||
if task_type in ['all', 'knowledge_base']:
|
||||
query = f"""
|
||||
SELECT
|
||||
t.task_id,
|
||||
'knowledge_base' as task_type,
|
||||
t.kb_id as meeting_id,
|
||||
k.title as meeting_title,
|
||||
t.status,
|
||||
t.progress,
|
||||
t.error_message,
|
||||
t.created_at,
|
||||
t.updated_at,
|
||||
u.username as creator_name
|
||||
FROM knowledge_base_tasks t
|
||||
LEFT JOIN knowledge_bases k ON t.kb_id = k.kb_id
|
||||
LEFT JOIN users u ON k.creator_id = u.user_id
|
||||
WHERE 1=1 {status_condition}
|
||||
ORDER BY t.created_at DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
cursor.execute(query, (limit,))
|
||||
tasks.extend(cursor.fetchall())
|
||||
|
||||
# 按创建时间排序并限制返回数量
|
||||
tasks.sort(key=lambda x: x['created_at'], reverse=True)
|
||||
tasks = tasks[:limit]
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="获取任务监控数据成功",
|
||||
data={"tasks": tasks, "total": len(tasks)}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取任务监控数据失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return create_api_response(code="500", message=f"获取任务监控数据失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/admin/system/resources")
|
||||
async def get_system_resources(current_user=Depends(get_current_admin_user)):
|
||||
"""获取服务器资源使用情况"""
|
||||
try:
|
||||
import psutil
|
||||
|
||||
# CPU 使用率
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
cpu_count = psutil.cpu_count()
|
||||
|
||||
# 内存使用情况
|
||||
memory = psutil.virtual_memory()
|
||||
memory_total_gb = round(memory.total / BYTES_TO_GB, 2)
|
||||
memory_used_gb = round(memory.used / BYTES_TO_GB, 2)
|
||||
|
||||
# 磁盘使用情况
|
||||
disk = psutil.disk_usage('/')
|
||||
disk_total_gb = round(disk.total / BYTES_TO_GB, 2)
|
||||
disk_used_gb = round(disk.used / BYTES_TO_GB, 2)
|
||||
|
||||
resources = {
|
||||
"cpu": {
|
||||
"percent": cpu_percent,
|
||||
"count": cpu_count
|
||||
},
|
||||
"memory": {
|
||||
"total_gb": memory_total_gb,
|
||||
"used_gb": memory_used_gb,
|
||||
"percent": memory.percent
|
||||
},
|
||||
"disk": {
|
||||
"total_gb": disk_total_gb,
|
||||
"used_gb": disk_used_gb,
|
||||
"percent": disk.percent
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
return create_api_response(code="200", message="获取系统资源成功", data=resources)
|
||||
|
||||
except ImportError:
|
||||
return create_api_response(
|
||||
code="500",
|
||||
message="psutil 库未安装,请运行: pip install psutil"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"获取系统资源失败: {e}")
|
||||
return create_api_response(code="500", message=f"获取系统资源失败: {str(e)}")
|
||||
|
|
@ -0,0 +1,568 @@
|
|||
from fastapi import APIRouter, UploadFile, File, Form, Depends, HTTPException, BackgroundTasks
|
||||
from app.core.database import get_db_connection
|
||||
from app.core.config import BASE_DIR, AUDIO_DIR, TEMP_UPLOAD_DIR
|
||||
from app.core.auth import get_current_user
|
||||
from app.core.response import create_api_response
|
||||
from app.services.async_transcription_service import AsyncTranscriptionService
|
||||
from app.services.async_meeting_service import async_meeting_service
|
||||
from app.services.audio_service import handle_audio_upload
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
router = APIRouter()
|
||||
transcription_service = AsyncTranscriptionService()
|
||||
|
||||
# 临时上传目录 - 放在项目目录下
|
||||
TEMP_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 配置常量
|
||||
MAX_CHUNK_SIZE = 2 * 1024 * 1024 # 2MB per chunk
|
||||
MAX_TOTAL_SIZE = 500 * 1024 * 1024 # 500MB total
|
||||
MAX_DURATION = 3600 # 1 hour max recording
|
||||
SESSION_EXPIRE_HOURS = 1 # 会话1小时后过期
|
||||
|
||||
# 支持的音频格式
|
||||
SUPPORTED_MIME_TYPES = {
|
||||
'audio/webm;codecs=opus': '.webm',
|
||||
'audio/webm': '.webm',
|
||||
'audio/ogg;codecs=opus': '.ogg',
|
||||
'audio/mp4': '.m4a',
|
||||
'audio/mpeg': '.mp3'
|
||||
}
|
||||
|
||||
|
||||
# ============ Pydantic Models ============
|
||||
|
||||
class InitUploadRequest(BaseModel):
|
||||
meeting_id: int
|
||||
mime_type: str
|
||||
estimated_duration: Optional[int] = None # 预计时长(秒)
|
||||
|
||||
|
||||
class CompleteUploadRequest(BaseModel):
|
||||
session_id: str
|
||||
meeting_id: int
|
||||
total_chunks: int
|
||||
mime_type: str
|
||||
auto_transcribe: bool = True
|
||||
auto_summarize: bool = True
|
||||
prompt_id: Optional[int] = None # 提示词模版ID(可选)
|
||||
|
||||
|
||||
class CancelUploadRequest(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
# ============ 工具函数 ============
|
||||
|
||||
def validate_session_id(session_id: str) -> str:
|
||||
"""验证session_id格式,防止路径注入攻击"""
|
||||
if not re.match(r'^sess_\d+_[a-zA-Z0-9]+$', session_id):
|
||||
raise ValueError("Invalid session_id format")
|
||||
return session_id
|
||||
|
||||
|
||||
def validate_mime_type(mime_type: str) -> str:
|
||||
"""验证MIME类型是否支持"""
|
||||
if mime_type not in SUPPORTED_MIME_TYPES:
|
||||
raise ValueError(f"Unsupported MIME type: {mime_type}")
|
||||
return SUPPORTED_MIME_TYPES[mime_type]
|
||||
|
||||
|
||||
def get_session_dir(session_id: str) -> Path:
|
||||
"""获取会话目录路径"""
|
||||
validate_session_id(session_id)
|
||||
return TEMP_UPLOAD_DIR / session_id
|
||||
|
||||
|
||||
def get_session_metadata_path(session_id: str) -> Path:
|
||||
"""获取会话metadata文件路径"""
|
||||
return get_session_dir(session_id) / "metadata.json"
|
||||
|
||||
|
||||
def create_session_metadata(session_id: str, meeting_id: int, mime_type: str, user_id: int) -> dict:
|
||||
"""创建会话metadata"""
|
||||
now = datetime.now()
|
||||
expires_at = now + timedelta(hours=SESSION_EXPIRE_HOURS)
|
||||
|
||||
metadata = {
|
||||
"session_id": session_id,
|
||||
"meeting_id": meeting_id,
|
||||
"user_id": user_id,
|
||||
"mime_type": mime_type,
|
||||
"total_chunks": None,
|
||||
"received_chunks": [],
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": expires_at.isoformat()
|
||||
}
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def save_session_metadata(session_id: str, metadata: dict):
|
||||
"""保存会话metadata"""
|
||||
metadata_path = get_session_metadata_path(session_id)
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def load_session_metadata(session_id: str) -> dict:
|
||||
"""加载会话metadata"""
|
||||
metadata_path = get_session_metadata_path(session_id)
|
||||
if not metadata_path.exists():
|
||||
raise FileNotFoundError(f"Session {session_id} not found")
|
||||
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def update_session_chunks(session_id: str, chunk_index: int):
|
||||
"""更新已接收的分片列表"""
|
||||
metadata = load_session_metadata(session_id)
|
||||
|
||||
if chunk_index not in metadata['received_chunks']:
|
||||
metadata['received_chunks'].append(chunk_index)
|
||||
metadata['received_chunks'].sort()
|
||||
|
||||
save_session_metadata(session_id, metadata)
|
||||
|
||||
|
||||
def get_session_total_size(session_id: str) -> int:
|
||||
"""获取会话已上传的总大小"""
|
||||
session_dir = get_session_dir(session_id)
|
||||
total_size = 0
|
||||
|
||||
if session_dir.exists():
|
||||
for chunk_file in session_dir.glob("chunk_*.webm"):
|
||||
total_size += chunk_file.stat().st_size
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def merge_audio_chunks(session_id: str, meeting_id: int, total_chunks: int, mime_type: str) -> str:
|
||||
"""合并音频分片"""
|
||||
session_dir = get_session_dir(session_id)
|
||||
|
||||
# 1. 验证分片完整性
|
||||
missing = []
|
||||
for i in range(total_chunks):
|
||||
chunk_path = session_dir / f"chunk_{i:04d}.webm"
|
||||
if not chunk_path.exists():
|
||||
missing.append(i)
|
||||
|
||||
if missing:
|
||||
raise ValueError(f"Missing chunks: {missing}")
|
||||
|
||||
# 2. 创建输出目录
|
||||
meeting_audio_dir = AUDIO_DIR / str(meeting_id)
|
||||
meeting_audio_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. 生成输出文件名
|
||||
file_extension = validate_mime_type(mime_type)
|
||||
output_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
output_path = meeting_audio_dir / output_filename
|
||||
|
||||
# 4. 按序合并分片
|
||||
with open(output_path, 'wb') as outfile:
|
||||
for i in range(total_chunks):
|
||||
chunk_path = session_dir / f"chunk_{i:04d}.webm"
|
||||
with open(chunk_path, 'rb') as infile:
|
||||
outfile.write(infile.read())
|
||||
|
||||
# 5. 清理临时文件
|
||||
shutil.rmtree(session_dir)
|
||||
|
||||
# 返回相对路径
|
||||
return f"/{output_path.relative_to(BASE_DIR)}"
|
||||
|
||||
|
||||
def cleanup_session(session_id: str):
|
||||
"""清理会话文件"""
|
||||
session_dir = get_session_dir(session_id)
|
||||
if session_dir.exists():
|
||||
shutil.rmtree(session_dir)
|
||||
|
||||
|
||||
def cleanup_expired_sessions():
|
||||
"""清理过期的会话(可以由定时任务调用)"""
|
||||
now = datetime.now()
|
||||
cleaned_count = 0
|
||||
|
||||
if not TEMP_UPLOAD_DIR.exists():
|
||||
return cleaned_count
|
||||
|
||||
for session_dir in TEMP_UPLOAD_DIR.iterdir():
|
||||
if not session_dir.is_dir():
|
||||
continue
|
||||
|
||||
metadata_path = session_dir / "metadata.json"
|
||||
if metadata_path.exists():
|
||||
try:
|
||||
with open(metadata_path, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
expires_at = datetime.fromisoformat(metadata['expires_at'])
|
||||
if now > expires_at:
|
||||
shutil.rmtree(session_dir)
|
||||
cleaned_count += 1
|
||||
print(f"Cleaned up expired session: {session_dir.name}")
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up session {session_dir.name}: {e}")
|
||||
|
||||
return cleaned_count
|
||||
|
||||
|
||||
# ============ API Endpoints ============
|
||||
|
||||
@router.post("/audio/stream/init")
|
||||
async def init_upload_session(
|
||||
request: InitUploadRequest,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
初始化音频流式上传会话
|
||||
|
||||
创建临时目录,生成session_id,返回给客户端用于后续分片上传
|
||||
"""
|
||||
try:
|
||||
# 1. 验证会议是否存在且属于当前用户
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT user_id FROM meetings WHERE meeting_id = %s",
|
||||
(request.meeting_id,)
|
||||
)
|
||||
meeting = cursor.fetchone()
|
||||
|
||||
if not meeting:
|
||||
return create_api_response(
|
||||
code="404",
|
||||
message="会议不存在"
|
||||
)
|
||||
|
||||
if meeting['user_id'] != current_user['user_id']:
|
||||
return create_api_response(
|
||||
code="403",
|
||||
message="无权限操作此会议"
|
||||
)
|
||||
|
||||
# 2. 验证MIME类型
|
||||
try:
|
||||
validate_mime_type(request.mime_type)
|
||||
except ValueError as e:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message=str(e)
|
||||
)
|
||||
|
||||
# 3. 生成session_id
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
random_str = uuid.uuid4().hex[:8]
|
||||
session_id = f"sess_{timestamp}_{random_str}"
|
||||
|
||||
# 4. 创建会话目录
|
||||
session_dir = get_session_dir(session_id)
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 5. 创建并保存metadata
|
||||
metadata = create_session_metadata(
|
||||
session_id=session_id,
|
||||
meeting_id=request.meeting_id,
|
||||
mime_type=request.mime_type,
|
||||
user_id=current_user['user_id']
|
||||
)
|
||||
save_session_metadata(session_id, metadata)
|
||||
|
||||
# 6. 清理过期会话
|
||||
cleanup_expired_sessions()
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="上传会话初始化成功",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"chunk_size": MAX_CHUNK_SIZE,
|
||||
"max_chunks": 1000
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error initializing upload session: {e}")
|
||||
return create_api_response(
|
||||
code="500",
|
||||
message=f"初始化上传会话失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/audio/stream/chunk")
|
||||
async def upload_audio_chunk(
|
||||
session_id: str = Form(...),
|
||||
chunk_index: int = Form(...),
|
||||
chunk: UploadFile = File(...),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
上传音频分片
|
||||
|
||||
接收并保存音频分片文件
|
||||
"""
|
||||
try:
|
||||
# 1. 验证session_id格式
|
||||
try:
|
||||
validate_session_id(session_id)
|
||||
except ValueError:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message="Invalid session_id format"
|
||||
)
|
||||
|
||||
# 2. 加载session metadata
|
||||
try:
|
||||
metadata = load_session_metadata(session_id)
|
||||
except FileNotFoundError:
|
||||
return create_api_response(
|
||||
code="404",
|
||||
message="Session not found"
|
||||
)
|
||||
|
||||
# 3. 验证会话所有权
|
||||
if metadata['user_id'] != current_user['user_id']:
|
||||
return create_api_response(
|
||||
code="403",
|
||||
message="Permission denied"
|
||||
)
|
||||
|
||||
# 4. 验证分片大小
|
||||
chunk_data = await chunk.read()
|
||||
if len(chunk_data) > MAX_CHUNK_SIZE:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message=f"Chunk size exceeds {MAX_CHUNK_SIZE // (1024*1024)}MB limit"
|
||||
)
|
||||
|
||||
# 5. 验证总大小
|
||||
session_total = get_session_total_size(session_id)
|
||||
if session_total + len(chunk_data) > MAX_TOTAL_SIZE:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message=f"Total size exceeds {MAX_TOTAL_SIZE // (1024*1024)}MB limit"
|
||||
)
|
||||
|
||||
# 6. 保存分片文件
|
||||
session_dir = get_session_dir(session_id)
|
||||
chunk_path = session_dir / f"chunk_{chunk_index:04d}.webm"
|
||||
|
||||
with open(chunk_path, 'wb') as f:
|
||||
f.write(chunk_data)
|
||||
|
||||
# 7. 更新metadata
|
||||
update_session_chunks(session_id, chunk_index)
|
||||
|
||||
# 8. 获取已接收分片总数
|
||||
metadata = load_session_metadata(session_id)
|
||||
total_received = len(metadata['received_chunks'])
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="分片上传成功",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"chunk_index": chunk_index,
|
||||
"received": True,
|
||||
"total_received": total_received
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error uploading chunk: {e}")
|
||||
return create_api_response(
|
||||
code="500",
|
||||
message=f"分片上传失败: {str(e)}",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"chunk_index": chunk_index,
|
||||
"should_retry": True
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/audio/stream/complete")
|
||||
async def complete_upload(
|
||||
request: CompleteUploadRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
完成上传并合并分片
|
||||
|
||||
验证分片完整性,合并所有分片,保存最终音频文件,可选启动转录任务和自动总结
|
||||
"""
|
||||
try:
|
||||
# 1. 验证session_id
|
||||
try:
|
||||
validate_session_id(request.session_id)
|
||||
except ValueError:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message="Invalid session_id format"
|
||||
)
|
||||
|
||||
# 2. 加载session metadata
|
||||
try:
|
||||
metadata = load_session_metadata(request.session_id)
|
||||
except FileNotFoundError:
|
||||
return create_api_response(
|
||||
code="404",
|
||||
message="Session not found"
|
||||
)
|
||||
|
||||
# 3. 验证会话所有权
|
||||
if metadata['user_id'] != current_user['user_id']:
|
||||
return create_api_response(
|
||||
code="403",
|
||||
message="Permission denied"
|
||||
)
|
||||
|
||||
# 4. 验证会议ID一致性
|
||||
if metadata['meeting_id'] != request.meeting_id:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message="Meeting ID mismatch"
|
||||
)
|
||||
|
||||
# 5. 合并音频分片
|
||||
try:
|
||||
file_path = merge_audio_chunks(
|
||||
session_id=request.session_id,
|
||||
meeting_id=request.meeting_id,
|
||||
total_chunks=request.total_chunks,
|
||||
mime_type=request.mime_type
|
||||
)
|
||||
except ValueError as e:
|
||||
# 分片不完整
|
||||
return create_api_response(
|
||||
code="500",
|
||||
message=f"音频合并失败:{str(e)}",
|
||||
data={
|
||||
"should_retry": True
|
||||
}
|
||||
)
|
||||
|
||||
# 6. 获取文件信息
|
||||
full_path = BASE_DIR / file_path.lstrip('/')
|
||||
file_size = full_path.stat().st_size
|
||||
file_name = full_path.name
|
||||
|
||||
# 7. 调用 audio_service 处理文件(数据库更新、启动转录和总结)
|
||||
result = handle_audio_upload(
|
||||
file_path=file_path,
|
||||
file_name=file_name,
|
||||
file_size=file_size,
|
||||
meeting_id=request.meeting_id,
|
||||
current_user=current_user,
|
||||
auto_summarize=request.auto_summarize,
|
||||
background_tasks=background_tasks,
|
||||
prompt_id=request.prompt_id # 传递提示词模版ID
|
||||
)
|
||||
|
||||
# 如果处理失败,返回错误
|
||||
if not result["success"]:
|
||||
return result["response"]
|
||||
|
||||
# 8. 返回成功响应
|
||||
transcription_task_id = result["transcription_task_id"]
|
||||
message_suffix = ""
|
||||
if transcription_task_id:
|
||||
if request.auto_summarize:
|
||||
message_suffix = ",正在进行转录和总结"
|
||||
else:
|
||||
message_suffix = ",正在进行转录"
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="音频上传完成" + message_suffix,
|
||||
data={
|
||||
"meeting_id": request.meeting_id,
|
||||
"file_path": file_path,
|
||||
"file_size": file_size,
|
||||
"duration": None, # 可以通过ffprobe获取,但不是必需的
|
||||
"task_id": transcription_task_id,
|
||||
"task_status": "pending" if transcription_task_id else None,
|
||||
"auto_summarize": request.auto_summarize
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error completing upload: {e}")
|
||||
return create_api_response(
|
||||
code="500",
|
||||
message=f"完成上传失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/audio/stream/cancel")
|
||||
async def cancel_upload(
|
||||
request: CancelUploadRequest,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
取消上传会话
|
||||
|
||||
清理会话临时文件和目录
|
||||
"""
|
||||
try:
|
||||
# 1. 验证session_id
|
||||
try:
|
||||
validate_session_id(request.session_id)
|
||||
except ValueError:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message="Invalid session_id format"
|
||||
)
|
||||
|
||||
# 2. 加载session metadata(验证所有权)
|
||||
try:
|
||||
metadata = load_session_metadata(request.session_id)
|
||||
|
||||
# 验证会话所有权
|
||||
if metadata['user_id'] != current_user['user_id']:
|
||||
return create_api_response(
|
||||
code="403",
|
||||
message="Permission denied"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
# 会话不存在,视为已清理
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="上传会话已取消",
|
||||
data={
|
||||
"session_id": request.session_id,
|
||||
"cleaned": True
|
||||
}
|
||||
)
|
||||
|
||||
# 3. 清理会话文件
|
||||
cleanup_session(request.session_id)
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="上传会话已取消",
|
||||
data={
|
||||
"session_id": request.session_id,
|
||||
"cleaned": True
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error canceling upload: {e}")
|
||||
return create_api_response(
|
||||
code="500",
|
||||
message=f"取消上传失败: {str(e)}"
|
||||
)
|
||||
|
|
@ -12,7 +12,7 @@ from typing import Optional
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/downloads", response_model=dict)
|
||||
@router.get("/clients", response_model=dict)
|
||||
async def get_client_downloads(
|
||||
platform_type: Optional[str] = None,
|
||||
platform_name: Optional[str] = None,
|
||||
|
|
@ -81,7 +81,7 @@ async def get_client_downloads(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/downloads/latest", response_model=dict)
|
||||
@router.get("/clients/latest", response_model=dict)
|
||||
async def get_latest_clients():
|
||||
"""
|
||||
获取所有平台的最新版本客户端(公开接口,用于首页下载)
|
||||
|
|
@ -102,19 +102,23 @@ async def get_latest_clients():
|
|||
# 按平台类型分组
|
||||
mobile_clients = []
|
||||
desktop_clients = []
|
||||
terminal_clients = []
|
||||
|
||||
for client in clients:
|
||||
if client['platform_type'] == 'mobile':
|
||||
mobile_clients.append(client)
|
||||
else:
|
||||
elif client['platform_type'] == 'desktop':
|
||||
desktop_clients.append(client)
|
||||
elif client['platform_type'] == 'terminal':
|
||||
terminal_clients.append(client)
|
||||
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="获取成功",
|
||||
data={
|
||||
"mobile": mobile_clients,
|
||||
"desktop": desktop_clients
|
||||
"desktop": desktop_clients,
|
||||
"terminal": terminal_clients
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -125,10 +129,17 @@ async def get_latest_clients():
|
|||
)
|
||||
|
||||
|
||||
@router.get("/downloads/{platform_name}/latest", response_model=dict)
|
||||
async def get_latest_version_by_platform(platform_name: str):
|
||||
@router.get("/clients/latest/by-platform", response_model=dict)
|
||||
async def get_latest_version_by_platform_type_and_name(
|
||||
platform_type: str,
|
||||
platform_name: str
|
||||
):
|
||||
"""
|
||||
获取指定平台的最新版本(公开接口,用于客户端版本检查)
|
||||
通过平台类型和平台名称获取最新版本(公开接口,用于客户端版本检查)
|
||||
|
||||
参数:
|
||||
platform_type: 平台类型 (mobile, desktop, terminal)
|
||||
platform_name: 具体平台 (ios, android, windows, mac_intel, mac_m, linux, mcu)
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
|
|
@ -136,17 +147,20 @@ async def get_latest_version_by_platform(platform_name: str):
|
|||
|
||||
query = """
|
||||
SELECT * FROM client_downloads
|
||||
WHERE platform_name = %s AND is_active = TRUE AND is_latest = TRUE
|
||||
WHERE platform_type = %s
|
||||
AND platform_name = %s
|
||||
AND is_active = TRUE
|
||||
AND is_latest = TRUE
|
||||
LIMIT 1
|
||||
"""
|
||||
cursor.execute(query, (platform_name,))
|
||||
cursor.execute(query, (platform_type, platform_name))
|
||||
client = cursor.fetchone()
|
||||
cursor.close()
|
||||
|
||||
if not client:
|
||||
return create_api_response(
|
||||
code="404",
|
||||
message=f"未找到平台 {platform_name} 的客户端"
|
||||
message=f"未找到平台类型 {platform_type} 下的 {platform_name} 客户端"
|
||||
)
|
||||
|
||||
return create_api_response(
|
||||
|
|
@ -162,7 +176,7 @@ async def get_latest_version_by_platform(platform_name: str):
|
|||
)
|
||||
|
||||
|
||||
@router.get("/downloads/{id}", response_model=dict)
|
||||
@router.get("/clients/{id}", response_model=dict)
|
||||
async def get_client_download_by_id(id: int):
|
||||
"""
|
||||
获取指定ID的客户端详情(公开接口)
|
||||
|
|
@ -195,7 +209,7 @@ async def get_client_download_by_id(id: int):
|
|||
)
|
||||
|
||||
|
||||
@router.post("/downloads", response_model=dict)
|
||||
@router.post("/clients", response_model=dict)
|
||||
async def create_client_download(
|
||||
request: CreateClientDownloadRequest,
|
||||
current_user: dict = Depends(get_current_admin_user)
|
||||
|
|
@ -255,7 +269,7 @@ async def create_client_download(
|
|||
)
|
||||
|
||||
|
||||
@router.put("/downloads/{id}", response_model=dict)
|
||||
@router.put("/clients/{id}", response_model=dict)
|
||||
async def update_client_download(
|
||||
id: int,
|
||||
request: UpdateClientDownloadRequest,
|
||||
|
|
@ -353,7 +367,7 @@ async def update_client_download(
|
|||
)
|
||||
|
||||
|
||||
@router.delete("/downloads/{id}", response_model=dict)
|
||||
@router.delete("/clients/{id}", response_model=dict)
|
||||
async def delete_client_download(
|
||||
id: int,
|
||||
current_user: dict = Depends(get_current_admin_user)
|
||||
|
|
|
|||
|
|
@ -131,6 +131,7 @@ def create_knowledge_base(
|
|||
kb_id=kb_id,
|
||||
user_prompt=request.user_prompt,
|
||||
source_meeting_ids=request.source_meeting_ids,
|
||||
prompt_id=request.prompt_id, # 传递 prompt_id 参数
|
||||
cursor=cursor
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import app.core.config as config_module
|
|||
from app.services.llm_service import LLMService
|
||||
from app.services.async_transcription_service import AsyncTranscriptionService
|
||||
from app.services.async_meeting_service import async_meeting_service
|
||||
from app.services.audio_service import handle_audio_upload
|
||||
from app.core.auth import get_current_user
|
||||
from app.core.response import create_api_response
|
||||
from typing import List, Optional
|
||||
|
|
@ -25,194 +26,8 @@ transcription_service = AsyncTranscriptionService()
|
|||
|
||||
class GenerateSummaryRequest(BaseModel):
|
||||
user_prompt: Optional[str] = ""
|
||||
prompt_id: Optional[int] = None # 提示词模版ID,如果不指定则使用默认模版
|
||||
|
||||
def _handle_audio_upload(
|
||||
audio_file: UploadFile,
|
||||
meeting_id: int,
|
||||
force_replace_bool: bool,
|
||||
current_user: dict
|
||||
):
|
||||
"""
|
||||
音频上传的公共处理逻辑
|
||||
|
||||
Args:
|
||||
audio_file: 上传的音频文件
|
||||
meeting_id: 会议ID
|
||||
force_replace_bool: 是否强制替换
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"success": bool, # 是否成功
|
||||
"needs_confirmation": bool, # 是否需要用户确认
|
||||
"response": dict, # 如果需要返回,这里是响应数据
|
||||
"file_info": dict, # 文件信息 (成功时)
|
||||
"transcription_task_id": str, # 转录任务ID (成功时)
|
||||
"replaced_existing": bool, # 是否替换了现有文件 (成功时)
|
||||
"has_transcription": bool # 原来是否有转录记录 (成功时)
|
||||
}
|
||||
"""
|
||||
# 1. 文件类型验证
|
||||
file_extension = os.path.splitext(audio_file.filename)[1].lower()
|
||||
if file_extension not in ALLOWED_EXTENSIONS:
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(
|
||||
code="400",
|
||||
message=f"不支持的文件类型。支持的类型: {', '.join(ALLOWED_EXTENSIONS)}"
|
||||
)
|
||||
}
|
||||
|
||||
# 2. 文件大小验证
|
||||
max_file_size = getattr(config_module, 'MAX_FILE_SIZE', 100 * 1024 * 1024)
|
||||
if audio_file.size > max_file_size:
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(
|
||||
code="400",
|
||||
message=f"文件大小超过 {max_file_size // (1024 * 1024)}MB 限制"
|
||||
)
|
||||
}
|
||||
|
||||
# 3. 权限和已有文件检查
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 检查会议是否存在及权限
|
||||
cursor.execute("SELECT user_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
meeting = cursor.fetchone()
|
||||
if not meeting:
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(code="404", message="会议不存在")
|
||||
}
|
||||
if meeting['user_id'] != current_user['user_id']:
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(code="403", message="无权限操作此会议")
|
||||
}
|
||||
|
||||
# 检查已有音频文件
|
||||
cursor.execute(
|
||||
"SELECT file_name, file_path, upload_time FROM audio_files WHERE meeting_id = %s",
|
||||
(meeting_id,)
|
||||
)
|
||||
existing_info = cursor.fetchone()
|
||||
|
||||
# 检查是否有转录记录
|
||||
has_transcription = False
|
||||
if existing_info:
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) as segment_count FROM transcript_segments WHERE meeting_id = %s",
|
||||
(meeting_id,)
|
||||
)
|
||||
has_transcription = cursor.fetchone()['segment_count'] > 0
|
||||
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(code="500", message=f"检查已有文件失败: {str(e)}")
|
||||
}
|
||||
|
||||
# 4. 如果已有转录记录且未确认替换,返回提示
|
||||
if existing_info and has_transcription and not force_replace_bool:
|
||||
return {
|
||||
"success": False,
|
||||
"needs_confirmation": True,
|
||||
"response": create_api_response(
|
||||
code="300",
|
||||
message="该会议已有音频文件和转录记录,重新上传将删除现有的转录内容和会议总结",
|
||||
data={
|
||||
"requires_confirmation": True,
|
||||
"existing_file": {
|
||||
"file_name": existing_info['file_name'],
|
||||
"upload_time": existing_info['upload_time'].isoformat() if existing_info['upload_time'] else None
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
# 5. 保存音频文件
|
||||
meeting_dir = AUDIO_DIR / str(meeting_id)
|
||||
meeting_dir.mkdir(exist_ok=True)
|
||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
absolute_path = meeting_dir / unique_filename
|
||||
relative_path = absolute_path.relative_to(BASE_DIR)
|
||||
|
||||
try:
|
||||
with open(absolute_path, "wb") as buffer:
|
||||
shutil.copyfileobj(audio_file.file, buffer)
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(code="500", message=f"保存文件失败: {str(e)}")
|
||||
}
|
||||
|
||||
transcription_task_id = None
|
||||
replaced_existing = existing_info is not None
|
||||
|
||||
try:
|
||||
# 6. 更新数据库记录
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 删除旧的音频文件
|
||||
if replaced_existing and force_replace_bool:
|
||||
if existing_info and existing_info['file_path']:
|
||||
old_file_path = BASE_DIR / existing_info['file_path'].lstrip('/')
|
||||
if old_file_path.exists():
|
||||
try:
|
||||
os.remove(old_file_path)
|
||||
print(f"Deleted old audio file: {old_file_path}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to delete old file {old_file_path}: {e}")
|
||||
|
||||
# 更新或插入音频文件记录
|
||||
if replaced_existing:
|
||||
cursor.execute(
|
||||
'UPDATE audio_files SET file_name = %s, file_path = %s, file_size = %s, upload_time = NOW(), task_id = NULL WHERE meeting_id = %s',
|
||||
(audio_file.filename, '/' + str(relative_path), audio_file.size, meeting_id)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
'INSERT INTO audio_files (meeting_id, file_name, file_path, file_size, upload_time) VALUES (%s, %s, %s, %s, NOW())',
|
||||
(meeting_id, audio_file.filename, '/' + str(relative_path), audio_file.size)
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
|
||||
# 7. 启动转录任务
|
||||
try:
|
||||
transcription_task_id = transcription_service.start_transcription(meeting_id, '/' + str(relative_path))
|
||||
print(f"Transcription task {transcription_task_id} started for meeting {meeting_id}")
|
||||
except Exception as e:
|
||||
print(f"Failed to start transcription: {e}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# 出错时清理已上传的文件
|
||||
if os.path.exists(absolute_path):
|
||||
os.remove(absolute_path)
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(code="500", message=f"处理失败: {str(e)}")
|
||||
}
|
||||
|
||||
# 8. 返回成功结果
|
||||
return {
|
||||
"success": True,
|
||||
"file_info": {
|
||||
"file_name": audio_file.filename,
|
||||
"file_path": '/' + str(relative_path),
|
||||
"file_size": audio_file.size
|
||||
},
|
||||
"transcription_task_id": transcription_task_id,
|
||||
"replaced_existing": replaced_existing,
|
||||
"has_transcription": has_transcription
|
||||
}
|
||||
|
||||
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
|
||||
"""
|
||||
|
|
@ -559,8 +374,8 @@ def get_meeting_for_edit(meeting_id: int, current_user: dict = Depends(get_curre
|
|||
async def upload_audio(
|
||||
audio_file: UploadFile = File(...),
|
||||
meeting_id: int = Form(...),
|
||||
force_replace: str = Form("false"),
|
||||
auto_summarize: str = Form("true"),
|
||||
prompt_id: Optional[int] = Form(None), # 可选的提示词模版ID
|
||||
background_tasks: BackgroundTasks = None,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
|
|
@ -572,41 +387,84 @@ async def upload_audio(
|
|||
Args:
|
||||
audio_file: 音频文件
|
||||
meeting_id: 会议ID
|
||||
force_replace: 是否强制替换("true"/"false")
|
||||
auto_summarize: 是否自动生成总结("true"/"false",默认"true")
|
||||
prompt_id: 提示词模版ID(可选,如果不指定则使用默认模版)
|
||||
background_tasks: FastAPI后台任务
|
||||
current_user: 当前登录用户
|
||||
|
||||
Returns:
|
||||
HTTP 300: 需要用户确认(已有转录记录)
|
||||
HTTP 200: 处理成功,返回任务ID
|
||||
HTTP 400/403/404/500: 各种错误情况
|
||||
"""
|
||||
force_replace_bool = force_replace.lower() in ("true", "1", "yes")
|
||||
auto_summarize_bool = auto_summarize.lower() in ("true", "1", "yes")
|
||||
|
||||
# 调用公共处理方法
|
||||
result = _handle_audio_upload(audio_file, meeting_id, force_replace_bool, current_user)
|
||||
# 打印接收到的 prompt_id
|
||||
print(f"[Upload Audio] Meeting ID: {meeting_id}, Received prompt_id: {prompt_id}, Type: {type(prompt_id)}, Auto-summarize: {auto_summarize_bool}")
|
||||
|
||||
# 如果不成功,直接返回响应
|
||||
# 1. 文件类型验证
|
||||
file_extension = os.path.splitext(audio_file.filename)[1].lower()
|
||||
if file_extension not in ALLOWED_EXTENSIONS:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message=f"不支持的文件类型。支持的类型: {', '.join(ALLOWED_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
# 2. 文件大小验证
|
||||
max_file_size = getattr(config_module, 'MAX_FILE_SIZE', 100 * 1024 * 1024)
|
||||
if audio_file.size > max_file_size:
|
||||
return create_api_response(
|
||||
code="400",
|
||||
message=f"文件大小超过 {max_file_size // (1024 * 1024)}MB 限制"
|
||||
)
|
||||
|
||||
# 3. 保存音频文件到磁盘
|
||||
meeting_dir = AUDIO_DIR / str(meeting_id)
|
||||
meeting_dir.mkdir(exist_ok=True)
|
||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
absolute_path = meeting_dir / unique_filename
|
||||
relative_path = absolute_path.relative_to(BASE_DIR)
|
||||
|
||||
try:
|
||||
with open(absolute_path, "wb") as buffer:
|
||||
shutil.copyfileobj(audio_file.file, buffer)
|
||||
except Exception as e:
|
||||
return create_api_response(code="500", message=f"保存文件失败: {str(e)}")
|
||||
|
||||
file_path = '/' + str(relative_path)
|
||||
file_name = audio_file.filename
|
||||
file_size = audio_file.size
|
||||
|
||||
# 4. 调用 audio_service 处理文件(权限检查、数据库更新、启动转录)
|
||||
result = handle_audio_upload(
|
||||
file_path=file_path,
|
||||
file_name=file_name,
|
||||
file_size=file_size,
|
||||
meeting_id=meeting_id,
|
||||
current_user=current_user,
|
||||
auto_summarize=auto_summarize_bool,
|
||||
background_tasks=background_tasks,
|
||||
prompt_id=prompt_id # 传递 prompt_id 参数
|
||||
)
|
||||
|
||||
# 如果不成功,删除已保存的文件并返回错误
|
||||
if not result["success"]:
|
||||
if absolute_path.exists():
|
||||
try:
|
||||
os.remove(absolute_path)
|
||||
print(f"Deleted file due to processing error: {absolute_path}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to delete file {absolute_path}: {e}")
|
||||
return result["response"]
|
||||
|
||||
# 成功:根据auto_summarize参数决定是否添加监控任务
|
||||
# 5. 返回成功响应
|
||||
transcription_task_id = result["transcription_task_id"]
|
||||
if auto_summarize_bool and transcription_task_id:
|
||||
background_tasks.add_task(
|
||||
async_meeting_service.monitor_and_auto_summarize,
|
||||
meeting_id,
|
||||
transcription_task_id
|
||||
)
|
||||
print(f"[upload-audio] Auto-summarize enabled, monitor task added for meeting {meeting_id}")
|
||||
message_suffix = ",正在进行转录和总结"
|
||||
else:
|
||||
print(f"[upload-audio] Auto-summarize disabled for meeting {meeting_id}")
|
||||
message_suffix = ""
|
||||
message_suffix = ""
|
||||
if transcription_task_id:
|
||||
if auto_summarize_bool:
|
||||
message_suffix = ",正在进行转录和总结"
|
||||
else:
|
||||
message_suffix = ",正在进行转录"
|
||||
|
||||
# 返回成功响应
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="Audio file uploaded successfully" +
|
||||
|
|
@ -888,7 +746,8 @@ def generate_meeting_summary_async(meeting_id: int, request: GenerateSummaryRequ
|
|||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
return create_api_response(code="404", message="Meeting not found")
|
||||
task_id = async_meeting_service.start_summary_generation(meeting_id, request.user_prompt)
|
||||
# 传递 prompt_id 参数给服务层
|
||||
task_id = async_meeting_service.start_summary_generation(meeting_id, request.user_prompt, request.prompt_id)
|
||||
background_tasks.add_task(async_meeting_service._process_task, task_id)
|
||||
return create_api_response(code="200", message="Summary generation task has been accepted.", data={
|
||||
"task_id": task_id, "status": "pending", "meeting_id": meeting_id
|
||||
|
|
@ -1025,12 +884,14 @@ def get_meeting_preview_data(meeting_id: int):
|
|||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 检查会议是否存在
|
||||
# 检查会议是否存在,并获取模版信息
|
||||
query = '''
|
||||
SELECT m.meeting_id, m.title, m.meeting_time, m.summary, m.updated_at,
|
||||
m.user_id as creator_id, u.caption as creator_username
|
||||
SELECT m.meeting_id, m.title, m.meeting_time, m.summary, m.updated_at, m.prompt_id,
|
||||
m.user_id as creator_id, u.caption as creator_username,
|
||||
p.name as prompt_name
|
||||
FROM meetings m
|
||||
JOIN users u ON m.user_id = u.user_id
|
||||
LEFT JOIN prompts p ON m.prompt_id = p.id
|
||||
WHERE m.meeting_id = %s
|
||||
'''
|
||||
cursor.execute(query, (meeting_id,))
|
||||
|
|
@ -1056,6 +917,8 @@ def get_meeting_preview_data(meeting_id: int):
|
|||
"meeting_time": meeting['meeting_time'],
|
||||
"summary": meeting['summary'],
|
||||
"creator_username": meeting['creator_username'],
|
||||
"prompt_id": meeting['prompt_id'],
|
||||
"prompt_name": meeting['prompt_name'],
|
||||
"attendees": attendees,
|
||||
"attendees_count": len(attendees)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,11 +11,14 @@ router = APIRouter()
|
|||
# Pydantic Models
|
||||
class PromptIn(BaseModel):
|
||||
name: str
|
||||
tags: Optional[str] = ""
|
||||
task_type: str # 'MEETING_TASK' 或 'KNOWLEDGE_TASK'
|
||||
content: str
|
||||
is_default: bool = False
|
||||
is_active: bool = True
|
||||
|
||||
class PromptOut(PromptIn):
|
||||
id: int
|
||||
creator_id: int
|
||||
created_at: str
|
||||
|
||||
class PromptListResponse(BaseModel):
|
||||
|
|
@ -28,44 +31,105 @@ def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_use
|
|||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
try:
|
||||
# 如果设置为默认,需要先取消同类型其他提示词的默认状态
|
||||
if prompt.is_default:
|
||||
cursor.execute(
|
||||
"UPDATE prompts SET is_default = FALSE WHERE task_type = %s",
|
||||
(prompt.task_type,)
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"INSERT INTO prompts (name, tags, content, creator_id) VALUES (%s, %s, %s, %s)",
|
||||
(prompt.name, prompt.tags, prompt.content, current_user["user_id"])
|
||||
"""INSERT INTO prompts (name, task_type, content, is_default, is_active, creator_id)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)""",
|
||||
(prompt.name, prompt.task_type, prompt.content, prompt.is_default,
|
||||
prompt.is_active, current_user["user_id"])
|
||||
)
|
||||
connection.commit()
|
||||
new_id = cursor.lastrowid
|
||||
return create_api_response(code="200", message="提示词创建成功", data={"id": new_id, **prompt.dict()})
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="提示词创建成功",
|
||||
data={"id": new_id, **prompt.dict()}
|
||||
)
|
||||
except Exception as e:
|
||||
if "UNIQUE constraint failed" in str(e) or "Duplicate entry" in str(e):
|
||||
if "Duplicate entry" in str(e):
|
||||
return create_api_response(code="400", message="提示词名称已存在")
|
||||
return create_api_response(code="500", message=f"创建提示词失败: {e}")
|
||||
|
||||
@router.get("/prompts")
|
||||
def get_prompts(page: int = 1, size: int = 12, current_user: dict = Depends(get_current_user)):
|
||||
"""Get a paginated list of prompts filtered by current user."""
|
||||
@router.get("/prompts/active/{task_type}")
|
||||
def get_active_prompts(task_type: str, current_user: dict = Depends(get_current_user)):
|
||||
"""Get all active prompts for a specific task type."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
# 只获取当前用户创建的提示词
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) as total FROM prompts WHERE creator_id = %s",
|
||||
(current_user["user_id"],)
|
||||
"""SELECT id, name, is_default
|
||||
FROM prompts
|
||||
WHERE task_type = %s AND is_active = TRUE
|
||||
ORDER BY is_default DESC, created_at DESC""",
|
||||
(task_type,)
|
||||
)
|
||||
prompts = cursor.fetchall()
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="获取启用模版列表成功",
|
||||
data={"prompts": prompts}
|
||||
)
|
||||
|
||||
@router.get("/prompts")
|
||||
def get_prompts(
|
||||
task_type: Optional[str] = None,
|
||||
page: int = 1,
|
||||
size: int = 50,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get a paginated list of prompts filtered by current user and optionally by task_type."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 构建 WHERE 条件
|
||||
where_conditions = ["creator_id = %s"]
|
||||
params = [current_user["user_id"]]
|
||||
|
||||
if task_type:
|
||||
where_conditions.append("task_type = %s")
|
||||
params.append(task_type)
|
||||
|
||||
where_clause = " AND ".join(where_conditions)
|
||||
|
||||
# 获取总数
|
||||
cursor.execute(
|
||||
f"SELECT COUNT(*) as total FROM prompts WHERE {where_clause}",
|
||||
tuple(params)
|
||||
)
|
||||
total = cursor.fetchone()['total']
|
||||
|
||||
# 获取分页数据
|
||||
offset = (page - 1) * size
|
||||
cursor.execute(
|
||||
"SELECT id, name, tags, content, created_at FROM prompts WHERE creator_id = %s ORDER BY created_at DESC LIMIT %s OFFSET %s",
|
||||
(current_user["user_id"], size, offset)
|
||||
f"""SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at
|
||||
FROM prompts
|
||||
WHERE {where_clause}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s""",
|
||||
tuple(params + [size, offset])
|
||||
)
|
||||
prompts = cursor.fetchall()
|
||||
return create_api_response(code="200", message="获取提示词列表成功", data={"prompts": prompts, "total": total})
|
||||
return create_api_response(
|
||||
code="200",
|
||||
message="获取提示词列表成功",
|
||||
data={"prompts": prompts, "total": total}
|
||||
)
|
||||
|
||||
@router.get("/prompts/{prompt_id}")
|
||||
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""Get a single prompt by its ID."""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute("SELECT id, name, tags, content, created_at FROM prompts WHERE id = %s", (prompt_id,))
|
||||
cursor.execute(
|
||||
"""SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at
|
||||
FROM prompts WHERE id = %s""",
|
||||
(prompt_id,)
|
||||
)
|
||||
prompt = cursor.fetchone()
|
||||
if not prompt:
|
||||
return create_api_response(code="404", message="提示词不存在")
|
||||
|
|
@ -74,19 +138,50 @@ def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
|||
@router.put("/prompts/{prompt_id}")
|
||||
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
||||
"""Update an existing prompt."""
|
||||
print(f"[UPDATE PROMPT] prompt_id={prompt_id}, type={type(prompt_id)}")
|
||||
print(f"[UPDATE PROMPT] user_id={current_user['user_id']}")
|
||||
print(f"[UPDATE PROMPT] data: name={prompt.name}, task_type={prompt.task_type}, content_len={len(prompt.content)}, is_default={prompt.is_default}, is_active={prompt.is_active}")
|
||||
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
try:
|
||||
cursor.execute(
|
||||
"UPDATE prompts SET name = %s, tags = %s, content = %s WHERE id = %s",
|
||||
(prompt.name, prompt.tags, prompt.content, prompt_id)
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
# 先检查记录是否存在
|
||||
cursor.execute("SELECT id, creator_id FROM prompts WHERE id = %s", (prompt_id,))
|
||||
existing = cursor.fetchone()
|
||||
print(f"[UPDATE PROMPT] existing record: {existing}")
|
||||
|
||||
if not existing:
|
||||
print(f"[UPDATE PROMPT] Prompt {prompt_id} not found in database")
|
||||
return create_api_response(code="404", message="提示词不存在")
|
||||
|
||||
# 如果设置为默认,需要先取消同类型其他提示词的默认状态
|
||||
if prompt.is_default:
|
||||
print(f"[UPDATE PROMPT] Setting as default, clearing other defaults for task_type={prompt.task_type}")
|
||||
cursor.execute(
|
||||
"UPDATE prompts SET is_default = FALSE WHERE task_type = %s AND id != %s",
|
||||
(prompt.task_type, prompt_id)
|
||||
)
|
||||
print(f"[UPDATE PROMPT] Cleared {cursor.rowcount} other default prompts")
|
||||
|
||||
print(f"[UPDATE PROMPT] Executing UPDATE query")
|
||||
cursor.execute(
|
||||
"""UPDATE prompts
|
||||
SET name = %s, task_type = %s, content = %s, is_default = %s, is_active = %s
|
||||
WHERE id = %s""",
|
||||
(prompt.name, prompt.task_type, prompt.content, prompt.is_default,
|
||||
prompt.is_active, prompt_id)
|
||||
)
|
||||
rows_affected = cursor.rowcount
|
||||
print(f"[UPDATE PROMPT] UPDATE affected {rows_affected} rows (0 means no changes needed)")
|
||||
|
||||
# 注意:rowcount=0 不代表记录不存在,可能是所有字段值都相同
|
||||
# 我们已经在上面确认了记录存在,所以这里直接提交即可
|
||||
connection.commit()
|
||||
print(f"[UPDATE PROMPT] Success! Committed changes")
|
||||
return create_api_response(code="200", message="提示词更新成功")
|
||||
except Exception as e:
|
||||
if "UNIQUE constraint failed" in str(e) or "Duplicate entry" in str(e):
|
||||
print(f"[UPDATE PROMPT] Exception: {type(e).__name__}: {e}")
|
||||
if "Duplicate entry" in str(e):
|
||||
return create_api_response(code="400", message="提示词名称已存在")
|
||||
return create_api_response(code="500", message=f"更新提示词失败: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -147,23 +147,42 @@ def reset_password(user_id: int, current_user: dict = Depends(get_current_user))
|
|||
return create_api_response(code="200", message=f"用户 {user_id} 的密码已重置")
|
||||
|
||||
@router.get("/users")
|
||||
def get_all_users(page: int = 1, size: int = 10, role_id: Optional[int] = None, current_user: dict = Depends(get_current_user)):
|
||||
def get_all_users(
|
||||
page: int = 1,
|
||||
size: int = 10,
|
||||
role_id: Optional[int] = None,
|
||||
search: Optional[str] = None,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
count_query = "SELECT COUNT(*) as total FROM users"
|
||||
params = []
|
||||
|
||||
# 构建WHERE条件
|
||||
where_conditions = []
|
||||
count_params = []
|
||||
|
||||
if role_id is not None:
|
||||
count_query += " WHERE role_id = %s"
|
||||
params.append(role_id)
|
||||
|
||||
cursor.execute(count_query, tuple(params))
|
||||
where_conditions.append("role_id = %s")
|
||||
count_params.append(role_id)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
where_conditions.append("(username LIKE %s OR caption LIKE %s)")
|
||||
count_params.extend([search_pattern, search_pattern])
|
||||
|
||||
# 统计查询
|
||||
count_query = "SELECT COUNT(*) as total FROM users"
|
||||
if where_conditions:
|
||||
count_query += " WHERE " + " AND ".join(where_conditions)
|
||||
|
||||
cursor.execute(count_query, tuple(count_params))
|
||||
total = cursor.fetchone()['total']
|
||||
|
||||
|
||||
offset = (page - 1) * size
|
||||
|
||||
|
||||
# 主查询
|
||||
query = '''
|
||||
SELECT
|
||||
SELECT
|
||||
u.user_id, u.username, u.caption, u.email, u.created_at, u.role_id,
|
||||
r.role_name,
|
||||
(SELECT COUNT(*) FROM meetings WHERE user_id = u.user_id) as meetings_created,
|
||||
|
|
@ -171,24 +190,24 @@ def get_all_users(page: int = 1, size: int = 10, role_id: Optional[int] = None,
|
|||
FROM users u
|
||||
LEFT JOIN roles r ON u.role_id = r.role_id
|
||||
'''
|
||||
|
||||
|
||||
query_params = []
|
||||
if role_id is not None:
|
||||
query += " WHERE u.role_id = %s"
|
||||
query_params.append(role_id)
|
||||
|
||||
if where_conditions:
|
||||
query += " WHERE " + " AND ".join(where_conditions)
|
||||
query_params.extend(count_params)
|
||||
|
||||
query += '''
|
||||
ORDER BY u.user_id ASC
|
||||
LIMIT %s OFFSET %s
|
||||
'''
|
||||
|
||||
|
||||
query_params.extend([size, offset])
|
||||
|
||||
|
||||
cursor.execute(query, tuple(query_params))
|
||||
users = cursor.fetchall()
|
||||
|
||||
|
||||
user_list = [UserInfo(**user) for user in users]
|
||||
|
||||
|
||||
response_data = UserListResponse(users=user_list, total=total)
|
||||
return create_api_response(code="200", message="获取用户列表成功", data=response_data.dict())
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
BASE_DIR = Path(__file__).parent.parent.parent
|
||||
UPLOAD_DIR = BASE_DIR / "uploads"
|
||||
AUDIO_DIR = UPLOAD_DIR / "audio"
|
||||
TEMP_UPLOAD_DIR = UPLOAD_DIR / "temp_audio"
|
||||
MARKDOWN_DIR = UPLOAD_DIR / "markdown"
|
||||
VOICEPRINT_DIR = UPLOAD_DIR / "voiceprint"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,21 @@
|
|||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
# 无论从哪里运行,都能正确找到 app 模块
|
||||
current_file = Path(__file__).resolve()
|
||||
project_root = current_file.parent.parent # backend/
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base, client_downloads, voiceprint
|
||||
from app.api.endpoints import auth, users, meetings, tags, admin, admin_dashboard, tasks, prompts, knowledge_base, client_downloads, voiceprint, audio
|
||||
from app.core.config import UPLOAD_DIR, API_CONFIG
|
||||
from app.api.endpoints.admin import load_system_config
|
||||
import os
|
||||
|
||||
app = FastAPI(
|
||||
title="iMeeting API",
|
||||
|
|
@ -35,11 +45,13 @@ app.include_router(users.router, prefix="/api", tags=["Users"])
|
|||
app.include_router(meetings.router, prefix="/api", tags=["Meetings"])
|
||||
app.include_router(tags.router, prefix="/api", tags=["Tags"])
|
||||
app.include_router(admin.router, prefix="/api", tags=["Admin"])
|
||||
app.include_router(admin_dashboard.router, prefix="/api", tags=["AdminDashboard"])
|
||||
app.include_router(tasks.router, prefix="/api", tags=["Tasks"])
|
||||
app.include_router(prompts.router, prefix="/api", tags=["Prompts"])
|
||||
app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"])
|
||||
app.include_router(client_downloads.router, prefix="/api/clients", tags=["ClientDownloads"])
|
||||
app.include_router(client_downloads.router, prefix="/api", tags=["ClientDownloads"])
|
||||
app.include_router(voiceprint.router, prefix="/api", tags=["Voiceprint"])
|
||||
app.include_router(audio.router, prefix="/api", tags=["Audio"])
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
|
|
@ -57,10 +69,10 @@ def health_check():
|
|||
if __name__ == "__main__":
|
||||
# 简单的uvicorn配置,避免参数冲突
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host=API_CONFIG['host'],
|
||||
"app.main:app",
|
||||
host=API_CONFIG['host'],
|
||||
port=API_CONFIG['port'],
|
||||
limit_max_requests=1000,
|
||||
timeout_keep_alive=30,
|
||||
reload=True,
|
||||
)
|
||||
)
|
||||
|
|
@ -152,6 +152,7 @@ class CreateKnowledgeBaseRequest(BaseModel):
|
|||
user_prompt: Optional[str] = None
|
||||
source_meeting_ids: Optional[str] = None
|
||||
tags: Optional[str] = None
|
||||
prompt_id: Optional[int] = None # 提示词模版ID,如果不指定则使用默认模版
|
||||
|
||||
class UpdateKnowledgeBaseRequest(BaseModel):
|
||||
title: str
|
||||
|
|
@ -227,3 +228,30 @@ class VoiceprintTemplate(BaseModel):
|
|||
duration_seconds: int
|
||||
sample_rate: int
|
||||
channels: int
|
||||
|
||||
# 菜单权限相关模型
|
||||
class MenuInfo(BaseModel):
|
||||
menu_id: int
|
||||
menu_code: str
|
||||
menu_name: str
|
||||
menu_icon: Optional[str] = None
|
||||
menu_url: Optional[str] = None
|
||||
menu_type: str # 'action', 'link', 'divider'
|
||||
parent_id: Optional[int] = None
|
||||
sort_order: int
|
||||
is_active: bool
|
||||
description: Optional[str] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
class MenuListResponse(BaseModel):
|
||||
menus: List[MenuInfo]
|
||||
total: int
|
||||
|
||||
class RolePermissionInfo(BaseModel):
|
||||
role_id: int
|
||||
role_name: str
|
||||
menu_ids: List[int]
|
||||
|
||||
class UpdateRolePermissionsRequest(BaseModel):
|
||||
menu_ids: List[int]
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class AsyncKnowledgeBaseService:
|
|||
self.redis_client = redis.Redis(**REDIS_CONFIG)
|
||||
self.llm_service = LLMService()
|
||||
|
||||
def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str:
|
||||
def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], prompt_id: Optional[int] = None, cursor=None) -> str:
|
||||
"""
|
||||
创建异步知识库生成任务
|
||||
|
||||
|
|
@ -29,6 +29,7 @@ class AsyncKnowledgeBaseService:
|
|||
kb_id: 知识库ID
|
||||
user_prompt: 用户提示词
|
||||
source_meeting_ids: 源会议ID列表
|
||||
prompt_id: 提示词模版ID(可选,如果不指定则使用默认模版)
|
||||
cursor: 数据库游标(可选)
|
||||
|
||||
Returns:
|
||||
|
|
@ -39,13 +40,13 @@ class AsyncKnowledgeBaseService:
|
|||
# If a cursor is passed, use it directly to avoid creating a new transaction
|
||||
if cursor:
|
||||
query = """
|
||||
INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at)
|
||||
VALUES (%s, %s, %s, %s, NOW())
|
||||
INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, prompt_id, created_at)
|
||||
VALUES (%s, %s, %s, %s, %s, NOW())
|
||||
"""
|
||||
cursor.execute(query, (task_id, user_id, kb_id, user_prompt))
|
||||
cursor.execute(query, (task_id, user_id, kb_id, user_prompt, prompt_id))
|
||||
else:
|
||||
# Fallback to the old method if no cursor is provided
|
||||
self._save_task_to_db(task_id, user_id, kb_id, user_prompt)
|
||||
self._save_task_to_db(task_id, user_id, kb_id, user_prompt, prompt_id)
|
||||
|
||||
current_time = datetime.now().isoformat()
|
||||
task_data = {
|
||||
|
|
@ -53,6 +54,7 @@ class AsyncKnowledgeBaseService:
|
|||
'user_id': str(user_id),
|
||||
'kb_id': str(kb_id),
|
||||
'user_prompt': user_prompt if user_prompt else "",
|
||||
'prompt_id': str(prompt_id) if prompt_id else '',
|
||||
'status': 'pending',
|
||||
'progress': '0',
|
||||
'created_at': current_time,
|
||||
|
|
@ -61,7 +63,7 @@ class AsyncKnowledgeBaseService:
|
|||
self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data)
|
||||
self.redis_client.expire(f"kb_task:{task_id}", 86400)
|
||||
|
||||
print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}")
|
||||
print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}, prompt_id: {prompt_id}")
|
||||
return task_id
|
||||
|
||||
def _process_task(self, task_id: str):
|
||||
|
|
@ -78,6 +80,8 @@ class AsyncKnowledgeBaseService:
|
|||
|
||||
kb_id = int(task_data['kb_id'])
|
||||
user_prompt = task_data.get('user_prompt', '')
|
||||
prompt_id_str = task_data.get('prompt_id', '')
|
||||
prompt_id = int(prompt_id_str) if prompt_id_str and prompt_id_str != '' else None
|
||||
|
||||
# 1. 更新状态为processing
|
||||
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
|
||||
|
|
@ -88,7 +92,7 @@ class AsyncKnowledgeBaseService:
|
|||
|
||||
# 3. 构建提示词
|
||||
self._update_task_status_in_redis(task_id, 'processing', 30, message="准备AI提示词...")
|
||||
full_prompt = self._build_prompt(source_text, user_prompt)
|
||||
full_prompt = self._build_prompt(source_text, user_prompt, prompt_id)
|
||||
|
||||
# 4. 调用LLM API
|
||||
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
|
||||
|
|
@ -98,7 +102,7 @@ class AsyncKnowledgeBaseService:
|
|||
|
||||
# 5. 保存结果到数据库
|
||||
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
|
||||
self._save_result_to_db(kb_id, generated_content)
|
||||
self._save_result_to_db(kb_id, generated_content, prompt_id)
|
||||
|
||||
# 6. 任务完成
|
||||
self._update_task_in_db(task_id, 'completed', 100)
|
||||
|
|
@ -156,7 +160,7 @@ class AsyncKnowledgeBaseService:
|
|||
print(f"获取会议总结错误: {e}")
|
||||
return ""
|
||||
|
||||
def _build_prompt(self, source_text: str, user_prompt: str) -> str:
|
||||
def _build_prompt(self, source_text: str, user_prompt: str, prompt_id: Optional[int] = None) -> str:
|
||||
"""
|
||||
构建完整的提示词
|
||||
使用数据库中配置的KNOWLEDGE_TASK提示词模板
|
||||
|
|
@ -164,12 +168,13 @@ class AsyncKnowledgeBaseService:
|
|||
Args:
|
||||
source_text: 源会议总结文本
|
||||
user_prompt: 用户自定义提示词
|
||||
prompt_id: 提示词模版ID(可选,如果不指定则使用默认模版)
|
||||
|
||||
Returns:
|
||||
str: 完整的提示词
|
||||
"""
|
||||
# 从数据库获取知识库任务的提示词模板
|
||||
system_prompt = self.llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||||
# 从数据库获取知识库任务的提示词模板(支持指定prompt_id)
|
||||
system_prompt = self.llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id=prompt_id)
|
||||
|
||||
prompt = f"{system_prompt}\n\n"
|
||||
|
||||
|
|
@ -180,13 +185,14 @@ class AsyncKnowledgeBaseService:
|
|||
|
||||
return prompt
|
||||
|
||||
def _save_result_to_db(self, kb_id: int, content: str) -> Optional[int]:
|
||||
def _save_result_to_db(self, kb_id: int, content: str, prompt_id: Optional[int] = None) -> Optional[int]:
|
||||
"""
|
||||
保存生成结果到数据库
|
||||
|
||||
Args:
|
||||
kb_id: 知识库ID
|
||||
content: 生成的内容
|
||||
prompt_id: 提示词模版ID
|
||||
|
||||
Returns:
|
||||
Optional[int]: 知识库ID,失败返回None
|
||||
|
|
@ -194,11 +200,11 @@ class AsyncKnowledgeBaseService:
|
|||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s"
|
||||
cursor.execute(query, (content, kb_id))
|
||||
query = "UPDATE knowledge_bases SET content = %s, prompt_id = %s, updated_at = NOW() WHERE kb_id = %s"
|
||||
cursor.execute(query, (content, prompt_id, kb_id))
|
||||
connection.commit()
|
||||
|
||||
print(f"成功保存知识库内容,kb_id: {kb_id}")
|
||||
print(f"成功保存知识库内容,kb_id: {kb_id}, prompt_id: {prompt_id}")
|
||||
return kb_id
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -243,13 +249,31 @@ class AsyncKnowledgeBaseService:
|
|||
except Exception as e:
|
||||
print(f"Error updating task status in Redis: {e}")
|
||||
|
||||
def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str):
|
||||
"""保存任务到数据库"""
|
||||
def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str, prompt_id: Optional[int] = None):
|
||||
"""保存任务到数据库
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
user_id: 用户ID
|
||||
kb_id: 知识库ID
|
||||
user_prompt: 用户提示词
|
||||
prompt_id: 提示词模版ID(可选),如果为None则使用默认模版
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
insert_query = "INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, %s, 'pending', 0, NOW())"
|
||||
cursor.execute(insert_query, (task_id, user_id, kb_id, user_prompt))
|
||||
|
||||
# 如果没有指定 prompt_id,获取默认的知识库总结模版ID
|
||||
if prompt_id is None:
|
||||
cursor.execute(
|
||||
"SELECT id FROM prompts WHERE task_type = 'KNOWLEDGE_TASK' AND is_default = TRUE AND is_active = TRUE LIMIT 1"
|
||||
)
|
||||
default_prompt = cursor.fetchone()
|
||||
if default_prompt:
|
||||
prompt_id = default_prompt[0]
|
||||
|
||||
insert_query = "INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, prompt_id, status, progress, created_at) VALUES (%s, %s, %s, %s, %s, 'pending', 0, NOW())"
|
||||
cursor.execute(insert_query, (task_id, user_id, kb_id, user_prompt, prompt_id))
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
print(f"Error saving task to database: {e}")
|
||||
|
|
|
|||
|
|
@ -23,22 +23,24 @@ class AsyncMeetingService:
|
|||
self.redis_client = redis.Redis(**REDIS_CONFIG)
|
||||
self.llm_service = LLMService() # 复用现有的同步LLM服务
|
||||
|
||||
def start_summary_generation(self, meeting_id: int, user_prompt: str = "") -> str:
|
||||
def start_summary_generation(self, meeting_id: int, user_prompt: str = "", prompt_id: Optional[int] = None) -> str:
|
||||
"""
|
||||
创建异步总结任务,任务的执行将由外部(如API层的BackgroundTasks)触发。
|
||||
|
||||
Args:
|
||||
meeting_id: 会议ID
|
||||
user_prompt: 用户额外提示词
|
||||
prompt_id: 可选的提示词模版ID,如果不指定则使用默认模版
|
||||
|
||||
Returns:
|
||||
str: 任务ID
|
||||
"""
|
||||
|
||||
try:
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 在数据库中创建任务记录
|
||||
self._save_task_to_db(task_id, meeting_id, user_prompt)
|
||||
self._save_task_to_db(task_id, meeting_id, user_prompt, prompt_id)
|
||||
|
||||
# 将任务详情存入Redis,用于快速查询状态
|
||||
current_time = datetime.now().isoformat()
|
||||
|
|
@ -46,6 +48,7 @@ class AsyncMeetingService:
|
|||
'task_id': task_id,
|
||||
'meeting_id': str(meeting_id),
|
||||
'user_prompt': user_prompt,
|
||||
'prompt_id': str(prompt_id) if prompt_id else '',
|
||||
'status': 'pending',
|
||||
'progress': '0',
|
||||
'created_at': current_time,
|
||||
|
|
@ -54,7 +57,6 @@ class AsyncMeetingService:
|
|||
self.redis_client.hset(f"llm_task:{task_id}", mapping=task_data)
|
||||
self.redis_client.expire(f"llm_task:{task_id}", 86400)
|
||||
|
||||
print(f"Meeting summary task created: {task_id} for meeting: {meeting_id}")
|
||||
return task_id
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -75,6 +77,8 @@ class AsyncMeetingService:
|
|||
|
||||
meeting_id = int(task_data['meeting_id'])
|
||||
user_prompt = task_data.get('user_prompt', '')
|
||||
prompt_id_str = task_data.get('prompt_id', '')
|
||||
prompt_id = int(prompt_id_str) if prompt_id_str and prompt_id_str != '' else None
|
||||
|
||||
# 1. 更新状态为processing
|
||||
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
|
||||
|
|
@ -87,7 +91,7 @@ class AsyncMeetingService:
|
|||
|
||||
# 3. 构建提示词
|
||||
self._update_task_status_in_redis(task_id, 'processing', 40, message="准备AI提示词...")
|
||||
full_prompt = self._build_prompt(transcript_text, user_prompt)
|
||||
full_prompt = self._build_prompt(transcript_text, user_prompt, prompt_id)
|
||||
|
||||
# 4. 调用LLM API
|
||||
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在分析会议内容...")
|
||||
|
|
@ -97,7 +101,7 @@ class AsyncMeetingService:
|
|||
|
||||
# 5. 保存结果到主表
|
||||
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存总结结果...")
|
||||
self._save_summary_to_db(meeting_id, summary_content, user_prompt)
|
||||
self._save_summary_to_db(meeting_id, summary_content, user_prompt, prompt_id)
|
||||
|
||||
# 6. 任务完成
|
||||
self._update_task_in_db(task_id, 'completed', 100, result=summary_content)
|
||||
|
|
@ -111,7 +115,7 @@ class AsyncMeetingService:
|
|||
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg)
|
||||
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
|
||||
|
||||
def monitor_and_auto_summarize(self, meeting_id: int, transcription_task_id: str):
|
||||
def monitor_and_auto_summarize(self, meeting_id: int, transcription_task_id: str, prompt_id: Optional[int] = None):
|
||||
"""
|
||||
监控转录任务,完成后自动生成总结
|
||||
此方法设计为由BackgroundTasks调用,在后台运行
|
||||
|
|
@ -119,13 +123,14 @@ class AsyncMeetingService:
|
|||
Args:
|
||||
meeting_id: 会议ID
|
||||
transcription_task_id: 转录任务ID
|
||||
prompt_id: 提示词模版ID(可选,如果不指定则使用默认模版)
|
||||
|
||||
流程:
|
||||
1. 循环轮询转录任务状态
|
||||
2. 转录成功后自动启动总结任务
|
||||
3. 转录失败或超时则停止轮询并记录日志
|
||||
"""
|
||||
print(f"[Monitor] Started monitoring transcription task {transcription_task_id} for meeting {meeting_id}")
|
||||
print(f"[Monitor] Started monitoring transcription task {transcription_task_id} for meeting {meeting_id}, prompt_id: {prompt_id}")
|
||||
|
||||
# 获取配置参数
|
||||
poll_interval = TRANSCRIPTION_POLL_CONFIG['poll_interval']
|
||||
|
|
@ -156,7 +161,7 @@ class AsyncMeetingService:
|
|||
|
||||
# 启动总结任务
|
||||
try:
|
||||
summary_task_id = self.start_summary_generation(meeting_id, user_prompt="")
|
||||
summary_task_id = self.start_summary_generation(meeting_id, user_prompt="", prompt_id=prompt_id)
|
||||
print(f"[Monitor] Summary task {summary_task_id} started for meeting {meeting_id}")
|
||||
|
||||
# 在后台执行总结任务
|
||||
|
|
@ -231,13 +236,18 @@ class AsyncMeetingService:
|
|||
print(f"获取会议转录内容错误: {e}")
|
||||
return ""
|
||||
|
||||
def _build_prompt(self, transcript_text: str, user_prompt: str) -> str:
|
||||
def _build_prompt(self, transcript_text: str, user_prompt: str, prompt_id: Optional[int] = None) -> str:
|
||||
"""
|
||||
构建完整的提示词
|
||||
使用数据库中配置的MEETING_TASK提示词模板
|
||||
|
||||
Args:
|
||||
transcript_text: 会议转录文本
|
||||
user_prompt: 用户额外提示词
|
||||
prompt_id: 可选的提示词模版ID,如果不指定则使用默认模版
|
||||
"""
|
||||
# 从数据库获取会议任务的提示词模板
|
||||
system_prompt = self.llm_service.get_task_prompt('MEETING_TASK')
|
||||
# 从数据库获取会议任务的提示词模板(支持指定prompt_id)
|
||||
system_prompt = self.llm_service.get_task_prompt('MEETING_TASK', prompt_id=prompt_id)
|
||||
|
||||
prompt = f"{system_prompt}\n\n"
|
||||
|
||||
|
|
@ -248,22 +258,22 @@ class AsyncMeetingService:
|
|||
|
||||
return prompt
|
||||
|
||||
def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str) -> Optional[int]:
|
||||
"""保存总结到数据库 - 更新meetings表的summary、user_prompt和updated_at字段"""
|
||||
def _save_summary_to_db(self, meeting_id: int, summary_content: str, user_prompt: str, prompt_id: Optional[int] = None) -> Optional[int]:
|
||||
"""保存总结到数据库 - 更新meetings表的summary、user_prompt、prompt_id和updated_at字段"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
|
||||
# 更新meetings表的summary、user_prompt和updated_at字段
|
||||
# 更新meetings表的summary、user_prompt、prompt_id和updated_at字段
|
||||
update_query = """
|
||||
UPDATE meetings
|
||||
SET summary = %s, user_prompt = %s, updated_at = NOW()
|
||||
SET summary = %s, user_prompt = %s, prompt_id = %s, updated_at = NOW()
|
||||
WHERE meeting_id = %s
|
||||
"""
|
||||
cursor.execute(update_query, (summary_content, user_prompt, meeting_id))
|
||||
cursor.execute(update_query, (summary_content, user_prompt, prompt_id, meeting_id))
|
||||
connection.commit()
|
||||
|
||||
print(f"成功保存会议总结到meetings表,meeting_id: {meeting_id}")
|
||||
print(f"成功保存会议总结到meetings表,meeting_id: {meeting_id}, prompt_id: {prompt_id}")
|
||||
return meeting_id
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -326,14 +336,39 @@ class AsyncMeetingService:
|
|||
except Exception as e:
|
||||
print(f"Error updating task status in Redis: {e}")
|
||||
|
||||
def _save_task_to_db(self, task_id: str, meeting_id: int, user_prompt: str):
|
||||
"""保存任务到数据库"""
|
||||
def _save_task_to_db(self, task_id: str, meeting_id: int, user_prompt: str, prompt_id: Optional[int] = None):
|
||||
"""保存任务到数据库
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
meeting_id: 会议ID
|
||||
user_prompt: 用户额外提示词
|
||||
prompt_id: 可选的提示词模版ID,如果为None则使用默认模版
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
insert_query = "INSERT INTO llm_tasks (task_id, meeting_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, 'pending', 0, NOW())"
|
||||
cursor.execute(insert_query, (task_id, meeting_id, user_prompt))
|
||||
|
||||
# 如果没有指定 prompt_id,获取默认的会议总结模版ID
|
||||
if prompt_id is None:
|
||||
print(f"[Meeting Service] prompt_id is None, fetching default template for MEETING_TASK")
|
||||
cursor.execute(
|
||||
"SELECT id FROM prompts WHERE task_type = 'MEETING_TASK' AND is_default = TRUE AND is_active = TRUE LIMIT 1"
|
||||
)
|
||||
default_prompt = cursor.fetchone()
|
||||
if default_prompt:
|
||||
prompt_id = default_prompt[0]
|
||||
print(f"[Meeting Service] Found default template ID: {prompt_id}")
|
||||
else:
|
||||
print(f"[Meeting Service] WARNING: No default template found for MEETING_TASK!")
|
||||
else:
|
||||
print(f"[Meeting Service] Using provided prompt_id: {prompt_id}")
|
||||
|
||||
print(f"[Meeting Service] Inserting task into llm_tasks - task_id: {task_id}, meeting_id: {meeting_id}, prompt_id: {prompt_id}")
|
||||
insert_query = "INSERT INTO llm_tasks (task_id, meeting_id, user_prompt, prompt_id, status, progress, created_at) VALUES (%s, %s, %s, %s, 'pending', 0, NOW())"
|
||||
cursor.execute(insert_query, (task_id, meeting_id, user_prompt, prompt_id))
|
||||
connection.commit()
|
||||
print(f"[Meeting Service] Task saved successfully to database")
|
||||
except Exception as e:
|
||||
print(f"Error saving task to database: {e}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -0,0 +1,172 @@
|
|||
"""
|
||||
音频处理服务
|
||||
|
||||
处理已保存的完整音频文件:数据库更新、转录、自动总结
|
||||
"""
|
||||
from fastapi import BackgroundTasks
|
||||
from app.core.database import get_db_connection
|
||||
from app.core.response import create_api_response
|
||||
from app.core.config import BASE_DIR
|
||||
from app.services.async_transcription_service import AsyncTranscriptionService
|
||||
from app.services.async_meeting_service import async_meeting_service
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
|
||||
transcription_service = AsyncTranscriptionService()
|
||||
|
||||
|
||||
def handle_audio_upload(
|
||||
file_path: str,
|
||||
file_name: str,
|
||||
file_size: int,
|
||||
meeting_id: int,
|
||||
current_user: dict,
|
||||
auto_summarize: bool = True,
|
||||
background_tasks: BackgroundTasks = None,
|
||||
prompt_id: int = None
|
||||
) -> dict:
|
||||
"""
|
||||
处理已保存的完整音频文件
|
||||
|
||||
职责:
|
||||
1. 权限检查
|
||||
2. 检查已有文件和转录记录
|
||||
3. 更新数据库(audio_files 表)
|
||||
4. 启动转录任务
|
||||
5. 可选启动自动总结监控
|
||||
|
||||
Args:
|
||||
file_path: 已保存的文件路径(相对于 BASE_DIR 的路径,如 /uploads/audio/123/xxx.webm)
|
||||
file_name: 原始文件名
|
||||
file_size: 文件大小(字节)
|
||||
meeting_id: 会议ID
|
||||
current_user: 当前用户信息
|
||||
auto_summarize: 是否自动生成总结(默认True)
|
||||
background_tasks: FastAPI 后台任务对象
|
||||
prompt_id: 提示词模版ID(可选,如果不指定则使用默认模版)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"success": bool, # 是否成功
|
||||
"response": dict, # 如果需要返回,这里是响应数据
|
||||
"file_info": dict, # 文件信息 (成功时)
|
||||
"transcription_task_id": str, # 转录任务ID (成功时)
|
||||
"replaced_existing": bool, # 是否替换了现有文件 (成功时)
|
||||
"has_transcription": bool # 原来是否有转录记录 (成功时)
|
||||
}
|
||||
"""
|
||||
print(f"[Audio Service] handle_audio_upload called - Meeting ID: {meeting_id}, Auto-summarize: {auto_summarize}, Received prompt_id: {prompt_id}, Type: {type(prompt_id)}")
|
||||
|
||||
# 1. 权限和已有文件检查
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 检查会议是否存在及权限
|
||||
cursor.execute("SELECT user_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
meeting = cursor.fetchone()
|
||||
if not meeting:
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(code="404", message="会议不存在")
|
||||
}
|
||||
if meeting['user_id'] != current_user['user_id']:
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(code="403", message="无权限操作此会议")
|
||||
}
|
||||
|
||||
# 检查已有音频文件
|
||||
cursor.execute(
|
||||
"SELECT file_name, file_path, upload_time FROM audio_files WHERE meeting_id = %s",
|
||||
(meeting_id,)
|
||||
)
|
||||
existing_info = cursor.fetchone()
|
||||
|
||||
# 检查是否有转录记录
|
||||
has_transcription = False
|
||||
if existing_info:
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) as segment_count FROM transcript_segments WHERE meeting_id = %s",
|
||||
(meeting_id,)
|
||||
)
|
||||
has_transcription = cursor.fetchone()['segment_count'] > 0
|
||||
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(code="500", message=f"检查已有文件失败: {str(e)}")
|
||||
}
|
||||
|
||||
# 2. 删除旧的音频文件(如果存在)
|
||||
replaced_existing = existing_info is not None
|
||||
if replaced_existing and existing_info['file_path']:
|
||||
old_file_path = BASE_DIR / existing_info['file_path'].lstrip('/')
|
||||
if old_file_path.exists():
|
||||
try:
|
||||
os.remove(old_file_path)
|
||||
print(f"Deleted old audio file: {old_file_path}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to delete old file {old_file_path}: {e}")
|
||||
|
||||
transcription_task_id = None
|
||||
|
||||
try:
|
||||
# 3. 更新数据库记录
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
if replaced_existing:
|
||||
cursor.execute(
|
||||
'UPDATE audio_files SET file_name = %s, file_path = %s, file_size = %s, upload_time = NOW(), task_id = NULL WHERE meeting_id = %s',
|
||||
(file_name, file_path, file_size, meeting_id)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
'INSERT INTO audio_files (meeting_id, file_name, file_path, file_size, upload_time) VALUES (%s, %s, %s, %s, NOW())',
|
||||
(meeting_id, file_name, file_path, file_size)
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
|
||||
# 4. 启动转录任务
|
||||
try:
|
||||
transcription_task_id = transcription_service.start_transcription(meeting_id, file_path)
|
||||
print(f"Transcription task {transcription_task_id} started for meeting {meeting_id}")
|
||||
|
||||
# 5. 如果启用自动总结且提供了 background_tasks,添加监控任务
|
||||
if auto_summarize and transcription_task_id and background_tasks:
|
||||
background_tasks.add_task(
|
||||
async_meeting_service.monitor_and_auto_summarize,
|
||||
meeting_id,
|
||||
transcription_task_id,
|
||||
prompt_id # 传递 prompt_id 给自动总结监控任务
|
||||
)
|
||||
print(f"[audio_service] Auto-summarize enabled, monitor task added for meeting {meeting_id}, prompt_id: {prompt_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to start transcription: {e}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# 出错时的处理(文件已保存,不删除)
|
||||
return {
|
||||
"success": False,
|
||||
"response": create_api_response(code="500", message=f"处理失败: {str(e)}")
|
||||
}
|
||||
|
||||
# 6. 返回成功结果
|
||||
return {
|
||||
"success": True,
|
||||
"file_info": {
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"file_size": file_size
|
||||
},
|
||||
"transcription_task_id": transcription_task_id,
|
||||
"replaced_existing": replaced_existing,
|
||||
"has_transcription": has_transcription
|
||||
}
|
||||
|
|
@ -38,39 +38,54 @@ class LLMService:
|
|||
"""动态获取top_p"""
|
||||
return config_module.LLM_CONFIG["top_p"]
|
||||
|
||||
def get_task_prompt(self, task_name: str, cursor=None) -> str:
|
||||
def get_task_prompt(self, task_type: str, cursor=None, prompt_id: Optional[int] = None) -> str:
|
||||
"""
|
||||
统一的提示词获取方法
|
||||
|
||||
Args:
|
||||
task_name: 任务名称,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等
|
||||
task_type: 任务类型,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等
|
||||
cursor: 数据库游标,如果传入则使用,否则创建新连接
|
||||
prompt_id: 可选的提示词ID,如果指定则使用该提示词,否则使用默认提示词
|
||||
|
||||
Returns:
|
||||
str: 提示词内容,如果未找到返回默认提示词
|
||||
"""
|
||||
query = """
|
||||
SELECT p.content
|
||||
FROM prompt_config pc
|
||||
JOIN prompts p ON pc.prompt_id = p.id
|
||||
WHERE pc.task_name = %s
|
||||
"""
|
||||
# 如果指定了 prompt_id,直接获取该提示词
|
||||
if prompt_id:
|
||||
query = """
|
||||
SELECT content
|
||||
FROM prompts
|
||||
WHERE id = %s AND task_type = %s AND is_active = TRUE
|
||||
LIMIT 1
|
||||
"""
|
||||
params = (prompt_id, task_type)
|
||||
else:
|
||||
# 否则获取默认提示词
|
||||
query = """
|
||||
SELECT content
|
||||
FROM prompts
|
||||
WHERE task_type = %s
|
||||
AND is_default = TRUE
|
||||
AND is_active = TRUE
|
||||
LIMIT 1
|
||||
"""
|
||||
params = (task_type,)
|
||||
|
||||
if cursor:
|
||||
cursor.execute(query, (task_name,))
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result['content'] if isinstance(result, dict) else result[0]
|
||||
else:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute(query, (task_name,))
|
||||
cursor.execute(query, params)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result['content']
|
||||
|
||||
# 返回默认提示词
|
||||
return self._get_default_prompt(task_name)
|
||||
return self._get_default_prompt(task_type)
|
||||
|
||||
def _get_default_prompt(self, task_name: str) -> str:
|
||||
"""获取默认提示词"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,201 @@
|
|||
# 客户端管理 - 专用终端类型添加说明
|
||||
|
||||
## 概述
|
||||
|
||||
本次更新在客户端管理系统中添加了"专用终端"(terminal)大类型,支持 Android 专用终端和单片机(MCU)平台。
|
||||
|
||||
## 数据库变更
|
||||
|
||||
### 1. 修改表结构
|
||||
|
||||
执行 SQL 文件:`add_dedicated_terminal.sql`
|
||||
|
||||
```bash
|
||||
mysql -u [username] -p [database_name] < backend/sql/add_dedicated_terminal.sql
|
||||
```
|
||||
|
||||
**变更内容:**
|
||||
- 修改 `client_downloads` 表的 `platform_type` 枚举,添加 `terminal` 类型
|
||||
- 插入两条示例数据:
|
||||
- Android 专用终端(platform_type: `terminal`, platform_name: `android`)
|
||||
- 单片机固件(platform_type: `terminal`, platform_name: `mcu`)
|
||||
|
||||
### 2. 新的平台类型
|
||||
|
||||
| platform_type | platform_name | 说明 |
|
||||
|--------------|--------------|------|
|
||||
| terminal | android | Android 专用终端 |
|
||||
| terminal | mcu | 单片机(MCU)固件 |
|
||||
|
||||
## API 接口变更
|
||||
|
||||
### 1. 新增接口:通过平台类型和平台名称获取最新版本
|
||||
|
||||
**接口路径:** `GET /api/downloads/latest/by-platform`
|
||||
|
||||
**请求参数:**
|
||||
- `platform_type` (string, required): 平台类型 (mobile, desktop, terminal)
|
||||
- `platform_name` (string, required): 具体平台名称
|
||||
|
||||
**示例请求:**
|
||||
```bash
|
||||
# 获取 Android 专用终端最新版本
|
||||
curl "http://localhost:8000/api/downloads/latest/by-platform?platform_type=terminal&platform_name=android"
|
||||
|
||||
# 获取单片机固件最新版本
|
||||
curl "http://localhost:8000/api/downloads/latest/by-platform?platform_type=terminal&platform_name=mcu"
|
||||
```
|
||||
|
||||
**返回示例:**
|
||||
```json
|
||||
{
|
||||
"code": "200",
|
||||
"message": "获取成功",
|
||||
"data": {
|
||||
"id": 7,
|
||||
"platform_type": "terminal",
|
||||
"platform_name": "android",
|
||||
"version": "1.0.0",
|
||||
"version_code": 1000,
|
||||
"download_url": "https://download.imeeting.com/terminals/android/iMeeting-Terminal-1.0.0.apk",
|
||||
"file_size": 25165824,
|
||||
"release_notes": "专用终端初始版本\n- 支持专用硬件集成\n- 优化的录音功能\n- 低功耗模式\n- 自动上传同步",
|
||||
"is_active": true,
|
||||
"is_latest": true,
|
||||
"min_system_version": "Android 5.0",
|
||||
"created_at": "2025-01-15T10:00:00",
|
||||
"updated_at": "2025-01-15T10:00:00",
|
||||
"created_by": 1
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 更新接口:获取所有平台最新版本
|
||||
|
||||
**接口路径:** `GET /api/downloads/latest`
|
||||
|
||||
**变更:** 返回数据中新增 `terminal` 字段
|
||||
|
||||
**返回示例:**
|
||||
```json
|
||||
{
|
||||
"code": "200",
|
||||
"message": "获取成功",
|
||||
"data": {
|
||||
"mobile": [...],
|
||||
"desktop": [...],
|
||||
"terminal": [
|
||||
{
|
||||
"id": 7,
|
||||
"platform_type": "terminal",
|
||||
"platform_name": "android",
|
||||
"version": "1.0.0",
|
||||
...
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"platform_type": "terminal",
|
||||
"platform_name": "mcu",
|
||||
"version": "1.0.0",
|
||||
...
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 已有接口说明
|
||||
|
||||
**原有接口:** `GET /api/downloads/{platform_name}/latest`
|
||||
|
||||
- 此接口标记为【已废弃】,建议使用新接口 `/downloads/latest/by-platform`
|
||||
- 原因:只通过 `platform_name` 查询可能产生歧义(如 mobile 的 android 和 terminal 的 android)
|
||||
- 保留此接口是为了向后兼容,但新开发应使用新接口
|
||||
|
||||
## 使用场景
|
||||
|
||||
### 场景 1:专用终端设备版本检查
|
||||
|
||||
专用终端设备(如会议室固定录音设备、单片机硬件)启动时检查更新:
|
||||
|
||||
```javascript
|
||||
// Android 专用终端
|
||||
const response = await fetch(
|
||||
'/api/downloads/latest/by-platform?platform_type=terminal&platform_name=android'
|
||||
);
|
||||
const { data } = await response.json();
|
||||
|
||||
if (data.version_code > currentVersionCode) {
|
||||
// 发现新版本,提示更新
|
||||
showUpdateDialog(data);
|
||||
}
|
||||
```
|
||||
|
||||
### 场景 2:后台管理界面展示
|
||||
|
||||
管理员查看所有终端版本:
|
||||
|
||||
```javascript
|
||||
const response = await fetch('/api/downloads?platform_type=terminal');
|
||||
const { data } = await response.json();
|
||||
|
||||
// data.clients 包含所有 terminal 类型的客户端版本
|
||||
renderClientList(data.clients);
|
||||
```
|
||||
|
||||
### 场景 3:固件更新服务器
|
||||
|
||||
单片机设备定期轮询更新:
|
||||
|
||||
```c
|
||||
// MCU 固件代码示例
|
||||
char url[] = "http://api.imeeting.com/downloads/latest/by-platform?platform_type=terminal&platform_name=mcu";
|
||||
http_get(url, response_buffer);
|
||||
|
||||
// 解析 JSON 获取 download_url 和 version_code
|
||||
if (new_version > FIRMWARE_VERSION) {
|
||||
download_and_update(download_url);
|
||||
}
|
||||
```
|
||||
|
||||
## 测试建议
|
||||
|
||||
### 1. 数据库测试
|
||||
```sql
|
||||
-- 验证表结构修改
|
||||
DESCRIBE client_downloads;
|
||||
|
||||
-- 验证数据插入
|
||||
SELECT * FROM client_downloads WHERE platform_type = 'terminal';
|
||||
```
|
||||
|
||||
### 2. API 测试
|
||||
```bash
|
||||
# 测试新接口
|
||||
curl "http://localhost:8000/api/downloads/latest/by-platform?platform_type=terminal&platform_name=android"
|
||||
|
||||
curl "http://localhost:8000/api/downloads/latest/by-platform?platform_type=terminal&platform_name=mcu"
|
||||
|
||||
# 测试获取所有最新版本
|
||||
curl "http://localhost:8000/api/downloads/latest"
|
||||
|
||||
# 测试列表接口
|
||||
curl "http://localhost:8000/api/downloads?platform_type=terminal"
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **执行 SQL 前请备份数据库**
|
||||
2. **ENUM 类型修改**:ALTER TABLE 会修改表结构,请在低峰期执行
|
||||
3. **新接口优先**:建议所有新开发使用 `/downloads/latest/by-platform` 接口
|
||||
4. **版本管理**:上传新版本时记得设置 `is_latest=TRUE` 并将同平台旧版本设为 `FALSE`
|
||||
5. **platform_name 唯一性**:如果 mobile 和 terminal 都有 android,建议:
|
||||
- mobile 的保持 `android`
|
||||
- terminal 的改为 `android_terminal` 或其他区分名称
|
||||
- 或者始终使用新接口同时传递 platform_type 和 platform_name
|
||||
|
||||
## 文件清单
|
||||
|
||||
- `backend/sql/add_dedicated_terminal.sql` - 数据库迁移 SQL
|
||||
- `backend/app/api/endpoints/client_downloads.py` - API 接口代码
|
||||
- `backend/sql/README_terminal_update.md` - 本说明文档
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
-- 添加专用终端类型支持
|
||||
-- 修改 platform_type 枚举,添加 'terminal' 类型
|
||||
|
||||
ALTER TABLE client_downloads
|
||||
MODIFY COLUMN platform_type ENUM('mobile', 'desktop', 'terminal') NOT NULL
|
||||
COMMENT '平台类型:mobile-移动端, desktop-桌面端, terminal-专用终端';
|
||||
|
||||
-- 插入专用终端示例数据
|
||||
|
||||
-- Android 专用终端
|
||||
INSERT INTO client_downloads (
|
||||
platform_type,
|
||||
platform_name,
|
||||
version,
|
||||
version_code,
|
||||
download_url,
|
||||
file_size,
|
||||
release_notes,
|
||||
is_active,
|
||||
is_latest,
|
||||
min_system_version,
|
||||
created_by
|
||||
) VALUES
|
||||
(
|
||||
'terminal',
|
||||
'android',
|
||||
'1.0.0',
|
||||
1000,
|
||||
'https://download.imeeting.com/terminals/android/iMeeting-1.0.0-Terminal.apk',
|
||||
25165824, -- 24MB
|
||||
'专用终端初始版本
|
||||
- 支持专用硬件集成
|
||||
- 优化的录音功能
|
||||
- 低功耗模式
|
||||
- 自动上传同步',
|
||||
TRUE,
|
||||
TRUE,
|
||||
'Android 5.0',
|
||||
1
|
||||
),
|
||||
|
||||
-- 单片机(MCU)专用终端
|
||||
(
|
||||
'terminal',
|
||||
'mcu',
|
||||
'1.0.0',
|
||||
1000,
|
||||
'https://download.imeeting.com/terminals/mcu/iMeeting-1.0.0-MCU.bin',
|
||||
2097152, -- 2MB
|
||||
'单片机固件初始版本
|
||||
- 嵌入式录音系统
|
||||
- 低功耗设计
|
||||
- 支持WiFi/4G上传
|
||||
- 硬件级音频处理',
|
||||
TRUE,
|
||||
TRUE,
|
||||
'ESP32 / STM32',
|
||||
1
|
||||
);
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
-- ===================================================================
|
||||
-- 菜单权限系统数据库迁移脚本
|
||||
-- 创建日期: 2025-12-10
|
||||
-- 说明: 添加 menus 表和 role_menu_permissions 表,实现基于角色的菜单权限管理
|
||||
-- ===================================================================
|
||||
|
||||
-- ----------------------------
|
||||
-- Table structure for menus
|
||||
-- ----------------------------
|
||||
DROP TABLE IF EXISTS `menus`;
|
||||
CREATE TABLE `menus` (
|
||||
`menu_id` int(11) NOT NULL AUTO_INCREMENT COMMENT '菜单ID',
|
||||
`menu_code` varchar(50) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '菜单代码(唯一标识)',
|
||||
`menu_name` varchar(100) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '菜单名称',
|
||||
`menu_icon` varchar(50) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '菜单图标标识',
|
||||
`menu_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '菜单URL/路由',
|
||||
`menu_type` enum('action','link','divider') COLLATE utf8mb4_unicode_ci DEFAULT 'action' COMMENT '菜单类型: action-操作/link-链接/divider-分隔符',
|
||||
`parent_id` int(11) DEFAULT NULL COMMENT '父菜单ID(用于层级菜单)',
|
||||
`sort_order` int(11) DEFAULT 0 COMMENT '排序顺序',
|
||||
`is_active` tinyint(1) DEFAULT 1 COMMENT '是否启用: 1-启用, 0-禁用',
|
||||
`description` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '菜单描述',
|
||||
`created_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
|
||||
`updated_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
PRIMARY KEY (`menu_id`),
|
||||
UNIQUE KEY `uk_menu_code` (`menu_code`),
|
||||
KEY `idx_parent_id` (`parent_id`),
|
||||
KEY `idx_is_active` (`is_active`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='系统菜单表';
|
||||
|
||||
-- ----------------------------
|
||||
-- Table structure for role_menu_permissions
|
||||
-- ----------------------------
|
||||
DROP TABLE IF EXISTS `role_menu_permissions`;
|
||||
CREATE TABLE `role_menu_permissions` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT '权限ID',
|
||||
`role_id` int(11) NOT NULL COMMENT '角色ID',
|
||||
`menu_id` int(11) NOT NULL COMMENT '菜单ID',
|
||||
`created_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
|
||||
`updated_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `uk_role_menu` (`role_id`,`menu_id`),
|
||||
KEY `idx_role_id` (`role_id`),
|
||||
KEY `idx_menu_id` (`menu_id`),
|
||||
CONSTRAINT `fk_rmp_role_id` FOREIGN KEY (`role_id`) REFERENCES `roles` (`role_id`) ON DELETE CASCADE,
|
||||
CONSTRAINT `fk_rmp_menu_id` FOREIGN KEY (`menu_id`) REFERENCES `menus` (`menu_id`) ON DELETE CASCADE
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='角色菜单权限映射表';
|
||||
|
||||
-- ----------------------------
|
||||
-- 初始化菜单数据(基于现有系统的下拉菜单)
|
||||
-- ----------------------------
|
||||
BEGIN;
|
||||
|
||||
-- 用户菜单项
|
||||
INSERT INTO `menus` (`menu_code`, `menu_name`, `menu_icon`, `menu_url`, `menu_type`, `sort_order`, `is_active`, `description`)
|
||||
VALUES
|
||||
('change_password', '修改密码', 'KeyRound', NULL, 'action', 1, 1, '用户修改自己的密码'),
|
||||
('prompt_management', '提示词仓库', 'BookText', '/prompt-management', 'link', 2, 1, '管理AI提示词模版'),
|
||||
('platform_admin', '平台管理', 'Shield', '/admin/management', 'link', 3, 1, '平台管理员后台'),
|
||||
('logout', '退出登录', 'LogOut', NULL, 'action', 99, 1, '退出当前账号');
|
||||
|
||||
COMMIT;
|
||||
|
||||
-- ----------------------------
|
||||
-- 初始化角色权限数据
|
||||
-- 注意:角色表已存在,role_id=1为平台管理员,role_id=2为普通用户
|
||||
-- ----------------------------
|
||||
BEGIN;
|
||||
|
||||
-- 平台管理员(role_id=1)拥有所有菜单权限
|
||||
INSERT INTO `role_menu_permissions` (`role_id`, `menu_id`)
|
||||
SELECT 1, menu_id FROM `menus` WHERE is_active = 1;
|
||||
|
||||
-- 普通用户(role_id=2)拥有除"平台管理"外的所有菜单权限
|
||||
INSERT INTO `role_menu_permissions` (`role_id`, `menu_id`)
|
||||
SELECT 2, menu_id FROM `menus` WHERE menu_code != 'platform_admin' AND is_active = 1;
|
||||
|
||||
COMMIT;
|
||||
|
||||
-- ----------------------------
|
||||
-- 查询验证
|
||||
-- ----------------------------
|
||||
-- 查看所有菜单
|
||||
-- SELECT * FROM menus ORDER BY sort_order;
|
||||
|
||||
-- 查看平台管理员的菜单权限
|
||||
-- SELECT r.role_name, m.menu_name, m.menu_code, m.menu_url
|
||||
-- FROM role_menu_permissions rmp
|
||||
-- JOIN roles r ON rmp.role_id = r.role_id
|
||||
-- JOIN menus m ON rmp.menu_id = m.menu_id
|
||||
-- WHERE r.role_id = 1
|
||||
-- ORDER BY m.sort_order;
|
||||
|
||||
-- 查看普通用户的菜单权限
|
||||
-- SELECT r.role_name, m.menu_name, m.menu_code, m.menu_url
|
||||
-- FROM role_menu_permissions rmp
|
||||
-- JOIN roles r ON rmp.role_id = r.role_id
|
||||
-- JOIN menus m ON rmp.menu_id = m.menu_id
|
||||
-- WHERE r.role_id = 2
|
||||
-- ORDER BY m.sort_order;
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
-- 为 llm_tasks 表添加 prompt_id 列,用于支持自定义模版选择功能
|
||||
-- 执行日期:2025-12-08
|
||||
|
||||
ALTER TABLE `llm_tasks`
|
||||
ADD COLUMN `prompt_id` int(11) DEFAULT NULL COMMENT '提示词模版ID' AFTER `user_prompt`,
|
||||
ADD KEY `idx_prompt_id` (`prompt_id`);
|
||||
|
||||
-- 说明:
|
||||
-- 1. prompt_id 允许为 NULL,表示使用默认模版
|
||||
-- 2. 添加索引以优化查询性能
|
||||
-- 3. 不添加外键约束,因为 prompts 表中的记录可能被删除,我们希望保留历史任务记录
|
||||
|
|
@ -534,10 +534,12 @@ CREATE TABLE `knowledge_bases` (
|
|||
`is_shared` tinyint(1) NOT NULL DEFAULT '0' COMMENT '是否为共享知识库',
|
||||
`source_meeting_ids` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '内容来源的会议ID列表 (逗号分隔)',
|
||||
`tags` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '逗号分隔的标签',
|
||||
`prompt_id` int(11) DEFAULT 0 COMMENT '使用的提示词模版ID,0表示未使用或使用默认模版',
|
||||
`created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (`kb_id`),
|
||||
KEY `idx_creator_id` (`creator_id`),
|
||||
KEY `idx_prompt_id` (`prompt_id`),
|
||||
CONSTRAINT `knowledge_bases_ibfk_1` FOREIGN KEY (`creator_id`) REFERENCES `users` (`user_id`) ON DELETE CASCADE
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=28 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='知识库条目表';
|
||||
|
||||
|
|
@ -686,9 +688,11 @@ CREATE TABLE `meetings` (
|
|||
`meeting_time` timestamp NULL DEFAULT NULL,
|
||||
`user_prompt` text COLLATE utf8mb4_unicode_ci,
|
||||
`summary` text CHARACTER SET utf8mb4,
|
||||
`prompt_id` int(11) DEFAULT 0 COMMENT '使用的提示词模版ID,0表示未使用或使用默认模版',
|
||||
`created_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (`meeting_id`)
|
||||
PRIMARY KEY (`meeting_id`),
|
||||
KEY `idx_prompt_id` (`prompt_id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=372 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
|
||||
-- ----------------------------
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
-- 提示词表改造迁移脚本
|
||||
-- 将 prompt_config 表的功能整合到 prompts 表
|
||||
|
||||
-- 步骤1: 添加新字段
|
||||
ALTER TABLE prompts
|
||||
ADD COLUMN task_type ENUM('MEETING_TASK', 'KNOWLEDGE_TASK')
|
||||
COMMENT '任务类型:MEETING_TASK-会议任务, KNOWLEDGE_TASK-知识库任务' AFTER name;
|
||||
|
||||
ALTER TABLE prompts
|
||||
ADD COLUMN is_default BOOLEAN NOT NULL DEFAULT FALSE
|
||||
COMMENT '是否为该任务类型的默认模板' AFTER content;
|
||||
|
||||
-- 步骤2: 修改 is_active 字段(如果存在且类型不是 BOOLEAN)
|
||||
-- 先检查字段是否存在,如果不存在则添加
|
||||
ALTER TABLE prompts
|
||||
MODIFY COLUMN is_active BOOLEAN NOT NULL DEFAULT TRUE
|
||||
COMMENT '是否启用(只有启用的提示词才能被使用)';
|
||||
|
||||
-- 步骤3: 删除 tags 字段
|
||||
ALTER TABLE prompts DROP COLUMN IF EXISTS tags;
|
||||
|
||||
-- 步骤4: 从 prompt_config 迁移数据(如果 prompt_config 表存在)
|
||||
-- 更新 task_type 和 is_default
|
||||
UPDATE prompts p
|
||||
LEFT JOIN prompt_config pc ON p.id = pc.prompt_id
|
||||
SET
|
||||
p.task_type = CASE
|
||||
WHEN pc.task_name IS NOT NULL THEN pc.task_name
|
||||
ELSE 'MEETING_TASK' -- 默认值
|
||||
END,
|
||||
p.is_default = CASE
|
||||
WHEN pc.is_default = 1 THEN TRUE
|
||||
ELSE FALSE
|
||||
END
|
||||
WHERE pc.prompt_id IS NOT NULL OR p.task_type IS NULL;
|
||||
|
||||
-- 步骤5: 为所有没有设置 task_type 的提示词设置默认值
|
||||
UPDATE prompts
|
||||
SET task_type = 'MEETING_TASK'
|
||||
WHERE task_type IS NULL;
|
||||
|
||||
-- 步骤6: 将 task_type 设置为 NOT NULL
|
||||
ALTER TABLE prompts
|
||||
MODIFY COLUMN task_type ENUM('MEETING_TASK', 'KNOWLEDGE_TASK') NOT NULL
|
||||
COMMENT '任务类型:MEETING_TASK-会议任务, KNOWLEDGE_TASK-知识库任务';
|
||||
|
||||
-- 步骤7: 确保每个 task_type 只有一个默认提示词
|
||||
-- 如果有多个默认,只保留 id 最小的那个
|
||||
UPDATE prompts p1
|
||||
LEFT JOIN (
|
||||
SELECT task_type, MIN(id) as min_id
|
||||
FROM prompts
|
||||
WHERE is_default = TRUE
|
||||
GROUP BY task_type
|
||||
) p2 ON p1.task_type = p2.task_type
|
||||
SET p1.is_default = FALSE
|
||||
WHERE p1.is_default = TRUE AND p1.id != p2.min_id;
|
||||
|
||||
-- 步骤8: (可选) 备注 prompt_config 表已废弃
|
||||
-- 如果需要删除 prompt_config 表,取消下面的注释
|
||||
-- DROP TABLE IF EXISTS prompt_config;
|
||||
|
||||
-- 迁移完成
|
||||
SELECT '提示词表迁移完成!' as message;
|
||||
SELECT task_type, COUNT(*) as total, SUM(is_default) as default_count
|
||||
FROM prompts
|
||||
GROUP BY task_type;
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
-- ============================================
|
||||
-- 添加 prompt_id 字段到主表
|
||||
-- 创建时间: 2025-01-11
|
||||
-- 说明: 在 meetings 和 knowledge_bases 表中添加 prompt_id 字段
|
||||
-- 用于记录会议/知识库使用的提示词模版
|
||||
-- ============================================
|
||||
|
||||
-- 1. 为 meetings 表添加 prompt_id 字段
|
||||
ALTER TABLE meetings
|
||||
ADD COLUMN prompt_id INT(11) DEFAULT 0 COMMENT '使用的提示词模版ID,0表示未使用或使用默认模版'
|
||||
AFTER summary;
|
||||
|
||||
-- 为 meetings 表添加索引
|
||||
ALTER TABLE meetings
|
||||
ADD INDEX idx_prompt_id (prompt_id);
|
||||
|
||||
-- 2. 为 knowledge_bases 表添加 prompt_id 字段
|
||||
ALTER TABLE knowledge_bases
|
||||
ADD COLUMN prompt_id INT(11) DEFAULT 0 COMMENT '使用的提示词模版ID,0表示未使用或使用默认模版'
|
||||
AFTER tags;
|
||||
|
||||
-- 为 knowledge_bases 表添加索引
|
||||
ALTER TABLE knowledge_bases
|
||||
ADD INDEX idx_prompt_id (prompt_id);
|
||||
|
||||
-- ============================================
|
||||
-- 验证修改
|
||||
-- ============================================
|
||||
-- 查看 meetings 表结构
|
||||
-- DESCRIBE meetings;
|
||||
|
||||
-- 查看 knowledge_bases 表结构
|
||||
-- DESCRIBE knowledge_bases;
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
-- 为现有数据库添加转录任务支持的SQL脚本
|
||||
|
||||
-- 1. 更新audio_files表结构,添加缺失字段
|
||||
ALTER TABLE audio_files
|
||||
ADD COLUMN file_name VARCHAR(255) AFTER meeting_id,
|
||||
ADD COLUMN file_size BIGINT DEFAULT NULL AFTER file_path,
|
||||
ADD COLUMN task_id VARCHAR(255) DEFAULT NULL AFTER upload_time;
|
||||
|
||||
-- 2. 创建转录任务表
|
||||
CREATE TABLE transcript_tasks (
|
||||
task_id VARCHAR(255) PRIMARY KEY,
|
||||
meeting_id INT NOT NULL,
|
||||
status ENUM('pending', 'processing', 'completed', 'failed') DEFAULT 'pending',
|
||||
progress INT DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
completed_at TIMESTAMP NULL,
|
||||
error_message TEXT NULL,
|
||||
|
||||
FOREIGN KEY (meeting_id) REFERENCES meetings(meeting_id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- 3. 添加索引以优化查询性能
|
||||
-- audio_files 表索引
|
||||
ALTER TABLE audio_files ADD INDEX idx_task_id (task_id);
|
||||
|
||||
-- transcript_tasks 表索引
|
||||
ALTER TABLE transcript_tasks ADD INDEX idx_meeting_id (meeting_id);
|
||||
ALTER TABLE transcript_tasks ADD INDEX idx_status (status);
|
||||
ALTER TABLE transcript_tasks ADD INDEX idx_created_at (created_at);
|
||||
|
||||
-- 4. 更新现有测试数据(如果需要)
|
||||
-- 这些语句是可选的,用于更新现有的测试数据
|
||||
UPDATE audio_files SET file_name = 'test_audio.mp3' WHERE file_name IS NULL;
|
||||
UPDATE audio_files SET file_size = 10485760 WHERE file_size IS NULL; -- 10MB
|
||||
|
||||
SELECT '转录任务表创建完成!' as message;
|
||||
|
||||
|
||||
CREATE TABLE llm_tasks (
|
||||
task_id VARCHAR(100) PRIMARY KEY,
|
||||
llm_task_id VARCHAR(100) DEFAULT NULL,
|
||||
meeting_id INT NOT NULL,
|
||||
user_prompt TEXT,
|
||||
status VARCHAR(50) DEFAULT 'pending',
|
||||
progress INT DEFAULT 0,
|
||||
result TEXT,
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
completed_at TIMESTAMP NULL,
|
||||
INDEX idx_meeting_id (meeting_id),
|
||||
INDEX idx_status (status),
|
||||
INDEX idx_created_at (created_at)
|
||||
)
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
"""
|
||||
测试知识库提示词模版选择功能
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, 'app')
|
||||
|
||||
from app.services.llm_service import LLMService
|
||||
from app.services.async_knowledge_base_service import AsyncKnowledgeBaseService
|
||||
from app.core.database import get_db_connection
|
||||
|
||||
def test_get_active_knowledge_prompts():
|
||||
"""测试获取启用的知识库提示词列表"""
|
||||
print("\n=== 测试1: 获取启用的知识库提示词列表 ===")
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 获取KNOWLEDGE_TASK类型的启用模版
|
||||
query = """
|
||||
SELECT id, name, is_default
|
||||
FROM prompts
|
||||
WHERE task_type = %s AND is_active = TRUE
|
||||
ORDER BY is_default DESC, created_at DESC
|
||||
"""
|
||||
cursor.execute(query, ('KNOWLEDGE_TASK',))
|
||||
prompts = cursor.fetchall()
|
||||
|
||||
print(f"✓ 找到 {len(prompts)} 个启用的知识库任务模版:")
|
||||
for p in prompts:
|
||||
default_flag = " [默认]" if p['is_default'] else ""
|
||||
print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}")
|
||||
|
||||
return prompts
|
||||
except Exception as e:
|
||||
print(f"✗ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
def test_get_task_prompt_with_id(prompts):
|
||||
"""测试通过prompt_id获取知识库提示词内容"""
|
||||
print("\n=== 测试2: 通过prompt_id获取知识库提示词内容 ===")
|
||||
|
||||
if not prompts:
|
||||
print("⚠ 没有可用的提示词模版,跳过测试")
|
||||
return
|
||||
|
||||
llm_service = LLMService()
|
||||
|
||||
# 测试获取第一个提示词
|
||||
test_prompt = prompts[0]
|
||||
try:
|
||||
content = llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id=test_prompt['id'])
|
||||
print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}")
|
||||
print(f" 内容长度: {len(content)} 字符")
|
||||
print(f" 内容预览: {content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"✗ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# 测试获取默认提示词(不指定prompt_id)
|
||||
try:
|
||||
default_content = llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||||
print(f"✓ 成功获取默认提示词")
|
||||
print(f" 内容长度: {len(default_content)} 字符")
|
||||
except Exception as e:
|
||||
print(f"✗ 获取默认提示词失败: {e}")
|
||||
|
||||
def test_async_kb_service_signature():
|
||||
"""测试async_knowledge_base_service的方法签名"""
|
||||
print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===")
|
||||
|
||||
import inspect
|
||||
async_service = AsyncKnowledgeBaseService()
|
||||
|
||||
# 检查start_generation方法签名
|
||||
sig = inspect.signature(async_service.start_generation)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
if 'prompt_id' in params:
|
||||
print(f"✓ start_generation 方法支持 prompt_id 参数")
|
||||
print(f" 参数列表: {params}")
|
||||
else:
|
||||
print(f"✗ start_generation 方法缺少 prompt_id 参数")
|
||||
print(f" 参数列表: {params}")
|
||||
|
||||
# 检查_build_prompt方法签名
|
||||
sig2 = inspect.signature(async_service._build_prompt)
|
||||
params2 = list(sig2.parameters.keys())
|
||||
|
||||
if 'prompt_id' in params2:
|
||||
print(f"✓ _build_prompt 方法支持 prompt_id 参数")
|
||||
print(f" 参数列表: {params2}")
|
||||
else:
|
||||
print(f"✗ _build_prompt 方法缺少 prompt_id 参数")
|
||||
print(f" 参数列表: {params2}")
|
||||
|
||||
def test_database_schema():
|
||||
"""测试数据库schema是否包含prompt_id列"""
|
||||
print("\n=== 测试4: 验证数据库schema ===")
|
||||
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 检查knowledge_base_tasks表是否有prompt_id列
|
||||
cursor.execute("""
|
||||
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = 'knowledge_base_tasks'
|
||||
AND COLUMN_NAME = 'prompt_id'
|
||||
""")
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
print(f"✓ knowledge_base_tasks 表包含 prompt_id 列")
|
||||
print(f" 类型: {result['DATA_TYPE']}")
|
||||
print(f" 可空: {result['IS_NULLABLE']}")
|
||||
print(f" 默认值: {result['COLUMN_DEFAULT']}")
|
||||
else:
|
||||
print(f"✗ knowledge_base_tasks 表缺少 prompt_id 列")
|
||||
except Exception as e:
|
||||
print(f"✗ 数据库检查失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def test_api_model():
|
||||
"""测试API模型定义"""
|
||||
print("\n=== 测试5: 验证API模型定义 ===")
|
||||
|
||||
try:
|
||||
from app.models.models import CreateKnowledgeBaseRequest
|
||||
import inspect
|
||||
|
||||
# 检查CreateKnowledgeBaseRequest模型
|
||||
fields = CreateKnowledgeBaseRequest.model_fields
|
||||
|
||||
if 'prompt_id' in fields:
|
||||
print(f"✓ CreateKnowledgeBaseRequest 包含 prompt_id 字段")
|
||||
print(f" 字段列表: {list(fields.keys())}")
|
||||
else:
|
||||
print(f"✗ CreateKnowledgeBaseRequest 缺少 prompt_id 字段")
|
||||
print(f" 字段列表: {list(fields.keys())}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ API模型检查失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("=" * 60)
|
||||
print("开始测试知识库提示词模版选择功能")
|
||||
print("=" * 60)
|
||||
|
||||
# 运行所有测试
|
||||
prompts = test_get_active_knowledge_prompts()
|
||||
test_get_task_prompt_with_id(prompts)
|
||||
test_async_kb_service_signature()
|
||||
test_database_schema()
|
||||
test_api_model()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成")
|
||||
print("=" * 60)
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试菜单权限数据是否存在
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.core.database import get_db_connection
|
||||
|
||||
def test_menu_permissions():
|
||||
print("=== 测试菜单权限数据 ===\n")
|
||||
|
||||
try:
|
||||
# 连接数据库
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 1. 检查menus表
|
||||
print("1. 检查menus表:")
|
||||
cursor.execute("SELECT COUNT(*) as count FROM menus")
|
||||
menu_count = cursor.fetchone()['count']
|
||||
print(f" - 菜单总数: {menu_count}")
|
||||
|
||||
if menu_count > 0:
|
||||
cursor.execute("SELECT menu_id, menu_code, menu_name, is_active FROM menus ORDER BY sort_order")
|
||||
menus = cursor.fetchall()
|
||||
for menu in menus:
|
||||
print(f" - [{menu['menu_id']}] {menu['menu_name']} ({menu['menu_code']}) - 启用: {menu['is_active']}")
|
||||
else:
|
||||
print(" ⚠️ menus表为空!")
|
||||
|
||||
print()
|
||||
|
||||
# 2. 检查roles表
|
||||
print("2. 检查roles表:")
|
||||
cursor.execute("SELECT * FROM roles ORDER BY role_id")
|
||||
roles = cursor.fetchall()
|
||||
for role in roles:
|
||||
print(f" - [{role['role_id']}] {role['role_name']}")
|
||||
|
||||
print()
|
||||
|
||||
# 3. 检查role_menu_permissions表
|
||||
print("3. 检查role_menu_permissions表:")
|
||||
cursor.execute("SELECT COUNT(*) as count FROM role_menu_permissions")
|
||||
perm_count = cursor.fetchone()['count']
|
||||
print(f" - 权限总数: {perm_count}")
|
||||
|
||||
if perm_count > 0:
|
||||
cursor.execute("""
|
||||
SELECT r.role_name, m.menu_name, rmp.role_id, rmp.menu_id
|
||||
FROM role_menu_permissions rmp
|
||||
JOIN roles r ON rmp.role_id = r.role_id
|
||||
JOIN menus m ON rmp.menu_id = m.menu_id
|
||||
ORDER BY rmp.role_id, m.sort_order
|
||||
""")
|
||||
permissions = cursor.fetchall()
|
||||
|
||||
current_role = None
|
||||
for perm in permissions:
|
||||
if current_role != perm['role_name']:
|
||||
current_role = perm['role_name']
|
||||
print(f"\n {current_role}的权限:")
|
||||
print(f" - {perm['menu_name']}")
|
||||
else:
|
||||
print(" ⚠️ role_menu_permissions表为空!")
|
||||
|
||||
print("\n" + "="*50)
|
||||
|
||||
# 4. 检查是否需要执行SQL脚本
|
||||
if menu_count == 0 or perm_count == 0:
|
||||
print("\n❌ 数据库中缺少菜单或权限数据!")
|
||||
print("请执行以下命令初始化数据:")
|
||||
print("\nmysql -h 10.100.51.161 -u root -psagacity imeeting_dev < backend/sql/add_menu_permissions_system.sql")
|
||||
print("\n或者在MySQL客户端中执行该SQL文件。")
|
||||
else:
|
||||
print("\n✅ 菜单权限数据正常!")
|
||||
|
||||
cursor.close()
|
||||
connection.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 错误: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_menu_permissions()
|
||||
|
|
@ -0,0 +1,176 @@
|
|||
"""
|
||||
测试提示词模版选择功能
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, 'app')
|
||||
|
||||
from app.services.llm_service import LLMService
|
||||
from app.services.async_meeting_service import AsyncMeetingService
|
||||
from app.core.database import get_db_connection
|
||||
|
||||
def test_get_active_prompts():
|
||||
"""测试获取启用的提示词列表"""
|
||||
print("\n=== 测试1: 获取启用的提示词列表 ===")
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 获取MEETING_TASK类型的启用模版
|
||||
query = """
|
||||
SELECT id, name, is_default
|
||||
FROM prompts
|
||||
WHERE task_type = %s AND is_active = TRUE
|
||||
ORDER BY is_default DESC, created_at DESC
|
||||
"""
|
||||
cursor.execute(query, ('MEETING_TASK',))
|
||||
prompts = cursor.fetchall()
|
||||
|
||||
print(f"✓ 找到 {len(prompts)} 个启用的会议任务模版:")
|
||||
for p in prompts:
|
||||
default_flag = " [默认]" if p['is_default'] else ""
|
||||
print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}")
|
||||
|
||||
return prompts
|
||||
except Exception as e:
|
||||
print(f"✗ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
def test_get_task_prompt_with_id(prompts):
|
||||
"""测试通过prompt_id获取提示词内容"""
|
||||
print("\n=== 测试2: 通过prompt_id获取提示词内容 ===")
|
||||
|
||||
if not prompts:
|
||||
print("⚠ 没有可用的提示词模版,跳过测试")
|
||||
return
|
||||
|
||||
llm_service = LLMService()
|
||||
|
||||
# 测试获取第一个提示词
|
||||
test_prompt = prompts[0]
|
||||
try:
|
||||
content = llm_service.get_task_prompt('MEETING_TASK', prompt_id=test_prompt['id'])
|
||||
print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}")
|
||||
print(f" 内容长度: {len(content)} 字符")
|
||||
print(f" 内容预览: {content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"✗ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# 测试获取默认提示词(不指定prompt_id)
|
||||
try:
|
||||
default_content = llm_service.get_task_prompt('MEETING_TASK')
|
||||
print(f"✓ 成功获取默认提示词")
|
||||
print(f" 内容长度: {len(default_content)} 字符")
|
||||
except Exception as e:
|
||||
print(f"✗ 获取默认提示词失败: {e}")
|
||||
|
||||
def test_async_meeting_service_signature():
|
||||
"""测试async_meeting_service的方法签名"""
|
||||
print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===")
|
||||
|
||||
import inspect
|
||||
async_service = AsyncMeetingService()
|
||||
|
||||
# 检查start_summary_generation方法签名
|
||||
sig = inspect.signature(async_service.start_summary_generation)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
if 'prompt_id' in params:
|
||||
print(f"✓ start_summary_generation 方法支持 prompt_id 参数")
|
||||
print(f" 参数列表: {params}")
|
||||
else:
|
||||
print(f"✗ start_summary_generation 方法缺少 prompt_id 参数")
|
||||
print(f" 参数列表: {params}")
|
||||
|
||||
# 检查monitor_and_auto_summarize方法签名
|
||||
sig2 = inspect.signature(async_service.monitor_and_auto_summarize)
|
||||
params2 = list(sig2.parameters.keys())
|
||||
|
||||
if 'prompt_id' in params2:
|
||||
print(f"✓ monitor_and_auto_summarize 方法支持 prompt_id 参数")
|
||||
print(f" 参数列表: {params2}")
|
||||
else:
|
||||
print(f"✗ monitor_and_auto_summarize 方法缺少 prompt_id 参数")
|
||||
print(f" 参数列表: {params2}")
|
||||
|
||||
def test_database_schema():
|
||||
"""测试数据库schema是否包含prompt_id列"""
|
||||
print("\n=== 测试4: 验证数据库schema ===")
|
||||
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# 检查llm_tasks表是否有prompt_id列
|
||||
cursor.execute("""
|
||||
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = 'llm_tasks'
|
||||
AND COLUMN_NAME = 'prompt_id'
|
||||
""")
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
print(f"✓ llm_tasks 表包含 prompt_id 列")
|
||||
print(f" 类型: {result['DATA_TYPE']}")
|
||||
print(f" 可空: {result['IS_NULLABLE']}")
|
||||
print(f" 默认值: {result['COLUMN_DEFAULT']}")
|
||||
else:
|
||||
print(f"✗ llm_tasks 表缺少 prompt_id 列")
|
||||
except Exception as e:
|
||||
print(f"✗ 数据库检查失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def test_api_endpoints():
|
||||
"""测试API端点定义"""
|
||||
print("\n=== 测试5: 验证API端点定义 ===")
|
||||
|
||||
try:
|
||||
from app.api.endpoints.meetings import GenerateSummaryRequest
|
||||
import inspect
|
||||
|
||||
# 检查GenerateSummaryRequest模型
|
||||
fields = GenerateSummaryRequest.__fields__
|
||||
|
||||
if 'prompt_id' in fields:
|
||||
print(f"✓ GenerateSummaryRequest 包含 prompt_id 字段")
|
||||
print(f" 字段列表: {list(fields.keys())}")
|
||||
else:
|
||||
print(f"✗ GenerateSummaryRequest 缺少 prompt_id 字段")
|
||||
print(f" 字段列表: {list(fields.keys())}")
|
||||
|
||||
# 检查audio_service.handle_audio_upload签名
|
||||
from app.services.audio_service import handle_audio_upload
|
||||
sig = inspect.signature(handle_audio_upload)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
if 'prompt_id' in params:
|
||||
print(f"✓ handle_audio_upload 方法支持 prompt_id 参数")
|
||||
else:
|
||||
print(f"✗ handle_audio_upload 方法缺少 prompt_id 参数")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ API端点检查失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("=" * 60)
|
||||
print("开始测试提示词模版选择功能")
|
||||
print("=" * 60)
|
||||
|
||||
# 运行所有测试
|
||||
prompts = test_get_active_prompts()
|
||||
test_get_task_prompt_with_id(prompts)
|
||||
test_async_meeting_service_signature()
|
||||
test_database_schema()
|
||||
test_api_endpoints()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成")
|
||||
print("=" * 60)
|
||||
Loading…
Reference in New Issue