feat: enhance table processing with inline object extraction and base64 image handling

This commit is contained in:
myhloli
2026-03-21 18:44:56 +08:00
parent 4d57a0fe58
commit 7685afc4de
3 changed files with 361 additions and 4 deletions

View File

@@ -1,3 +1,4 @@
import base64
import html
import cv2
@@ -107,6 +108,197 @@ class BatchAnalyze:
layout_res[:] = [item for item in layout_res if keep_item(item)]
@staticmethod
def _bbox_center(bbox: list[float]) -> tuple[float, float]:
return (float(bbox[0] + bbox[2]) / 2.0, float(bbox[1] + bbox[3]) / 2.0)
@staticmethod
def _is_point_in_bbox(point: tuple[float, float], bbox: list[float]) -> bool:
x, y = point
return bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3]
@staticmethod
def _bbox_intersection(bbox1: list[float], bbox2: list[float]) -> list[float] | None:
x0 = max(float(bbox1[0]), float(bbox2[0]))
y0 = max(float(bbox1[1]), float(bbox2[1]))
x1 = min(float(bbox1[2]), float(bbox2[2]))
y1 = min(float(bbox1[3]), float(bbox2[3]))
if x1 <= x0 or y1 <= y0:
return None
return [x0, y0, x1, y1]
@classmethod
def _bbox_intersection_area(cls, bbox1: list[float], bbox2: list[float]) -> float:
overlap_bbox = cls._bbox_intersection(bbox1, bbox2)
if overlap_bbox is None:
return 0.0
return float(overlap_bbox[2] - overlap_bbox[0]) * float(overlap_bbox[3] - overlap_bbox[1])
@staticmethod
def _bbox_to_relative_bbox(bbox: list[float], base_bbox: list[float]) -> list[float]:
return [
float(bbox[0]) - float(base_bbox[0]),
float(bbox[1]) - float(base_bbox[1]),
float(bbox[2]) - float(base_bbox[0]),
float(bbox[3]) - float(base_bbox[1]),
]
@staticmethod
def _bbox_to_quad(bbox: list[float]) -> np.ndarray:
x0, y0, x1, y1 = bbox
return np.asarray([[x0, y0], [x1, y0], [x1, y1], [x0, y1]], dtype=np.float32)
@staticmethod
def _encode_table_inline_image(np_img: np.ndarray, bbox: list[float]) -> str:
image_h, image_w = np_img.shape[:2]
image_bbox = normalize_to_int_bbox(bbox, image_size=(image_h, image_w))
if image_bbox is None:
return ""
x0, y0, x1, y1 = image_bbox
if x1 <= x0 or y1 <= y0:
return ""
crop_rgb = np_img[y0:y1, x0:x1]
if crop_rgb.size == 0:
return ""
crop_bgr = cv2.cvtColor(crop_rgb, cv2.COLOR_RGB2BGR)
success, encoded = cv2.imencode(".jpg", crop_bgr)
if not success:
return ""
b64_str = base64.b64encode(encoded.tobytes()).decode("ascii")
return f"data:image/jpg;base64,{b64_str}"
@staticmethod
def _get_virtual_image_bbox(bbox: list[float], box_size: float = 10.0) -> list[float]:
center_x, center_y = BatchAnalyze._bbox_center(bbox)
half_size = box_size / 2.0
return [
center_x - half_size,
center_y - half_size,
center_x + half_size,
center_y + half_size,
]
@staticmethod
def _table_supports_inline_objects(table_res_dict: dict) -> bool:
return str(table_res_dict.get("rotate_label", "0")) == "0"
@staticmethod
def _sort_table_ocr_result(ocr_result: list[list]) -> None:
if not ocr_result:
return
sorted_result = sorted(
ocr_result,
key=lambda item: (float(np.asarray(item[0])[0][1]), float(np.asarray(item[0])[0][0])),
)
for i in range(len(sorted_result) - 1):
for j in range(i, -1, -1):
cur_box = np.asarray(sorted_result[j][0], dtype=np.float32)
next_box = np.asarray(sorted_result[j + 1][0], dtype=np.float32)
if (
abs(float(next_box[0][1]) - float(cur_box[0][1])) < 10
and float(next_box[0][0]) < float(cur_box[0][0])
):
sorted_result[j], sorted_result[j + 1] = sorted_result[j + 1], sorted_result[j]
else:
break
ocr_result[:] = sorted_result
@classmethod
def _extract_table_inline_objects(
cls,
layout_res: list[dict],
np_img: np.ndarray,
formula_enable: bool,
) -> dict[int, list[dict]]:
image_h, image_w = np_img.shape[:2]
image_size = (image_h, image_w)
tables = []
for res in layout_res:
if res.get("label") != "table":
continue
table_bbox = normalize_to_int_bbox(res.get("bbox"), image_size=image_size)
if table_bbox is None:
continue
tables.append((res, table_bbox))
if not tables:
return {}
table_inline_objects = {id(table_res): [] for table_res, _ in tables}
remove_ids = set()
candidate_labels = {"image"}
if formula_enable:
candidate_labels.update({"inline_formula", "display_formula"})
for layout_item in layout_res:
label = layout_item.get("label")
if label not in candidate_labels:
continue
item_bbox = normalize_to_int_bbox(layout_item.get("bbox"), image_size=image_size)
if item_bbox is None:
continue
item_center = cls._bbox_center(item_bbox)
matched_tables = []
for table_res, table_bbox in tables:
if not cls._is_point_in_bbox(item_center, table_bbox):
continue
overlap_area = cls._bbox_intersection_area(item_bbox, table_bbox)
matched_tables.append((overlap_area, table_res, table_bbox))
if not matched_tables:
continue
matched_tables.sort(key=lambda item: item[0], reverse=True)
_, table_res, table_bbox = matched_tables[0]
overlap_bbox = cls._bbox_intersection(item_bbox, table_bbox)
if overlap_bbox is None:
continue
rel_overlap_bbox = cls._bbox_to_relative_bbox(overlap_bbox, table_bbox)
score = float(layout_item.get("score", 1.0))
if label == "image":
image_src = cls._encode_table_inline_image(np_img, item_bbox)
if not image_src:
continue
content = f'<img src="{image_src}"/>'
token_bbox = cls._get_virtual_image_bbox(rel_overlap_bbox)
kind = "image"
else:
latex = layout_item.get("latex", "")
if not latex:
continue
content = f"<eq>{html.escape(latex)}</eq>"
token_bbox = rel_overlap_bbox
kind = "formula"
table_inline_objects[id(table_res)].append(
{
"kind": kind,
"page_bbox": item_bbox,
"table_rel_mask_bbox": rel_overlap_bbox,
"table_token_bbox": token_bbox,
"content": content,
"score": score,
}
)
remove_ids.add(id(layout_item))
if remove_ids:
layout_res[:] = [item for item in layout_res if id(item) not in remove_ids]
return table_inline_objects
def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0:
@@ -167,6 +359,15 @@ class BatchAnalyze:
_, ocr_enable, _lang = images_with_extra_info[index]
layout_res = images_layout_res[index]
np_img = np_images[index]
table_inline_objects = (
self._extract_table_inline_objects(
layout_res,
np_img,
formula_enable=self.formula_enable,
)
if self.table_enable
else {}
)
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
@@ -191,11 +392,17 @@ class BatchAnalyze:
wireless_table_img = get_crop_table_img(scale = 1)
wired_table_img = get_crop_table_img(scale = 10/3)
table_page_bbox = normalize_to_int_bbox(
table_res.get("bbox"),
image_size=np_img.shape[:2],
) or [0, 0, 0, 0]
table_res_list_all_page.append({'table_res':table_res,
'lang':_lang,
'table_img':wireless_table_img,
'wired_table_img':wired_table_img,
'table_page_bbox':table_page_bbox,
'table_inline_objects':table_inline_objects.get(id(table_res), []),
})
# 表格识别 table recognition
@@ -243,7 +450,30 @@ class BatchAnalyze:
tqdm(table_res_list_all_page, desc="Table-ocr det")
):
bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR)
ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0]
table_inline_objects = (
table_res_dict.get("table_inline_objects", [])
if self._table_supports_inline_objects(table_res_dict)
else []
)
inline_mask_boxes = [
{"bbox": inline_object["table_rel_mask_bbox"]}
for inline_object in table_inline_objects
]
formula_mask_boxes = [
{"bbox": inline_object["table_rel_mask_bbox"]}
for inline_object in table_inline_objects
if inline_object["kind"] == "formula"
]
det_image = (
self._apply_mask_boxes_to_image(bgr_image, inline_mask_boxes)
if inline_mask_boxes
else bgr_image
)
ocr_result = det_ocr_engine.ocr(det_image, rec=False)[0]
if ocr_result and formula_mask_boxes:
ocr_result = update_det_boxes(ocr_result, formula_mask_boxes)
if ocr_result:
ocr_result = sorted_boxes(ocr_result)
# 构造需要 OCR 识别的图片字典包括cropped_img, dt_box, table_id并按照语言进行分组
for dt_box in ocr_result:
rec_img_lang_group[table_res_dict["lang"]].append(
@@ -281,6 +511,26 @@ class BatchAnalyze:
]
# 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型
for table_res_dict in table_res_list_all_page:
if not self._table_supports_inline_objects(table_res_dict):
continue
table_inline_objects = table_res_dict.get("table_inline_objects", [])
if not table_inline_objects:
continue
table_ocr_result = table_res_dict.setdefault("ocr_result", [])
for inline_object in table_inline_objects:
table_ocr_result.append(
[
self._bbox_to_quad(inline_object["table_token_bbox"]),
inline_object["content"],
inline_object["score"],
]
)
self._sort_table_ocr_result(table_ocr_result)
wireless_table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.WirelessTable,
)

View File

@@ -1,6 +1,8 @@
# Copyright (c) Opendatalab. All rights reserved.
import base64
import copy
import os
import re
import time
from loguru import logger
@@ -18,7 +20,69 @@ from mineru.utils.model_utils import clean_memory
from mineru.backend.pipeline.pipeline_magic_model import MagicModel
from mineru.utils.ocr_utils import OcrConfidence, rotate_vertical_crop_if_needed
from mineru.version import __version__
from mineru.utils.hash_utils import bytes_md5
from mineru.utils.hash_utils import bytes_md5, str_sha256
def _save_base64_image(b64_data_uri: str, image_writer, page_index: int):
"""Persist a data-URI image via image_writer and return a relative path."""
m = re.match(r'data:image/(\w+);base64,(.+)', b64_data_uri, re.DOTALL)
if not m:
logger.warning(f"Unrecognized image_base64 format in page {page_index}, skipping.")
return None
fmt = m.group(1)
ext = "jpg" if fmt == "jpeg" else fmt
try:
img_bytes = base64.b64decode(m.group(2))
except Exception as e:
logger.warning(f"Failed to decode image_base64 on page {page_index}: {e}")
return None
img_path = f"{str_sha256(b64_data_uri)}.{ext}"
image_writer.write(img_path, img_bytes)
return img_path
def _replace_inline_base64_img_src(markup: str, image_writer, page_index: int) -> str:
"""Replace inline base64 img src attributes with local relative paths."""
if not markup or "base64," not in markup:
return markup
def _replace_src(match, _writer=image_writer, _idx=page_index):
img_path = _save_base64_image(match.group(1), _writer, _idx)
if img_path:
return f'src="{img_path}"'
return match.group(0)
return re.sub(
r'src="(data:image/[^"]+)"',
_replace_src,
markup,
)
def _replace_inline_table_images(preproc_blocks: list[dict], image_writer, page_index: int) -> None:
"""Persist inline base64 images embedded inside table HTML."""
if not image_writer:
return
for block in preproc_blocks:
if block.get("type") != BlockType.TABLE:
continue
for sub_block in block.get("blocks", []):
if sub_block.get("type") != BlockType.TABLE_BODY:
continue
for line in sub_block.get("lines", []):
for span in line.get("spans", []):
if span.get("type") != ContentType.TABLE:
continue
span["html"] = _replace_inline_base64_img_src(
span.get("html", ""),
image_writer,
page_index,
)
def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False):
@@ -53,6 +117,8 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
"""构造page_info"""
_replace_inline_table_images(preproc_blocks, image_writer, page_index)
page_info = make_page_info_dict(preproc_blocks, page_index, page_w, page_h, discarded_blocks)
return page_info

View File

@@ -1,3 +1,6 @@
import re
from html import unescape
from loguru import logger
from mineru.utils.char_utils import full_to_half_exclude_marks, is_hyphen_at_line_end
@@ -165,7 +168,10 @@ def render_visual_block_segments(block, img_buket_path=''):
if span['type'] != ContentType.TABLE:
continue
if span.get('html', ''):
rendered_segments.append((span['html'], 'html_block'))
rendered_segments.append((
_format_embedded_html(span['html'], img_buket_path),
'html_block',
))
elif span.get('image_path', ''):
rendered_segments.append((f"![]({img_buket_path}/{span['image_path']})", 'markdown_line'))
return rendered_segments
@@ -204,6 +210,38 @@ inline_right_delimiter = delimiters['inline']['right']
CJK_LANGS = {'zh', 'ja', 'ko'}
def _prefix_table_img_src(html, img_buket_path):
"""Prefix non-data image sources in table HTML with img_buket_path."""
if not html or not img_buket_path:
return html
return re.sub(
r'src="(?!data:)([^"]+)"',
lambda match: f'src="{img_buket_path}/{match.group(1)}"',
html,
)
def _replace_eq_tags_in_table_html(html):
"""Replace <eq>...</eq> tags in table HTML with inline math delimiters."""
if not html:
return html
return re.sub(
r'<eq>(.*?)</eq>',
lambda match: (
f" {inline_left_delimiter}{unescape(match.group(1))}{inline_right_delimiter} "
),
html,
flags=re.DOTALL,
)
def _format_embedded_html(html, img_buket_path):
"""Normalize embedded table HTML for markdown/content outputs."""
return _replace_eq_tags_in_table_html(_prefix_table_img_src(html, img_buket_path))
def merge_para_with_text(para_block):
if _is_fenced_code_block(para_block):
code_text = _merge_para_text(
@@ -419,7 +457,10 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_size)
for span in line['spans']:
if span['type'] == ContentType.TABLE:
if span.get('html', ''):
para_content[BlockType.TABLE_BODY] = f"{span['html']}"
para_content[BlockType.TABLE_BODY] = _format_embedded_html(
span['html'],
img_buket_path,
)
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"