284 lines
10 KiB
Python
284 lines
10 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||
from typing import Optional, List
|
||
from app.models.models import KnowledgeBase, KnowledgeBaseListResponse, CreateKnowledgeBaseRequest, UpdateKnowledgeBaseRequest, Tag
|
||
from app.core.database import get_db_connection
|
||
from app.core.auth import get_current_user
|
||
from app.core.response import create_api_response
|
||
from app.services.async_knowledge_base_service import async_kb_service
|
||
import datetime
|
||
|
||
router = APIRouter()
|
||
|
||
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
|
||
"""
|
||
处理标签:查询已存在的标签,如果提供了 creator_id 则创建不存在的标签
|
||
"""
|
||
if not tag_string:
|
||
return []
|
||
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
|
||
if not tag_names:
|
||
return []
|
||
|
||
# 如果提供了 creator_id,则创建不存在的标签
|
||
if creator_id:
|
||
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
|
||
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
|
||
|
||
# 查询所有标签信息
|
||
format_strings = ', '.join(['%s'] * len(tag_names))
|
||
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
|
||
tags_data = cursor.fetchall()
|
||
return [Tag(**tag) for tag in tags_data]
|
||
|
||
|
||
@router.get("/knowledge-bases", response_model=KnowledgeBaseListResponse)
|
||
def get_knowledge_bases(
|
||
page: int = 1,
|
||
size: int = 10,
|
||
is_shared: Optional[bool] = None,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
base_query = "FROM knowledge_bases kb JOIN users u ON kb.creator_id = u.user_id"
|
||
where_clauses = []
|
||
params = []
|
||
|
||
if is_shared is not None:
|
||
if is_shared:
|
||
where_clauses.append("kb.is_shared = 1")
|
||
else: # Personal
|
||
where_clauses.append("kb.is_shared = 0 AND kb.creator_id = %s")
|
||
params.append(current_user['user_id'])
|
||
else: # Both personal and shared
|
||
where_clauses.append("(kb.is_shared = 1 OR kb.creator_id = %s)")
|
||
params.append(current_user['user_id'])
|
||
|
||
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
||
|
||
count_query = "SELECT COUNT(*) as total " + base_query + where_sql
|
||
cursor.execute(count_query, tuple(params))
|
||
total = cursor.fetchone()['total']
|
||
|
||
offset = (page - 1) * size
|
||
|
||
query = f"""
|
||
SELECT
|
||
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
|
||
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at
|
||
{base_query}
|
||
{where_sql}
|
||
ORDER BY kb.updated_at DESC
|
||
LIMIT %s OFFSET %s
|
||
"""
|
||
|
||
query_params = params + [size, offset]
|
||
cursor.execute(query, tuple(query_params))
|
||
kbs_data = cursor.fetchall()
|
||
|
||
kb_list = []
|
||
for kb_data in kbs_data:
|
||
# 列表页不需要处理 tags,直接使用字符串
|
||
# kb_data['tags'] 保持原样(逗号分隔的标签名称字符串)
|
||
# Count source meetings - filter empty strings
|
||
if kb_data.get('source_meeting_ids'):
|
||
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
|
||
kb_data['source_meeting_count'] = len(meeting_ids)
|
||
else:
|
||
kb_data['source_meeting_count'] = 0
|
||
# Add created_by_name for consistency
|
||
kb_data['created_by_name'] = kb_data.get('creator_caption')
|
||
kb_list.append(KnowledgeBase(**kb_data))
|
||
|
||
return KnowledgeBaseListResponse(kbs=kb_list, total=total)
|
||
|
||
@router.post("/knowledge-bases")
|
||
def create_knowledge_base(
|
||
request: CreateKnowledgeBaseRequest,
|
||
background_tasks: BackgroundTasks,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor()
|
||
|
||
# 自动生成标题,格式为: YYYY-MM-DD 知识条目
|
||
if not request.title:
|
||
now = datetime.datetime.now()
|
||
request.title = now.strftime("%Y-%m-%d") + " 知识条目"
|
||
|
||
# Create the knowledge base entry first
|
||
insert_kb_query = """
|
||
INSERT INTO knowledge_bases (title, creator_id, is_shared, source_meeting_ids, user_prompt, tags, created_at, updated_at)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||
"""
|
||
now = datetime.datetime.utcnow()
|
||
cursor.execute(insert_kb_query, (
|
||
request.title,
|
||
current_user['user_id'],
|
||
request.is_shared,
|
||
request.source_meeting_ids,
|
||
request.user_prompt,
|
||
request.tags, # 创建时 tags 应该为 None 或空字符串
|
||
now,
|
||
now
|
||
))
|
||
kb_id = cursor.lastrowid
|
||
|
||
# Start the async task
|
||
task_id = async_kb_service.start_generation(
|
||
user_id=current_user['user_id'],
|
||
kb_id=kb_id,
|
||
user_prompt=request.user_prompt,
|
||
source_meeting_ids=request.source_meeting_ids,
|
||
cursor=cursor
|
||
)
|
||
|
||
connection.commit()
|
||
|
||
# Add the background task to process the knowledge base generation
|
||
background_tasks.add_task(async_kb_service._process_task, task_id)
|
||
|
||
return {"task_id": task_id, "kb_id": kb_id}
|
||
|
||
@router.get("/knowledge-bases/{kb_id}")
|
||
def get_knowledge_base_detail(
|
||
kb_id: int,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
query = """
|
||
SELECT
|
||
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
|
||
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at,
|
||
u.username as created_by_name
|
||
FROM knowledge_bases kb
|
||
JOIN users u ON kb.creator_id = u.user_id
|
||
WHERE kb.kb_id = %s
|
||
"""
|
||
cursor.execute(query, (kb_id,))
|
||
kb_data = cursor.fetchone()
|
||
|
||
if not kb_data:
|
||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||
|
||
# Check access permissions
|
||
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
|
||
raise HTTPException(status_code=403, detail="Access denied")
|
||
|
||
# Process tags - 获取标签的完整信息(包括颜色)
|
||
# 详情页不需要创建新标签,所以不传 creator_id
|
||
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
|
||
|
||
# Get source meetings details
|
||
source_meetings = []
|
||
if kb_data.get('source_meeting_ids'):
|
||
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
|
||
if meeting_ids:
|
||
placeholders = ','.join(['%s'] * len(meeting_ids))
|
||
meeting_query = f"""
|
||
SELECT meeting_id, title
|
||
FROM meetings
|
||
WHERE meeting_id IN ({placeholders})
|
||
"""
|
||
cursor.execute(meeting_query, tuple(meeting_ids))
|
||
meetings_data = cursor.fetchall()
|
||
source_meetings = [{'meeting_id': m['meeting_id'], 'title': m['title']} for m in meetings_data]
|
||
kb_data['source_meeting_count'] = len(source_meetings)
|
||
else:
|
||
kb_data['source_meeting_count'] = 0
|
||
else:
|
||
kb_data['source_meeting_count'] = 0
|
||
|
||
kb_data['source_meetings'] = source_meetings
|
||
|
||
return kb_data
|
||
|
||
@router.put("/knowledge-bases/{kb_id}")
|
||
def update_knowledge_base(
|
||
kb_id: int,
|
||
request: UpdateKnowledgeBaseRequest,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
# Check if knowledge base exists and user has permission
|
||
cursor.execute(
|
||
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
|
||
(kb_id,)
|
||
)
|
||
kb = cursor.fetchone()
|
||
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||
|
||
if kb['creator_id'] != current_user['user_id']:
|
||
raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base")
|
||
|
||
# 使用 _process_tags 处理标签(会自动创建新标签)
|
||
if request.tags:
|
||
_process_tags(cursor, request.tags, current_user['user_id'])
|
||
|
||
# Update the knowledge base
|
||
now = datetime.datetime.utcnow()
|
||
update_query = """
|
||
UPDATE knowledge_bases
|
||
SET title = %s, content = %s, tags = %s, updated_at = %s
|
||
WHERE kb_id = %s
|
||
"""
|
||
cursor.execute(update_query, (
|
||
request.title,
|
||
request.content,
|
||
request.tags,
|
||
now,
|
||
kb_id
|
||
))
|
||
connection.commit()
|
||
|
||
return {"message": "Knowledge base updated successfully"}
|
||
|
||
@router.delete("/knowledge-bases/{kb_id}")
|
||
def delete_knowledge_base(
|
||
kb_id: int,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
|
||
# Check if knowledge base exists and user has permission
|
||
cursor.execute(
|
||
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
|
||
(kb_id,)
|
||
)
|
||
kb = cursor.fetchone()
|
||
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||
|
||
if kb['creator_id'] != current_user['user_id']:
|
||
raise HTTPException(status_code=403, detail="Only the creator can delete this knowledge base")
|
||
|
||
# Delete the knowledge base
|
||
cursor.execute("DELETE FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
|
||
connection.commit()
|
||
|
||
return {"message": "Knowledge base deleted successfully"}
|
||
|
||
@router.get("/knowledge-bases/tasks/{task_id}")
|
||
def get_task_status(task_id: str):
|
||
"""获取知识库生成任务状态"""
|
||
task_status = async_kb_service.get_task_status(task_id)
|
||
|
||
if task_status.get('status') == 'not_found':
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
||
return {
|
||
"status": task_status.get('status'),
|
||
"progress": task_status.get('progress', 0),
|
||
"error": task_status.get('error_message')
|
||
}
|
||
|