119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
import os
|
|
from typing import List, Dict, Union
|
|
|
|
from doclayout_yolo import YOLOv10
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
from PIL import Image, ImageDraw
|
|
|
|
from mineru.utils.enum_class import ModelPath
|
|
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
|
|
|
class DocLayoutYOLOModel:
|
|
def __init__(
|
|
self,
|
|
weight: str,
|
|
device: str = "cuda",
|
|
imgsz: int = 1280,
|
|
conf: float = 0.1,
|
|
iou: float = 0.45,
|
|
):
|
|
self.model = YOLOv10(weight).to(device)
|
|
self.device = device
|
|
self.imgsz = imgsz
|
|
self.conf = conf
|
|
self.iou = iou
|
|
|
|
def _parse_prediction(self, prediction) -> List[Dict]:
|
|
layout_res = []
|
|
|
|
# 容错处理
|
|
if not hasattr(prediction, "boxes") or prediction.boxes is None:
|
|
return layout_res
|
|
|
|
for xyxy, conf, cls in zip(
|
|
prediction.boxes.xyxy.cpu(),
|
|
prediction.boxes.conf.cpu(),
|
|
prediction.boxes.cls.cpu(),
|
|
):
|
|
coords = list(map(int, xyxy.tolist()))
|
|
xmin, ymin, xmax, ymax = coords
|
|
layout_res.append({
|
|
"category_id": int(cls.item()),
|
|
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
"score": round(float(conf.item()), 3),
|
|
})
|
|
return layout_res
|
|
|
|
def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
|
|
prediction = self.model.predict(
|
|
image,
|
|
imgsz=self.imgsz,
|
|
conf=self.conf,
|
|
iou=self.iou,
|
|
verbose=False
|
|
)[0]
|
|
return self._parse_prediction(prediction)
|
|
|
|
def batch_predict(
|
|
self,
|
|
images: List[Union[np.ndarray, Image.Image]],
|
|
batch_size: int = 4
|
|
) -> List[List[Dict]]:
|
|
results = []
|
|
with tqdm(total=len(images), desc="Layout Predict") as pbar:
|
|
for idx in range(0, len(images), batch_size):
|
|
batch = images[idx: idx + batch_size]
|
|
if batch_size == 1:
|
|
conf = 0.9 * self.conf
|
|
else:
|
|
conf = self.conf
|
|
predictions = self.model.predict(
|
|
batch,
|
|
imgsz=self.imgsz,
|
|
conf=conf,
|
|
iou=self.iou,
|
|
verbose=False,
|
|
)
|
|
for pred in predictions:
|
|
results.append(self._parse_prediction(pred))
|
|
pbar.update(len(batch))
|
|
return results
|
|
|
|
def visualize(
|
|
self,
|
|
image: Union[np.ndarray, Image.Image],
|
|
results: List
|
|
) -> Image.Image:
|
|
|
|
if isinstance(image, np.ndarray):
|
|
image = Image.fromarray(image)
|
|
|
|
draw = ImageDraw.Draw(image)
|
|
for res in results:
|
|
poly = res['poly']
|
|
xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
|
|
print(
|
|
f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
|
|
# 使用PIL在图像上画框
|
|
draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
|
|
# 在框旁边画置信度
|
|
draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red", font_size=22)
|
|
return image
|
|
|
|
|
|
if __name__ == '__main__':
|
|
image_path = r"C:\Users\zhaoxiaomeng\Downloads\下载1.jpg"
|
|
doclayout_yolo_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
|
|
device = 'cuda'
|
|
model = DocLayoutYOLOModel(
|
|
weight=doclayout_yolo_weights,
|
|
device=device,
|
|
)
|
|
image = Image.open(image_path)
|
|
results = model.predict(image)
|
|
|
|
image = model.visualize(image, results)
|
|
|
|
image.show() # 显示图像 |