初步实现了知识库模块

main
mula.liu 2025-10-16 11:09:19 +08:00
parent 7f9c9fb950
commit a5f544d7a2
6 changed files with 481 additions and 3 deletions

BIN
.DS_Store vendored

Binary file not shown.

BIN
app/.DS_Store vendored

Binary file not shown.

View File

@ -0,0 +1,241 @@
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from typing import Optional, List
from app.models.models import KnowledgeBase, KnowledgeBaseListResponse, CreateKnowledgeBaseRequest, Tag
from app.core.database import get_db_connection
from app.core.auth import get_current_user
from app.core.response import create_api_response
from app.services.async_knowledge_base_service import async_kb_service
import datetime
router = APIRouter()
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
if not tag_string:
return []
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
if not tag_names:
return []
placeholders = ','.join(['%s'] * len(tag_names))
select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})"
cursor.execute(select_query, tuple(tag_names))
tags_data = cursor.fetchall()
existing_tags = {tag['name']: tag for tag in tags_data}
new_tags = [name for name in tag_names if name not in existing_tags]
if new_tags:
insert_query = "INSERT INTO tags (name) VALUES (%s)"
cursor.executemany(insert_query, [(name,) for name in new_tags])
# Re-fetch all tags to get their IDs and default colors
cursor.execute(select_query, tuple(tag_names))
tags_data = cursor.fetchall()
return [Tag(**tag) for tag in tags_data]
@router.get("/knowledge-bases", response_model=KnowledgeBaseListResponse)
def get_knowledge_bases(
page: int = 1,
size: int = 10,
is_shared: Optional[bool] = None,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
base_query = "FROM knowledge_bases kb JOIN users u ON kb.creator_id = u.user_id"
where_clauses = []
params = []
if is_shared is not None:
if is_shared:
where_clauses.append("kb.is_shared = 1")
else: # Personal
where_clauses.append("kb.is_shared = 0 AND kb.creator_id = %s")
params.append(current_user['user_id'])
else: # Both personal and shared
where_clauses.append("(kb.is_shared = 1 OR kb.creator_id = %s)")
params.append(current_user['user_id'])
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
count_query = "SELECT COUNT(*) as total " + base_query + where_sql
cursor.execute(count_query, tuple(params))
total = cursor.fetchone()['total']
offset = (page - 1) * size
query = f"""
SELECT
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at
{base_query}
{where_sql}
ORDER BY kb.updated_at DESC
LIMIT %s OFFSET %s
"""
query_params = params + [size, offset]
cursor.execute(query, tuple(query_params))
kbs_data = cursor.fetchall()
kb_list = []
for kb_data in kbs_data:
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
# Count source meetings - filter empty strings
if kb_data.get('source_meeting_ids'):
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
kb_data['source_meeting_count'] = len(meeting_ids)
else:
kb_data['source_meeting_count'] = 0
# Add created_by_name for consistency
kb_data['created_by_name'] = kb_data.get('creator_caption')
kb_list.append(KnowledgeBase(**kb_data))
return KnowledgeBaseListResponse(kbs=kb_list, total=total)
@router.post("/knowledge-bases")
def create_knowledge_base(
request: CreateKnowledgeBaseRequest,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor()
# 自动生成标题,格式为: YYYY-MM-DD 知识条目
if not request.title:
now = datetime.datetime.now()
request.title = now.strftime("%Y-%m-%d") + " 知识条目"
# Create the knowledge base entry first
insert_kb_query = """
INSERT INTO knowledge_bases (title, creator_id, is_shared, source_meeting_ids, user_prompt, tags, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
"""
now = datetime.datetime.utcnow()
cursor.execute(insert_kb_query, (
request.title,
current_user['user_id'],
request.is_shared,
request.source_meeting_ids,
request.user_prompt,
request.tags,
now,
now
))
kb_id = cursor.lastrowid
# Start the async task
task_id = async_kb_service.start_generation(
user_id=current_user['user_id'],
kb_id=kb_id,
user_prompt=request.user_prompt,
source_meeting_ids=request.source_meeting_ids,
cursor=cursor
)
connection.commit()
# Add the background task to process the knowledge base generation
background_tasks.add_task(async_kb_service._process_task, task_id)
return {"task_id": task_id, "kb_id": kb_id}
@router.get("/knowledge-bases/{kb_id}")
def get_knowledge_base_detail(
kb_id: int,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = """
SELECT
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at,
u.username as created_by_name
FROM knowledge_bases kb
JOIN users u ON kb.creator_id = u.user_id
WHERE kb.kb_id = %s
"""
cursor.execute(query, (kb_id,))
kb_data = cursor.fetchone()
if not kb_data:
raise HTTPException(status_code=404, detail="Knowledge base not found")
# Check access permissions
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Access denied")
# Process tags
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
# Get source meetings details
source_meetings = []
if kb_data.get('source_meeting_ids'):
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
if meeting_ids:
placeholders = ','.join(['%s'] * len(meeting_ids))
meeting_query = f"""
SELECT meeting_id, title
FROM meetings
WHERE meeting_id IN ({placeholders})
"""
cursor.execute(meeting_query, tuple(meeting_ids))
meetings_data = cursor.fetchall()
source_meetings = [{'meeting_id': m['meeting_id'], 'title': m['title']} for m in meetings_data]
kb_data['source_meeting_count'] = len(source_meetings)
else:
kb_data['source_meeting_count'] = 0
else:
kb_data['source_meeting_count'] = 0
kb_data['source_meetings'] = source_meetings
return kb_data
@router.delete("/knowledge-bases/{kb_id}")
def delete_knowledge_base(
kb_id: int,
current_user: dict = Depends(get_current_user)
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# Check if knowledge base exists and user has permission
cursor.execute(
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
(kb_id,)
)
kb = cursor.fetchone()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
if kb['creator_id'] != current_user['user_id']:
raise HTTPException(status_code=403, detail="Only the creator can delete this knowledge base")
# Delete the knowledge base
cursor.execute("DELETE FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
connection.commit()
return {"message": "Knowledge base deleted successfully"}
@router.get("/knowledge-bases/tasks/{task_id}")
def get_task_status(task_id: str):
"""获取知识库生成任务状态"""
task_status = async_kb_service.get_task_status(task_id)
if task_status.get('status') == 'not_found':
raise HTTPException(status_code=404, detail="Task not found")
return {
"status": task_status.get('status'),
"progress": task_status.get('progress', 0),
"error": task_status.get('error_message')
}

View File

@ -117,4 +117,39 @@ class BatchTranscriptUpdateRequest(BaseModel):
class PasswordChangeRequest(BaseModel):
old_password: str
new_password: str
new_password: str
class KnowledgeBase(BaseModel):
kb_id: int
title: str
content: Optional[str] = None
creator_id: int
creator_caption: str # To show in the UI
is_shared: bool
source_meeting_ids: Optional[str] = None
tags: Optional[List[Tag]] = []
created_at: datetime.datetime
updated_at: datetime.datetime
class KnowledgeBaseTask(BaseModel):
task_id: str
user_id: int
kb_id: int
user_prompt: Optional[str] = None
status: str
progress: int
error_message: Optional[str] = None
created_at: datetime.datetime
updated_at: datetime.datetime
completed_at: Optional[datetime.datetime] = None
class CreateKnowledgeBaseRequest(BaseModel):
title: Optional[str] = None # 改为可选,后台自动生成
is_shared: bool
user_prompt: Optional[str] = None
source_meeting_ids: Optional[str] = None
tags: Optional[str] = None
class KnowledgeBaseListResponse(BaseModel):
kbs: List[KnowledgeBase]
total: int

View File

@ -0,0 +1,201 @@
import uuid
from datetime import datetime
from typing import Optional, Dict, Any, List
import redis
from app.core.database import get_db_connection
from app.services.llm_service import LLMService
class AsyncKnowledgeBaseService:
def __init__(self):
from app.core.config import REDIS_CONFIG
if 'decode_responses' not in REDIS_CONFIG:
REDIS_CONFIG['decode_responses'] = True
self.redis_client = redis.Redis(**REDIS_CONFIG)
self.llm_service = LLMService()
def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str:
task_id = str(uuid.uuid4())
# If a cursor is passed, use it directly to avoid creating a new transaction
if cursor:
query = """
INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at)
VALUES (%s, %s, %s, %s, NOW())
"""
cursor.execute(query, (task_id, user_id, kb_id, user_prompt))
else:
# Fallback to the old method if no cursor is provided
self._save_task_to_db(task_id, user_id, kb_id, user_prompt)
current_time = datetime.now().isoformat()
task_data = {
'task_id': task_id,
'user_id': str(user_id),
'kb_id': str(kb_id),
'user_prompt': user_prompt if user_prompt else "",
'status': 'pending',
'progress': '0',
'created_at': current_time,
'updated_at': current_time
}
self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data)
self.redis_client.expire(f"kb_task:{task_id}", 86400)
print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}")
return task_id
def _process_task(self, task_id: str):
print(f"Background task started for knowledge base task: {task_id}")
try:
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
if not task_data:
print(f"Error: Task {task_id} not found in Redis for processing.")
return
kb_id = int(task_data['kb_id'])
user_prompt = task_data.get('user_prompt', '')
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# Get source meeting summaries
source_text = ""
cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
kb_info = cursor.fetchone()
if kb_info and kb_info['source_meeting_ids']:
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()]
if meeting_ids:
summaries = []
for meeting_id in meeting_ids:
cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,))
summary = cursor.fetchone()
if summary and summary['summary']:
summaries.append(summary['summary'])
source_text = "\n\n---\n\n".join(summaries)
# Get system prompt
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取知识库生成模版...")
system_prompt = self._get_knowledge_task_prompt(cursor)
# Build final prompt
final_prompt = f"{system_prompt}\n\n"
if source_text:
final_prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
final_prompt += f"用户要求:{user_prompt}"
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
generated_content = self.llm_service._call_llm_api(final_prompt)
if not generated_content:
raise Exception("LLM API call failed or returned empty content")
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
self._save_result_to_db(kb_id, generated_content, cursor)
self._update_task_in_db(task_id, 'completed', 100, cursor=cursor)
self._update_task_status_in_redis(task_id, 'completed', 100)
connection.commit()
print(f"Task {task_id} completed successfully")
except Exception as e:
error_msg = str(e)
print(f"Task {task_id} failed: {error_msg}")
# Use a new connection for error logging to avoid issues with a potentially broken transaction
with get_db_connection() as err_conn:
err_cursor = err_conn.cursor()
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg, cursor=err_cursor)
err_conn.commit()
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
def _get_knowledge_task_prompt(self, cursor) -> str:
query = """
SELECT p.content
FROM prompt_config pc
JOIN prompts p ON pc.prompt_id = p.id
WHERE pc.task_name = 'KNOWLEDGE_TASK'
"""
cursor.execute(query)
result = cursor.fetchone()
if result:
return result['content']
else:
# Fallback prompt
return "Please generate a knowledge base article based on the provided information."
def _save_result_to_db(self, kb_id: int, content: str, cursor):
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s"
cursor.execute(query, (content, kb_id))
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None, cursor=None):
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
cursor.execute(query, (status, progress, error_message, status, task_id))
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
update_data = {
'status': status,
'progress': str(progress),
'updated_at': datetime.now().isoformat()
}
if message: update_data['message'] = message
if error_message: update_data['error_message'] = error_message
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""获取任务状态 - 与 async_llm_service 保持一致"""
try:
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
if not task_data:
task_data = self._get_task_from_db(task_id)
if not task_data:
return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'}
return {
'task_id': task_id,
'status': task_data.get('status', 'unknown'),
'progress': int(task_data.get('progress', 0)),
'kb_id': int(task_data.get('kb_id', 0)),
'created_at': task_data.get('created_at'),
'updated_at': task_data.get('updated_at'),
'error_message': task_data.get('error_message')
}
except Exception as e:
print(f"Error getting task status: {e}")
return {'task_id': task_id, 'status': 'error', 'error_message': str(e)}
def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str):
"""保存任务到数据库"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
insert_query = "INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, %s, 'pending', 0, NOW())"
cursor.execute(insert_query, (task_id, user_id, kb_id, user_prompt))
connection.commit()
except Exception as e:
print(f"Error saving task to database: {e}")
raise
def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]:
"""从数据库获取任务信息"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = "SELECT * FROM knowledge_base_tasks WHERE task_id = %s"
cursor.execute(query, (task_id,))
task = cursor.fetchone()
if task:
# 确保所有字段都是字符串以匹配Redis的行为
return {k: v.isoformat() if isinstance(v, datetime) else str(v) for k, v in task.items()}
return None
except Exception as e:
print(f"Error getting task from database: {e}")
return None
async_kb_service = AsyncKnowledgeBaseService()

View File

@ -2,7 +2,7 @@ import uvicorn
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base
from app.core.config import UPLOAD_DIR, API_CONFIG
from app.api.endpoints.admin import load_system_config
import os
@ -19,7 +19,7 @@ load_system_config()
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
# allow_origins=API_CONFIG['cors_origins'],
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
@ -37,6 +37,7 @@ app.include_router(tags.router, prefix="/api", tags=["Tags"])
app.include_router(admin.router, prefix="/api", tags=["Admin"])
app.include_router(tasks.router, prefix="/api", tags=["Tasks"])
app.include_router(prompts.router, prefix="/api", tags=["Prompts"])
app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"])
@app.get("/")
def read_root():