imetting_backend/app/api/endpoints/knowledge_base.py

284 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from typing import Optional, List
from app.models.models import KnowledgeBase, KnowledgeBaseListResponse, CreateKnowledgeBaseRequest, UpdateKnowledgeBaseRequest, Tag
from app.core.database import get_db_connection
from app.core.auth import get_current_user
from app.core.response import create_api_response
from app.services.async_knowledge_base_service import async_kb_service
import datetime
router = APIRouter()
def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]:
"""
处理标签:查询已存在的标签,如果提供了 creator_id 则创建不存在的标签
"""
if not tag_string:
return []
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
if not tag_names:
return []
# 如果提供了 creator_id则创建不存在的标签
if creator_id:
insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)"
cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names])
# 查询所有标签信息
format_strings = ', '.join(['%s'] * len(tag_names))
cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names))
tags_data = cursor.fetchall()
return [Tag(**tag) for tag in tags_data]
@router.get("/knowledge-bases", response_model=KnowledgeBaseListResponse)
def get_knowledge_bases(
page: int = 1,
size: int = 10,
is_shared: Optional[bool] = None,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
base_query = "FROM knowledge_bases kb JOIN users u ON kb.creator_id = u.user_id"
where_clauses = []
params = []
if is_shared is not None:
if is_shared:
where_clauses.append("kb.is_shared = 1")
else: # Personal
where_clauses.append("kb.is_shared = 0 AND kb.creator_id = %s")
params.append(current_user['user_id'])
else: # Both personal and shared
where_clauses.append("(kb.is_shared = 1 OR kb.creator_id = %s)")
params.append(current_user['user_id'])
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
count_query = "SELECT COUNT(*) as total " + base_query + where_sql
cursor.execute(count_query, tuple(params))
total = cursor.fetchone()['total']
offset = (page - 1) * size
query = f"""
SELECT
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at
{base_query}
{where_sql}
ORDER BY kb.updated_at DESC
LIMIT %s OFFSET %s
"""
query_params = params + [size, offset]
cursor.execute(query, tuple(query_params))
kbs_data = cursor.fetchall()
kb_list = []
for kb_data in kbs_data:
# 列表页不需要处理 tags直接使用字符串
# kb_data['tags'] 保持原样(逗号分隔的标签名称字符串)
# Count source meetings - filter empty strings
if kb_data.get('source_meeting_ids'):
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
kb_data['source_meeting_count'] = len(meeting_ids)
else:
kb_data['source_meeting_count'] = 0
# Add created_by_name for consistency
kb_data['created_by_name'] = kb_data.get('creator_caption')
kb_list.append(KnowledgeBase(**kb_data))
return KnowledgeBaseListResponse(kbs=kb_list, total=total)
@router.post("/knowledge-bases")
def create_knowledge_base(
request: CreateKnowledgeBaseRequest,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor()
# 自动生成标题,格式为: YYYY-MM-DD 知识条目
if not request.title:
now = datetime.datetime.now()
request.title = now.strftime("%Y-%m-%d") + " 知识条目"
# Create the knowledge base entry first
insert_kb_query = """
INSERT INTO knowledge_bases (title, creator_id, is_shared, source_meeting_ids, user_prompt, tags, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
"""
now = datetime.datetime.utcnow()
cursor.execute(insert_kb_query, (
request.title,
current_user['user_id'],
request.is_shared,
request.source_meeting_ids,
request.user_prompt,
request.tags, # 创建时 tags 应该为 None 或空字符串
now,
now
))
kb_id = cursor.lastrowid
# Start the async task
task_id = async_kb_service.start_generation(
user_id=current_user['user_id'],
kb_id=kb_id,
user_prompt=request.user_prompt,
source_meeting_ids=request.source_meeting_ids,
cursor=cursor
)
connection.commit()
# Add the background task to process the knowledge base generation
background_tasks.add_task(async_kb_service._process_task, task_id)
return {"task_id": task_id, "kb_id": kb_id}
@router.get("/knowledge-bases/{kb_id}")
def get_knowledge_base_detail(
kb_id: int,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = """
SELECT
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at,
u.username as created_by_name
FROM knowledge_bases kb
JOIN users u ON kb.creator_id = u.user_id
WHERE kb.kb_id = %s
"""
cursor.execute(query, (kb_id,))
kb_data = cursor.fetchone()
if not kb_data:
raise HTTPException(status_code=404, detail="Knowledge base not found")
# Check access permissions
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Access denied")
# Process tags - 获取标签的完整信息(包括颜色)
# 详情页不需要创建新标签,所以不传 creator_id
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
# Get source meetings details
source_meetings = []
if kb_data.get('source_meeting_ids'):
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
if meeting_ids:
placeholders = ','.join(['%s'] * len(meeting_ids))
meeting_query = f"""
SELECT meeting_id, title
FROM meetings
WHERE meeting_id IN ({placeholders})
"""
cursor.execute(meeting_query, tuple(meeting_ids))
meetings_data = cursor.fetchall()
source_meetings = [{'meeting_id': m['meeting_id'], 'title': m['title']} for m in meetings_data]
kb_data['source_meeting_count'] = len(source_meetings)
else:
kb_data['source_meeting_count'] = 0
else:
kb_data['source_meeting_count'] = 0
kb_data['source_meetings'] = source_meetings
return kb_data
@router.put("/knowledge-bases/{kb_id}")
def update_knowledge_base(
kb_id: int,
request: UpdateKnowledgeBaseRequest,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# Check if knowledge base exists and user has permission
cursor.execute(
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
(kb_id,)
)
kb = cursor.fetchone()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
if kb['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base")
# 使用 _process_tags 处理标签(会自动创建新标签)
if request.tags:
_process_tags(cursor, request.tags, current_user['user_id'])
# Update the knowledge base
now = datetime.datetime.utcnow()
update_query = """
UPDATE knowledge_bases
SET title = %s, content = %s, tags = %s, updated_at = %s
WHERE kb_id = %s
"""
cursor.execute(update_query, (
request.title,
request.content,
request.tags,
now,
kb_id
))
connection.commit()
return {"message": "Knowledge base updated successfully"}
@router.delete("/knowledge-bases/{kb_id}")
def delete_knowledge_base(
kb_id: int,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# Check if knowledge base exists and user has permission
cursor.execute(
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
(kb_id,)
)
kb = cursor.fetchone()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
if kb['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Only the creator can delete this knowledge base")
# Delete the knowledge base
cursor.execute("DELETE FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
connection.commit()
return {"message": "Knowledge base deleted successfully"}
@router.get("/knowledge-bases/tasks/{task_id}")
def get_task_status(task_id: str):
"""获取知识库生成任务状态"""
task_status = async_kb_service.get_task_status(task_id)
if task_status.get('status') == 'not_found':
raise HTTPException(status_code=404, detail="Task not found")
return {
"status": task_status.get('status'),
"progress": task_status.get('progress', 0),
"error": task_status.get('error_message')
}