221 lines
8.8 KiB
Python
221 lines
8.8 KiB
Python
import os
|
||
import shutil
|
||
import asyncio
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any, Optional
|
||
from whoosh import index
|
||
from whoosh.fields import Schema, TEXT, ID, DATETIME
|
||
from whoosh.qparser import QueryParser, MultifieldParser
|
||
from whoosh.analysis import Tokenizer, Token
|
||
from whoosh.highlight import HtmlFormatter
|
||
import jieba
|
||
|
||
from app.core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 定义中文分词器
|
||
class ChineseTokenizer(Tokenizer):
|
||
def __call__(self, value, positions=False, chars=False, keeporiginal=False, removestops=True, start_pos=0, start_char=0, mode='', **kwargs):
|
||
t = Token(positions, chars, removestops=removestops, mode=mode, **kwargs)
|
||
# cut_for_search returns generator, need to list if iterating repeatedly?
|
||
# Tokenizer expects yield Token.
|
||
try:
|
||
seglist = jieba.cut_for_search(value)
|
||
for w in seglist:
|
||
t.original = t.text = w
|
||
t.boost = 1.0
|
||
if positions:
|
||
t.pos = start_pos
|
||
start_pos += 1
|
||
if chars:
|
||
t.startchar = start_char
|
||
t.endchar = start_char + len(w)
|
||
start_char += len(w)
|
||
yield t
|
||
except Exception as e:
|
||
logger.error(f"Jieba tokenization error: {e}")
|
||
|
||
def ChineseAnalyzer():
|
||
return ChineseTokenizer()
|
||
|
||
class SearchService:
|
||
def __init__(self):
|
||
# 索引存储路径 (使用配置文件中的存储根目录)
|
||
storage_root = Path(settings.STORAGE_ROOT)
|
||
# 如果是相对路径,转换为绝对路径 (相对于 backend 根目录,这里假设 settings 已经处理好或者我们手动处理)
|
||
# settings.STORAGE_ROOT 默认为 /data/...
|
||
# 本地开发时可能是 ./storage
|
||
if not storage_root.is_absolute():
|
||
backend_dir = Path(__file__).parent.parent.parent
|
||
storage_root = (backend_dir / storage_root).resolve()
|
||
|
||
self.index_dir = storage_root / "search_index"
|
||
|
||
try:
|
||
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||
except Exception as e:
|
||
logger.error(f"Failed to create search index directory: {e}")
|
||
|
||
self.schema = Schema(
|
||
project_id=ID(stored=True),
|
||
path=ID(unique=True, stored=True),
|
||
title=TEXT(stored=True, analyzer=ChineseAnalyzer()),
|
||
content=TEXT(stored=True, analyzer=ChineseAnalyzer())
|
||
)
|
||
|
||
self.ix = None
|
||
try:
|
||
self._load_or_create_index()
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize search index: {e}")
|
||
|
||
def _load_or_create_index(self):
|
||
# 检查是否包含索引文件
|
||
if index.exists_in(str(self.index_dir)):
|
||
try:
|
||
self.ix = index.open_dir(str(self.index_dir))
|
||
except Exception as e:
|
||
logger.warning(f"Failed to open index, trying to recreate: {e}")
|
||
# 如果打开失败(例如损坏),尝试重建
|
||
shutil.rmtree(str(self.index_dir))
|
||
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||
self.ix = index.create_in(str(self.index_dir), self.schema)
|
||
else:
|
||
self.ix = index.create_in(str(self.index_dir), self.schema)
|
||
|
||
def _add_document_sync(self, project_id: str, path: str, title: str, content: str):
|
||
if not self.ix:
|
||
return
|
||
try:
|
||
writer = self.ix.writer()
|
||
writer.update_document(
|
||
project_id=str(project_id),
|
||
path=path,
|
||
title=title,
|
||
content=content
|
||
)
|
||
writer.commit()
|
||
except Exception as e:
|
||
logger.error(f"Failed to add document to index: {e}")
|
||
# 如果是 LockError,可能需要清理锁? 暂时忽略
|
||
|
||
async def add_document(self, project_id: str, path: str, title: str, content: str):
|
||
"""添加或更新文档索引 (Async)"""
|
||
loop = asyncio.get_running_loop()
|
||
await loop.run_in_executor(None, self._add_document_sync, project_id, path, title, content)
|
||
|
||
def _delete_document_sync(self, unique_path: str):
|
||
if not self.ix:
|
||
return
|
||
try:
|
||
writer = self.ix.writer()
|
||
writer.delete_by_term('path', unique_path)
|
||
writer.commit()
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete document from index: {e}")
|
||
|
||
async def delete_document(self, project_id: str, path: str):
|
||
"""删除文档索引 (Async)"""
|
||
unique_path = f"{project_id}:{path}"
|
||
loop = asyncio.get_running_loop()
|
||
await loop.run_in_executor(None, self._delete_document_sync, unique_path)
|
||
|
||
def _delete_project_documents_sync(self, project_id: str):
|
||
if not self.ix:
|
||
return
|
||
try:
|
||
writer = self.ix.writer()
|
||
writer.delete_by_term('project_id', str(project_id))
|
||
writer.commit()
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete project documents: {e}")
|
||
|
||
async def delete_project_documents(self, project_id: str):
|
||
"""删除项目下的所有文档索引 (Async)"""
|
||
loop = asyncio.get_running_loop()
|
||
await loop.run_in_executor(None, self._delete_project_documents_sync, project_id)
|
||
|
||
def _search_sync(self, keyword: str, project_id: Optional[str] = None, limit: int = 20):
|
||
if not keyword or not self.ix:
|
||
return []
|
||
|
||
try:
|
||
with self.ix.searcher() as searcher:
|
||
parser = MultifieldParser(["title", "content"], schema=self.ix.schema)
|
||
query = parser.parse(keyword)
|
||
|
||
filter_query = None
|
||
if project_id:
|
||
filter_parser = QueryParser("project_id", schema=self.ix.schema)
|
||
filter_query = filter_parser.parse(str(project_id))
|
||
|
||
results = searcher.search(query, filter=filter_query, limit=limit)
|
||
results.formatter = HtmlFormatter(tagname="mark", classname="search-highlight", termclass="search-term")
|
||
|
||
search_results = []
|
||
for hit in results:
|
||
# 提取原始路径 (去掉 project_id 前缀)
|
||
full_path = hit.get("path", "")
|
||
if ":" in full_path:
|
||
_, real_path = full_path.split(":", 1)
|
||
else:
|
||
real_path = full_path
|
||
|
||
# 安全获取高亮
|
||
try:
|
||
highlights = hit.highlights("content") or hit.highlights("title") or hit.get("title", "")
|
||
except:
|
||
highlights = hit.get("title", "")
|
||
|
||
search_results.append({
|
||
"project_id": hit.get("project_id"),
|
||
"path": real_path,
|
||
"title": hit.get("title"),
|
||
"highlights": highlights,
|
||
"score": hit.score
|
||
})
|
||
|
||
return search_results
|
||
except Exception as e:
|
||
logger.error(f"Search failed: {e}")
|
||
# 返回空列表而不是抛出异常
|
||
return []
|
||
|
||
async def search(self, keyword: str, project_id: Optional[str] = None, limit: int = 20):
|
||
"""搜索文档 (Async)"""
|
||
loop = asyncio.get_running_loop()
|
||
return await loop.run_in_executor(None, self._search_sync, keyword, project_id, limit)
|
||
|
||
async def update_doc(self, project_id: int, path: str, title: str, content: str):
|
||
"""添加或更新文档 (对外接口)"""
|
||
unique_path = f"{project_id}:{path}"
|
||
await self.add_document(str(project_id), unique_path, title, content)
|
||
|
||
async def remove_doc(self, project_id: int, path: str):
|
||
"""删除文档 (对外接口)"""
|
||
await self.delete_document(str(project_id), path)
|
||
|
||
def rebuild_index_sync(self, documents: List[Dict]):
|
||
"""同步重建索引"""
|
||
if not self.ix:
|
||
return
|
||
|
||
try:
|
||
writer = self.ix.writer()
|
||
# 这里的 documents 必须包含 project_id, path, title, content
|
||
for doc in documents:
|
||
unique_path = f"{doc['project_id']}:{doc['path']}"
|
||
writer.update_document(
|
||
project_id=str(doc['project_id']),
|
||
path=unique_path,
|
||
title=doc['title'],
|
||
content=doc['content']
|
||
)
|
||
writer.commit()
|
||
except Exception as e:
|
||
logger.error(f"Rebuild index failed: {e}")
|
||
|
||
search_service = SearchService()
|