365 lines
12 KiB
Python
365 lines
12 KiB
Python
import os
|
|
|
|
import torch
|
|
from loguru import logger
|
|
|
|
from .model_list import AtomicModel
|
|
from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
|
|
from ...model.mfd.yolo_v8 import YOLOv8MFDModel
|
|
from ...model.mfr.unimernet.Unimernet import UnimernetModel
|
|
from ...model.mfr.pp_formulanet_plus_m.predict_formula import FormulaRecognizer
|
|
from mineru.model.ocr.pytorch_paddle import PytorchPaddleOCR
|
|
from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
|
|
from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
|
|
# from ...model.table.rec.RapidTable import RapidTableModel
|
|
from ...model.table.rec.slanet_plus.main import RapidTableModel
|
|
from ...model.table.rec.unet_table.main import UnetTableModel
|
|
from ...utils.config_reader import get_device
|
|
from ...utils.enum_class import ModelPath
|
|
from ...utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
|
MFR_MODEL = os.getenv('MINERU_FORMULA_CH_SUPPORT', 'False')
|
|
if MFR_MODEL.lower() in ['true', '1', 'yes']:
|
|
MFR_MODEL = "pp_formulanet_plus_m"
|
|
elif MFR_MODEL.lower() in ['false', '0', 'no']:
|
|
MFR_MODEL = "unimernet_small"
|
|
else:
|
|
logger.warning(f"Invalid MINERU_FORMULA_CH_SUPPORT value: {MFR_MODEL}, set to default 'False'")
|
|
MFR_MODEL = "unimernet_small"
|
|
|
|
|
|
def img_orientation_cls_model_init():
|
|
atom_model_manager = AtomModelSingleton()
|
|
ocr_engine = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.OCR,
|
|
det_db_box_thresh=0.5,
|
|
det_db_unclip_ratio=1.6,
|
|
lang="ch_lite",
|
|
enable_merge_det_boxes=False
|
|
)
|
|
cls_model = PaddleOrientationClsModel(ocr_engine)
|
|
return cls_model
|
|
|
|
|
|
def table_cls_model_init():
|
|
return PaddleTableClsModel()
|
|
|
|
|
|
def wired_table_model_init(lang=None):
|
|
atom_model_manager = AtomModelSingleton()
|
|
ocr_engine = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.OCR,
|
|
det_db_box_thresh=0.5,
|
|
det_db_unclip_ratio=1.6,
|
|
lang=lang,
|
|
enable_merge_det_boxes=False
|
|
)
|
|
table_model = UnetTableModel(ocr_engine)
|
|
return table_model
|
|
|
|
|
|
def wireless_table_model_init(lang=None):
|
|
atom_model_manager = AtomModelSingleton()
|
|
ocr_engine = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.OCR,
|
|
det_db_box_thresh=0.5,
|
|
det_db_unclip_ratio=1.6,
|
|
lang=lang,
|
|
enable_merge_det_boxes=False
|
|
)
|
|
table_model = RapidTableModel(ocr_engine)
|
|
return table_model
|
|
|
|
|
|
def mfd_model_init(weight, device='cpu'):
|
|
if str(device).startswith('npu'):
|
|
device = torch.device(device)
|
|
mfd_model = YOLOv8MFDModel(weight, device)
|
|
return mfd_model
|
|
|
|
|
|
def mfr_model_init(weight_dir, device='cpu'):
|
|
if MFR_MODEL == "unimernet_small":
|
|
mfr_model = UnimernetModel(weight_dir, device)
|
|
elif MFR_MODEL == "pp_formulanet_plus_m":
|
|
mfr_model = FormulaRecognizer(weight_dir, device)
|
|
else:
|
|
logger.error('MFR model name not allow')
|
|
exit(1)
|
|
return mfr_model
|
|
|
|
|
|
def doclayout_yolo_model_init(weight, device='cpu'):
|
|
if str(device).startswith('npu'):
|
|
device = torch.device(device)
|
|
model = DocLayoutYOLOModel(weight, device)
|
|
return model
|
|
|
|
def ocr_model_init(det_db_box_thresh=0.3,
|
|
lang=None,
|
|
det_db_unclip_ratio=1.8,
|
|
enable_merge_det_boxes=True
|
|
):
|
|
if lang is not None and lang != '':
|
|
model = PytorchPaddleOCR(
|
|
det_db_box_thresh=det_db_box_thresh,
|
|
lang=lang,
|
|
use_dilation=True,
|
|
det_db_unclip_ratio=det_db_unclip_ratio,
|
|
enable_merge_det_boxes=enable_merge_det_boxes,
|
|
)
|
|
else:
|
|
model = PytorchPaddleOCR(
|
|
det_db_box_thresh=det_db_box_thresh,
|
|
use_dilation=True,
|
|
det_db_unclip_ratio=det_db_unclip_ratio,
|
|
enable_merge_det_boxes=enable_merge_det_boxes,
|
|
)
|
|
return model
|
|
|
|
|
|
class AtomModelSingleton:
|
|
_instance = None
|
|
_models = {}
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def get_atom_model(self, atom_model_name: str, **kwargs):
|
|
|
|
lang = kwargs.get('lang', None)
|
|
|
|
if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
|
|
key = (
|
|
atom_model_name,
|
|
lang
|
|
)
|
|
elif atom_model_name in [AtomicModel.OCR]:
|
|
key = (
|
|
atom_model_name,
|
|
kwargs.get('det_db_box_thresh', 0.3),
|
|
lang,
|
|
kwargs.get('det_db_unclip_ratio', 1.8),
|
|
kwargs.get('enable_merge_det_boxes', True)
|
|
)
|
|
else:
|
|
key = atom_model_name
|
|
|
|
if key not in self._models:
|
|
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
|
return self._models[key]
|
|
|
|
def atom_model_init(model_name: str, **kwargs):
|
|
atom_model = None
|
|
if model_name == AtomicModel.Layout:
|
|
atom_model = doclayout_yolo_model_init(
|
|
kwargs.get('doclayout_yolo_weights'),
|
|
kwargs.get('device')
|
|
)
|
|
elif model_name == AtomicModel.MFD:
|
|
atom_model = mfd_model_init(
|
|
kwargs.get('mfd_weights'),
|
|
kwargs.get('device')
|
|
)
|
|
elif model_name == AtomicModel.MFR:
|
|
atom_model = mfr_model_init(
|
|
kwargs.get('mfr_weight_dir'),
|
|
kwargs.get('device')
|
|
)
|
|
elif model_name == AtomicModel.OCR:
|
|
atom_model = ocr_model_init(
|
|
kwargs.get('det_db_box_thresh', 0.3),
|
|
kwargs.get('lang'),
|
|
kwargs.get('det_db_unclip_ratio', 1.8),
|
|
kwargs.get('enable_merge_det_boxes', True)
|
|
)
|
|
elif model_name == AtomicModel.WirelessTable:
|
|
atom_model = wireless_table_model_init(
|
|
kwargs.get('lang'),
|
|
)
|
|
elif model_name == AtomicModel.WiredTable:
|
|
atom_model = wired_table_model_init(
|
|
kwargs.get('lang'),
|
|
)
|
|
elif model_name == AtomicModel.TableCls:
|
|
atom_model = table_cls_model_init()
|
|
elif model_name == AtomicModel.ImgOrientationCls:
|
|
atom_model = img_orientation_cls_model_init()
|
|
else:
|
|
logger.error('model name not allow')
|
|
exit(1)
|
|
|
|
if atom_model is None:
|
|
logger.error('model init failed')
|
|
exit(1)
|
|
else:
|
|
return atom_model
|
|
|
|
|
|
class MineruPipelineModel:
|
|
def __init__(self, **kwargs):
|
|
self.formula_config = kwargs.get('formula_config')
|
|
self.apply_formula = self.formula_config.get('enable', True)
|
|
self.table_config = kwargs.get('table_config')
|
|
self.apply_table = self.table_config.get('enable', True)
|
|
self.lang = kwargs.get('lang', None)
|
|
self.device = kwargs.get('device', 'cpu')
|
|
logger.info(
|
|
'DocAnalysis init, this may take some times......'
|
|
)
|
|
atom_model_manager = AtomModelSingleton()
|
|
|
|
if self.apply_formula:
|
|
# 初始化公式检测模型
|
|
self.mfd_model = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.MFD,
|
|
mfd_weights=str(
|
|
os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
|
|
),
|
|
device=self.device,
|
|
)
|
|
|
|
# 初始化公式解析模型
|
|
if MFR_MODEL == "unimernet_small":
|
|
mfr_model_path = ModelPath.unimernet_small
|
|
elif MFR_MODEL == "pp_formulanet_plus_m":
|
|
mfr_model_path = ModelPath.pp_formulanet_plus_m
|
|
else:
|
|
logger.error('MFR model name not allow')
|
|
exit(1)
|
|
|
|
self.mfr_model = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.MFR,
|
|
mfr_weight_dir=str(os.path.join(auto_download_and_get_model_root_path(mfr_model_path), mfr_model_path)),
|
|
device=self.device,
|
|
)
|
|
|
|
# 初始化layout模型
|
|
self.layout_model = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.Layout,
|
|
doclayout_yolo_weights=str(
|
|
os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
|
|
),
|
|
device=self.device,
|
|
)
|
|
# 初始化ocr
|
|
self.ocr_model = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.OCR,
|
|
det_db_box_thresh=0.3,
|
|
lang=self.lang
|
|
)
|
|
# init table model
|
|
if self.apply_table:
|
|
self.wired_table_model = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.WiredTable,
|
|
lang=self.lang,
|
|
)
|
|
self.wireless_table_model = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.WirelessTable,
|
|
lang=self.lang,
|
|
)
|
|
self.table_cls_model = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.TableCls,
|
|
)
|
|
self.img_orientation_cls_model = atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.ImgOrientationCls,
|
|
lang=self.lang,
|
|
)
|
|
|
|
logger.info('DocAnalysis init done!')
|
|
|
|
|
|
class HybridModelSingleton:
|
|
_instance = None
|
|
_models = {}
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def get_model(
|
|
self,
|
|
lang=None,
|
|
formula_enable=None,
|
|
):
|
|
key = (lang, formula_enable)
|
|
if key not in self._models:
|
|
self._models[key] = MineruHybridModel(
|
|
lang=lang,
|
|
formula_enable=formula_enable,
|
|
)
|
|
return self._models[key]
|
|
|
|
def ocr_det_batch_setting(device):
|
|
# 检测torch的版本号
|
|
import torch
|
|
from packaging import version
|
|
if version.parse(torch.__version__) >= version.parse("2.8.0") or str(device).startswith('mps'):
|
|
enable_ocr_det_batch = False
|
|
else:
|
|
enable_ocr_det_batch = True
|
|
return enable_ocr_det_batch
|
|
|
|
class MineruHybridModel:
|
|
def __init__(
|
|
self,
|
|
device=None,
|
|
lang=None,
|
|
formula_enable=True,
|
|
):
|
|
if device is not None:
|
|
self.device = device
|
|
else:
|
|
self.device = get_device()
|
|
|
|
self.lang = lang
|
|
|
|
self.enable_ocr_det_batch = ocr_det_batch_setting(self.device)
|
|
|
|
if str(self.device).startswith('npu'):
|
|
try:
|
|
import torch_npu
|
|
if torch_npu.npu.is_available():
|
|
torch_npu.npu.set_compile_mode(jit_compile=False)
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
"NPU is selected as device, but torch_npu is not available. "
|
|
"Please ensure that the torch_npu package is installed correctly."
|
|
) from e
|
|
|
|
self.atom_model_manager = AtomModelSingleton()
|
|
|
|
# 初始化OCR模型
|
|
self.ocr_model = self.atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.OCR,
|
|
det_db_box_thresh=0.3,
|
|
lang=self.lang
|
|
)
|
|
|
|
if formula_enable:
|
|
# 初始化公式检测模型
|
|
self.mfd_model = self.atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.MFD,
|
|
mfd_weights=str(
|
|
os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
|
|
),
|
|
device=self.device,
|
|
)
|
|
|
|
# 初始化公式解析模型
|
|
if MFR_MODEL == "unimernet_small":
|
|
mfr_model_path = ModelPath.unimernet_small
|
|
elif MFR_MODEL == "pp_formulanet_plus_m":
|
|
mfr_model_path = ModelPath.pp_formulanet_plus_m
|
|
else:
|
|
logger.error('MFR model name not allow')
|
|
exit(1)
|
|
|
|
self.mfr_model = self.atom_model_manager.get_atom_model(
|
|
atom_model_name=AtomicModel.MFR,
|
|
mfr_weight_dir=str(os.path.join(auto_download_and_get_model_root_path(mfr_model_path), mfr_model_path)),
|
|
device=self.device,
|
|
) |