修复思维导图助手功能

master
Your Name 2026-06-12 18:27:45 +08:00
parent eee05f2cd5
commit 13265d2629
5 changed files with 336 additions and 21 deletions

View File

@ -4,12 +4,13 @@ import os
import re import re
import tempfile import tempfile
import asyncio import asyncio
import threading
import uvicorn import uvicorn
import click import click
import zipfile import zipfile
from pathlib import Path from pathlib import Path
import glob import glob
from fastapi import Depends, FastAPI, HTTPException, UploadFile, File, Form, APIRouter from fastapi import Depends, FastAPI, HTTPException, UploadFile, File, Form, APIRouter, Header
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse from fastapi.responses import JSONResponse, FileResponse
@ -33,6 +34,16 @@ logger.add(sys.stderr, level=log_level)
# --- 全局变量与辅助类 --- # --- 全局变量与辅助类 ---
_request_semaphore: Optional[asyncio.Semaphore] = None _request_semaphore: Optional[asyncio.Semaphore] = None
# --- 任务进度跟踪 ---
_task_progress: dict = {}
def _update_task_progress(task_id: Optional[str], progress: int, stage: str):
"""更新任务进度安全调用task_id 为 None 时静默跳过)"""
if task_id and task_id in _task_progress:
_task_progress[task_id]["progress"] = min(progress, 100)
_task_progress[task_id]["stage"] = stage
class ProgressTracker: class ProgressTracker:
def __init__(self): def __init__(self):
@ -47,6 +58,150 @@ class ProgressTracker:
return {"progress": self.progress, "status": self.status} return {"progress": self.progress, "status": self.status}
# --- 当前活跃任务ID用于日志捕获 ---
_current_task_id: Optional[str] = None
def _log_progress_sink(message):
"""loguru 日志捕获:根据日志内容自动更新进度"""
task_id = _current_task_id
if not task_id or task_id not in _task_progress:
return
try:
msg = str(message)
except Exception:
return
# VLM 引擎加载阶段 (5% ~ 40%)
if "Automatically detected platform" in msg:
_update_task_progress(task_id, 7, "检测计算平台")
elif "Resolved architecture" in msg:
_update_task_progress(task_id, 10, "解析模型架构")
elif "Starting to load model" in msg:
_update_task_progress(task_id, 12, "开始加载模型权重")
elif "Loading weights took" in msg:
_update_task_progress(task_id, 18, "模型权重加载完成")
elif "Model loading took" in msg:
_update_task_progress(task_id, 22, "模型加载完成,初始化引擎")
elif "Dynamo bytecode transform" in msg:
_update_task_progress(task_id, 27, "编译优化计算图")
elif "torch.compile takes" in msg:
_update_task_progress(task_id, 30, "计算图编译完成")
elif "Available KV cache memory" in msg:
_update_task_progress(task_id, 32, "分配显存缓存")
elif "Graph capturing finished" in msg:
_update_task_progress(task_id, 36, "CUDA图捕获完成")
elif "init engine" in msg and "took" in msg:
_update_task_progress(task_id, 38, "VLM引擎初始化完成")
elif "get vllm" in msg and "predictor cost" in msg:
_update_task_progress(task_id, 40, "VLM预测器就绪")
elif "hybrid batch ratio" in msg:
_update_task_progress(task_id, 42, "开始文档分析")
# VLM 推理阶段 (40% ~ 65%) 由 progress_callback 处理
# Pipeline 模型加载和推理阶段 (65% ~ 97%)
elif "Target directory already exists" in msg:
cur = _task_progress.get(task_id, {}).get("progress", 0)
if cur < 75:
_update_task_progress(task_id, max(cur, 66), "加载Pipeline模型")
elif "MFD Predict:" in msg:
_update_task_progress(task_id, 78, "数学公式检测")
elif "MFR Predict:" in msg:
_update_task_progress(task_id, 85, "数学公式识别")
elif "OCR-det:" in msg:
_update_task_progress(task_id, 90, "文字区域检测")
elif "OCR-rec Predict:" in msg:
_update_task_progress(task_id, 95, "文字识别")
elif "local output dir is" in msg:
_update_task_progress(task_id, 97, "保存输出结果")
class _StderrProgressCapture:
"""后台线程捕获 stderr 中的 tqdm 进度条输出,解析并更新任务进度"""
# tqdm 进度条模式:名称: 百分比|...| 当前/总数
_PATTERNS = [
(re.compile(r'Two Step Extraction:\s*(\d+)%.*?(\d+)/(\d+)'), 'extract'),
(re.compile(r'MFD Predict:\s*(\d+)%.*?(\d+)/(\d+)'), 'mfd'),
(re.compile(r'MFR Predict:\s*(\d+)%.*?(\d+)/(\d+)'), 'mfr'),
(re.compile(r'OCR-det:\s*(\d+)%.*?(\d+)/(\d+)'), 'ocr_det'),
(re.compile(r'OCR-rec Predict:\s*(\d+)%.*?(\d+)/(\d+)'), 'ocr_rec'),
(re.compile(r'Loading safetensors.*?:\s*(\d+)%'), 'load_model'),
(re.compile(r'Capturing CUDA graphs.*?:\s*(\d+)%'), 'cuda_graph'),
]
# 各阶段的进度映射范围 [start%, end%]
_RANGES = {
'load_model': (12, 18),
'cuda_graph': (33, 37),
'extract': (42, 65),
'mfd': (75, 80),
'mfr': (80, 87),
'ocr_det': (87, 92),
'ocr_rec': (92, 96),
}
_STAGE_LABELS = {
'load_model': '加载模型权重',
'cuda_graph': '捕获CUDA计算图',
'extract': 'VLM文档分析',
'mfd': '数学公式检测',
'mfr': '数学公式识别',
'ocr_det': '文字区域检测',
'ocr_rec': '文字识别',
}
def __init__(self, task_id: str):
self.task_id = task_id
self._active = False
self._thread: Optional[threading.Thread] = None
self._orig_stderr = None
def start(self):
self._active = True
self._orig_stderr = sys.stderr
self._thread = threading.Thread(target=self._reader_loop, daemon=True)
self._thread.start()
def stop(self):
self._active = False
if self._thread and self._thread.is_alive():
self._thread.join(timeout=2)
self._thread = None
def _reader_loop(self):
buf = ""
orig = self._orig_stderr
while self._active:
try:
ch = orig.read(1)
if not ch:
break
buf += ch
# tqdm 用 \r 更新同一行,\n 表示新行
if ch == '\r' or ch == '\n':
if buf.strip():
self._parse_line(buf.strip())
buf = ""
except Exception:
break
def _parse_line(self, line: str):
for pattern, stage in self._PATTERNS:
m = pattern.search(line)
if m:
pct = int(m.group(1))
lo, hi = self._RANGES.get(stage, (0, 100))
mapped = lo + int((hi - lo) * pct / 100)
label = self._STAGE_LABELS.get(stage, stage)
if stage == 'extract' and len(m.groups()) >= 3:
cur, total = m.group(2), m.group(3)
label = f"VLM文档分析 ({cur}/{total}页)"
elif stage in ('mfd', 'mfr', 'ocr_det', 'ocr_rec') and len(m.groups()) >= 3:
cur, total = m.group(2), m.group(3)
label = f"{label} ({cur}/{total})"
_update_task_progress(self.task_id, mapped, label)
break
async def limit_concurrency(): async def limit_concurrency():
if _request_semaphore is not None: if _request_semaphore is not None:
if _request_semaphore.locked(): if _request_semaphore.locked():
@ -93,6 +248,14 @@ def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str)
api_router = APIRouter(prefix="/api") api_router = APIRouter(prefix="/api")
@api_router.get("/parse_progress/{task_id}")
async def get_parse_progress(task_id: str):
"""查询解析任务的实时进度"""
if task_id not in _task_progress:
raise HTTPException(status_code=404, detail="Task not found")
return _task_progress[task_id]
@api_router.post(path="/file_parse", dependencies=[Depends(limit_concurrency)]) @api_router.post(path="/file_parse", dependencies=[Depends(limit_concurrency)])
async def parse_pdf( async def parse_pdf(
files: List[UploadFile] = File(..., description="Upload pdf, image, or Word files for parsing"), files: List[UploadFile] = File(..., description="Upload pdf, image, or Word files for parsing"),
@ -111,11 +274,20 @@ async def parse_pdf(
response_format_zip: bool = Form(False), response_format_zip: bool = Form(False),
start_page_id: int = Form(0), start_page_id: int = Form(0),
end_page_id: int = Form(99999), end_page_id: int = Form(99999),
x_task_id: Optional[str] = Header(None),
): ):
# 从 app 实例状态中获取配置 (FastAPI 实例会在下方创建) # 从 app 实例状态中获取配置 (FastAPI 实例会在下方创建)
from fastapi import Request from fastapi import Request
config = getattr(app.state, "config", {}) config = getattr(app.state, "config", {})
# 初始化进度跟踪
task_id = x_task_id or str(uuid.uuid4())
file_names_str = ", ".join(f.filename or "unknown" for f in files)
_task_progress[task_id] = {
"progress": 0, "stage": "准备中", "status": "processing",
"error": None, "file_names": file_names_str,
}
try: try:
unique_dir = os.path.join(output_dir, str(uuid.uuid4())) unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
os.makedirs(unique_dir, exist_ok=True) os.makedirs(unique_dir, exist_ok=True)
@ -146,6 +318,27 @@ async def parse_pdf(
if len(actual_lang_list) != len(pdf_file_names): if len(actual_lang_list) != len(pdf_file_names):
actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names) actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
# 进度回调:将 common.py 内部进度映射为细粒度阶段
def progress_callback(pct, msg):
msg_str = str(msg)
if "处理文件" in msg_str:
_update_task_progress(task_id, 40, f"VLM文档分析: {msg_str}")
elif "完成文件" in msg_str:
_update_task_progress(task_id, 65, f"VLM分析完成: {msg_str}")
else:
_update_task_progress(task_id, int(40 + pct * 0.25), msg_str)
_update_task_progress(task_id, 5, "开始解析文档")
# 设置日志捕获将当前任务ID绑定到日志 sink
global _current_task_id
_current_task_id = task_id
sink_id = logger.add(_log_progress_sink, level="DEBUG")
# 启动 stderr 捕获(解析 tqdm 进度条输出)
stderr_capture = _StderrProgressCapture(task_id)
stderr_capture.start()
await aio_do_parse( await aio_do_parse(
output_dir=unique_dir, pdf_file_names=pdf_file_names, pdf_bytes_list=pdf_bytes_list, output_dir=unique_dir, pdf_file_names=pdf_file_names, pdf_bytes_list=pdf_bytes_list,
p_lang_list=actual_lang_list, backend=backend, parse_method=parse_method, p_lang_list=actual_lang_list, backend=backend, parse_method=parse_method,
@ -153,9 +346,20 @@ async def parse_pdf(
f_draw_layout_bbox=False, f_draw_span_bbox=False, f_dump_md=return_md, f_draw_layout_bbox=False, f_draw_span_bbox=False, f_dump_md=return_md,
f_dump_middle_json=return_middle_json, f_dump_model_output=return_model_output, f_dump_middle_json=return_middle_json, f_dump_model_output=return_model_output,
f_dump_orig_pdf=False, f_dump_content_list=return_content_list, f_dump_orig_pdf=False, f_dump_content_list=return_content_list,
start_page_id=start_page_id, end_page_id=end_page_id, **config start_page_id=start_page_id, end_page_id=end_page_id,
progress_callback=progress_callback, **config
) )
_update_task_progress(task_id, 97, "生成结果文件")
_update_task_progress(task_id, 100, "转换完成")
_task_progress[task_id]["status"] = "completed"
# 清理日志捕获和 stderr 捕获
stderr_capture.stop()
logger.remove(sink_id)
_current_task_id = None
if response_format_zip: if response_format_zip:
zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="mineru_results_") zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="mineru_results_")
os.close(zip_fd) os.close(zip_fd)
@ -205,6 +409,15 @@ async def parse_pdf(
content={"backend": backend, "version": __version__, "results": result_dict}) content={"backend": backend, "version": __version__, "results": result_dict})
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
# 清理日志捕获和 stderr 捕获
try:
stderr_capture.stop()
logger.remove(sink_id)
_current_task_id = None
except Exception:
pass
_task_progress[task_id]["status"] = "failed"
_task_progress[task_id]["error"] = str(e)
return JSONResponse(status_code=500, content={"error": f"Internal Error: {str(e)}"}) return JSONResponse(status_code=500, content={"error": f"Internal Error: {str(e)}"})

View File

@ -29,11 +29,19 @@ export interface ParseResult {
}> }>
} }
export interface ParseProgress {
progress: number
stage: string
status: 'processing' | 'completed' | 'failed'
error: string | null
file_names: string
}
export const documentApi = { export const documentApi = {
/** /**
* *
*/ */
parseDocument(params: ParseParams): Promise<ParseResult> { parseDocument(params: ParseParams, taskId?: string): Promise<ParseResult> {
const formData = new FormData() const formData = new FormData()
// 添加文件 // 添加文件
@ -58,16 +66,26 @@ export const documentApi = {
formData.append('server_url', params.server_url) formData.append('server_url', params.server_url)
} }
return request.post('/file_parse', formData, { return request.post('/api/file_parse', formData, {
headers: { headers: {
'Content-Type': 'multipart/form-data' 'Content-Type': 'multipart/form-data',
...(taskId ? { 'X-Task-ID': taskId } : {})
} }
}).then(result => { }).then(result => {
console.log("解析成功:", result); console.log("解析成功:", result);
return result; return result as unknown as ParseResult;
}).catch(error => { }).catch(error => {
console.error("解析失败:", error); console.error("解析失败:", error);
throw error; throw error;
}) })
},
/**
*
*/
getParseProgress(taskId: string): Promise<ParseProgress> {
return request.get(`/api/parse_progress/${taskId}`).then(result => {
return result as unknown as ParseProgress
})
} }
} }

View File

@ -155,6 +155,38 @@ export function useDocumentProcessor() {
error.value = null error.value = null
} }
// 进度相关
const progressPercent = ref(0)
const progressStage = ref('')
let progressTimer: ReturnType<typeof setInterval> | null = null
// 启动进度轮询
const startProgressPolling = (taskId: string) => {
stopProgressPolling()
progressPercent.value = 0
progressStage.value = '准备中'
progressTimer = setInterval(async () => {
try {
const data = await documentApi.getParseProgress(taskId)
progressPercent.value = data.progress
progressStage.value = data.stage
if (data.status === 'completed' || data.status === 'failed') {
stopProgressPolling()
}
} catch {
// 忽略轮询失败
}
}, 1000)
}
// 停止进度轮询
const stopProgressPolling = () => {
if (progressTimer) {
clearInterval(progressTimer)
progressTimer = null
}
}
// 处理文档转换 // 处理文档转换
const processDocument = async () => { const processDocument = async () => {
if (uploadedFiles.value.length === 0) { if (uploadedFiles.value.length === 0) {
@ -166,12 +198,16 @@ export function useDocumentProcessor() {
error.value = null error.value = null
processingStage.value = '准备提交解析任务' processingStage.value = '准备提交解析任务'
// 生成任务ID
const taskId = `task-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`
startProgressPolling(taskId)
try { try {
processingStage.value = '提交文档到解析服务' processingStage.value = '提交文档到解析服务'
const params: ParseParams = { const params: ParseParams = {
files: uploadedFiles.value, files: uploadedFiles.value,
output_dir: './output', output_dir: './output',
lang_list: [config.language], lang_list: config.language,
backend: config.backend, backend: config.backend,
parse_method: config.forceOcr ? 'ocr' : 'auto', parse_method: config.forceOcr ? 'ocr' : 'auto',
formula_enable: config.formulaEnable, formula_enable: config.formulaEnable,
@ -188,7 +224,7 @@ export function useDocumentProcessor() {
} }
processingStage.value = '服务端正在解析文档' processingStage.value = '服务端正在解析文档'
const response = await documentApi.parseDocument(params) const response = await documentApi.parseDocument(params, taskId)
if (response.results) { if (response.results) {
processingStage.value = '生成 Markdown 和思维导图' processingStage.value = '生成 Markdown 和思维导图'
@ -205,6 +241,7 @@ export function useDocumentProcessor() {
} catch (err: any) { } catch (err: any) {
error.value = err.message || '转换失败' error.value = err.message || '转换失败'
} finally { } finally {
stopProgressPolling()
processingStage.value = '' processingStage.value = ''
isProcessing.value = false isProcessing.value = false
} }
@ -242,6 +279,8 @@ export function useDocumentProcessor() {
isUploading, isUploading,
isProcessing, isProcessing,
processingStage, processingStage,
progressPercent,
progressStage,
error, error,
// 选项 // 选项

View File

@ -103,8 +103,19 @@
<div class="result-content"> <div class="result-content">
<div v-if="isProcessing" class="loading-container"> <div v-if="isProcessing" class="loading-container">
<el-icon class="loading-icon"><Loading /></el-icon> <div class="progress-wrapper">
<p class="loading-text">{{ processingStage || '正在处理文档...' }}</p> <div class="progress-info">
<span class="progress-stage">{{ progressStage || processingStage || '正在处理文档...' }}</span>
<span class="progress-percent">{{ progressPercent }}%</span>
</div>
<el-progress
:percentage="progressPercent"
:stroke-width="10"
:show-text="false"
:indeterminate="progressPercent === 0"
class="progress-bar"
/>
</div>
</div> </div>
<div v-else-if="!results" class="empty-state"> <div v-else-if="!results" class="empty-state">
@ -226,6 +237,8 @@ const {
isUploading, isUploading,
isProcessing, isProcessing,
processingStage, processingStage,
progressPercent,
progressStage,
error, error,
backendOptions, backendOptions,
languageOptions, languageOptions,
@ -536,7 +549,6 @@ onUnmounted(() => {
.drag-upload-area.collapsed { .drag-upload-area.collapsed {
padding: 8px 16px; padding: 8px 16px;
min-height: 36px; min-height: 36px;
max-height: 60px;
display: block; display: block;
} }
@ -745,7 +757,47 @@ onUnmounted(() => {
flex-direction: column; flex-direction: column;
} }
.loading-container, .loading-container {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
height: 400px;
gap: 16px;
}
.progress-wrapper {
width: 100%;
max-width: 500px;
display: flex;
flex-direction: column;
gap: 12px;
}
.progress-info {
display: flex;
justify-content: space-between;
align-items: center;
}
.progress-stage {
font-size: 14px;
color: #495057;
font-weight: 500;
}
.progress-percent {
font-size: 14px;
color: #165DFF;
font-weight: 600;
min-width: 40px;
text-align: right;
}
.progress-bar {
width: 100%;
}
.empty-state { .empty-state {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
@ -755,13 +807,6 @@ onUnmounted(() => {
gap: 16px; gap: 16px;
} }
.loading-icon {
font-size: 48px;
color: #165DFF;
animation: spin 1s linear infinite;
}
.loading-text,
.empty-text { .empty-text {
font-size: 14px; font-size: 14px;
color: #6C757D; color: #6C757D;

View File

@ -12,8 +12,8 @@ export default defineConfig({
server: { server: {
port: 3000, port: 3000,
proxy: { proxy: {
'/file_parse': { '/api': {
target: 'http://10.100.52.43:8000', target: 'http://localhost:8000',
changeOrigin: true changeOrigin: true
} }
} }