mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
refactor: enhance batch processing and IOU filtering in hybrid analysis
This commit is contained in:
@@ -42,7 +42,8 @@ def ocr_det(
|
||||
np_images,
|
||||
results,
|
||||
mfd_res,
|
||||
_ocr_enable
|
||||
_ocr_enable,
|
||||
batch_radio: int = 1,
|
||||
):
|
||||
ocr_res_list = []
|
||||
if not hybrid_pipeline_model.enable_ocr_det_batch:
|
||||
@@ -137,7 +138,7 @@ def ocr_det(
|
||||
batch_images.append(padded_img)
|
||||
|
||||
# 批处理检测
|
||||
det_batch_size = min(len(batch_images), OCR_DET_BASE_BATCH_SIZE)
|
||||
det_batch_size = min(len(batch_images), batch_radio*OCR_DET_BASE_BATCH_SIZE)
|
||||
batch_results = hybrid_pipeline_model.ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
|
||||
|
||||
# 处理批处理结果
|
||||
@@ -202,6 +203,7 @@ def _process_ocr_and_formulas(
|
||||
inline_formula_enable,
|
||||
_ocr_enable,
|
||||
_vlm_ocr_enable,
|
||||
batch_radio: int = 2,
|
||||
):
|
||||
"""处理OCR和公式识别"""
|
||||
|
||||
@@ -225,19 +227,20 @@ def _process_ocr_and_formulas(
|
||||
# 在进行`行内`公式检测和识别前,先将图像中的图片、表格、`行间`公式区域mask掉
|
||||
np_images = mask_image_regions(np_images, results)
|
||||
# 公式检测
|
||||
images_mfd_res = hybrid_pipeline_model.mfd_model.batch_predict(np_images, 1)
|
||||
images_mfd_res = hybrid_pipeline_model.mfd_model.batch_predict(np_images, batch_size=1, conf=0.5)
|
||||
# 公式识别
|
||||
inline_formula_list = hybrid_pipeline_model.mfr_model.batch_predict(
|
||||
images_mfd_res,
|
||||
np_images,
|
||||
batch_size=MFR_BASE_BATCH_SIZE,
|
||||
interline_enable=False,
|
||||
batch_size=batch_radio*MFR_BASE_BATCH_SIZE,
|
||||
interline_enable=True,
|
||||
)
|
||||
|
||||
mfd_res = []
|
||||
for page_inline_formula_list in inline_formula_list:
|
||||
page_mfd_res = []
|
||||
for formula in page_inline_formula_list:
|
||||
formula['category_id'] = 13
|
||||
page_mfd_res.append({
|
||||
"bbox": [int(formula['poly'][0]), int(formula['poly'][1]),
|
||||
int(formula['poly'][4]), int(formula['poly'][5])],
|
||||
@@ -251,6 +254,7 @@ def _process_ocr_and_formulas(
|
||||
results,
|
||||
mfd_res,
|
||||
_ocr_enable,
|
||||
batch_radio=batch_radio,
|
||||
)
|
||||
|
||||
# 如果需要ocr则做ocr_rec
|
||||
|
||||
@@ -27,31 +27,37 @@ class YOLOv8MFDModel:
|
||||
def _run_predict(
|
||||
self,
|
||||
inputs: Union[np.ndarray, Image.Image, List],
|
||||
is_batch: bool = False
|
||||
is_batch: bool = False,
|
||||
conf: float = None,
|
||||
) -> List:
|
||||
preds = self.model.predict(
|
||||
inputs,
|
||||
imgsz=self.imgsz,
|
||||
conf=self.conf,
|
||||
conf=conf if conf is not None else self.conf,
|
||||
iou=self.iou,
|
||||
verbose=False,
|
||||
device=self.device
|
||||
)
|
||||
return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu()
|
||||
|
||||
def predict(self, image: Union[np.ndarray, Image.Image]):
|
||||
return self._run_predict(image)
|
||||
def predict(
|
||||
self,
|
||||
image: Union[np.ndarray, Image.Image],
|
||||
conf: float = None,
|
||||
):
|
||||
return self._run_predict(image, is_batch=False, conf=conf)
|
||||
|
||||
def batch_predict(
|
||||
self,
|
||||
images: List[Union[np.ndarray, Image.Image]],
|
||||
batch_size: int = 4
|
||||
batch_size: int = 4,
|
||||
conf: float = None,
|
||||
) -> List:
|
||||
results = []
|
||||
with tqdm(total=len(images), desc="MFD Predict") as pbar:
|
||||
for idx in range(0, len(images), batch_size):
|
||||
batch = images[idx: idx + batch_size]
|
||||
batch_preds = self._run_predict(batch, is_batch=True)
|
||||
batch_preds = self._run_predict(batch, is_batch=True, conf=conf)
|
||||
results.extend(batch_preds)
|
||||
pbar.update(len(batch))
|
||||
return results
|
||||
|
||||
@@ -2,6 +2,8 @@ import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from mineru.utils.boxbase import calculate_iou
|
||||
|
||||
|
||||
class MathDataset(Dataset):
|
||||
def __init__(self, image_paths, transform=None):
|
||||
@@ -31,11 +33,64 @@ class UnimernetModel(object):
|
||||
self.model = self.model.to(dtype=torch.float16)
|
||||
self.model.eval()
|
||||
|
||||
@staticmethod
|
||||
def _filter_boxes_by_iou(xyxy, conf, cla, iou_threshold=0.8):
|
||||
"""过滤IOU超过阈值的重叠框,保留置信度较高的框。
|
||||
|
||||
Args:
|
||||
xyxy: 框坐标张量,shape为(N, 4)
|
||||
conf: 置信度张量,shape为(N,)
|
||||
cla: 类别张量,shape为(N,)
|
||||
iou_threshold: IOU阈值,默认0.9
|
||||
|
||||
Returns:
|
||||
过滤后的xyxy, conf, cla张量
|
||||
"""
|
||||
if len(xyxy) == 0:
|
||||
return xyxy, conf, cla
|
||||
|
||||
# 转换为CPU进行处理
|
||||
xyxy_cpu = xyxy.cpu()
|
||||
conf_cpu = conf.cpu()
|
||||
|
||||
n = len(xyxy_cpu)
|
||||
keep = [True] * n
|
||||
|
||||
for i in range(n):
|
||||
if not keep[i]:
|
||||
continue
|
||||
bbox1 = xyxy_cpu[i].tolist()
|
||||
for j in range(i + 1, n):
|
||||
if not keep[j]:
|
||||
continue
|
||||
bbox2 = xyxy_cpu[j].tolist()
|
||||
iou = calculate_iou(bbox1, bbox2)
|
||||
if iou > iou_threshold:
|
||||
# 保留置信度较高的框
|
||||
if conf_cpu[i] >= conf_cpu[j]:
|
||||
keep[j] = False
|
||||
else:
|
||||
keep[i] = False
|
||||
break # i被删除,跳出内循环
|
||||
|
||||
keep_indices = [i for i in range(n) if keep[i]]
|
||||
if len(keep_indices) == n:
|
||||
return xyxy, conf, cla
|
||||
|
||||
keep_indices = torch.tensor(keep_indices, dtype=torch.long)
|
||||
return xyxy[keep_indices], conf[keep_indices], cla[keep_indices]
|
||||
|
||||
def predict(self, mfd_res, image):
|
||||
formula_list = []
|
||||
mf_image_list = []
|
||||
|
||||
# 对检测框进行IOU去重,保留置信度较高的框
|
||||
xyxy_filtered, conf_filtered, cla_filtered = self._filter_boxes_by_iou(
|
||||
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
||||
)
|
||||
|
||||
for xyxy, conf, cla in zip(
|
||||
mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
|
||||
xyxy_filtered.cpu(), conf_filtered.cpu(), cla_filtered.cpu()
|
||||
):
|
||||
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
||||
new_item = {
|
||||
@@ -79,8 +134,13 @@ class UnimernetModel(object):
|
||||
image = images[image_index]
|
||||
formula_list = []
|
||||
|
||||
# 对检测框进行IOU去重,保留置信度较高的框
|
||||
xyxy_filtered, conf_filtered, cla_filtered = self._filter_boxes_by_iou(
|
||||
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
||||
)
|
||||
|
||||
for idx, (xyxy, conf, cla) in enumerate(zip(
|
||||
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
||||
xyxy_filtered, conf_filtered, cla_filtered
|
||||
)):
|
||||
if not interline_enable and cla.item() == 1:
|
||||
continue # Skip interline regions if not enabled
|
||||
|
||||
Reference in New Issue
Block a user