From 7a365d92c96aea4732fffd4a3f46a2a739d81a7a Mon Sep 17 00:00:00 2001 From: myhloli Date: Sun, 22 Mar 2026 03:06:26 +0800 Subject: [PATCH 1/6] feat: enhance PDF generation by preserving original image raster content and optimizing image handling --- mineru/utils/pdf_image_tools.py | 44 +++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/mineru/utils/pdf_image_tools.py b/mineru/utils/pdf_image_tools.py index 545e3a9c..468bb6a6 100644 --- a/mineru/utils/pdf_image_tools.py +++ b/mineru/utils/pdf_image_tools.py @@ -8,6 +8,8 @@ import numpy as np import pypdfium2 as pdfium from loguru import logger from PIL import Image, ImageOps +from reportlab.lib.utils import ImageReader +from reportlab.pdfgen import canvas from mineru.data.data_reader_writer import FileBasedDataWriter from mineru.utils.check_sys_env import is_windows_environment @@ -331,29 +333,35 @@ def get_crop_np_img(bbox: tuple, input_img, scale=2): def images_bytes_to_pdf_bytes(image_bytes): - # 内存缓冲区 pdf_buffer = BytesIO() - # 载入并转换所有图像为 RGB 模式 - image = Image.open(BytesIO(image_bytes)) - # 根据 EXIF 信息自动转正(处理手机拍摄的带 Orientation 标记的图片) - image = ImageOps.exif_transpose(image) or image - # 只在必要时转换 - if image.mode != "RGB": - image = image.convert("RGB") + with Image.open(BytesIO(image_bytes)) as source_image: + image = ImageOps.exif_transpose(source_image) + if image.mode != "RGB": + image = image.convert("RGB") + else: + image = image.copy() - # 第一张图保存为 PDF,其余追加 - # Keep image inputs at the same raster size when CLI later renders the - # wrapper PDF at the default DPI; PIL defaults to 72 dpi for PDF output, which - # would upscale the image and noticeably hurt seal OCR detection quality. - image.save( - pdf_buffer, - format="PDF", - resolution=float(DEFAULT_PDF_IMAGE_DPI), - # save_all=True + # Preserve the original raster content when the CLI later renders the + # wrapper PDF back at DEFAULT_PDF_IMAGE_DPI. PIL's PDF writer introduces + # enough loss for layout detection to miss text blocks on some image inputs. + page_width = image.width * 72.0 / DEFAULT_PDF_IMAGE_DPI + page_height = image.height * 72.0 / DEFAULT_PDF_IMAGE_DPI + + pdf_canvas = canvas.Canvas(pdf_buffer, pagesize=(page_width, page_height)) + pdf_canvas.drawImage( + ImageReader(image), + 0, + 0, + width=page_width, + height=page_height, + preserveAspectRatio=False, + mask="auto", ) + pdf_canvas.showPage() + pdf_canvas.save() - # 获取 PDF bytes 并重置指针(可选) pdf_bytes = pdf_buffer.getvalue() pdf_buffer.close() + image.close() return pdf_bytes From fb7246540cf2d71b79602ca8a61323b6585415d9 Mon Sep 17 00:00:00 2001 From: myhloli Date: Sun, 22 Mar 2026 23:09:28 +0800 Subject: [PATCH 2/6] feat: improve paragraph splitting logic by adding conditions for block positioning and line count --- mineru/backend/pipeline/para_split.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mineru/backend/pipeline/para_split.py b/mineru/backend/pipeline/para_split.py index cec30ef7..afc67912 100644 --- a/mineru/backend/pipeline/para_split.py +++ b/mineru/backend/pipeline/para_split.py @@ -287,6 +287,10 @@ def __merge_2_text_blocks(block1, block2): and not span_start_with_num # 下一个block的第一个字符是大写字母 and not span_start_with_big_char + # 下一个块的y0要比上一个块的y1小 + and block1['bbox'][1] < block2['bbox'][3] + # 两个块任意一个块需要大于1行 + and (len(block1['lines']) > 1 or len(block2['lines']) > 1) ): if block1['page_num'] != block2['page_num']: for line in block1['lines']: From 01d8e18a13a23d777fc78cfb5b7dcd33d900f31a Mon Sep 17 00:00:00 2001 From: myhloli Date: Sun, 22 Mar 2026 23:20:32 +0800 Subject: [PATCH 3/6] feat: add support for SEAL block type in bounding box processing --- mineru/utils/draw_bbox.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mineru/utils/draw_bbox.py b/mineru/utils/draw_bbox.py index 9a93ebbf..77ccf3bd 100644 --- a/mineru/utils/draw_bbox.py +++ b/mineru/utils/draw_bbox.py @@ -233,6 +233,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): BlockType.INTERLINE_EQUATION, BlockType.LIST, BlockType.INDEX, + BlockType.SEAL, ]: bbox = block["bbox"] page_block_list.append(bbox) From 7423c135d139ce4cce0907adbe159e21b18d6b48 Mon Sep 17 00:00:00 2001 From: myhloli Date: Sun, 22 Mar 2026 23:47:54 +0800 Subject: [PATCH 4/6] feat: enhance formula number processing by appending tags to interline equations --- .../pipeline/model_json_to_middle_json.py | 29 ++++++++++++++----- mineru/model/mfr/utils.py | 1 + 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mineru/backend/pipeline/model_json_to_middle_json.py b/mineru/backend/pipeline/model_json_to_middle_json.py index 9460425b..53f8136a 100644 --- a/mineru/backend/pipeline/model_json_to_middle_json.py +++ b/mineru/backend/pipeline/model_json_to_middle_json.py @@ -215,21 +215,36 @@ def _get_interline_equation_span(block): return None +def _append_formula_number_tag(equation_block, formula_number_block): + equation_span = _get_interline_equation_span(equation_block) + tag_content = _normalize_formula_tag_content(_extract_text_from_block(formula_number_block)) + if equation_span is not None: + formula = equation_span.get("content", "") + equation_span["content"] = f"{formula}\\tag{{{tag_content}}}" + + def _optimize_formula_number_blocks(pdf_info_list): for page_info in pdf_info_list: optimized_blocks = [] - for block in page_info.get("preproc_blocks", []): + blocks = page_info.get("preproc_blocks", []) + for index, block in enumerate(blocks): if block.get("type") != BlockType.FORMULA_NUMBER: optimized_blocks.append(block) continue - prev_block = optimized_blocks[-1] if optimized_blocks else None + prev_block = blocks[index - 1] if index > 0 else None if prev_block and prev_block.get("type") == BlockType.INTERLINE_EQUATION: - equation_span = _get_interline_equation_span(prev_block) - tag_content = _normalize_formula_tag_content(_extract_text_from_block(block)) - if equation_span is not None: - formula = equation_span.get("content", "") - equation_span["content"] = f"{formula}\\tag{{{tag_content}}}" + _append_formula_number_tag(prev_block, block) + continue + + next_block = blocks[index + 1] if index + 1 < len(blocks) else None + next_next_block = blocks[index + 2] if index + 2 < len(blocks) else None + if ( + next_block + and next_block.get("type") == BlockType.INTERLINE_EQUATION + and (next_next_block is None or next_next_block.get("type") != BlockType.FORMULA_NUMBER) + ): + _append_formula_number_tag(next_block, block) continue block["type"] = BlockType.TEXT diff --git a/mineru/model/mfr/utils.py b/mineru/model/mfr/utils.py index 4de1053d..8936d26c 100644 --- a/mineru/model/mfr/utils.py +++ b/mineru/model/mfr/utils.py @@ -291,6 +291,7 @@ REPLACEMENTS_PATTERNS = { re.compile(r'\\vDash '): r'\\models ', re.compile(r'\\sq \\sqcup '): r'\\square ', re.compile(r'\\copyright'): r'©', + re.compile(r'\\Dot'): r'\\dot', } QQUAD_PATTERN = re.compile(r'\\qquad(?!\s)') From cbbabcb347ddb35fce4abb2e40b290c78769533c Mon Sep 17 00:00:00 2001 From: myhloli Date: Mon, 23 Mar 2026 00:44:48 +0800 Subject: [PATCH 5/6] feat: refactor prediction methods to streamline batch processing and enhance error handling --- .../pp_formulanet_plus_m/predict_formula.py | 94 ++++++++++--------- mineru/model/mfr/unimernet/Unimernet.py | 86 +++++++---------- 2 files changed, 84 insertions(+), 96 deletions(-) diff --git a/mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py b/mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py index 6582d4ba..62613d87 100644 --- a/mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py +++ b/mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py @@ -1,19 +1,21 @@ import math import os -import torch -import yaml from pathlib import Path +import torch +import yaml from loguru import logger from tqdm import tqdm -from mineru.model.utils.tools.infer import pytorchocr_utility + from mineru.model.utils.pytorchocr.base_ocr_v20 import BaseOCRV20 +from mineru.model.utils.tools.infer import pytorchocr_utility + from .processors import ( - UniMERNetImgDecode, - UniMERNetTestTransform, LatexImageFormat, ToBatch, UniMERNetDecode, + UniMERNetImgDecode, + UniMERNetTestTransform, ) @@ -33,7 +35,7 @@ class FormulaRecognizer(BaseOCRV20): "pytorchocr", "utils", "resources", - "pp_formulanet_arch_config.yaml" + "pp_formulanet_arch_config.yaml", ) self.infer_yaml_path = os.path.join( weight_dir, @@ -112,28 +114,19 @@ class FormulaRecognizer(BaseOCRV20): return formula_list, crop_targets - def predict(self, img_list, batch_size: int = 64): - # Reduce batch size by 50% to avoid potential memory issues during inference. - batch_size = max(1, int(0.5 * batch_size)) - batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=img_list) - batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs) - batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs) - inp = self.pre_tfs["ToBatch"](imgs=batch_imgs) - inp = torch.from_numpy(inp[0]) - inp = inp.to(self.device) - rec_formula = [] - with torch.no_grad(): - with tqdm(total=len(inp), desc="MFR Predict") as pbar: - for index in range(0, len(inp), batch_size): - batch_data = inp[index: index + batch_size] - # with torch.amp.autocast(device_type=self.device.type): - # batch_preds = [self.net(batch_data)] - batch_preds = [self.net(batch_data)] - batch_preds = [p.reshape([-1]) for p in batch_preds[0]] - batch_preds = [bp.cpu().numpy() for bp in batch_preds] - rec_formula += self.post_op(batch_preds) - pbar.update(len(batch_preds)) - return rec_formula + def predict( + self, + mfd_res, + image, + batch_size: int = 64, + interline_enable: bool = True, + ) -> list: + return self.batch_predict( + [mfd_res], + [image], + batch_size=batch_size, + interline_enable=interline_enable, + )[0] def batch_predict( self, @@ -142,15 +135,18 @@ class FormulaRecognizer(BaseOCRV20): batch_size: int = 64, interline_enable: bool = True, ) -> list: + if not images_mfd_res: + return [] + + if len(images_mfd_res) != len(images): + raise ValueError("images_mfd_res and images must have the same length.") + images_formula_list = [] mf_image_list = [] backfill_list = [] image_info = [] # Store (area, original_index, image) tuples - # Collect images with their original indices - for image_index in range(len(images_mfd_res)): - mfd_res = images_mfd_res[image_index] - image = images[image_index] + for mfd_res, image in zip(images_mfd_res, images): formula_list, crop_targets = self._build_formula_items( mfd_res, image, interline_enable=interline_enable ) @@ -166,24 +162,36 @@ class FormulaRecognizer(BaseOCRV20): images_formula_list.append(formula_list) - # Stable sort by area - image_info.sort(key=lambda x: x[0]) # sort by area + if not image_info: + return images_formula_list + + image_info.sort(key=lambda x: x[0]) sorted_indices = [x[1] for x in image_info] sorted_images = [x[2] for x in image_info] - - # Create mapping for results index_mapping = { new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices) } - if len(sorted_images) > 0: - # 进行预测 - batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1 - rec_formula = self.predict(sorted_images, batch_size) - else: - rec_formula = [] + batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) + batch_size = max(1, int(0.5 * batch_size)) + + batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=sorted_images) + batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs) + batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs) + inp = self.pre_tfs["ToBatch"](imgs=batch_imgs) + inp = torch.from_numpy(inp[0]).to(self.device) + + rec_formula = [] + with torch.no_grad(): + with tqdm(total=len(inp), desc="MFR Predict") as pbar: + for index in range(0, len(inp), batch_size): + batch_data = inp[index : index + batch_size] + batch_preds = [self.net(batch_data)] + batch_preds = [p.reshape([-1]) for p in batch_preds[0]] + batch_preds = [bp.cpu().numpy() for bp in batch_preds] + rec_formula += self.post_op(batch_preds) + pbar.update(len(batch_data)) - # Restore original order unsorted_results = [""] * len(rec_formula) for new_idx, latex in enumerate(rec_formula): original_idx = index_mapping[new_idx] diff --git a/mineru/model/mfr/unimernet/Unimernet.py b/mineru/model/mfr/unimernet/Unimernet.py index 386b3c54..b7fc9974 100644 --- a/mineru/model/mfr/unimernet/Unimernet.py +++ b/mineru/model/mfr/unimernet/Unimernet.py @@ -23,6 +23,7 @@ class MathDataset(Dataset): class UnimernetModel(object): def __init__(self, weight_dir, _device_="cpu"): from .unimernet_hf import UnimernetModel + if _device_.startswith("mps") or _device_.startswith("npu") or _device_.startswith("musa"): self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager") else: @@ -79,48 +80,39 @@ class UnimernetModel(object): return formula_list, crop_targets - def predict(self, mfd_res, image): - formula_list, crop_targets = self._build_formula_items( - mfd_res, image, interline_enable=True - ) - mf_image_list = [] - - for _, (xmin, ymin, xmax, ymax) in crop_targets: - bbox_img = image[ymin:ymax, xmin:xmax] - mf_image_list.append(bbox_img) - - if not mf_image_list: - return formula_list - - dataset = MathDataset(mf_image_list, transform=self.model.transform) - dataloader = DataLoader(dataset, batch_size=32, num_workers=0) - mfr_res = [] - for mf_img in dataloader: - mf_img = mf_img.to(dtype=self.model.dtype) - mf_img = mf_img.to(self.device) - with torch.no_grad(): - output = self.model.generate({"image": mf_img}) - mfr_res.extend(output["fixed_str"]) - for (res, _), latex in zip(crop_targets, mfr_res): - res["latex"] = latex - return formula_list + def predict( + self, + mfd_res, + image, + batch_size: int = 64, + interline_enable: bool = True, + ) -> list: + return self.batch_predict( + [mfd_res], + [image], + batch_size=batch_size, + interline_enable=interline_enable, + )[0] def batch_predict( - self, - images_mfd_res: list, - images: list, - batch_size: int = 64, - interline_enable: bool = True, + self, + images_mfd_res: list, + images: list, + batch_size: int = 64, + interline_enable: bool = True, ) -> list: + if not images_mfd_res: + return [] + + if len(images_mfd_res) != len(images): + raise ValueError("images_mfd_res and images must have the same length.") + images_formula_list = [] mf_image_list = [] backfill_list = [] image_info = [] # Store (area, original_index, image) tuples - # Collect images with their original indices - for image_index in range(len(images_mfd_res)): - mfd_res = images_mfd_res[image_index] - image = images[image_index] + for mfd_res, image in zip(images_mfd_res, images): formula_list, crop_targets = self._build_formula_items( mfd_res, image, interline_enable=interline_enable ) @@ -136,45 +128,33 @@ class UnimernetModel(object): images_formula_list.append(formula_list) - # Stable sort by area - image_info.sort(key=lambda x: x[0]) # sort by area + if not image_info: + return images_formula_list + + image_info.sort(key=lambda x: x[0]) sorted_indices = [x[1] for x in image_info] sorted_images = [x[2] for x in image_info] - - # Create mapping for results index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)} - # Create dataset with sorted images dataset = MathDataset(sorted_images, transform=self.model.transform) - - # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂 - batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1 - + batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) - # Process batches and store results mfr_res = [] - # for mf_img in dataloader: - with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar: - for index, mf_img in enumerate(dataloader): + for mf_img in dataloader: mf_img = mf_img.to(dtype=self.model.dtype) mf_img = mf_img.to(self.device) with torch.no_grad(): output = self.model.generate({"image": mf_img}, batch_size=batch_size) mfr_res.extend(output["fixed_str"]) + pbar.update(len(mf_img)) - # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size - current_batch_size = min(batch_size, len(sorted_images) - index * batch_size) - pbar.update(current_batch_size) - - # Restore original order unsorted_results = [""] * len(mfr_res) for new_idx, latex in enumerate(mfr_res): original_idx = index_mapping[new_idx] unsorted_results[original_idx] = latex - # Fill results back for res, latex in zip(backfill_list, unsorted_results): res["latex"] = latex From 6eb91d3632b77ef31602d1fb153fdaad47ae211b Mon Sep 17 00:00:00 2001 From: myhloli Date: Mon, 23 Mar 2026 01:46:00 +0800 Subject: [PATCH 6/6] feat: optimize batch processing by implementing dynamic batch grouping and enhancing formula item handling --- .../pp_formulanet_plus_m/predict_formula.py | 25 ++++--- mineru/model/mfr/unimernet/Unimernet.py | 32 +++++--- mineru/model/mfr/utils.py | 75 ++++++++++++++++++- 3 files changed, 113 insertions(+), 19 deletions(-) diff --git a/mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py b/mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py index 62613d87..57b516dc 100644 --- a/mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py +++ b/mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py @@ -4,12 +4,12 @@ from pathlib import Path import torch import yaml -from loguru import logger from tqdm import tqdm from mineru.model.utils.pytorchocr.base_ocr_v20 import BaseOCRV20 from mineru.model.utils.tools.infer import pytorchocr_utility +from ..utils import build_mfr_batch_groups from .processors import ( LatexImageFormat, ToBatch, @@ -43,7 +43,8 @@ class FormulaRecognizer(BaseOCRV20): ) network_config = pytorchocr_utility.AnalysisConfig( - self.weights_path, self.yaml_path + self.weights_path, + self.yaml_path, ) weights = self.read_pytorch_weights(self.weights_path) @@ -144,11 +145,13 @@ class FormulaRecognizer(BaseOCRV20): images_formula_list = [] mf_image_list = [] backfill_list = [] - image_info = [] # Store (area, original_index, image) tuples + image_info = [] for mfd_res, image in zip(images_mfd_res, images): formula_list, crop_targets = self._build_formula_items( - mfd_res, image, interline_enable=interline_enable + mfd_res, + image, + interline_enable=interline_enable, ) for formula_item, (xmin, ymin, xmax, ymax) in crop_targets: @@ -166,14 +169,18 @@ class FormulaRecognizer(BaseOCRV20): return images_formula_list image_info.sort(key=lambda x: x[0]) + sorted_areas = [x[0] for x in image_info] sorted_indices = [x[1] for x in image_info] sorted_images = [x[2] for x in image_info] index_mapping = { new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices) } - batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) - batch_size = max(1, int(0.5 * batch_size)) + formula_requested_batch_size = max(1, batch_size // 2) + batch_groups = build_mfr_batch_groups( + sorted_areas, + formula_requested_batch_size, + ) batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=sorted_images) batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs) @@ -184,13 +191,13 @@ class FormulaRecognizer(BaseOCRV20): rec_formula = [] with torch.no_grad(): with tqdm(total=len(inp), desc="MFR Predict") as pbar: - for index in range(0, len(inp), batch_size): - batch_data = inp[index : index + batch_size] + for batch_group in batch_groups: + batch_data = inp[batch_group] batch_preds = [self.net(batch_data)] batch_preds = [p.reshape([-1]) for p in batch_preds[0]] batch_preds = [bp.cpu().numpy() for bp in batch_preds] rec_formula += self.post_op(batch_preds) - pbar.update(len(batch_data)) + pbar.update(len(batch_group)) unsorted_results = [""] * len(rec_formula) for new_idx, latex in enumerate(rec_formula): diff --git a/mineru/model/mfr/unimernet/Unimernet.py b/mineru/model/mfr/unimernet/Unimernet.py index b7fc9974..11151500 100644 --- a/mineru/model/mfr/unimernet/Unimernet.py +++ b/mineru/model/mfr/unimernet/Unimernet.py @@ -4,6 +4,8 @@ import torch from torch.utils.data import DataLoader, Dataset from tqdm import tqdm +from ..utils import build_mfr_batch_groups + class MathDataset(Dataset): def __init__(self, image_paths, transform=None): @@ -25,7 +27,10 @@ class UnimernetModel(object): from .unimernet_hf import UnimernetModel if _device_.startswith("mps") or _device_.startswith("npu") or _device_.startswith("musa"): - self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager") + self.model = UnimernetModel.from_pretrained( + weight_dir, + attn_implementation="eager", + ) else: self.model = UnimernetModel.from_pretrained(weight_dir) self.device = torch.device(_device_) @@ -110,11 +115,13 @@ class UnimernetModel(object): images_formula_list = [] mf_image_list = [] backfill_list = [] - image_info = [] # Store (area, original_index, image) tuples + image_info = [] for mfd_res, image in zip(images_mfd_res, images): formula_list, crop_targets = self._build_formula_items( - mfd_res, image, interline_enable=interline_enable + mfd_res, + image, + interline_enable=interline_enable, ) for formula_item, (xmin, ymin, xmax, ymax) in crop_targets: @@ -132,23 +139,30 @@ class UnimernetModel(object): return images_formula_list image_info.sort(key=lambda x: x[0]) + sorted_areas = [x[0] for x in image_info] sorted_indices = [x[1] for x in image_info] sorted_images = [x[2] for x in image_info] - index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)} + index_mapping = { + new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices) + } + batch_groups = build_mfr_batch_groups(sorted_areas, batch_size) dataset = MathDataset(sorted_images, transform=self.model.transform) - batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) + dataloader = DataLoader(dataset, batch_sampler=batch_groups, num_workers=0) mfr_res = [] with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar: - for mf_img in dataloader: + for batch_group, mf_img in zip(batch_groups, dataloader): + current_batch_size = len(batch_group) mf_img = mf_img.to(dtype=self.model.dtype) mf_img = mf_img.to(self.device) with torch.no_grad(): - output = self.model.generate({"image": mf_img}, batch_size=batch_size) + output = self.model.generate( + {"image": mf_img}, + batch_size=current_batch_size, + ) mfr_res.extend(output["fixed_str"]) - pbar.update(len(mf_img)) + pbar.update(current_batch_size) unsorted_results = [""] * len(mfr_res) for new_idx, latex in enumerate(mfr_res): diff --git a/mineru/model/mfr/utils.py b/mineru/model/mfr/utils.py index 8936d26c..aee712cd 100644 --- a/mineru/model/mfr/utils.py +++ b/mineru/model/mfr/utils.py @@ -336,4 +336,77 @@ def latex_rm_whitespace(s: str): while s.endswith('\\'): s = s[:-1] - return s \ No newline at end of file + return s + + +def largest_power_of_two_leq(value: int) -> int: + if value < 1: + return 0 + return 2 ** (value.bit_length() - 1) + + +def get_mfr_effective_batch_size(num_items: int, requested_batch_size: int) -> int: + return min( + requested_batch_size, + largest_power_of_two_leq(max(1, num_items)), + ) + + +def get_mfr_min_dynamic_batch_size(requested_batch_size: int) -> int: + return max(1, requested_batch_size // 8) + + +def build_mfr_batch_groups(sorted_areas: list[int], requested_batch_size: int) -> list[list[int]]: + if not sorted_areas: + return [] + + effective_batch_size = get_mfr_effective_batch_size( + len(sorted_areas), + requested_batch_size, + ) + if effective_batch_size < 1: + return [] + + min_dynamic_batch_size = get_mfr_min_dynamic_batch_size(requested_batch_size) + total_count = len(sorted_areas) + if total_count < min_dynamic_batch_size: + return [list(range(total_count))] + + base_mean_area = sum(sorted_areas[:effective_batch_size]) / effective_batch_size + batch_groups = [] + cursor = 0 + + while cursor < total_count: + remaining_count = total_count - cursor + if remaining_count < min_dynamic_batch_size: + batch_groups.append(list(range(cursor, total_count))) + break + + probe_size = min(effective_batch_size, remaining_count) + current_mean_area = sum(sorted_areas[cursor : cursor + probe_size]) / probe_size + ratio = 1 if base_mean_area <= 0 else current_mean_area / base_mean_area + + candidate_batch_size = effective_batch_size + threshold = 4 + while ( + ratio >= threshold + and candidate_batch_size // 2 >= min_dynamic_batch_size + ): + candidate_batch_size //= 2 + threshold *= 2 + + candidate_batch_size = min( + candidate_batch_size, + largest_power_of_two_leq(remaining_count), + ) + batch_groups.append(list(range(cursor, cursor + candidate_batch_size))) + cursor += candidate_batch_size + + if ( + len(batch_groups) >= 2 + and len(batch_groups[-1]) < min_dynamic_batch_size + ): + tail_group = batch_groups.pop() + batch_groups[-1].extend(tail_group) + + return batch_groups