282 lines
10 KiB
Python
282 lines
10 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
|
from typing import Optional, List
|
|
from app.models.models import KnowledgeBase, KnowledgeBaseListResponse, CreateKnowledgeBaseRequest, UpdateKnowledgeBaseRequest, Tag
|
|
from app.core.database import get_db_connection
|
|
from app.core.auth import get_current_user
|
|
from app.core.response import create_api_response
|
|
from app.services.async_knowledge_base_service import async_kb_service
|
|
import datetime
|
|
|
|
router = APIRouter()
|
|
|
|
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
|
|
if not tag_string:
|
|
return []
|
|
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
|
|
if not tag_names:
|
|
return []
|
|
|
|
placeholders = ','.join(['%s'] * len(tag_names))
|
|
select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})"
|
|
cursor.execute(select_query, tuple(tag_names))
|
|
tags_data = cursor.fetchall()
|
|
|
|
existing_tags = {tag['name']: tag for tag in tags_data}
|
|
new_tags = [name for name in tag_names if name not in existing_tags]
|
|
|
|
if new_tags:
|
|
insert_query = "INSERT INTO tags (name) VALUES (%s)"
|
|
cursor.executemany(insert_query, [(name,) for name in new_tags])
|
|
|
|
# Re-fetch all tags to get their IDs and default colors
|
|
cursor.execute(select_query, tuple(tag_names))
|
|
tags_data = cursor.fetchall()
|
|
|
|
return [Tag(**tag) for tag in tags_data]
|
|
|
|
|
|
@router.get("/knowledge-bases", response_model=KnowledgeBaseListResponse)
|
|
def get_knowledge_bases(
|
|
page: int = 1,
|
|
size: int = 10,
|
|
is_shared: Optional[bool] = None,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
base_query = "FROM knowledge_bases kb JOIN users u ON kb.creator_id = u.user_id"
|
|
where_clauses = []
|
|
params = []
|
|
|
|
if is_shared is not None:
|
|
if is_shared:
|
|
where_clauses.append("kb.is_shared = 1")
|
|
else: # Personal
|
|
where_clauses.append("kb.is_shared = 0 AND kb.creator_id = %s")
|
|
params.append(current_user['user_id'])
|
|
else: # Both personal and shared
|
|
where_clauses.append("(kb.is_shared = 1 OR kb.creator_id = %s)")
|
|
params.append(current_user['user_id'])
|
|
|
|
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
|
|
|
count_query = "SELECT COUNT(*) as total " + base_query + where_sql
|
|
cursor.execute(count_query, tuple(params))
|
|
total = cursor.fetchone()['total']
|
|
|
|
offset = (page - 1) * size
|
|
|
|
query = f"""
|
|
SELECT
|
|
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
|
|
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at
|
|
{base_query}
|
|
{where_sql}
|
|
ORDER BY kb.updated_at DESC
|
|
LIMIT %s OFFSET %s
|
|
"""
|
|
|
|
query_params = params + [size, offset]
|
|
cursor.execute(query, tuple(query_params))
|
|
kbs_data = cursor.fetchall()
|
|
|
|
kb_list = []
|
|
for kb_data in kbs_data:
|
|
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
|
|
# Count source meetings - filter empty strings
|
|
if kb_data.get('source_meeting_ids'):
|
|
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
|
|
kb_data['source_meeting_count'] = len(meeting_ids)
|
|
else:
|
|
kb_data['source_meeting_count'] = 0
|
|
# Add created_by_name for consistency
|
|
kb_data['created_by_name'] = kb_data.get('creator_caption')
|
|
kb_list.append(KnowledgeBase(**kb_data))
|
|
|
|
return KnowledgeBaseListResponse(kbs=kb_list, total=total)
|
|
|
|
@router.post("/knowledge-bases")
|
|
def create_knowledge_base(
|
|
request: CreateKnowledgeBaseRequest,
|
|
background_tasks: BackgroundTasks,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor()
|
|
|
|
# 自动生成标题,格式为: YYYY-MM-DD 知识条目
|
|
if not request.title:
|
|
now = datetime.datetime.now()
|
|
request.title = now.strftime("%Y-%m-%d") + " 知识条目"
|
|
|
|
# Create the knowledge base entry first
|
|
insert_kb_query = """
|
|
INSERT INTO knowledge_bases (title, creator_id, is_shared, source_meeting_ids, user_prompt, tags, created_at, updated_at)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
|
"""
|
|
now = datetime.datetime.utcnow()
|
|
cursor.execute(insert_kb_query, (
|
|
request.title,
|
|
current_user['user_id'],
|
|
request.is_shared,
|
|
request.source_meeting_ids,
|
|
request.user_prompt,
|
|
request.tags,
|
|
now,
|
|
now
|
|
))
|
|
kb_id = cursor.lastrowid
|
|
|
|
# Start the async task
|
|
task_id = async_kb_service.start_generation(
|
|
user_id=current_user['user_id'],
|
|
kb_id=kb_id,
|
|
user_prompt=request.user_prompt,
|
|
source_meeting_ids=request.source_meeting_ids,
|
|
cursor=cursor
|
|
)
|
|
|
|
connection.commit()
|
|
|
|
# Add the background task to process the knowledge base generation
|
|
background_tasks.add_task(async_kb_service._process_task, task_id)
|
|
|
|
return {"task_id": task_id, "kb_id": kb_id}
|
|
|
|
@router.get("/knowledge-bases/{kb_id}")
|
|
def get_knowledge_base_detail(
|
|
kb_id: int,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
query = """
|
|
SELECT
|
|
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
|
|
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at,
|
|
u.username as created_by_name
|
|
FROM knowledge_bases kb
|
|
JOIN users u ON kb.creator_id = u.user_id
|
|
WHERE kb.kb_id = %s
|
|
"""
|
|
cursor.execute(query, (kb_id,))
|
|
kb_data = cursor.fetchone()
|
|
|
|
if not kb_data:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
# Check access permissions
|
|
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
# Process tags
|
|
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
|
|
|
|
# Get source meetings details
|
|
source_meetings = []
|
|
if kb_data.get('source_meeting_ids'):
|
|
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
|
|
if meeting_ids:
|
|
placeholders = ','.join(['%s'] * len(meeting_ids))
|
|
meeting_query = f"""
|
|
SELECT meeting_id, title
|
|
FROM meetings
|
|
WHERE meeting_id IN ({placeholders})
|
|
"""
|
|
cursor.execute(meeting_query, tuple(meeting_ids))
|
|
meetings_data = cursor.fetchall()
|
|
source_meetings = [{'meeting_id': m['meeting_id'], 'title': m['title']} for m in meetings_data]
|
|
kb_data['source_meeting_count'] = len(source_meetings)
|
|
else:
|
|
kb_data['source_meeting_count'] = 0
|
|
else:
|
|
kb_data['source_meeting_count'] = 0
|
|
|
|
kb_data['source_meetings'] = source_meetings
|
|
|
|
return kb_data
|
|
|
|
@router.put("/knowledge-bases/{kb_id}")
|
|
def update_knowledge_base(
|
|
kb_id: int,
|
|
request: UpdateKnowledgeBaseRequest,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
# Check if knowledge base exists and user has permission
|
|
cursor.execute(
|
|
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
|
|
(kb_id,)
|
|
)
|
|
kb = cursor.fetchone()
|
|
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
if kb['creator_id'] != current_user['user_id']:
|
|
raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base")
|
|
|
|
# Update the knowledge base
|
|
now = datetime.datetime.utcnow()
|
|
update_query = """
|
|
UPDATE knowledge_bases
|
|
SET title = %s, content = %s, tags = %s, updated_at = %s
|
|
WHERE kb_id = %s
|
|
"""
|
|
cursor.execute(update_query, (
|
|
request.title,
|
|
request.content,
|
|
request.tags,
|
|
now,
|
|
kb_id
|
|
))
|
|
connection.commit()
|
|
|
|
return {"message": "Knowledge base updated successfully"}
|
|
|
|
@router.delete("/knowledge-bases/{kb_id}")
|
|
def delete_knowledge_base(
|
|
kb_id: int,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
with get_db_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
# Check if knowledge base exists and user has permission
|
|
cursor.execute(
|
|
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
|
|
(kb_id,)
|
|
)
|
|
kb = cursor.fetchone()
|
|
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
if kb['creator_id'] != current_user['user_id']:
|
|
raise HTTPException(status_code=403, detail="Only the creator can delete this knowledge base")
|
|
|
|
# Delete the knowledge base
|
|
cursor.execute("DELETE FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
|
|
connection.commit()
|
|
|
|
return {"message": "Knowledge base deleted successfully"}
|
|
|
|
@router.get("/knowledge-bases/tasks/{task_id}")
|
|
def get_task_status(task_id: str):
|
|
"""获取知识库生成任务状态"""
|
|
task_status = async_kb_service.get_task_status(task_id)
|
|
|
|
if task_status.get('status') == 'not_found':
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
return {
|
|
"status": task_status.get('status'),
|
|
"progress": task_status.get('progress', 0),
|
|
"error": task_status.get('error_message')
|
|
}
|
|
|