437 lines
16 KiB
Python
437 lines
16 KiB
Python
"""
|
||
MinerU Tianshu - SQLite Task Database Manager
|
||
天枢任务数据库管理器
|
||
|
||
负责任务的持久化存储、状态管理和原子性操作
|
||
"""
|
||
import sqlite3
|
||
import json
|
||
import uuid
|
||
from contextlib import contextmanager
|
||
from typing import Optional, List, Dict
|
||
from pathlib import Path
|
||
|
||
|
||
class TaskDB:
|
||
"""任务数据库管理类"""
|
||
|
||
def __init__(self, db_path='mineru_tianshu.db'):
|
||
self.db_path = db_path
|
||
self._init_db()
|
||
|
||
def _get_conn(self):
|
||
"""获取数据库连接(每次创建新连接,避免 pickle 问题)
|
||
|
||
并发安全说明:
|
||
- 使用 check_same_thread=False 是安全的,因为:
|
||
1. 每次调用都创建新连接,不跨线程共享
|
||
2. 连接使用完立即关闭(在 get_cursor 上下文管理器中)
|
||
3. 不使用连接池,避免线程间共享同一连接
|
||
- timeout=30.0 防止死锁,如果锁等待超过30秒会抛出异常
|
||
"""
|
||
conn = sqlite3.connect(
|
||
self.db_path,
|
||
check_same_thread=False,
|
||
timeout=30.0
|
||
)
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
@contextmanager
|
||
def get_cursor(self):
|
||
"""上下文管理器,自动提交和错误处理"""
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
try:
|
||
yield cursor
|
||
conn.commit()
|
||
except Exception as e:
|
||
conn.rollback()
|
||
raise e
|
||
finally:
|
||
conn.close() # 关闭连接
|
||
|
||
def _init_db(self):
|
||
"""初始化数据库表"""
|
||
with self.get_cursor() as cursor:
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS tasks (
|
||
task_id TEXT PRIMARY KEY,
|
||
file_name TEXT NOT NULL,
|
||
file_path TEXT,
|
||
status TEXT DEFAULT 'pending',
|
||
priority INTEGER DEFAULT 0,
|
||
backend TEXT DEFAULT 'pipeline',
|
||
options TEXT,
|
||
result_path TEXT,
|
||
error_message TEXT,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
started_at TIMESTAMP,
|
||
completed_at TIMESTAMP,
|
||
worker_id TEXT,
|
||
retry_count INTEGER DEFAULT 0
|
||
)
|
||
''')
|
||
|
||
# 创建索引加速查询
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_priority ON tasks(priority DESC)')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created_at ON tasks(created_at)')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_worker_id ON tasks(worker_id)')
|
||
|
||
def create_task(self, file_name: str, file_path: str,
|
||
backend: str = 'pipeline', options: dict = None,
|
||
priority: int = 0) -> str:
|
||
"""
|
||
创建新任务
|
||
|
||
Args:
|
||
file_name: 文件名
|
||
file_path: 文件路径
|
||
backend: 处理后端 (pipeline/vlm-transformers/vlm-vllm-engine)
|
||
options: 处理选项 (dict)
|
||
priority: 优先级,数字越大越优先
|
||
|
||
Returns:
|
||
task_id: 任务ID
|
||
"""
|
||
task_id = str(uuid.uuid4())
|
||
with self.get_cursor() as cursor:
|
||
cursor.execute('''
|
||
INSERT INTO tasks (task_id, file_name, file_path, backend, options, priority)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
''', (task_id, file_name, file_path, backend, json.dumps(options or {}), priority))
|
||
return task_id
|
||
|
||
def get_next_task(self, worker_id: str, max_retries: int = 3) -> Optional[Dict]:
|
||
"""
|
||
获取下一个待处理任务(原子操作,防止并发冲突)
|
||
|
||
Args:
|
||
worker_id: Worker ID
|
||
max_retries: 当任务被其他 worker 抢走时的最大重试次数(默认3次)
|
||
|
||
Returns:
|
||
task: 任务字典,如果没有任务返回 None
|
||
|
||
并发安全说明:
|
||
1. 使用 BEGIN IMMEDIATE 立即获取写锁
|
||
2. UPDATE 时检查 status = 'pending' 防止重复拉取
|
||
3. 检查 rowcount 确保更新成功
|
||
4. 如果任务被抢走,立即重试而不是返回 None(避免不必要的等待)
|
||
"""
|
||
for attempt in range(max_retries):
|
||
with self.get_cursor() as cursor:
|
||
# 使用事务确保原子性
|
||
cursor.execute('BEGIN IMMEDIATE')
|
||
|
||
# 按优先级和创建时间获取任务
|
||
cursor.execute('''
|
||
SELECT * FROM tasks
|
||
WHERE status = 'pending'
|
||
ORDER BY priority DESC, created_at ASC
|
||
LIMIT 1
|
||
''')
|
||
|
||
task = cursor.fetchone()
|
||
if task:
|
||
# 立即标记为 processing,并确保状态仍是 pending
|
||
cursor.execute('''
|
||
UPDATE tasks
|
||
SET status = 'processing',
|
||
started_at = CURRENT_TIMESTAMP,
|
||
worker_id = ?
|
||
WHERE task_id = ? AND status = 'pending'
|
||
''', (worker_id, task['task_id']))
|
||
|
||
# 检查是否更新成功(防止被其他 worker 抢走)
|
||
if cursor.rowcount == 0:
|
||
# 任务被其他进程抢走了,立即重试
|
||
# 因为队列中可能还有其他待处理任务
|
||
continue
|
||
|
||
return dict(task)
|
||
else:
|
||
# 队列中没有待处理任务,返回 None
|
||
return None
|
||
|
||
# 重试次数用尽,仍未获取到任务(高并发场景)
|
||
return None
|
||
|
||
def _build_update_clauses(self, status: str, result_path: str = None,
|
||
error_message: str = None, worker_id: str = None,
|
||
task_id: str = None):
|
||
"""
|
||
构建 UPDATE 和 WHERE 子句的辅助方法
|
||
|
||
Args:
|
||
status: 新状态
|
||
result_path: 结果路径(可选)
|
||
error_message: 错误信息(可选)
|
||
worker_id: Worker ID(可选)
|
||
task_id: 任务ID(可选)
|
||
|
||
Returns:
|
||
tuple: (update_clauses, update_params, where_clauses, where_params)
|
||
"""
|
||
update_clauses = ['status = ?']
|
||
update_params = [status]
|
||
where_clauses = []
|
||
where_params = []
|
||
|
||
# 添加 task_id 条件(如果提供)
|
||
if task_id:
|
||
where_clauses.append('task_id = ?')
|
||
where_params.append(task_id)
|
||
|
||
# 处理 completed 状态
|
||
if status == 'completed':
|
||
update_clauses.append('completed_at = CURRENT_TIMESTAMP')
|
||
if result_path:
|
||
update_clauses.append('result_path = ?')
|
||
update_params.append(result_path)
|
||
# 只更新正在处理的任务
|
||
where_clauses.append("status = 'processing'")
|
||
if worker_id:
|
||
where_clauses.append('worker_id = ?')
|
||
where_params.append(worker_id)
|
||
|
||
# 处理 failed 状态
|
||
elif status == 'failed':
|
||
update_clauses.append('completed_at = CURRENT_TIMESTAMP')
|
||
if error_message:
|
||
update_clauses.append('error_message = ?')
|
||
update_params.append(error_message)
|
||
# 只更新正在处理的任务
|
||
where_clauses.append("status = 'processing'")
|
||
if worker_id:
|
||
where_clauses.append('worker_id = ?')
|
||
where_params.append(worker_id)
|
||
|
||
return update_clauses, update_params, where_clauses, where_params
|
||
|
||
def update_task_status(self, task_id: str, status: str,
|
||
result_path: str = None, error_message: str = None,
|
||
worker_id: str = None):
|
||
"""
|
||
更新任务状态
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
status: 新状态 (pending/processing/completed/failed/cancelled)
|
||
result_path: 结果路径(可选)
|
||
error_message: 错误信息(可选)
|
||
worker_id: Worker ID(可选,用于并发检查)
|
||
|
||
Returns:
|
||
bool: 更新是否成功
|
||
|
||
并发安全说明:
|
||
1. 更新为 completed/failed 时会检查状态是 processing
|
||
2. 如果提供 worker_id,会检查任务是否属于该 worker
|
||
3. 返回 False 表示任务被其他进程修改了
|
||
"""
|
||
with self.get_cursor() as cursor:
|
||
# 使用辅助方法构建 UPDATE 和 WHERE 子句
|
||
update_clauses, update_params, where_clauses, where_params = \
|
||
self._build_update_clauses(status, result_path, error_message, worker_id, task_id)
|
||
|
||
# 合并参数:先 UPDATE 部分,再 WHERE 部分
|
||
all_params = update_params + where_params
|
||
|
||
sql = f'''
|
||
UPDATE tasks
|
||
SET {', '.join(update_clauses)}
|
||
WHERE {' AND '.join(where_clauses)}
|
||
'''
|
||
|
||
cursor.execute(sql, all_params)
|
||
|
||
# 检查更新是否成功
|
||
success = cursor.rowcount > 0
|
||
|
||
# 调试日志(仅在失败时)
|
||
if not success and status in ['completed', 'failed']:
|
||
from loguru import logger
|
||
logger.debug(
|
||
f"Status update failed: task_id={task_id}, status={status}, "
|
||
f"worker_id={worker_id}, SQL: {sql}, params: {all_params}"
|
||
)
|
||
|
||
return success
|
||
|
||
def get_task(self, task_id: str) -> Optional[Dict]:
|
||
"""
|
||
查询任务详情
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
task: 任务字典,如果不存在返回 None
|
||
"""
|
||
with self.get_cursor() as cursor:
|
||
cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,))
|
||
task = cursor.fetchone()
|
||
return dict(task) if task else None
|
||
|
||
def get_queue_stats(self) -> Dict[str, int]:
|
||
"""
|
||
获取队列统计信息
|
||
|
||
Returns:
|
||
stats: 各状态的任务数量
|
||
"""
|
||
with self.get_cursor() as cursor:
|
||
cursor.execute('''
|
||
SELECT status, COUNT(*) as count
|
||
FROM tasks
|
||
GROUP BY status
|
||
''')
|
||
stats = {row['status']: row['count'] for row in cursor.fetchall()}
|
||
return stats
|
||
|
||
def get_tasks_by_status(self, status: str, limit: int = 100) -> List[Dict]:
|
||
"""
|
||
根据状态获取任务列表
|
||
|
||
Args:
|
||
status: 任务状态
|
||
limit: 返回数量限制
|
||
|
||
Returns:
|
||
tasks: 任务列表
|
||
"""
|
||
with self.get_cursor() as cursor:
|
||
cursor.execute('''
|
||
SELECT * FROM tasks
|
||
WHERE status = ?
|
||
ORDER BY created_at DESC
|
||
LIMIT ?
|
||
''', (status, limit))
|
||
return [dict(row) for row in cursor.fetchall()]
|
||
|
||
def cleanup_old_task_files(self, days: int = 7):
|
||
"""
|
||
清理旧任务的结果文件(保留数据库记录)
|
||
|
||
Args:
|
||
days: 清理多少天前的任务文件
|
||
|
||
Returns:
|
||
int: 删除的文件目录数
|
||
|
||
注意:
|
||
- 只删除结果文件,保留数据库记录
|
||
- 数据库中的 result_path 字段会被清空
|
||
- 用户仍可查询任务状态和历史记录
|
||
"""
|
||
from pathlib import Path
|
||
import shutil
|
||
|
||
with self.get_cursor() as cursor:
|
||
# 查询要清理文件的任务
|
||
cursor.execute('''
|
||
SELECT task_id, result_path FROM tasks
|
||
WHERE completed_at < datetime('now', '-' || ? || ' days')
|
||
AND status IN ('completed', 'failed')
|
||
AND result_path IS NOT NULL
|
||
''', (days,))
|
||
|
||
old_tasks = cursor.fetchall()
|
||
file_count = 0
|
||
|
||
# 删除结果文件
|
||
for task in old_tasks:
|
||
if task['result_path']:
|
||
result_path = Path(task['result_path'])
|
||
if result_path.exists() and result_path.is_dir():
|
||
try:
|
||
shutil.rmtree(result_path)
|
||
file_count += 1
|
||
|
||
# 清空数据库中的 result_path,表示文件已被清理
|
||
cursor.execute('''
|
||
UPDATE tasks
|
||
SET result_path = NULL
|
||
WHERE task_id = ?
|
||
''', (task['task_id'],))
|
||
|
||
except Exception as e:
|
||
from loguru import logger
|
||
logger.warning(f"Failed to delete result files for task {task['task_id']}: {e}")
|
||
|
||
return file_count
|
||
|
||
def cleanup_old_task_records(self, days: int = 30):
|
||
"""
|
||
清理极旧的任务记录(可选功能)
|
||
|
||
Args:
|
||
days: 删除多少天前的任务记录
|
||
|
||
Returns:
|
||
int: 删除的记录数
|
||
|
||
注意:
|
||
- 这个方法会永久删除数据库记录
|
||
- 建议设置较长的保留期(如30-90天)
|
||
- 一般情况下不需要调用此方法
|
||
"""
|
||
with self.get_cursor() as cursor:
|
||
cursor.execute('''
|
||
DELETE FROM tasks
|
||
WHERE completed_at < datetime('now', '-' || ? || ' days')
|
||
AND status IN ('completed', 'failed')
|
||
''', (days,))
|
||
|
||
deleted_count = cursor.rowcount
|
||
return deleted_count
|
||
|
||
def reset_stale_tasks(self, timeout_minutes: int = 60):
|
||
"""
|
||
重置超时的 processing 任务为 pending
|
||
|
||
Args:
|
||
timeout_minutes: 超时时间(分钟)
|
||
"""
|
||
with self.get_cursor() as cursor:
|
||
cursor.execute('''
|
||
UPDATE tasks
|
||
SET status = 'pending',
|
||
worker_id = NULL,
|
||
retry_count = retry_count + 1
|
||
WHERE status = 'processing'
|
||
AND started_at < datetime('now', '-' || ? || ' minutes')
|
||
''', (timeout_minutes,))
|
||
reset_count = cursor.rowcount
|
||
return reset_count
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 测试代码
|
||
db = TaskDB('test_tianshu.db')
|
||
|
||
# 创建测试任务
|
||
task_id = db.create_task(
|
||
file_name='test.pdf',
|
||
file_path='/tmp/test.pdf',
|
||
backend='pipeline',
|
||
options={'lang': 'ch', 'formula_enable': True},
|
||
priority=1
|
||
)
|
||
print(f"Created task: {task_id}")
|
||
|
||
# 查询任务
|
||
task = db.get_task(task_id)
|
||
print(f"Task details: {task}")
|
||
|
||
# 获取统计
|
||
stats = db.get_queue_stats()
|
||
print(f"Queue stats: {stats}")
|
||
|
||
# 清理测试数据库
|
||
Path('test_tianshu.db').unlink(missing_ok=True)
|
||
print("Test completed!")
|
||
|