imetting_backend/app/api/endpoints/knowledge_base.py

282 lines
10 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from typing import Optional, List
from app.models.models import KnowledgeBase, KnowledgeBaseListResponse, CreateKnowledgeBaseRequest, UpdateKnowledgeBaseRequest, Tag
from app.core.database import get_db_connection
from app.core.auth import get_current_user
from app.core.response import create_api_response
from app.services.async_knowledge_base_service import async_kb_service
import datetime
router = APIRouter()
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
if not tag_string:
return []
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
if not tag_names:
return []
placeholders = ','.join(['%s'] * len(tag_names))
select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})"
cursor.execute(select_query, tuple(tag_names))
tags_data = cursor.fetchall()
existing_tags = {tag['name']: tag for tag in tags_data}
new_tags = [name for name in tag_names if name not in existing_tags]
if new_tags:
insert_query = "INSERT INTO tags (name) VALUES (%s)"
cursor.executemany(insert_query, [(name,) for name in new_tags])
# Re-fetch all tags to get their IDs and default colors
cursor.execute(select_query, tuple(tag_names))
tags_data = cursor.fetchall()
return [Tag(**tag) for tag in tags_data]
@router.get("/knowledge-bases", response_model=KnowledgeBaseListResponse)
def get_knowledge_bases(
page: int = 1,
size: int = 10,
is_shared: Optional[bool] = None,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
base_query = "FROM knowledge_bases kb JOIN users u ON kb.creator_id = u.user_id"
where_clauses = []
params = []
if is_shared is not None:
if is_shared:
where_clauses.append("kb.is_shared = 1")
else: # Personal
where_clauses.append("kb.is_shared = 0 AND kb.creator_id = %s")
params.append(current_user['user_id'])
else: # Both personal and shared
where_clauses.append("(kb.is_shared = 1 OR kb.creator_id = %s)")
params.append(current_user['user_id'])
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
count_query = "SELECT COUNT(*) as total " + base_query + where_sql
cursor.execute(count_query, tuple(params))
total = cursor.fetchone()['total']
offset = (page - 1) * size
query = f"""
SELECT
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at
{base_query}
{where_sql}
ORDER BY kb.updated_at DESC
LIMIT %s OFFSET %s
"""
query_params = params + [size, offset]
cursor.execute(query, tuple(query_params))
kbs_data = cursor.fetchall()
kb_list = []
for kb_data in kbs_data:
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
# Count source meetings - filter empty strings
if kb_data.get('source_meeting_ids'):
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
kb_data['source_meeting_count'] = len(meeting_ids)
else:
kb_data['source_meeting_count'] = 0
# Add created_by_name for consistency
kb_data['created_by_name'] = kb_data.get('creator_caption')
kb_list.append(KnowledgeBase(**kb_data))
return KnowledgeBaseListResponse(kbs=kb_list, total=total)
@router.post("/knowledge-bases")
def create_knowledge_base(
request: CreateKnowledgeBaseRequest,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor()
# 自动生成标题,格式为: YYYY-MM-DD 知识条目
if not request.title:
now = datetime.datetime.now()
request.title = now.strftime("%Y-%m-%d") + " 知识条目"
# Create the knowledge base entry first
insert_kb_query = """
INSERT INTO knowledge_bases (title, creator_id, is_shared, source_meeting_ids, user_prompt, tags, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
"""
now = datetime.datetime.utcnow()
cursor.execute(insert_kb_query, (
request.title,
current_user['user_id'],
request.is_shared,
request.source_meeting_ids,
request.user_prompt,
request.tags,
now,
now
))
kb_id = cursor.lastrowid
# Start the async task
task_id = async_kb_service.start_generation(
user_id=current_user['user_id'],
kb_id=kb_id,
user_prompt=request.user_prompt,
source_meeting_ids=request.source_meeting_ids,
cursor=cursor
)
connection.commit()
# Add the background task to process the knowledge base generation
background_tasks.add_task(async_kb_service._process_task, task_id)
return {"task_id": task_id, "kb_id": kb_id}
@router.get("/knowledge-bases/{kb_id}")
def get_knowledge_base_detail(
kb_id: int,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = """
SELECT
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at,
u.username as created_by_name
FROM knowledge_bases kb
JOIN users u ON kb.creator_id = u.user_id
WHERE kb.kb_id = %s
"""
cursor.execute(query, (kb_id,))
kb_data = cursor.fetchone()
if not kb_data:
raise HTTPException(status_code=404, detail="Knowledge base not found")
# Check access permissions
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Access denied")
# Process tags
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
# Get source meetings details
source_meetings = []
if kb_data.get('source_meeting_ids'):
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
if meeting_ids:
placeholders = ','.join(['%s'] * len(meeting_ids))
meeting_query = f"""
SELECT meeting_id, title
FROM meetings
WHERE meeting_id IN ({placeholders})
"""
cursor.execute(meeting_query, tuple(meeting_ids))
meetings_data = cursor.fetchall()
source_meetings = [{'meeting_id': m['meeting_id'], 'title': m['title']} for m in meetings_data]
kb_data['source_meeting_count'] = len(source_meetings)
else:
kb_data['source_meeting_count'] = 0
else:
kb_data['source_meeting_count'] = 0
kb_data['source_meetings'] = source_meetings
return kb_data
@router.put("/knowledge-bases/{kb_id}")
def update_knowledge_base(
kb_id: int,
request: UpdateKnowledgeBaseRequest,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# Check if knowledge base exists and user has permission
cursor.execute(
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
(kb_id,)
)
kb = cursor.fetchone()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
if kb['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base")
# Update the knowledge base
now = datetime.datetime.utcnow()
update_query = """
UPDATE knowledge_bases
SET title = %s, content = %s, tags = %s, updated_at = %s
WHERE kb_id = %s
"""
cursor.execute(update_query, (
request.title,
request.content,
request.tags,
now,
kb_id
))
connection.commit()
return {"message": "Knowledge base updated successfully"}
@router.delete("/knowledge-bases/{kb_id}")
def delete_knowledge_base(
kb_id: int,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# Check if knowledge base exists and user has permission
cursor.execute(
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
(kb_id,)
)
kb = cursor.fetchone()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
if kb['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Only the creator can delete this knowledge base")
# Delete the knowledge base
cursor.execute("DELETE FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
connection.commit()
return {"message": "Knowledge base deleted successfully"}
@router.get("/knowledge-bases/tasks/{task_id}")
def get_task_status(task_id: str):
"""获取知识库生成任务状态"""
task_status = async_kb_service.get_task_status(task_id)
if task_status.get('status') == 'not_found':
raise HTTPException(status_code=404, detail="Task not found")
return {
"status": task_status.get('status'),
"progress": task_status.get('progress', 0),
"error": task_status.get('error_message')
}