nex_docus/backend/app/api/v1/search.py

270 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
文档搜索相关 API
"""
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from typing import Optional, List
from pathlib import Path
import logging
from app.core.database import get_db
from app.core.deps import get_current_user
from app.models.user import User
from app.models.project import Project, ProjectMember
from app.services.search_service import search_service
from app.services.storage import storage_service
from app.schemas.response import success_response
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/documents", response_model=dict)
async def search_documents(
keyword: str = Query(..., min_length=1, description="搜索关键词"),
project_id: Optional[int] = Query(None, description="限制在指定项目中搜索"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
文档搜索 (混合模式Whoosh 全文检索 + 数据库项目搜索 + 文件系统文件名搜索 fallback)
"""
try:
if not keyword:
return success_response(data=[])
# 1. 确定搜索范围 (项目ID列表)
allowed_project_ids = []
if project_id:
# 检查指定项目的访问权限
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 检查权限
if project.owner_id != current_user.id and project.is_public != 1:
member_result = await db.execute(
select(ProjectMember).where(
ProjectMember.project_id == project_id,
ProjectMember.user_id == current_user.id
)
)
if not member_result.scalar_one_or_none():
raise HTTPException(status_code=403, detail="无权访问该项目")
allowed_project_ids.append(str(project_id))
else:
# 获取所有可访问的项目
# 1. 用户创建的项目
owned_result = await db.execute(
select(Project.id).where(Project.owner_id == current_user.id, Project.status == 1)
)
allowed_project_ids.extend([str(pid) for pid in owned_result.scalars().all()])
# 2. 用户参与的项目
member_result = await db.execute(
select(ProjectMember.project_id)
.join(Project, ProjectMember.project_id == Project.id)
.where(
ProjectMember.user_id == current_user.id,
Project.status == 1
)
)
allowed_project_ids.extend([str(pid) for pid in member_result.scalars().all()])
# 去重
allowed_project_ids = list(set(allowed_project_ids))
if not allowed_project_ids:
return success_response(data=[])
# 2. 执行搜索
search_results = []
# A. 数据库项目搜索 (仅当未指定 project_id 时,或者需要搜项目本身)
# 如果前端指定了 project_id通常是在项目内搜文件不需要搜项目本身
if not project_id:
projects_query = select(Project).where(
Project.id.in_(allowed_project_ids),
or_(
Project.name.ilike(f"%{keyword}%"),
Project.description.ilike(f"%{keyword}%")
)
)
project_res = await db.execute(projects_query)
matched_projects = project_res.scalars().all()
for proj in matched_projects:
search_results.append({
"type": "project",
"project_id": proj.id,
"project_name": proj.name,
"project_description": proj.description or "",
"match_type": "项目名称/描述",
})
# B. Whoosh 全文检索
whoosh_results = []
try:
if project_id:
whoosh_results = await search_service.search(keyword, str(project_id))
else:
# 全局搜索
whoosh_results = await search_service.search(keyword, limit=50)
# 过滤权限
whoosh_results = [r for r in whoosh_results if str(r['project_id']) in allowed_project_ids]
except Exception as e:
logger.warning(f"Whoosh search failed: {e}")
pass
# 获取 Whoosh 结果涉及的项目 ID
whoosh_project_ids = set(res['project_id'] for res in whoosh_results if res.get('project_id'))
# 查询项目名称映射
project_name_map = {}
if whoosh_project_ids:
p_res = await db.execute(select(Project.id, Project.name).where(Project.id.in_(whoosh_project_ids)))
for pid, pname in p_res.all():
project_name_map[str(pid)] = pname
# 添加 Whoosh 结果
for res in whoosh_results:
pid_str = str(res['project_id'])
search_results.append({
"type": "file",
"project_id": res['project_id'],
"project_name": project_name_map.get(pid_str, "未知项目"),
"file_path": res['path'],
"file_name": res['title'],
"highlights": res.get('highlights'),
"match_type": "全文检索"
})
# C. 文件系统文件名搜索 (Fallback / Complementary)
# 为了保证未索引的文件也能通过文件名搜到
# 获取需要扫描的项目
projects_to_scan = []
if project_id:
# 单项目扫描
res = await db.execute(select(Project).where(Project.id == project_id))
p = res.scalar_one_or_none()
if p: projects_to_scan = [p]
elif len(search_results) < 20:
# 全局扫描:仅当结果较少时才进行全盘扫描,避免性能问题
# 这是一个简单的启发式策略
res = await db.execute(select(Project).where(Project.id.in_(allowed_project_ids)))
projects_to_scan = res.scalars().all()
# 已存在的文件路径集合 (用于去重)
existing_paths = set()
for res in search_results:
if res.get('type') == 'file':
# 统一 key 格式
existing_paths.add(f"{res['project_id']}:{res['file_path']}")
keyword_lower = keyword.lower()
for project in projects_to_scan:
try:
project_path = storage_service.get_secure_path(project.storage_key)
if not project_path.exists(): continue
# 查找文件名匹配
md_files = list(project_path.rglob("*.md"))
pdf_files = list(project_path.rglob("*.pdf"))
for file_path in md_files + pdf_files:
if "_assets" in file_path.parts: continue
if keyword_lower in file_path.name.lower():
rel_path = str(file_path.relative_to(project_path))
unique_key = f"{project.id}:{rel_path}"
if unique_key not in existing_paths:
search_results.append({
"type": "file",
"project_id": project.id,
"project_name": project.name,
"file_path": rel_path,
"file_name": file_path.name,
"match_type": "文件名匹配"
})
existing_paths.add(unique_key)
except Exception:
continue
return success_response(data=search_results[:100])
except Exception as e:
logger.error(f"Search API error: {e}")
return success_response(data=[], message="搜索服务暂时不可用")
async def rebuild_index_task(db: AsyncSession):
"""后台任务:重建索引"""
logger.info("Starting index rebuild...")
try:
# 获取所有项目
result = await db.execute(select(Project).where(Project.status == 1))
projects = result.scalars().all()
documents = []
for project in projects:
try:
# 遍历项目文件
project_root = storage_service.get_secure_path(project.storage_key)
if not project_root.exists():
continue
# 查找所有 .md 文件
md_files = list(project_root.rglob("*.md"))
for file_path in md_files:
if "_assets" in file_path.parts:
continue
try:
content = await storage_service.read_file(file_path)
relative_path = str(file_path.relative_to(project_root))
documents.append({
"project_id": project.id,
"path": relative_path,
"title": file_path.stem,
"content": content
})
except Exception:
continue
except Exception as e:
logger.error(f"Error processing project {project.id}: {e}")
continue
# 批量写入索引
import asyncio
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, search_service.rebuild_index_sync, documents)
logger.info(f"Index rebuild completed. Indexed {len(documents)} documents.")
except Exception as e:
logger.error(f"Index rebuild failed: {e}")
@router.post("/rebuild-index", response_model=dict)
async def rebuild_index(
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
重建搜索索引 (仅限超级管理员)
"""
if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="权限不足")
background_tasks.add_task(rebuild_index_task, db)
return success_response(message="索引重建任务已启动")