from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from typing import Optional, List from app.models.models import KnowledgeBase, KnowledgeBaseListResponse, CreateKnowledgeBaseRequest, UpdateKnowledgeBaseRequest, Tag from app.core.database import get_db_connection from app.core.auth import get_current_user from app.core.response import create_api_response from app.services.async_knowledge_base_service import async_kb_service import datetime router = APIRouter() def _process_tags(cursor, tag_string: Optional[str], creator_id: Optional[int] = None) -> List[Tag]: """ 处理标签:查询已存在的标签,如果提供了 creator_id 则创建不存在的标签 """ if not tag_string: return [] tag_names = [name.strip() for name in tag_string.split(',') if name.strip()] if not tag_names: return [] # 如果提供了 creator_id,则创建不存在的标签 if creator_id: insert_ignore_query = "INSERT IGNORE INTO tags (name, creator_id) VALUES (%s, %s)" cursor.executemany(insert_ignore_query, [(name, creator_id) for name in tag_names]) # 查询所有标签信息 format_strings = ', '.join(['%s'] * len(tag_names)) cursor.execute(f"SELECT id, name, color FROM tags WHERE name IN ({format_strings})", tuple(tag_names)) tags_data = cursor.fetchall() return [Tag(**tag) for tag in tags_data] @router.get("/knowledge-bases", response_model=KnowledgeBaseListResponse) def get_knowledge_bases( page: int = 1, size: int = 10, is_shared: Optional[bool] = None, current_user: dict = Depends(get_current_user) ): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) base_query = "FROM knowledge_bases kb JOIN users u ON kb.creator_id = u.user_id" where_clauses = [] params = [] if is_shared is not None: if is_shared: where_clauses.append("kb.is_shared = 1") else: # Personal where_clauses.append("kb.is_shared = 0 AND kb.creator_id = %s") params.append(current_user['user_id']) else: # Both personal and shared where_clauses.append("(kb.is_shared = 1 OR kb.creator_id = %s)") params.append(current_user['user_id']) where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" count_query = "SELECT COUNT(*) as total " + base_query + where_sql cursor.execute(count_query, tuple(params)) total = cursor.fetchone()['total'] offset = (page - 1) * size query = f""" SELECT kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption, kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at {base_query} {where_sql} ORDER BY kb.updated_at DESC LIMIT %s OFFSET %s """ query_params = params + [size, offset] cursor.execute(query, tuple(query_params)) kbs_data = cursor.fetchall() kb_list = [] for kb_data in kbs_data: # 列表页不需要处理 tags,直接使用字符串 # kb_data['tags'] 保持原样(逗号分隔的标签名称字符串) # Count source meetings - filter empty strings if kb_data.get('source_meeting_ids'): meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()] kb_data['source_meeting_count'] = len(meeting_ids) else: kb_data['source_meeting_count'] = 0 # Add created_by_name for consistency kb_data['created_by_name'] = kb_data.get('creator_caption') kb_list.append(KnowledgeBase(**kb_data)) return KnowledgeBaseListResponse(kbs=kb_list, total=total) @router.post("/knowledge-bases") def create_knowledge_base( request: CreateKnowledgeBaseRequest, background_tasks: BackgroundTasks, current_user: dict = Depends(get_current_user) ): with get_db_connection() as connection: cursor = connection.cursor() # 自动生成标题,格式为: YYYY-MM-DD 知识条目 if not request.title: now = datetime.datetime.now() request.title = now.strftime("%Y-%m-%d") + " 知识条目" # Create the knowledge base entry first insert_kb_query = """ INSERT INTO knowledge_bases (title, creator_id, is_shared, source_meeting_ids, user_prompt, tags, created_at, updated_at) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) """ now = datetime.datetime.utcnow() cursor.execute(insert_kb_query, ( request.title, current_user['user_id'], request.is_shared, request.source_meeting_ids, request.user_prompt, request.tags, # 创建时 tags 应该为 None 或空字符串 now, now )) kb_id = cursor.lastrowid # Start the async task task_id = async_kb_service.start_generation( user_id=current_user['user_id'], kb_id=kb_id, user_prompt=request.user_prompt, source_meeting_ids=request.source_meeting_ids, cursor=cursor ) connection.commit() # Add the background task to process the knowledge base generation background_tasks.add_task(async_kb_service._process_task, task_id) return {"task_id": task_id, "kb_id": kb_id} @router.get("/knowledge-bases/{kb_id}") def get_knowledge_base_detail( kb_id: int, current_user: dict = Depends(get_current_user) ): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) query = """ SELECT kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption, kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at, u.username as created_by_name FROM knowledge_bases kb JOIN users u ON kb.creator_id = u.user_id WHERE kb.kb_id = %s """ cursor.execute(query, (kb_id,)) kb_data = cursor.fetchone() if not kb_data: raise HTTPException(status_code=404, detail="Knowledge base not found") # Check access permissions if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']: raise HTTPException(status_code=403, detail="Access denied") # Process tags - 获取标签的完整信息(包括颜色) # 详情页不需要创建新标签,所以不传 creator_id kb_data['tags'] = _process_tags(cursor, kb_data.get('tags')) # Get source meetings details source_meetings = [] if kb_data.get('source_meeting_ids'): meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()] if meeting_ids: placeholders = ','.join(['%s'] * len(meeting_ids)) meeting_query = f""" SELECT meeting_id, title FROM meetings WHERE meeting_id IN ({placeholders}) """ cursor.execute(meeting_query, tuple(meeting_ids)) meetings_data = cursor.fetchall() source_meetings = [{'meeting_id': m['meeting_id'], 'title': m['title']} for m in meetings_data] kb_data['source_meeting_count'] = len(source_meetings) else: kb_data['source_meeting_count'] = 0 else: kb_data['source_meeting_count'] = 0 kb_data['source_meetings'] = source_meetings return kb_data @router.put("/knowledge-bases/{kb_id}") def update_knowledge_base( kb_id: int, request: UpdateKnowledgeBaseRequest, current_user: dict = Depends(get_current_user) ): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) # Check if knowledge base exists and user has permission cursor.execute( "SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s", (kb_id,) ) kb = cursor.fetchone() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") if kb['creator_id'] != current_user['user_id']: raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base") # 使用 _process_tags 处理标签(会自动创建新标签) if request.tags: _process_tags(cursor, request.tags, current_user['user_id']) # Update the knowledge base now = datetime.datetime.utcnow() update_query = """ UPDATE knowledge_bases SET title = %s, content = %s, tags = %s, updated_at = %s WHERE kb_id = %s """ cursor.execute(update_query, ( request.title, request.content, request.tags, now, kb_id )) connection.commit() return {"message": "Knowledge base updated successfully"} @router.delete("/knowledge-bases/{kb_id}") def delete_knowledge_base( kb_id: int, current_user: dict = Depends(get_current_user) ): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) # Check if knowledge base exists and user has permission cursor.execute( "SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s", (kb_id,) ) kb = cursor.fetchone() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") if kb['creator_id'] != current_user['user_id']: raise HTTPException(status_code=403, detail="Only the creator can delete this knowledge base") # Delete the knowledge base cursor.execute("DELETE FROM knowledge_bases WHERE kb_id = %s", (kb_id,)) connection.commit() return {"message": "Knowledge base deleted successfully"} @router.get("/knowledge-bases/tasks/{task_id}") def get_task_status(task_id: str): """获取知识库生成任务状态""" task_status = async_kb_service.get_task_status(task_id) if task_status.get('status') == 'not_found': raise HTTPException(status_code=404, detail="Task not found") return { "status": task_status.get('status'), "progress": task_status.get('progress', 0), "error": task_status.get('error_message') }