UnisMindMap/mineru/model/utils/pytorchocr/base_ocr_v20.py

40 lines
1.3 KiB
Python

import os
import torch
from .modeling.architectures.base_model import BaseModel
class BaseOCRV20:
def __init__(self, config, **kwargs):
self.config = config
self.build_net(**kwargs)
self.net.eval()
def build_net(self, **kwargs):
self.net = BaseModel(self.config, **kwargs)
def read_pytorch_weights(self, weights_path):
if not os.path.exists(weights_path):
raise FileNotFoundError('{} is not existed.'.format(weights_path))
weights = torch.load(weights_path)
return weights
def get_out_channels(self, weights):
if list(weights.keys())[-1].endswith('.weight') and len(list(weights.values())[-1].shape) == 2:
out_channels = list(weights.values())[-1].numpy().shape[1]
else:
out_channels = list(weights.values())[-1].numpy().shape[0]
return out_channels
def load_state_dict(self, weights):
self.net.load_state_dict(weights)
# print('weights is loaded.')
def load_pytorch_weights(self, weights_path):
self.net.load_state_dict(torch.load(weights_path, weights_only=True))
# print('model is loaded: {}'.format(weights_path))
def inference(self, inputs):
with torch.no_grad():
infer = self.net(inputs)
return infer