diff --git a/.DS_Store b/.DS_Store index 00dfbbc..c2b0d32 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/app/.DS_Store b/app/.DS_Store index 941feef..3a0254a 100644 Binary files a/app/.DS_Store and b/app/.DS_Store differ diff --git a/app/api/endpoints/knowledge_base.py b/app/api/endpoints/knowledge_base.py new file mode 100644 index 0000000..bfc40b4 --- /dev/null +++ b/app/api/endpoints/knowledge_base.py @@ -0,0 +1,241 @@ +from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks +from typing import Optional, List +from app.models.models import KnowledgeBase, KnowledgeBaseListResponse, CreateKnowledgeBaseRequest, Tag +from app.core.database import get_db_connection +from app.core.auth import get_current_user +from app.core.response import create_api_response +from app.services.async_knowledge_base_service import async_kb_service +import datetime + +router = APIRouter() + +def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]: + if not tag_string: + return [] + tag_names = [name.strip() for name in tag_string.split(',') if name.strip()] + if not tag_names: + return [] + + placeholders = ','.join(['%s'] * len(tag_names)) + select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})" + cursor.execute(select_query, tuple(tag_names)) + tags_data = cursor.fetchall() + + existing_tags = {tag['name']: tag for tag in tags_data} + new_tags = [name for name in tag_names if name not in existing_tags] + + if new_tags: + insert_query = "INSERT INTO tags (name) VALUES (%s)" + cursor.executemany(insert_query, [(name,) for name in new_tags]) + + # Re-fetch all tags to get their IDs and default colors + cursor.execute(select_query, tuple(tag_names)) + tags_data = cursor.fetchall() + + return [Tag(**tag) for tag in tags_data] + + +@router.get("/knowledge-bases", response_model=KnowledgeBaseListResponse) +def get_knowledge_bases( + page: int = 1, + size: int = 10, + is_shared: Optional[bool] = None, + current_user: dict = Depends(get_current_user) +): + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + base_query = "FROM knowledge_bases kb JOIN users u ON kb.creator_id = u.user_id" + where_clauses = [] + params = [] + + if is_shared is not None: + if is_shared: + where_clauses.append("kb.is_shared = 1") + else: # Personal + where_clauses.append("kb.is_shared = 0 AND kb.creator_id = %s") + params.append(current_user['user_id']) + else: # Both personal and shared + where_clauses.append("(kb.is_shared = 1 OR kb.creator_id = %s)") + params.append(current_user['user_id']) + + where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" + + count_query = "SELECT COUNT(*) as total " + base_query + where_sql + cursor.execute(count_query, tuple(params)) + total = cursor.fetchone()['total'] + + offset = (page - 1) * size + + query = f""" + SELECT + kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption, + kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at + {base_query} + {where_sql} + ORDER BY kb.updated_at DESC + LIMIT %s OFFSET %s + """ + + query_params = params + [size, offset] + cursor.execute(query, tuple(query_params)) + kbs_data = cursor.fetchall() + + kb_list = [] + for kb_data in kbs_data: + kb_data['tags'] = _process_tags(cursor, kb_data.get('tags')) + # Count source meetings - filter empty strings + if kb_data.get('source_meeting_ids'): + meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()] + kb_data['source_meeting_count'] = len(meeting_ids) + else: + kb_data['source_meeting_count'] = 0 + # Add created_by_name for consistency + kb_data['created_by_name'] = kb_data.get('creator_caption') + kb_list.append(KnowledgeBase(**kb_data)) + + return KnowledgeBaseListResponse(kbs=kb_list, total=total) + +@router.post("/knowledge-bases") +def create_knowledge_base( + request: CreateKnowledgeBaseRequest, + background_tasks: BackgroundTasks, + current_user: dict = Depends(get_current_user) +): + with get_db_connection() as connection: + cursor = connection.cursor() + + # 自动生成标题,格式为: YYYY-MM-DD 知识条目 + if not request.title: + now = datetime.datetime.now() + request.title = now.strftime("%Y-%m-%d") + " 知识条目" + + # Create the knowledge base entry first + insert_kb_query = """ + INSERT INTO knowledge_bases (title, creator_id, is_shared, source_meeting_ids, user_prompt, tags, created_at, updated_at) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + """ + now = datetime.datetime.utcnow() + cursor.execute(insert_kb_query, ( + request.title, + current_user['user_id'], + request.is_shared, + request.source_meeting_ids, + request.user_prompt, + request.tags, + now, + now + )) + kb_id = cursor.lastrowid + + # Start the async task + task_id = async_kb_service.start_generation( + user_id=current_user['user_id'], + kb_id=kb_id, + user_prompt=request.user_prompt, + source_meeting_ids=request.source_meeting_ids, + cursor=cursor + ) + + connection.commit() + + # Add the background task to process the knowledge base generation + background_tasks.add_task(async_kb_service._process_task, task_id) + + return {"task_id": task_id, "kb_id": kb_id} + +@router.get("/knowledge-bases/{kb_id}") +def get_knowledge_base_detail( + kb_id: int, + current_user: dict = Depends(get_current_user) +): + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + query = """ + SELECT + kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption, + kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at, + u.username as created_by_name + FROM knowledge_bases kb + JOIN users u ON kb.creator_id = u.user_id + WHERE kb.kb_id = %s + """ + cursor.execute(query, (kb_id,)) + kb_data = cursor.fetchone() + + if not kb_data: + raise HTTPException(status_code=404, detail="Knowledge base not found") + + # Check access permissions + if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']: + raise HTTPException(status_code=403, detail="Access denied") + + # Process tags + kb_data['tags'] = _process_tags(cursor, kb_data.get('tags')) + + # Get source meetings details + source_meetings = [] + if kb_data.get('source_meeting_ids'): + meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()] + if meeting_ids: + placeholders = ','.join(['%s'] * len(meeting_ids)) + meeting_query = f""" + SELECT meeting_id, title + FROM meetings + WHERE meeting_id IN ({placeholders}) + """ + cursor.execute(meeting_query, tuple(meeting_ids)) + meetings_data = cursor.fetchall() + source_meetings = [{'meeting_id': m['meeting_id'], 'title': m['title']} for m in meetings_data] + kb_data['source_meeting_count'] = len(source_meetings) + else: + kb_data['source_meeting_count'] = 0 + else: + kb_data['source_meeting_count'] = 0 + + kb_data['source_meetings'] = source_meetings + + return kb_data + +@router.delete("/knowledge-bases/{kb_id}") +def delete_knowledge_base( + kb_id: int, + current_user: dict = Depends(get_current_user) +): + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # Check if knowledge base exists and user has permission + cursor.execute( + "SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s", + (kb_id,) + ) + kb = cursor.fetchone() + + if not kb: + raise HTTPException(status_code=404, detail="Knowledge base not found") + + if kb['creator_id'] != current_user['user_id']: + raise HTTPException(status_code=403, detail="Only the creator can delete this knowledge base") + + # Delete the knowledge base + cursor.execute("DELETE FROM knowledge_bases WHERE kb_id = %s", (kb_id,)) + connection.commit() + + return {"message": "Knowledge base deleted successfully"} + +@router.get("/knowledge-bases/tasks/{task_id}") +def get_task_status(task_id: str): + """获取知识库生成任务状态""" + task_status = async_kb_service.get_task_status(task_id) + + if task_status.get('status') == 'not_found': + raise HTTPException(status_code=404, detail="Task not found") + + return { + "status": task_status.get('status'), + "progress": task_status.get('progress', 0), + "error": task_status.get('error_message') + } + diff --git a/app/models/models.py b/app/models/models.py index 6669d47..f0ff74d 100644 --- a/app/models/models.py +++ b/app/models/models.py @@ -117,4 +117,39 @@ class BatchTranscriptUpdateRequest(BaseModel): class PasswordChangeRequest(BaseModel): old_password: str - new_password: str \ No newline at end of file + new_password: str + +class KnowledgeBase(BaseModel): + kb_id: int + title: str + content: Optional[str] = None + creator_id: int + creator_caption: str # To show in the UI + is_shared: bool + source_meeting_ids: Optional[str] = None + tags: Optional[List[Tag]] = [] + created_at: datetime.datetime + updated_at: datetime.datetime + +class KnowledgeBaseTask(BaseModel): + task_id: str + user_id: int + kb_id: int + user_prompt: Optional[str] = None + status: str + progress: int + error_message: Optional[str] = None + created_at: datetime.datetime + updated_at: datetime.datetime + completed_at: Optional[datetime.datetime] = None + +class CreateKnowledgeBaseRequest(BaseModel): + title: Optional[str] = None # 改为可选,后台自动生成 + is_shared: bool + user_prompt: Optional[str] = None + source_meeting_ids: Optional[str] = None + tags: Optional[str] = None + +class KnowledgeBaseListResponse(BaseModel): + kbs: List[KnowledgeBase] + total: int diff --git a/app/services/async_knowledge_base_service.py b/app/services/async_knowledge_base_service.py new file mode 100644 index 0000000..49cecb5 --- /dev/null +++ b/app/services/async_knowledge_base_service.py @@ -0,0 +1,201 @@ +import uuid +from datetime import datetime +from typing import Optional, Dict, Any, List + +import redis +from app.core.database import get_db_connection +from app.services.llm_service import LLMService + +class AsyncKnowledgeBaseService: + + def __init__(self): + from app.core.config import REDIS_CONFIG + if 'decode_responses' not in REDIS_CONFIG: + REDIS_CONFIG['decode_responses'] = True + self.redis_client = redis.Redis(**REDIS_CONFIG) + self.llm_service = LLMService() + + def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str: + task_id = str(uuid.uuid4()) + + # If a cursor is passed, use it directly to avoid creating a new transaction + if cursor: + query = """ + INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at) + VALUES (%s, %s, %s, %s, NOW()) + """ + cursor.execute(query, (task_id, user_id, kb_id, user_prompt)) + else: + # Fallback to the old method if no cursor is provided + self._save_task_to_db(task_id, user_id, kb_id, user_prompt) + + current_time = datetime.now().isoformat() + task_data = { + 'task_id': task_id, + 'user_id': str(user_id), + 'kb_id': str(kb_id), + 'user_prompt': user_prompt if user_prompt else "", + 'status': 'pending', + 'progress': '0', + 'created_at': current_time, + 'updated_at': current_time + } + self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data) + self.redis_client.expire(f"kb_task:{task_id}", 86400) + + print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}") + return task_id + + def _process_task(self, task_id: str): + print(f"Background task started for knowledge base task: {task_id}") + try: + task_data = self.redis_client.hgetall(f"kb_task:{task_id}") + if not task_data: + print(f"Error: Task {task_id} not found in Redis for processing.") + return + + kb_id = int(task_data['kb_id']) + user_prompt = task_data.get('user_prompt', '') + + self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...") + + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + + # Get source meeting summaries + source_text = "" + cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,)) + kb_info = cursor.fetchone() + if kb_info and kb_info['source_meeting_ids']: + self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...") + meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()] + if meeting_ids: + summaries = [] + for meeting_id in meeting_ids: + cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,)) + summary = cursor.fetchone() + if summary and summary['summary']: + summaries.append(summary['summary']) + source_text = "\n\n---\n\n".join(summaries) + + # Get system prompt + self._update_task_status_in_redis(task_id, 'processing', 30, message="获取知识库生成模版...") + system_prompt = self._get_knowledge_task_prompt(cursor) + + # Build final prompt + final_prompt = f"{system_prompt}\n\n" + if source_text: + final_prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n" + final_prompt += f"用户要求:{user_prompt}" + + self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...") + generated_content = self.llm_service._call_llm_api(final_prompt) + + if not generated_content: + raise Exception("LLM API call failed or returned empty content") + + self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...") + self._save_result_to_db(kb_id, generated_content, cursor) + + self._update_task_in_db(task_id, 'completed', 100, cursor=cursor) + self._update_task_status_in_redis(task_id, 'completed', 100) + + connection.commit() + print(f"Task {task_id} completed successfully") + + except Exception as e: + error_msg = str(e) + print(f"Task {task_id} failed: {error_msg}") + # Use a new connection for error logging to avoid issues with a potentially broken transaction + with get_db_connection() as err_conn: + err_cursor = err_conn.cursor() + self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg, cursor=err_cursor) + err_conn.commit() + self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg) + + def _get_knowledge_task_prompt(self, cursor) -> str: + query = """ + SELECT p.content + FROM prompt_config pc + JOIN prompts p ON pc.prompt_id = p.id + WHERE pc.task_name = 'KNOWLEDGE_TASK' + """ + cursor.execute(query) + result = cursor.fetchone() + if result: + return result['content'] + else: + # Fallback prompt + return "Please generate a knowledge base article based on the provided information." + + + + def _save_result_to_db(self, kb_id: int, content: str, cursor): + query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s" + cursor.execute(query, (content, kb_id)) + + def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None, cursor=None): + query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s" + cursor.execute(query, (status, progress, error_message, status, task_id)) + + def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None): + update_data = { + 'status': status, + 'progress': str(progress), + 'updated_at': datetime.now().isoformat() + } + if message: update_data['message'] = message + if error_message: update_data['error_message'] = error_message + self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data) + + def get_task_status(self, task_id: str) -> Dict[str, Any]: + """获取任务状态 - 与 async_llm_service 保持一致""" + try: + task_data = self.redis_client.hgetall(f"kb_task:{task_id}") + if not task_data: + task_data = self._get_task_from_db(task_id) + if not task_data: + return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'} + + return { + 'task_id': task_id, + 'status': task_data.get('status', 'unknown'), + 'progress': int(task_data.get('progress', 0)), + 'kb_id': int(task_data.get('kb_id', 0)), + 'created_at': task_data.get('created_at'), + 'updated_at': task_data.get('updated_at'), + 'error_message': task_data.get('error_message') + } + except Exception as e: + print(f"Error getting task status: {e}") + return {'task_id': task_id, 'status': 'error', 'error_message': str(e)} + + def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str): + """保存任务到数据库""" + try: + with get_db_connection() as connection: + cursor = connection.cursor() + insert_query = "INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, %s, 'pending', 0, NOW())" + cursor.execute(insert_query, (task_id, user_id, kb_id, user_prompt)) + connection.commit() + except Exception as e: + print(f"Error saving task to database: {e}") + raise + + def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]: + """从数据库获取任务信息""" + try: + with get_db_connection() as connection: + cursor = connection.cursor(dictionary=True) + query = "SELECT * FROM knowledge_base_tasks WHERE task_id = %s" + cursor.execute(query, (task_id,)) + task = cursor.fetchone() + if task: + # 确保所有字段都是字符串,以匹配Redis的行为 + return {k: v.isoformat() if isinstance(v, datetime) else str(v) for k, v in task.items()} + return None + except Exception as e: + print(f"Error getting task from database: {e}") + return None + +async_kb_service = AsyncKnowledgeBaseService() diff --git a/main.py b/main.py index 7819ec4..701dc80 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ import uvicorn from fastapi import FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles -from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts +from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base from app.core.config import UPLOAD_DIR, API_CONFIG from app.api.endpoints.admin import load_system_config import os @@ -19,7 +19,7 @@ load_system_config() # 添加CORS中间件 app.add_middleware( CORSMiddleware, - # allow_origins=API_CONFIG['cors_origins'], + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -37,6 +37,7 @@ app.include_router(tags.router, prefix="/api", tags=["Tags"]) app.include_router(admin.router, prefix="/api", tags=["Admin"]) app.include_router(tasks.router, prefix="/api", tags=["Tasks"]) app.include_router(prompts.router, prefix="/api", tags=["Prompts"]) +app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"]) @app.get("/") def read_root():