""" 文档搜索相关 API """ from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, or_ from typing import Optional, List from pathlib import Path import logging from app.core.database import get_db from app.core.deps import get_current_user from app.models.user import User from app.models.project import Project, ProjectMember from app.services.search_service import search_service from app.services.storage import storage_service from app.schemas.response import success_response router = APIRouter() logger = logging.getLogger(__name__) @router.get("/documents", response_model=dict) async def search_documents( keyword: str = Query(..., min_length=1, description="搜索关键词"), project_id: Optional[int] = Query(None, description="限制在指定项目中搜索"), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db) ): """ 文档搜索 (混合模式:Whoosh 全文检索 + 数据库项目搜索 + 文件系统文件名搜索 fallback) """ try: if not keyword: return success_response(data=[]) # 1. 确定搜索范围 (项目ID列表) allowed_project_ids = [] if project_id: # 检查指定项目的访问权限 result = await db.execute(select(Project).where(Project.id == project_id)) project = result.scalar_one_or_none() if not project: raise HTTPException(status_code=404, detail="项目不存在") # 检查权限 if project.owner_id != current_user.id and project.is_public != 1: member_result = await db.execute( select(ProjectMember).where( ProjectMember.project_id == project_id, ProjectMember.user_id == current_user.id ) ) if not member_result.scalar_one_or_none(): raise HTTPException(status_code=403, detail="无权访问该项目") allowed_project_ids.append(str(project_id)) else: # 获取所有可访问的项目 # 1. 用户创建的项目 owned_result = await db.execute( select(Project.id).where(Project.owner_id == current_user.id, Project.status == 1) ) allowed_project_ids.extend([str(pid) for pid in owned_result.scalars().all()]) # 2. 用户参与的项目 member_result = await db.execute( select(ProjectMember.project_id) .join(Project, ProjectMember.project_id == Project.id) .where( ProjectMember.user_id == current_user.id, Project.status == 1 ) ) allowed_project_ids.extend([str(pid) for pid in member_result.scalars().all()]) # 去重 allowed_project_ids = list(set(allowed_project_ids)) if not allowed_project_ids: return success_response(data=[]) # 2. 执行搜索 search_results = [] # A. 数据库项目搜索 (仅当未指定 project_id 时,或者需要搜项目本身) # 如果前端指定了 project_id,通常是在项目内搜文件,不需要搜项目本身 if not project_id: projects_query = select(Project).where( Project.id.in_(allowed_project_ids), or_( Project.name.ilike(f"%{keyword}%"), Project.description.ilike(f"%{keyword}%") ) ) project_res = await db.execute(projects_query) matched_projects = project_res.scalars().all() for proj in matched_projects: search_results.append({ "type": "project", "project_id": proj.id, "project_name": proj.name, "project_description": proj.description or "", "match_type": "项目名称/描述", }) # B. Whoosh 全文检索 whoosh_results = [] try: if project_id: whoosh_results = await search_service.search(keyword, str(project_id)) else: # 全局搜索 whoosh_results = await search_service.search(keyword, limit=50) # 过滤权限 whoosh_results = [r for r in whoosh_results if str(r['project_id']) in allowed_project_ids] except Exception as e: logger.warning(f"Whoosh search failed: {e}") pass # 获取 Whoosh 结果涉及的项目 ID whoosh_project_ids = set(res['project_id'] for res in whoosh_results if res.get('project_id')) # 查询项目名称映射 project_name_map = {} if whoosh_project_ids: p_res = await db.execute(select(Project.id, Project.name).where(Project.id.in_(whoosh_project_ids))) for pid, pname in p_res.all(): project_name_map[str(pid)] = pname # 添加 Whoosh 结果 for res in whoosh_results: pid_str = str(res['project_id']) search_results.append({ "type": "file", "project_id": res['project_id'], "project_name": project_name_map.get(pid_str, "未知项目"), "file_path": res['path'], "file_name": res['title'], "highlights": res.get('highlights'), "match_type": "全文检索" }) # C. 文件系统文件名搜索 (Fallback / Complementary) # 为了保证未索引的文件也能通过文件名搜到 # 获取需要扫描的项目 projects_to_scan = [] if project_id: # 单项目扫描 res = await db.execute(select(Project).where(Project.id == project_id)) p = res.scalar_one_or_none() if p: projects_to_scan = [p] elif len(search_results) < 20: # 全局扫描:仅当结果较少时才进行全盘扫描,避免性能问题 # 这是一个简单的启发式策略 res = await db.execute(select(Project).where(Project.id.in_(allowed_project_ids))) projects_to_scan = res.scalars().all() # 已存在的文件路径集合 (用于去重) existing_paths = set() for res in search_results: if res.get('type') == 'file': # 统一 key 格式 existing_paths.add(f"{res['project_id']}:{res['file_path']}") keyword_lower = keyword.lower() for project in projects_to_scan: try: project_path = storage_service.get_secure_path(project.storage_key) if not project_path.exists(): continue # 查找文件名匹配 md_files = list(project_path.rglob("*.md")) pdf_files = list(project_path.rglob("*.pdf")) for file_path in md_files + pdf_files: if "_assets" in file_path.parts: continue if keyword_lower in file_path.name.lower(): rel_path = str(file_path.relative_to(project_path)) unique_key = f"{project.id}:{rel_path}" if unique_key not in existing_paths: search_results.append({ "type": "file", "project_id": project.id, "project_name": project.name, "file_path": rel_path, "file_name": file_path.name, "match_type": "文件名匹配" }) existing_paths.add(unique_key) except Exception: continue return success_response(data=search_results[:100]) except Exception as e: logger.error(f"Search API error: {e}") return success_response(data=[], message="搜索服务暂时不可用") async def rebuild_index_task(db: AsyncSession): """后台任务:重建索引""" logger.info("Starting index rebuild...") try: # 获取所有项目 result = await db.execute(select(Project).where(Project.status == 1)) projects = result.scalars().all() documents = [] for project in projects: try: # 遍历项目文件 project_root = storage_service.get_secure_path(project.storage_key) if not project_root.exists(): continue # 查找所有 .md 文件 md_files = list(project_root.rglob("*.md")) for file_path in md_files: if "_assets" in file_path.parts: continue try: content = await storage_service.read_file(file_path) relative_path = str(file_path.relative_to(project_root)) documents.append({ "project_id": project.id, "path": relative_path, "title": file_path.stem, "content": content }) except Exception: continue except Exception as e: logger.error(f"Error processing project {project.id}: {e}") continue # 批量写入索引 import asyncio loop = asyncio.get_running_loop() await loop.run_in_executor(None, search_service.rebuild_index_sync, documents) logger.info(f"Index rebuild completed. Indexed {len(documents)} documents.") except Exception as e: logger.error(f"Index rebuild failed: {e}") @router.post("/rebuild-index", response_model=dict) async def rebuild_index( background_tasks: BackgroundTasks, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db) ): """ 重建搜索索引 (仅限超级管理员) """ if not current_user.is_superuser: raise HTTPException(status_code=403, detail="权限不足") background_tasks.add_task(rebuild_index_task, db) return success_response(message="索引重建任务已启动")