Compare commits
2 Commits
7f9c9fb950
...
b674c75c65
| Author | SHA1 | Date |
|---|---|---|
|
|
b674c75c65 | |
|
|
a5f544d7a2 |
Binary file not shown.
|
|
@ -0,0 +1,281 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from typing import Optional, List
|
||||
from app.models.models import KnowledgeBase, KnowledgeBaseListResponse, CreateKnowledgeBaseRequest, UpdateKnowledgeBaseRequest, Tag
|
||||
from app.core.database import get_db_connection
|
||||
from app.core.auth import get_current_user
|
||||
from app.core.response import create_api_response
|
||||
from app.services.async_knowledge_base_service import async_kb_service
|
||||
import datetime
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def _process_tags(cursor, tag_string: Optional[str]) -> List[Tag]:
|
||||
if not tag_string:
|
||||
return []
|
||||
tag_names = [name.strip() for name in tag_string.split(',') if name.strip()]
|
||||
if not tag_names:
|
||||
return []
|
||||
|
||||
placeholders = ','.join(['%s'] * len(tag_names))
|
||||
select_query = f"SELECT id, name, color FROM tags WHERE name IN ({placeholders})"
|
||||
cursor.execute(select_query, tuple(tag_names))
|
||||
tags_data = cursor.fetchall()
|
||||
|
||||
existing_tags = {tag['name']: tag for tag in tags_data}
|
||||
new_tags = [name for name in tag_names if name not in existing_tags]
|
||||
|
||||
if new_tags:
|
||||
insert_query = "INSERT INTO tags (name) VALUES (%s)"
|
||||
cursor.executemany(insert_query, [(name,) for name in new_tags])
|
||||
|
||||
# Re-fetch all tags to get their IDs and default colors
|
||||
cursor.execute(select_query, tuple(tag_names))
|
||||
tags_data = cursor.fetchall()
|
||||
|
||||
return [Tag(**tag) for tag in tags_data]
|
||||
|
||||
|
||||
@router.get("/knowledge-bases", response_model=KnowledgeBaseListResponse)
|
||||
def get_knowledge_bases(
|
||||
page: int = 1,
|
||||
size: int = 10,
|
||||
is_shared: Optional[bool] = None,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
base_query = "FROM knowledge_bases kb JOIN users u ON kb.creator_id = u.user_id"
|
||||
where_clauses = []
|
||||
params = []
|
||||
|
||||
if is_shared is not None:
|
||||
if is_shared:
|
||||
where_clauses.append("kb.is_shared = 1")
|
||||
else: # Personal
|
||||
where_clauses.append("kb.is_shared = 0 AND kb.creator_id = %s")
|
||||
params.append(current_user['user_id'])
|
||||
else: # Both personal and shared
|
||||
where_clauses.append("(kb.is_shared = 1 OR kb.creator_id = %s)")
|
||||
params.append(current_user['user_id'])
|
||||
|
||||
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
||||
|
||||
count_query = "SELECT COUNT(*) as total " + base_query + where_sql
|
||||
cursor.execute(count_query, tuple(params))
|
||||
total = cursor.fetchone()['total']
|
||||
|
||||
offset = (page - 1) * size
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
|
||||
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at
|
||||
{base_query}
|
||||
{where_sql}
|
||||
ORDER BY kb.updated_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
"""
|
||||
|
||||
query_params = params + [size, offset]
|
||||
cursor.execute(query, tuple(query_params))
|
||||
kbs_data = cursor.fetchall()
|
||||
|
||||
kb_list = []
|
||||
for kb_data in kbs_data:
|
||||
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
|
||||
# Count source meetings - filter empty strings
|
||||
if kb_data.get('source_meeting_ids'):
|
||||
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
|
||||
kb_data['source_meeting_count'] = len(meeting_ids)
|
||||
else:
|
||||
kb_data['source_meeting_count'] = 0
|
||||
# Add created_by_name for consistency
|
||||
kb_data['created_by_name'] = kb_data.get('creator_caption')
|
||||
kb_list.append(KnowledgeBase(**kb_data))
|
||||
|
||||
return KnowledgeBaseListResponse(kbs=kb_list, total=total)
|
||||
|
||||
@router.post("/knowledge-bases")
|
||||
def create_knowledge_base(
|
||||
request: CreateKnowledgeBaseRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
|
||||
# 自动生成标题,格式为: YYYY-MM-DD 知识条目
|
||||
if not request.title:
|
||||
now = datetime.datetime.now()
|
||||
request.title = now.strftime("%Y-%m-%d") + " 知识条目"
|
||||
|
||||
# Create the knowledge base entry first
|
||||
insert_kb_query = """
|
||||
INSERT INTO knowledge_bases (title, creator_id, is_shared, source_meeting_ids, user_prompt, tags, created_at, updated_at)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
now = datetime.datetime.utcnow()
|
||||
cursor.execute(insert_kb_query, (
|
||||
request.title,
|
||||
current_user['user_id'],
|
||||
request.is_shared,
|
||||
request.source_meeting_ids,
|
||||
request.user_prompt,
|
||||
request.tags,
|
||||
now,
|
||||
now
|
||||
))
|
||||
kb_id = cursor.lastrowid
|
||||
|
||||
# Start the async task
|
||||
task_id = async_kb_service.start_generation(
|
||||
user_id=current_user['user_id'],
|
||||
kb_id=kb_id,
|
||||
user_prompt=request.user_prompt,
|
||||
source_meeting_ids=request.source_meeting_ids,
|
||||
cursor=cursor
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
|
||||
# Add the background task to process the knowledge base generation
|
||||
background_tasks.add_task(async_kb_service._process_task, task_id)
|
||||
|
||||
return {"task_id": task_id, "kb_id": kb_id}
|
||||
|
||||
@router.get("/knowledge-bases/{kb_id}")
|
||||
def get_knowledge_base_detail(
|
||||
kb_id: int,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
kb.kb_id, kb.title, kb.content, kb.creator_id, u.caption as creator_caption,
|
||||
kb.is_shared, kb.source_meeting_ids, kb.user_prompt, kb.tags, kb.created_at, kb.updated_at,
|
||||
u.username as created_by_name
|
||||
FROM knowledge_bases kb
|
||||
JOIN users u ON kb.creator_id = u.user_id
|
||||
WHERE kb.kb_id = %s
|
||||
"""
|
||||
cursor.execute(query, (kb_id,))
|
||||
kb_data = cursor.fetchone()
|
||||
|
||||
if not kb_data:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
|
||||
# Check access permissions
|
||||
if not kb_data['is_shared'] and kb_data['creator_id'] != current_user['user_id']:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Process tags
|
||||
kb_data['tags'] = _process_tags(cursor, kb_data.get('tags'))
|
||||
|
||||
# Get source meetings details
|
||||
source_meetings = []
|
||||
if kb_data.get('source_meeting_ids'):
|
||||
meeting_ids = [mid.strip() for mid in kb_data['source_meeting_ids'].split(',') if mid.strip()]
|
||||
if meeting_ids:
|
||||
placeholders = ','.join(['%s'] * len(meeting_ids))
|
||||
meeting_query = f"""
|
||||
SELECT meeting_id, title
|
||||
FROM meetings
|
||||
WHERE meeting_id IN ({placeholders})
|
||||
"""
|
||||
cursor.execute(meeting_query, tuple(meeting_ids))
|
||||
meetings_data = cursor.fetchall()
|
||||
source_meetings = [{'meeting_id': m['meeting_id'], 'title': m['title']} for m in meetings_data]
|
||||
kb_data['source_meeting_count'] = len(source_meetings)
|
||||
else:
|
||||
kb_data['source_meeting_count'] = 0
|
||||
else:
|
||||
kb_data['source_meeting_count'] = 0
|
||||
|
||||
kb_data['source_meetings'] = source_meetings
|
||||
|
||||
return kb_data
|
||||
|
||||
@router.put("/knowledge-bases/{kb_id}")
|
||||
def update_knowledge_base(
|
||||
kb_id: int,
|
||||
request: UpdateKnowledgeBaseRequest,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Check if knowledge base exists and user has permission
|
||||
cursor.execute(
|
||||
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
|
||||
(kb_id,)
|
||||
)
|
||||
kb = cursor.fetchone()
|
||||
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
|
||||
if kb['creator_id'] != current_user['user_id']:
|
||||
raise HTTPException(status_code=403, detail="Only the creator can update this knowledge base")
|
||||
|
||||
# Update the knowledge base
|
||||
now = datetime.datetime.utcnow()
|
||||
update_query = """
|
||||
UPDATE knowledge_bases
|
||||
SET title = %s, content = %s, tags = %s, updated_at = %s
|
||||
WHERE kb_id = %s
|
||||
"""
|
||||
cursor.execute(update_query, (
|
||||
request.title,
|
||||
request.content,
|
||||
request.tags,
|
||||
now,
|
||||
kb_id
|
||||
))
|
||||
connection.commit()
|
||||
|
||||
return {"message": "Knowledge base updated successfully"}
|
||||
|
||||
@router.delete("/knowledge-bases/{kb_id}")
|
||||
def delete_knowledge_base(
|
||||
kb_id: int,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Check if knowledge base exists and user has permission
|
||||
cursor.execute(
|
||||
"SELECT kb_id, creator_id FROM knowledge_bases WHERE kb_id = %s",
|
||||
(kb_id,)
|
||||
)
|
||||
kb = cursor.fetchone()
|
||||
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
|
||||
if kb['creator_id'] != current_user['user_id']:
|
||||
raise HTTPException(status_code=403, detail="Only the creator can delete this knowledge base")
|
||||
|
||||
# Delete the knowledge base
|
||||
cursor.execute("DELETE FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
|
||||
connection.commit()
|
||||
|
||||
return {"message": "Knowledge base deleted successfully"}
|
||||
|
||||
@router.get("/knowledge-bases/tasks/{task_id}")
|
||||
def get_task_status(task_id: str):
|
||||
"""获取知识库生成任务状态"""
|
||||
task_status = async_kb_service.get_task_status(task_id)
|
||||
|
||||
if task_status.get('status') == 'not_found':
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return {
|
||||
"status": task_status.get('status'),
|
||||
"progress": task_status.get('progress', 0),
|
||||
"error": task_status.get('error_message')
|
||||
}
|
||||
|
||||
|
|
@ -117,4 +117,47 @@ class BatchTranscriptUpdateRequest(BaseModel):
|
|||
|
||||
class PasswordChangeRequest(BaseModel):
|
||||
old_password: str
|
||||
new_password: str
|
||||
new_password: str
|
||||
|
||||
class KnowledgeBase(BaseModel):
|
||||
kb_id: int
|
||||
title: str
|
||||
content: Optional[str] = None
|
||||
creator_id: int
|
||||
creator_caption: str # To show in the UI
|
||||
is_shared: bool
|
||||
source_meeting_ids: Optional[str] = None
|
||||
user_prompt: Optional[str] = None
|
||||
tags: Optional[List[Tag]] = []
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
source_meeting_count: Optional[int] = 0
|
||||
created_by_name: Optional[str] = None
|
||||
|
||||
class KnowledgeBaseTask(BaseModel):
|
||||
task_id: str
|
||||
user_id: int
|
||||
kb_id: int
|
||||
user_prompt: Optional[str] = None
|
||||
status: str
|
||||
progress: int
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
completed_at: Optional[datetime.datetime] = None
|
||||
|
||||
class CreateKnowledgeBaseRequest(BaseModel):
|
||||
title: Optional[str] = None # 改为可选,后台自动生成
|
||||
is_shared: bool
|
||||
user_prompt: Optional[str] = None
|
||||
source_meeting_ids: Optional[str] = None
|
||||
tags: Optional[str] = None
|
||||
|
||||
class UpdateKnowledgeBaseRequest(BaseModel):
|
||||
title: str
|
||||
content: Optional[str] = None
|
||||
tags: Optional[str] = None
|
||||
|
||||
class KnowledgeBaseListResponse(BaseModel):
|
||||
kbs: List[KnowledgeBase]
|
||||
total: int
|
||||
|
|
|
|||
|
|
@ -0,0 +1,201 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
import redis
|
||||
from app.core.database import get_db_connection
|
||||
from app.services.llm_service import LLMService
|
||||
|
||||
class AsyncKnowledgeBaseService:
|
||||
|
||||
def __init__(self):
|
||||
from app.core.config import REDIS_CONFIG
|
||||
if 'decode_responses' not in REDIS_CONFIG:
|
||||
REDIS_CONFIG['decode_responses'] = True
|
||||
self.redis_client = redis.Redis(**REDIS_CONFIG)
|
||||
self.llm_service = LLMService()
|
||||
|
||||
def start_generation(self, user_id: int, kb_id: int, user_prompt: Optional[str], source_meeting_ids: Optional[str], cursor=None) -> str:
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# If a cursor is passed, use it directly to avoid creating a new transaction
|
||||
if cursor:
|
||||
query = """
|
||||
INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, created_at)
|
||||
VALUES (%s, %s, %s, %s, NOW())
|
||||
"""
|
||||
cursor.execute(query, (task_id, user_id, kb_id, user_prompt))
|
||||
else:
|
||||
# Fallback to the old method if no cursor is provided
|
||||
self._save_task_to_db(task_id, user_id, kb_id, user_prompt)
|
||||
|
||||
current_time = datetime.now().isoformat()
|
||||
task_data = {
|
||||
'task_id': task_id,
|
||||
'user_id': str(user_id),
|
||||
'kb_id': str(kb_id),
|
||||
'user_prompt': user_prompt if user_prompt else "",
|
||||
'status': 'pending',
|
||||
'progress': '0',
|
||||
'created_at': current_time,
|
||||
'updated_at': current_time
|
||||
}
|
||||
self.redis_client.hset(f"kb_task:{task_id}", mapping=task_data)
|
||||
self.redis_client.expire(f"kb_task:{task_id}", 86400)
|
||||
|
||||
print(f"Knowledge base generation task created: {task_id} for kb_id: {kb_id}")
|
||||
return task_id
|
||||
|
||||
def _process_task(self, task_id: str):
|
||||
print(f"Background task started for knowledge base task: {task_id}")
|
||||
try:
|
||||
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
|
||||
if not task_data:
|
||||
print(f"Error: Task {task_id} not found in Redis for processing.")
|
||||
return
|
||||
|
||||
kb_id = int(task_data['kb_id'])
|
||||
user_prompt = task_data.get('user_prompt', '')
|
||||
|
||||
self._update_task_status_in_redis(task_id, 'processing', 10, message="任务已开始...")
|
||||
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Get source meeting summaries
|
||||
source_text = ""
|
||||
cursor.execute("SELECT source_meeting_ids FROM knowledge_bases WHERE kb_id = %s", (kb_id,))
|
||||
kb_info = cursor.fetchone()
|
||||
if kb_info and kb_info['source_meeting_ids']:
|
||||
self._update_task_status_in_redis(task_id, 'processing', 20, message="获取关联会议纪要...")
|
||||
meeting_ids = [int(m_id) for m_id in kb_info['source_meeting_ids'].split(',') if m_id.isdigit()]
|
||||
if meeting_ids:
|
||||
summaries = []
|
||||
for meeting_id in meeting_ids:
|
||||
cursor.execute("SELECT summary FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
summary = cursor.fetchone()
|
||||
if summary and summary['summary']:
|
||||
summaries.append(summary['summary'])
|
||||
source_text = "\n\n---\n\n".join(summaries)
|
||||
|
||||
# Get system prompt
|
||||
self._update_task_status_in_redis(task_id, 'processing', 30, message="获取知识库生成模版...")
|
||||
system_prompt = self._get_knowledge_task_prompt(cursor)
|
||||
|
||||
# Build final prompt
|
||||
final_prompt = f"{system_prompt}\n\n"
|
||||
if source_text:
|
||||
final_prompt += f"请参考以下会议纪要内容:\n{source_text}\n\n"
|
||||
final_prompt += f"用户要求:{user_prompt}"
|
||||
|
||||
self._update_task_status_in_redis(task_id, 'processing', 50, message="AI正在生成知识库...")
|
||||
generated_content = self.llm_service._call_llm_api(final_prompt)
|
||||
|
||||
if not generated_content:
|
||||
raise Exception("LLM API call failed or returned empty content")
|
||||
|
||||
self._update_task_status_in_redis(task_id, 'processing', 95, message="保存结果...")
|
||||
self._save_result_to_db(kb_id, generated_content, cursor)
|
||||
|
||||
self._update_task_in_db(task_id, 'completed', 100, cursor=cursor)
|
||||
self._update_task_status_in_redis(task_id, 'completed', 100)
|
||||
|
||||
connection.commit()
|
||||
print(f"Task {task_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"Task {task_id} failed: {error_msg}")
|
||||
# Use a new connection for error logging to avoid issues with a potentially broken transaction
|
||||
with get_db_connection() as err_conn:
|
||||
err_cursor = err_conn.cursor()
|
||||
self._update_task_in_db(task_id, 'failed', 0, error_message=error_msg, cursor=err_cursor)
|
||||
err_conn.commit()
|
||||
self._update_task_status_in_redis(task_id, 'failed', 0, error_message=error_msg)
|
||||
|
||||
def _get_knowledge_task_prompt(self, cursor) -> str:
|
||||
query = """
|
||||
SELECT p.content
|
||||
FROM prompt_config pc
|
||||
JOIN prompts p ON pc.prompt_id = p.id
|
||||
WHERE pc.task_name = 'KNOWLEDGE_TASK'
|
||||
"""
|
||||
cursor.execute(query)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result['content']
|
||||
else:
|
||||
# Fallback prompt
|
||||
return "Please generate a knowledge base article based on the provided information."
|
||||
|
||||
|
||||
|
||||
def _save_result_to_db(self, kb_id: int, content: str, cursor):
|
||||
query = "UPDATE knowledge_bases SET content = %s, updated_at = NOW() WHERE kb_id = %s"
|
||||
cursor.execute(query, (content, kb_id))
|
||||
|
||||
def _update_task_in_db(self, task_id: str, status: str, progress: int, error_message: str = None, cursor=None):
|
||||
query = "UPDATE knowledge_base_tasks SET status = %s, progress = %s, error_message = %s, updated_at = NOW(), completed_at = IF(%s = 'completed', NOW(), completed_at) WHERE task_id = %s"
|
||||
cursor.execute(query, (status, progress, error_message, status, task_id))
|
||||
|
||||
def _update_task_status_in_redis(self, task_id: str, status: str, progress: int, message: str = None, error_message: str = None):
|
||||
update_data = {
|
||||
'status': status,
|
||||
'progress': str(progress),
|
||||
'updated_at': datetime.now().isoformat()
|
||||
}
|
||||
if message: update_data['message'] = message
|
||||
if error_message: update_data['error_message'] = error_message
|
||||
self.redis_client.hset(f"kb_task:{task_id}", mapping=update_data)
|
||||
|
||||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""获取任务状态 - 与 async_llm_service 保持一致"""
|
||||
try:
|
||||
task_data = self.redis_client.hgetall(f"kb_task:{task_id}")
|
||||
if not task_data:
|
||||
task_data = self._get_task_from_db(task_id)
|
||||
if not task_data:
|
||||
return {'task_id': task_id, 'status': 'not_found', 'error_message': 'Task not found'}
|
||||
|
||||
return {
|
||||
'task_id': task_id,
|
||||
'status': task_data.get('status', 'unknown'),
|
||||
'progress': int(task_data.get('progress', 0)),
|
||||
'kb_id': int(task_data.get('kb_id', 0)),
|
||||
'created_at': task_data.get('created_at'),
|
||||
'updated_at': task_data.get('updated_at'),
|
||||
'error_message': task_data.get('error_message')
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error getting task status: {e}")
|
||||
return {'task_id': task_id, 'status': 'error', 'error_message': str(e)}
|
||||
|
||||
def _save_task_to_db(self, task_id: str, user_id: int, kb_id: int, user_prompt: str):
|
||||
"""保存任务到数据库"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor()
|
||||
insert_query = "INSERT INTO knowledge_base_tasks (task_id, user_id, kb_id, user_prompt, status, progress, created_at) VALUES (%s, %s, %s, %s, 'pending', 0, NOW())"
|
||||
cursor.execute(insert_query, (task_id, user_id, kb_id, user_prompt))
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
print(f"Error saving task to database: {e}")
|
||||
raise
|
||||
|
||||
def _get_task_from_db(self, task_id: str) -> Optional[Dict[str, str]]:
|
||||
"""从数据库获取任务信息"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
query = "SELECT * FROM knowledge_base_tasks WHERE task_id = %s"
|
||||
cursor.execute(query, (task_id,))
|
||||
task = cursor.fetchone()
|
||||
if task:
|
||||
# 确保所有字段都是字符串,以匹配Redis的行为
|
||||
return {k: v.isoformat() if isinstance(v, datetime) else str(v) for k, v in task.items()}
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error getting task from database: {e}")
|
||||
return None
|
||||
|
||||
async_kb_service = AsyncKnowledgeBaseService()
|
||||
5
main.py
5
main.py
|
|
@ -2,7 +2,7 @@ import uvicorn
|
|||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts
|
||||
from app.api.endpoints import auth, users, meetings, tags, admin, tasks, prompts, knowledge_base
|
||||
from app.core.config import UPLOAD_DIR, API_CONFIG
|
||||
from app.api.endpoints.admin import load_system_config
|
||||
import os
|
||||
|
|
@ -19,7 +19,7 @@ load_system_config()
|
|||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
# allow_origins=API_CONFIG['cors_origins'],
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
|
@ -37,6 +37,7 @@ app.include_router(tags.router, prefix="/api", tags=["Tags"])
|
|||
app.include_router(admin.router, prefix="/api", tags=["Admin"])
|
||||
app.include_router(tasks.router, prefix="/api", tags=["Tasks"])
|
||||
app.include_router(prompts.router, prefix="/api", tags=["Prompts"])
|
||||
app.include_router(knowledge_base.router, prefix="/api", tags=["KnowledgeBase"])
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
|
|
|
|||
Loading…
Reference in New Issue