Merge pull request #4638 from myhloli/dev

Dev
This commit is contained in:
Xiaomeng Zhao
2026-03-23 10:53:50 +08:00
committed by GitHub
7 changed files with 242 additions and 131 deletions

View File

@@ -215,21 +215,36 @@ def _get_interline_equation_span(block):
return None
def _append_formula_number_tag(equation_block, formula_number_block):
equation_span = _get_interline_equation_span(equation_block)
tag_content = _normalize_formula_tag_content(_extract_text_from_block(formula_number_block))
if equation_span is not None:
formula = equation_span.get("content", "")
equation_span["content"] = f"{formula}\\tag{{{tag_content}}}"
def _optimize_formula_number_blocks(pdf_info_list):
for page_info in pdf_info_list:
optimized_blocks = []
for block in page_info.get("preproc_blocks", []):
blocks = page_info.get("preproc_blocks", [])
for index, block in enumerate(blocks):
if block.get("type") != BlockType.FORMULA_NUMBER:
optimized_blocks.append(block)
continue
prev_block = optimized_blocks[-1] if optimized_blocks else None
prev_block = blocks[index - 1] if index > 0 else None
if prev_block and prev_block.get("type") == BlockType.INTERLINE_EQUATION:
equation_span = _get_interline_equation_span(prev_block)
tag_content = _normalize_formula_tag_content(_extract_text_from_block(block))
if equation_span is not None:
formula = equation_span.get("content", "")
equation_span["content"] = f"{formula}\\tag{{{tag_content}}}"
_append_formula_number_tag(prev_block, block)
continue
next_block = blocks[index + 1] if index + 1 < len(blocks) else None
next_next_block = blocks[index + 2] if index + 2 < len(blocks) else None
if (
next_block
and next_block.get("type") == BlockType.INTERLINE_EQUATION
and (next_next_block is None or next_next_block.get("type") != BlockType.FORMULA_NUMBER)
):
_append_formula_number_tag(next_block, block)
continue
block["type"] = BlockType.TEXT

View File

@@ -287,6 +287,10 @@ def __merge_2_text_blocks(block1, block2):
and not span_start_with_num
# 下一个block的第一个字符是大写字母
and not span_start_with_big_char
# 下一个块的y0要比上一个块的y1小
and block1['bbox'][1] < block2['bbox'][3]
# 两个块任意一个块需要大于1行
and (len(block1['lines']) > 1 or len(block2['lines']) > 1)
):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:

View File

@@ -1,19 +1,21 @@
import math
import os
import torch
import yaml
from pathlib import Path
from loguru import logger
import torch
import yaml
from tqdm import tqdm
from mineru.model.utils.tools.infer import pytorchocr_utility
from mineru.model.utils.pytorchocr.base_ocr_v20 import BaseOCRV20
from mineru.model.utils.tools.infer import pytorchocr_utility
from ..utils import build_mfr_batch_groups
from .processors import (
UniMERNetImgDecode,
UniMERNetTestTransform,
LatexImageFormat,
ToBatch,
UniMERNetDecode,
UniMERNetImgDecode,
UniMERNetTestTransform,
)
@@ -33,7 +35,7 @@ class FormulaRecognizer(BaseOCRV20):
"pytorchocr",
"utils",
"resources",
"pp_formulanet_arch_config.yaml"
"pp_formulanet_arch_config.yaml",
)
self.infer_yaml_path = os.path.join(
weight_dir,
@@ -41,7 +43,8 @@ class FormulaRecognizer(BaseOCRV20):
)
network_config = pytorchocr_utility.AnalysisConfig(
self.weights_path, self.yaml_path
self.weights_path,
self.yaml_path,
)
weights = self.read_pytorch_weights(self.weights_path)
@@ -112,28 +115,19 @@ class FormulaRecognizer(BaseOCRV20):
return formula_list, crop_targets
def predict(self, img_list, batch_size: int = 64):
# Reduce batch size by 50% to avoid potential memory issues during inference.
batch_size = max(1, int(0.5 * batch_size))
batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=img_list)
batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
inp = self.pre_tfs["ToBatch"](imgs=batch_imgs)
inp = torch.from_numpy(inp[0])
inp = inp.to(self.device)
rec_formula = []
with torch.no_grad():
with tqdm(total=len(inp), desc="MFR Predict") as pbar:
for index in range(0, len(inp), batch_size):
batch_data = inp[index: index + batch_size]
# with torch.amp.autocast(device_type=self.device.type):
# batch_preds = [self.net(batch_data)]
batch_preds = [self.net(batch_data)]
batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
batch_preds = [bp.cpu().numpy() for bp in batch_preds]
rec_formula += self.post_op(batch_preds)
pbar.update(len(batch_preds))
return rec_formula
def predict(
self,
mfd_res,
image,
batch_size: int = 64,
interline_enable: bool = True,
) -> list:
return self.batch_predict(
[mfd_res],
[image],
batch_size=batch_size,
interline_enable=interline_enable,
)[0]
def batch_predict(
self,
@@ -142,17 +136,22 @@ class FormulaRecognizer(BaseOCRV20):
batch_size: int = 64,
interline_enable: bool = True,
) -> list:
if not images_mfd_res:
return []
if len(images_mfd_res) != len(images):
raise ValueError("images_mfd_res and images must have the same length.")
images_formula_list = []
mf_image_list = []
backfill_list = []
image_info = [] # Store (area, original_index, image) tuples
image_info = []
# Collect images with their original indices
for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index]
image = images[image_index]
for mfd_res, image in zip(images_mfd_res, images):
formula_list, crop_targets = self._build_formula_items(
mfd_res, image, interline_enable=interline_enable
mfd_res,
image,
interline_enable=interline_enable,
)
for formula_item, (xmin, ymin, xmax, ymax) in crop_targets:
@@ -166,24 +165,40 @@ class FormulaRecognizer(BaseOCRV20):
images_formula_list.append(formula_list)
# Stable sort by area
image_info.sort(key=lambda x: x[0]) # sort by area
if not image_info:
return images_formula_list
image_info.sort(key=lambda x: x[0])
sorted_areas = [x[0] for x in image_info]
sorted_indices = [x[1] for x in image_info]
sorted_images = [x[2] for x in image_info]
# Create mapping for results
index_mapping = {
new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)
}
if len(sorted_images) > 0:
# 进行预测
batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1
rec_formula = self.predict(sorted_images, batch_size)
else:
rec_formula = []
formula_requested_batch_size = max(1, batch_size // 2)
batch_groups = build_mfr_batch_groups(
sorted_areas,
formula_requested_batch_size,
)
batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=sorted_images)
batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
inp = self.pre_tfs["ToBatch"](imgs=batch_imgs)
inp = torch.from_numpy(inp[0]).to(self.device)
rec_formula = []
with torch.no_grad():
with tqdm(total=len(inp), desc="MFR Predict") as pbar:
for batch_group in batch_groups:
batch_data = inp[batch_group]
batch_preds = [self.net(batch_data)]
batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
batch_preds = [bp.cpu().numpy() for bp in batch_preds]
rec_formula += self.post_op(batch_preds)
pbar.update(len(batch_group))
# Restore original order
unsorted_results = [""] * len(rec_formula)
for new_idx, latex in enumerate(rec_formula):
original_idx = index_mapping[new_idx]

View File

@@ -4,6 +4,8 @@ import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from ..utils import build_mfr_batch_groups
class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
@@ -23,8 +25,12 @@ class MathDataset(Dataset):
class UnimernetModel(object):
def __init__(self, weight_dir, _device_="cpu"):
from .unimernet_hf import UnimernetModel
if _device_.startswith("mps") or _device_.startswith("npu") or _device_.startswith("musa"):
self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
self.model = UnimernetModel.from_pretrained(
weight_dir,
attn_implementation="eager",
)
else:
self.model = UnimernetModel.from_pretrained(weight_dir)
self.device = torch.device(_device_)
@@ -79,50 +85,43 @@ class UnimernetModel(object):
return formula_list, crop_targets
def predict(self, mfd_res, image):
formula_list, crop_targets = self._build_formula_items(
mfd_res, image, interline_enable=True
)
mf_image_list = []
for _, (xmin, ymin, xmax, ymax) in crop_targets:
bbox_img = image[ymin:ymax, xmin:xmax]
mf_image_list.append(bbox_img)
if not mf_image_list:
return formula_list
dataset = MathDataset(mf_image_list, transform=self.model.transform)
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["fixed_str"])
for (res, _), latex in zip(crop_targets, mfr_res):
res["latex"] = latex
return formula_list
def predict(
self,
mfd_res,
image,
batch_size: int = 64,
interline_enable: bool = True,
) -> list:
return self.batch_predict(
[mfd_res],
[image],
batch_size=batch_size,
interline_enable=interline_enable,
)[0]
def batch_predict(
self,
images_mfd_res: list,
images: list,
batch_size: int = 64,
interline_enable: bool = True,
self,
images_mfd_res: list,
images: list,
batch_size: int = 64,
interline_enable: bool = True,
) -> list:
if not images_mfd_res:
return []
if len(images_mfd_res) != len(images):
raise ValueError("images_mfd_res and images must have the same length.")
images_formula_list = []
mf_image_list = []
backfill_list = []
image_info = [] # Store (area, original_index, image) tuples
image_info = []
# Collect images with their original indices
for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index]
image = images[image_index]
for mfd_res, image in zip(images_mfd_res, images):
formula_list, crop_targets = self._build_formula_items(
mfd_res, image, interline_enable=interline_enable
mfd_res,
image,
interline_enable=interline_enable,
)
for formula_item, (xmin, ymin, xmax, ymax) in crop_targets:
@@ -136,45 +135,40 @@ class UnimernetModel(object):
images_formula_list.append(formula_list)
# Stable sort by area
image_info.sort(key=lambda x: x[0]) # sort by area
if not image_info:
return images_formula_list
image_info.sort(key=lambda x: x[0])
sorted_areas = [x[0] for x in image_info]
sorted_indices = [x[1] for x in image_info]
sorted_images = [x[2] for x in image_info]
index_mapping = {
new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)
}
# Create mapping for results
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
# Create dataset with sorted images
batch_groups = build_mfr_batch_groups(sorted_areas, batch_size)
dataset = MathDataset(sorted_images, transform=self.model.transform)
dataloader = DataLoader(dataset, batch_sampler=batch_groups, num_workers=0)
# 如果batch_size > len(sorted_images)则设置为不超过len(sorted_images)的2的幂
batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# Process batches and store results
mfr_res = []
# for mf_img in dataloader:
with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
for index, mf_img in enumerate(dataloader):
for batch_group, mf_img in zip(batch_groups, dataloader):
current_batch_size = len(batch_group)
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img}, batch_size=batch_size)
output = self.model.generate(
{"image": mf_img},
batch_size=current_batch_size,
)
mfr_res.extend(output["fixed_str"])
# 更新进度条每次增加batch_size但要注意最后一个batch可能不足batch_size
current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
pbar.update(current_batch_size)
# Restore original order
unsorted_results = [""] * len(mfr_res)
for new_idx, latex in enumerate(mfr_res):
original_idx = index_mapping[new_idx]
unsorted_results[original_idx] = latex
# Fill results back
for res, latex in zip(backfill_list, unsorted_results):
res["latex"] = latex

View File

@@ -291,6 +291,7 @@ REPLACEMENTS_PATTERNS = {
re.compile(r'\\vDash '): r'\\models ',
re.compile(r'\\sq \\sqcup '): r'\\square ',
re.compile(r'\\copyright'): r'©',
re.compile(r'\\Dot'): r'\\dot',
}
QQUAD_PATTERN = re.compile(r'\\qquad(?!\s)')
@@ -335,4 +336,77 @@ def latex_rm_whitespace(s: str):
while s.endswith('\\'):
s = s[:-1]
return s
return s
def largest_power_of_two_leq(value: int) -> int:
if value < 1:
return 0
return 2 ** (value.bit_length() - 1)
def get_mfr_effective_batch_size(num_items: int, requested_batch_size: int) -> int:
return min(
requested_batch_size,
largest_power_of_two_leq(max(1, num_items)),
)
def get_mfr_min_dynamic_batch_size(requested_batch_size: int) -> int:
return max(1, requested_batch_size // 8)
def build_mfr_batch_groups(sorted_areas: list[int], requested_batch_size: int) -> list[list[int]]:
if not sorted_areas:
return []
effective_batch_size = get_mfr_effective_batch_size(
len(sorted_areas),
requested_batch_size,
)
if effective_batch_size < 1:
return []
min_dynamic_batch_size = get_mfr_min_dynamic_batch_size(requested_batch_size)
total_count = len(sorted_areas)
if total_count < min_dynamic_batch_size:
return [list(range(total_count))]
base_mean_area = sum(sorted_areas[:effective_batch_size]) / effective_batch_size
batch_groups = []
cursor = 0
while cursor < total_count:
remaining_count = total_count - cursor
if remaining_count < min_dynamic_batch_size:
batch_groups.append(list(range(cursor, total_count)))
break
probe_size = min(effective_batch_size, remaining_count)
current_mean_area = sum(sorted_areas[cursor : cursor + probe_size]) / probe_size
ratio = 1 if base_mean_area <= 0 else current_mean_area / base_mean_area
candidate_batch_size = effective_batch_size
threshold = 4
while (
ratio >= threshold
and candidate_batch_size // 2 >= min_dynamic_batch_size
):
candidate_batch_size //= 2
threshold *= 2
candidate_batch_size = min(
candidate_batch_size,
largest_power_of_two_leq(remaining_count),
)
batch_groups.append(list(range(cursor, cursor + candidate_batch_size)))
cursor += candidate_batch_size
if (
len(batch_groups) >= 2
and len(batch_groups[-1]) < min_dynamic_batch_size
):
tail_group = batch_groups.pop()
batch_groups[-1].extend(tail_group)
return batch_groups

View File

@@ -233,6 +233,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
BlockType.INTERLINE_EQUATION,
BlockType.LIST,
BlockType.INDEX,
BlockType.SEAL,
]:
bbox = block["bbox"]
page_block_list.append(bbox)

View File

@@ -8,6 +8,8 @@ import numpy as np
import pypdfium2 as pdfium
from loguru import logger
from PIL import Image, ImageOps
from reportlab.lib.utils import ImageReader
from reportlab.pdfgen import canvas
from mineru.data.data_reader_writer import FileBasedDataWriter
from mineru.utils.check_sys_env import is_windows_environment
@@ -331,29 +333,35 @@ def get_crop_np_img(bbox: tuple, input_img, scale=2):
def images_bytes_to_pdf_bytes(image_bytes):
# 内存缓冲区
pdf_buffer = BytesIO()
# 载入并转换所有图像为 RGB 模式
image = Image.open(BytesIO(image_bytes))
# 根据 EXIF 信息自动转正(处理手机拍摄的带 Orientation 标记的图片)
image = ImageOps.exif_transpose(image) or image
# 只在必要时转换
if image.mode != "RGB":
image = image.convert("RGB")
with Image.open(BytesIO(image_bytes)) as source_image:
image = ImageOps.exif_transpose(source_image)
if image.mode != "RGB":
image = image.convert("RGB")
else:
image = image.copy()
# 第一张图保存为 PDF其余追加
# Keep image inputs at the same raster size when CLI later renders the
# wrapper PDF at the default DPI; PIL defaults to 72 dpi for PDF output, which
# would upscale the image and noticeably hurt seal OCR detection quality.
image.save(
pdf_buffer,
format="PDF",
resolution=float(DEFAULT_PDF_IMAGE_DPI),
# save_all=True
# Preserve the original raster content when the CLI later renders the
# wrapper PDF back at DEFAULT_PDF_IMAGE_DPI. PIL's PDF writer introduces
# enough loss for layout detection to miss text blocks on some image inputs.
page_width = image.width * 72.0 / DEFAULT_PDF_IMAGE_DPI
page_height = image.height * 72.0 / DEFAULT_PDF_IMAGE_DPI
pdf_canvas = canvas.Canvas(pdf_buffer, pagesize=(page_width, page_height))
pdf_canvas.drawImage(
ImageReader(image),
0,
0,
width=page_width,
height=page_height,
preserveAspectRatio=False,
mask="auto",
)
pdf_canvas.showPage()
pdf_canvas.save()
# 获取 PDF bytes 并重置指针(可选)
pdf_bytes = pdf_buffer.getvalue()
pdf_buffer.close()
image.close()
return pdf_bytes