270 lines
11 KiB
Python
270 lines
11 KiB
Python
"""
|
||
文档搜索相关 API
|
||
"""
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, or_
|
||
from typing import Optional, List
|
||
from pathlib import Path
|
||
import logging
|
||
|
||
from app.core.database import get_db
|
||
from app.core.deps import get_current_user
|
||
from app.models.user import User
|
||
from app.models.project import Project, ProjectMember
|
||
from app.services.search_service import search_service
|
||
from app.services.storage import storage_service
|
||
from app.schemas.response import success_response
|
||
|
||
router = APIRouter()
|
||
logger = logging.getLogger(__name__)
|
||
|
||
@router.get("/documents", response_model=dict)
|
||
async def search_documents(
|
||
keyword: str = Query(..., min_length=1, description="搜索关键词"),
|
||
project_id: Optional[int] = Query(None, description="限制在指定项目中搜索"),
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
文档搜索 (混合模式:Whoosh 全文检索 + 数据库项目搜索 + 文件系统文件名搜索 fallback)
|
||
"""
|
||
try:
|
||
if not keyword:
|
||
return success_response(data=[])
|
||
|
||
# 1. 确定搜索范围 (项目ID列表)
|
||
allowed_project_ids = []
|
||
|
||
if project_id:
|
||
# 检查指定项目的访问权限
|
||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||
project = result.scalar_one_or_none()
|
||
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="项目不存在")
|
||
|
||
# 检查权限
|
||
if project.owner_id != current_user.id and project.is_public != 1:
|
||
member_result = await db.execute(
|
||
select(ProjectMember).where(
|
||
ProjectMember.project_id == project_id,
|
||
ProjectMember.user_id == current_user.id
|
||
)
|
||
)
|
||
if not member_result.scalar_one_or_none():
|
||
raise HTTPException(status_code=403, detail="无权访问该项目")
|
||
|
||
allowed_project_ids.append(str(project_id))
|
||
else:
|
||
# 获取所有可访问的项目
|
||
# 1. 用户创建的项目
|
||
owned_result = await db.execute(
|
||
select(Project.id).where(Project.owner_id == current_user.id, Project.status == 1)
|
||
)
|
||
allowed_project_ids.extend([str(pid) for pid in owned_result.scalars().all()])
|
||
|
||
# 2. 用户参与的项目
|
||
member_result = await db.execute(
|
||
select(ProjectMember.project_id)
|
||
.join(Project, ProjectMember.project_id == Project.id)
|
||
.where(
|
||
ProjectMember.user_id == current_user.id,
|
||
Project.status == 1
|
||
)
|
||
)
|
||
allowed_project_ids.extend([str(pid) for pid in member_result.scalars().all()])
|
||
|
||
# 去重
|
||
allowed_project_ids = list(set(allowed_project_ids))
|
||
|
||
if not allowed_project_ids:
|
||
return success_response(data=[])
|
||
|
||
# 2. 执行搜索
|
||
search_results = []
|
||
|
||
# A. 数据库项目搜索 (仅当未指定 project_id 时,或者需要搜项目本身)
|
||
# 如果前端指定了 project_id,通常是在项目内搜文件,不需要搜项目本身
|
||
if not project_id:
|
||
projects_query = select(Project).where(
|
||
Project.id.in_(allowed_project_ids),
|
||
or_(
|
||
Project.name.ilike(f"%{keyword}%"),
|
||
Project.description.ilike(f"%{keyword}%")
|
||
)
|
||
)
|
||
project_res = await db.execute(projects_query)
|
||
matched_projects = project_res.scalars().all()
|
||
|
||
for proj in matched_projects:
|
||
search_results.append({
|
||
"type": "project",
|
||
"project_id": proj.id,
|
||
"project_name": proj.name,
|
||
"project_description": proj.description or "",
|
||
"match_type": "项目名称/描述",
|
||
})
|
||
|
||
# B. Whoosh 全文检索
|
||
whoosh_results = []
|
||
try:
|
||
if project_id:
|
||
whoosh_results = await search_service.search(keyword, str(project_id))
|
||
else:
|
||
# 全局搜索
|
||
whoosh_results = await search_service.search(keyword, limit=50)
|
||
# 过滤权限
|
||
whoosh_results = [r for r in whoosh_results if str(r['project_id']) in allowed_project_ids]
|
||
except Exception as e:
|
||
logger.warning(f"Whoosh search failed: {e}")
|
||
pass
|
||
|
||
# 获取 Whoosh 结果涉及的项目 ID
|
||
whoosh_project_ids = set(res['project_id'] for res in whoosh_results if res.get('project_id'))
|
||
|
||
# 查询项目名称映射
|
||
project_name_map = {}
|
||
if whoosh_project_ids:
|
||
p_res = await db.execute(select(Project.id, Project.name).where(Project.id.in_(whoosh_project_ids)))
|
||
for pid, pname in p_res.all():
|
||
project_name_map[str(pid)] = pname
|
||
|
||
# 添加 Whoosh 结果
|
||
for res in whoosh_results:
|
||
pid_str = str(res['project_id'])
|
||
search_results.append({
|
||
"type": "file",
|
||
"project_id": res['project_id'],
|
||
"project_name": project_name_map.get(pid_str, "未知项目"),
|
||
"file_path": res['path'],
|
||
"file_name": res['title'],
|
||
"highlights": res.get('highlights'),
|
||
"match_type": "全文检索"
|
||
})
|
||
|
||
# C. 文件系统文件名搜索 (Fallback / Complementary)
|
||
# 为了保证未索引的文件也能通过文件名搜到
|
||
|
||
# 获取需要扫描的项目
|
||
projects_to_scan = []
|
||
if project_id:
|
||
# 单项目扫描
|
||
res = await db.execute(select(Project).where(Project.id == project_id))
|
||
p = res.scalar_one_or_none()
|
||
if p: projects_to_scan = [p]
|
||
elif len(search_results) < 20:
|
||
# 全局扫描:仅当结果较少时才进行全盘扫描,避免性能问题
|
||
# 这是一个简单的启发式策略
|
||
res = await db.execute(select(Project).where(Project.id.in_(allowed_project_ids)))
|
||
projects_to_scan = res.scalars().all()
|
||
|
||
# 已存在的文件路径集合 (用于去重)
|
||
existing_paths = set()
|
||
for res in search_results:
|
||
if res.get('type') == 'file':
|
||
# 统一 key 格式
|
||
existing_paths.add(f"{res['project_id']}:{res['file_path']}")
|
||
|
||
keyword_lower = keyword.lower()
|
||
|
||
for project in projects_to_scan:
|
||
try:
|
||
project_path = storage_service.get_secure_path(project.storage_key)
|
||
if not project_path.exists(): continue
|
||
|
||
# 查找文件名匹配
|
||
md_files = list(project_path.rglob("*.md"))
|
||
pdf_files = list(project_path.rglob("*.pdf"))
|
||
|
||
for file_path in md_files + pdf_files:
|
||
if "_assets" in file_path.parts: continue
|
||
|
||
if keyword_lower in file_path.name.lower():
|
||
rel_path = str(file_path.relative_to(project_path))
|
||
unique_key = f"{project.id}:{rel_path}"
|
||
|
||
if unique_key not in existing_paths:
|
||
search_results.append({
|
||
"type": "file",
|
||
"project_id": project.id,
|
||
"project_name": project.name,
|
||
"file_path": rel_path,
|
||
"file_name": file_path.name,
|
||
"match_type": "文件名匹配"
|
||
})
|
||
existing_paths.add(unique_key)
|
||
except Exception:
|
||
continue
|
||
|
||
return success_response(data=search_results[:100])
|
||
|
||
except Exception as e:
|
||
logger.error(f"Search API error: {e}")
|
||
return success_response(data=[], message="搜索服务暂时不可用")
|
||
|
||
|
||
async def rebuild_index_task(db: AsyncSession):
|
||
"""后台任务:重建索引"""
|
||
logger.info("Starting index rebuild...")
|
||
try:
|
||
# 获取所有项目
|
||
result = await db.execute(select(Project).where(Project.status == 1))
|
||
projects = result.scalars().all()
|
||
|
||
documents = []
|
||
for project in projects:
|
||
try:
|
||
# 遍历项目文件
|
||
project_root = storage_service.get_secure_path(project.storage_key)
|
||
if not project_root.exists():
|
||
continue
|
||
|
||
# 查找所有 .md 文件
|
||
md_files = list(project_root.rglob("*.md"))
|
||
|
||
for file_path in md_files:
|
||
if "_assets" in file_path.parts:
|
||
continue
|
||
|
||
try:
|
||
content = await storage_service.read_file(file_path)
|
||
relative_path = str(file_path.relative_to(project_root))
|
||
|
||
documents.append({
|
||
"project_id": project.id,
|
||
"path": relative_path,
|
||
"title": file_path.stem,
|
||
"content": content
|
||
})
|
||
except Exception:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"Error processing project {project.id}: {e}")
|
||
continue
|
||
|
||
# 批量写入索引
|
||
import asyncio
|
||
loop = asyncio.get_running_loop()
|
||
await loop.run_in_executor(None, search_service.rebuild_index_sync, documents)
|
||
logger.info(f"Index rebuild completed. Indexed {len(documents)} documents.")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Index rebuild failed: {e}")
|
||
|
||
|
||
@router.post("/rebuild-index", response_model=dict)
|
||
async def rebuild_index(
|
||
background_tasks: BackgroundTasks,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
重建搜索索引 (仅限超级管理员)
|
||
"""
|
||
if not current_user.is_superuser:
|
||
raise HTTPException(status_code=403, detail="权限不足")
|
||
|
||
background_tasks.add_task(rebuild_index_task, db)
|
||
|
||
return success_response(message="索引重建任务已启动") |