refactor: enhance batch processing and IOU filtering in hybrid analysis

This commit is contained in:
myhloli
2025-12-22 00:41:52 +08:00
parent 37a43e3318
commit 996be34534
3 changed files with 83 additions and 13 deletions

View File

@@ -42,7 +42,8 @@ def ocr_det(
np_images,
results,
mfd_res,
_ocr_enable
_ocr_enable,
batch_radio: int = 1,
):
ocr_res_list = []
if not hybrid_pipeline_model.enable_ocr_det_batch:
@@ -137,7 +138,7 @@ def ocr_det(
batch_images.append(padded_img)
# 批处理检测
det_batch_size = min(len(batch_images), OCR_DET_BASE_BATCH_SIZE)
det_batch_size = min(len(batch_images), batch_radio*OCR_DET_BASE_BATCH_SIZE)
batch_results = hybrid_pipeline_model.ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
# 处理批处理结果
@@ -202,6 +203,7 @@ def _process_ocr_and_formulas(
inline_formula_enable,
_ocr_enable,
_vlm_ocr_enable,
batch_radio: int = 2,
):
"""处理OCR和公式识别"""
@@ -225,19 +227,20 @@ def _process_ocr_and_formulas(
# 在进行`行内`公式检测和识别前,先将图像中的图片、表格、`行间`公式区域mask掉
np_images = mask_image_regions(np_images, results)
# 公式检测
images_mfd_res = hybrid_pipeline_model.mfd_model.batch_predict(np_images, 1)
images_mfd_res = hybrid_pipeline_model.mfd_model.batch_predict(np_images, batch_size=1, conf=0.5)
# 公式识别
inline_formula_list = hybrid_pipeline_model.mfr_model.batch_predict(
images_mfd_res,
np_images,
batch_size=MFR_BASE_BATCH_SIZE,
interline_enable=False,
batch_size=batch_radio*MFR_BASE_BATCH_SIZE,
interline_enable=True,
)
mfd_res = []
for page_inline_formula_list in inline_formula_list:
page_mfd_res = []
for formula in page_inline_formula_list:
formula['category_id'] = 13
page_mfd_res.append({
"bbox": [int(formula['poly'][0]), int(formula['poly'][1]),
int(formula['poly'][4]), int(formula['poly'][5])],
@@ -251,6 +254,7 @@ def _process_ocr_and_formulas(
results,
mfd_res,
_ocr_enable,
batch_radio=batch_radio,
)
# 如果需要ocr则做ocr_rec

View File

@@ -27,31 +27,37 @@ class YOLOv8MFDModel:
def _run_predict(
self,
inputs: Union[np.ndarray, Image.Image, List],
is_batch: bool = False
is_batch: bool = False,
conf: float = None,
) -> List:
preds = self.model.predict(
inputs,
imgsz=self.imgsz,
conf=self.conf,
conf=conf if conf is not None else self.conf,
iou=self.iou,
verbose=False,
device=self.device
)
return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu()
def predict(self, image: Union[np.ndarray, Image.Image]):
return self._run_predict(image)
def predict(
self,
image: Union[np.ndarray, Image.Image],
conf: float = None,
):
return self._run_predict(image, is_batch=False, conf=conf)
def batch_predict(
self,
images: List[Union[np.ndarray, Image.Image]],
batch_size: int = 4
batch_size: int = 4,
conf: float = None,
) -> List:
results = []
with tqdm(total=len(images), desc="MFD Predict") as pbar:
for idx in range(0, len(images), batch_size):
batch = images[idx: idx + batch_size]
batch_preds = self._run_predict(batch, is_batch=True)
batch_preds = self._run_predict(batch, is_batch=True, conf=conf)
results.extend(batch_preds)
pbar.update(len(batch))
return results

View File

@@ -2,6 +2,8 @@ import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from mineru.utils.boxbase import calculate_iou
class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
@@ -31,11 +33,64 @@ class UnimernetModel(object):
self.model = self.model.to(dtype=torch.float16)
self.model.eval()
@staticmethod
def _filter_boxes_by_iou(xyxy, conf, cla, iou_threshold=0.8):
"""过滤IOU超过阈值的重叠框保留置信度较高的框。
Args:
xyxy: 框坐标张量shape为(N, 4)
conf: 置信度张量shape为(N,)
cla: 类别张量shape为(N,)
iou_threshold: IOU阈值默认0.9
Returns:
过滤后的xyxy, conf, cla张量
"""
if len(xyxy) == 0:
return xyxy, conf, cla
# 转换为CPU进行处理
xyxy_cpu = xyxy.cpu()
conf_cpu = conf.cpu()
n = len(xyxy_cpu)
keep = [True] * n
for i in range(n):
if not keep[i]:
continue
bbox1 = xyxy_cpu[i].tolist()
for j in range(i + 1, n):
if not keep[j]:
continue
bbox2 = xyxy_cpu[j].tolist()
iou = calculate_iou(bbox1, bbox2)
if iou > iou_threshold:
# 保留置信度较高的框
if conf_cpu[i] >= conf_cpu[j]:
keep[j] = False
else:
keep[i] = False
break # i被删除跳出内循环
keep_indices = [i for i in range(n) if keep[i]]
if len(keep_indices) == n:
return xyxy, conf, cla
keep_indices = torch.tensor(keep_indices, dtype=torch.long)
return xyxy[keep_indices], conf[keep_indices], cla[keep_indices]
def predict(self, mfd_res, image):
formula_list = []
mf_image_list = []
# 对检测框进行IOU去重保留置信度较高的框
xyxy_filtered, conf_filtered, cla_filtered = self._filter_boxes_by_iou(
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
)
for xyxy, conf, cla in zip(
mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
xyxy_filtered.cpu(), conf_filtered.cpu(), cla_filtered.cpu()
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
@@ -79,8 +134,13 @@ class UnimernetModel(object):
image = images[image_index]
formula_list = []
# 对检测框进行IOU去重保留置信度较高的框
xyxy_filtered, conf_filtered, cla_filtered = self._filter_boxes_by_iou(
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
)
for idx, (xyxy, conf, cla) in enumerate(zip(
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
xyxy_filtered, conf_filtered, cla_filtered
)):
if not interline_enable and cla.item() == 1:
continue # Skip interline regions if not enabled