mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
feat: enhance table processing with inline object extraction and base64 image handling
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import html
|
||||
|
||||
import cv2
|
||||
@@ -107,6 +108,197 @@ class BatchAnalyze:
|
||||
|
||||
layout_res[:] = [item for item in layout_res if keep_item(item)]
|
||||
|
||||
@staticmethod
|
||||
def _bbox_center(bbox: list[float]) -> tuple[float, float]:
|
||||
return (float(bbox[0] + bbox[2]) / 2.0, float(bbox[1] + bbox[3]) / 2.0)
|
||||
|
||||
@staticmethod
|
||||
def _is_point_in_bbox(point: tuple[float, float], bbox: list[float]) -> bool:
|
||||
x, y = point
|
||||
return bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3]
|
||||
|
||||
@staticmethod
|
||||
def _bbox_intersection(bbox1: list[float], bbox2: list[float]) -> list[float] | None:
|
||||
x0 = max(float(bbox1[0]), float(bbox2[0]))
|
||||
y0 = max(float(bbox1[1]), float(bbox2[1]))
|
||||
x1 = min(float(bbox1[2]), float(bbox2[2]))
|
||||
y1 = min(float(bbox1[3]), float(bbox2[3]))
|
||||
if x1 <= x0 or y1 <= y0:
|
||||
return None
|
||||
return [x0, y0, x1, y1]
|
||||
|
||||
@classmethod
|
||||
def _bbox_intersection_area(cls, bbox1: list[float], bbox2: list[float]) -> float:
|
||||
overlap_bbox = cls._bbox_intersection(bbox1, bbox2)
|
||||
if overlap_bbox is None:
|
||||
return 0.0
|
||||
return float(overlap_bbox[2] - overlap_bbox[0]) * float(overlap_bbox[3] - overlap_bbox[1])
|
||||
|
||||
@staticmethod
|
||||
def _bbox_to_relative_bbox(bbox: list[float], base_bbox: list[float]) -> list[float]:
|
||||
return [
|
||||
float(bbox[0]) - float(base_bbox[0]),
|
||||
float(bbox[1]) - float(base_bbox[1]),
|
||||
float(bbox[2]) - float(base_bbox[0]),
|
||||
float(bbox[3]) - float(base_bbox[1]),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _bbox_to_quad(bbox: list[float]) -> np.ndarray:
|
||||
x0, y0, x1, y1 = bbox
|
||||
return np.asarray([[x0, y0], [x1, y0], [x1, y1], [x0, y1]], dtype=np.float32)
|
||||
|
||||
@staticmethod
|
||||
def _encode_table_inline_image(np_img: np.ndarray, bbox: list[float]) -> str:
|
||||
image_h, image_w = np_img.shape[:2]
|
||||
image_bbox = normalize_to_int_bbox(bbox, image_size=(image_h, image_w))
|
||||
if image_bbox is None:
|
||||
return ""
|
||||
|
||||
x0, y0, x1, y1 = image_bbox
|
||||
if x1 <= x0 or y1 <= y0:
|
||||
return ""
|
||||
|
||||
crop_rgb = np_img[y0:y1, x0:x1]
|
||||
if crop_rgb.size == 0:
|
||||
return ""
|
||||
|
||||
crop_bgr = cv2.cvtColor(crop_rgb, cv2.COLOR_RGB2BGR)
|
||||
success, encoded = cv2.imencode(".jpg", crop_bgr)
|
||||
if not success:
|
||||
return ""
|
||||
|
||||
b64_str = base64.b64encode(encoded.tobytes()).decode("ascii")
|
||||
return f"data:image/jpg;base64,{b64_str}"
|
||||
|
||||
@staticmethod
|
||||
def _get_virtual_image_bbox(bbox: list[float], box_size: float = 10.0) -> list[float]:
|
||||
center_x, center_y = BatchAnalyze._bbox_center(bbox)
|
||||
half_size = box_size / 2.0
|
||||
return [
|
||||
center_x - half_size,
|
||||
center_y - half_size,
|
||||
center_x + half_size,
|
||||
center_y + half_size,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _table_supports_inline_objects(table_res_dict: dict) -> bool:
|
||||
return str(table_res_dict.get("rotate_label", "0")) == "0"
|
||||
|
||||
@staticmethod
|
||||
def _sort_table_ocr_result(ocr_result: list[list]) -> None:
|
||||
if not ocr_result:
|
||||
return
|
||||
|
||||
sorted_result = sorted(
|
||||
ocr_result,
|
||||
key=lambda item: (float(np.asarray(item[0])[0][1]), float(np.asarray(item[0])[0][0])),
|
||||
)
|
||||
|
||||
for i in range(len(sorted_result) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
cur_box = np.asarray(sorted_result[j][0], dtype=np.float32)
|
||||
next_box = np.asarray(sorted_result[j + 1][0], dtype=np.float32)
|
||||
if (
|
||||
abs(float(next_box[0][1]) - float(cur_box[0][1])) < 10
|
||||
and float(next_box[0][0]) < float(cur_box[0][0])
|
||||
):
|
||||
sorted_result[j], sorted_result[j + 1] = sorted_result[j + 1], sorted_result[j]
|
||||
else:
|
||||
break
|
||||
|
||||
ocr_result[:] = sorted_result
|
||||
|
||||
@classmethod
|
||||
def _extract_table_inline_objects(
|
||||
cls,
|
||||
layout_res: list[dict],
|
||||
np_img: np.ndarray,
|
||||
formula_enable: bool,
|
||||
) -> dict[int, list[dict]]:
|
||||
image_h, image_w = np_img.shape[:2]
|
||||
image_size = (image_h, image_w)
|
||||
|
||||
tables = []
|
||||
for res in layout_res:
|
||||
if res.get("label") != "table":
|
||||
continue
|
||||
table_bbox = normalize_to_int_bbox(res.get("bbox"), image_size=image_size)
|
||||
if table_bbox is None:
|
||||
continue
|
||||
tables.append((res, table_bbox))
|
||||
|
||||
if not tables:
|
||||
return {}
|
||||
|
||||
table_inline_objects = {id(table_res): [] for table_res, _ in tables}
|
||||
remove_ids = set()
|
||||
candidate_labels = {"image"}
|
||||
if formula_enable:
|
||||
candidate_labels.update({"inline_formula", "display_formula"})
|
||||
|
||||
for layout_item in layout_res:
|
||||
label = layout_item.get("label")
|
||||
if label not in candidate_labels:
|
||||
continue
|
||||
|
||||
item_bbox = normalize_to_int_bbox(layout_item.get("bbox"), image_size=image_size)
|
||||
if item_bbox is None:
|
||||
continue
|
||||
|
||||
item_center = cls._bbox_center(item_bbox)
|
||||
matched_tables = []
|
||||
for table_res, table_bbox in tables:
|
||||
if not cls._is_point_in_bbox(item_center, table_bbox):
|
||||
continue
|
||||
overlap_area = cls._bbox_intersection_area(item_bbox, table_bbox)
|
||||
matched_tables.append((overlap_area, table_res, table_bbox))
|
||||
|
||||
if not matched_tables:
|
||||
continue
|
||||
|
||||
matched_tables.sort(key=lambda item: item[0], reverse=True)
|
||||
_, table_res, table_bbox = matched_tables[0]
|
||||
overlap_bbox = cls._bbox_intersection(item_bbox, table_bbox)
|
||||
if overlap_bbox is None:
|
||||
continue
|
||||
|
||||
rel_overlap_bbox = cls._bbox_to_relative_bbox(overlap_bbox, table_bbox)
|
||||
score = float(layout_item.get("score", 1.0))
|
||||
|
||||
if label == "image":
|
||||
image_src = cls._encode_table_inline_image(np_img, item_bbox)
|
||||
if not image_src:
|
||||
continue
|
||||
content = f'<img src="{image_src}"/>'
|
||||
token_bbox = cls._get_virtual_image_bbox(rel_overlap_bbox)
|
||||
kind = "image"
|
||||
else:
|
||||
latex = layout_item.get("latex", "")
|
||||
if not latex:
|
||||
continue
|
||||
content = f"<eq>{html.escape(latex)}</eq>"
|
||||
token_bbox = rel_overlap_bbox
|
||||
kind = "formula"
|
||||
|
||||
table_inline_objects[id(table_res)].append(
|
||||
{
|
||||
"kind": kind,
|
||||
"page_bbox": item_bbox,
|
||||
"table_rel_mask_bbox": rel_overlap_bbox,
|
||||
"table_token_bbox": token_bbox,
|
||||
"content": content,
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
remove_ids.add(id(layout_item))
|
||||
|
||||
if remove_ids:
|
||||
layout_res[:] = [item for item in layout_res if id(item) not in remove_ids]
|
||||
|
||||
return table_inline_objects
|
||||
|
||||
|
||||
def __call__(self, images_with_extra_info: list) -> list:
|
||||
if len(images_with_extra_info) == 0:
|
||||
@@ -167,6 +359,15 @@ class BatchAnalyze:
|
||||
_, ocr_enable, _lang = images_with_extra_info[index]
|
||||
layout_res = images_layout_res[index]
|
||||
np_img = np_images[index]
|
||||
table_inline_objects = (
|
||||
self._extract_table_inline_objects(
|
||||
layout_res,
|
||||
np_img,
|
||||
formula_enable=self.formula_enable,
|
||||
)
|
||||
if self.table_enable
|
||||
else {}
|
||||
)
|
||||
|
||||
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
||||
get_res_list_from_layout_res(layout_res)
|
||||
@@ -191,11 +392,17 @@ class BatchAnalyze:
|
||||
|
||||
wireless_table_img = get_crop_table_img(scale = 1)
|
||||
wired_table_img = get_crop_table_img(scale = 10/3)
|
||||
table_page_bbox = normalize_to_int_bbox(
|
||||
table_res.get("bbox"),
|
||||
image_size=np_img.shape[:2],
|
||||
) or [0, 0, 0, 0]
|
||||
|
||||
table_res_list_all_page.append({'table_res':table_res,
|
||||
'lang':_lang,
|
||||
'table_img':wireless_table_img,
|
||||
'wired_table_img':wired_table_img,
|
||||
'table_page_bbox':table_page_bbox,
|
||||
'table_inline_objects':table_inline_objects.get(id(table_res), []),
|
||||
})
|
||||
|
||||
# 表格识别 table recognition
|
||||
@@ -243,7 +450,30 @@ class BatchAnalyze:
|
||||
tqdm(table_res_list_all_page, desc="Table-ocr det")
|
||||
):
|
||||
bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR)
|
||||
ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0]
|
||||
table_inline_objects = (
|
||||
table_res_dict.get("table_inline_objects", [])
|
||||
if self._table_supports_inline_objects(table_res_dict)
|
||||
else []
|
||||
)
|
||||
inline_mask_boxes = [
|
||||
{"bbox": inline_object["table_rel_mask_bbox"]}
|
||||
for inline_object in table_inline_objects
|
||||
]
|
||||
formula_mask_boxes = [
|
||||
{"bbox": inline_object["table_rel_mask_bbox"]}
|
||||
for inline_object in table_inline_objects
|
||||
if inline_object["kind"] == "formula"
|
||||
]
|
||||
det_image = (
|
||||
self._apply_mask_boxes_to_image(bgr_image, inline_mask_boxes)
|
||||
if inline_mask_boxes
|
||||
else bgr_image
|
||||
)
|
||||
ocr_result = det_ocr_engine.ocr(det_image, rec=False)[0]
|
||||
if ocr_result and formula_mask_boxes:
|
||||
ocr_result = update_det_boxes(ocr_result, formula_mask_boxes)
|
||||
if ocr_result:
|
||||
ocr_result = sorted_boxes(ocr_result)
|
||||
# 构造需要 OCR 识别的图片字典,包括cropped_img, dt_box, table_id,并按照语言进行分组
|
||||
for dt_box in ocr_result:
|
||||
rec_img_lang_group[table_res_dict["lang"]].append(
|
||||
@@ -281,6 +511,26 @@ class BatchAnalyze:
|
||||
]
|
||||
|
||||
# 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型
|
||||
for table_res_dict in table_res_list_all_page:
|
||||
if not self._table_supports_inline_objects(table_res_dict):
|
||||
continue
|
||||
|
||||
table_inline_objects = table_res_dict.get("table_inline_objects", [])
|
||||
if not table_inline_objects:
|
||||
continue
|
||||
|
||||
table_ocr_result = table_res_dict.setdefault("ocr_result", [])
|
||||
for inline_object in table_inline_objects:
|
||||
table_ocr_result.append(
|
||||
[
|
||||
self._bbox_to_quad(inline_object["table_token_bbox"]),
|
||||
inline_object["content"],
|
||||
inline_object["score"],
|
||||
]
|
||||
)
|
||||
|
||||
self._sort_table_ocr_result(table_ocr_result)
|
||||
|
||||
wireless_table_model = atom_model_manager.get_atom_model(
|
||||
atom_model_name=AtomicModel.WirelessTable,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Opendatalab. All rights reserved.
|
||||
import base64
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
from loguru import logger
|
||||
@@ -18,7 +20,69 @@ from mineru.utils.model_utils import clean_memory
|
||||
from mineru.backend.pipeline.pipeline_magic_model import MagicModel
|
||||
from mineru.utils.ocr_utils import OcrConfidence, rotate_vertical_crop_if_needed
|
||||
from mineru.version import __version__
|
||||
from mineru.utils.hash_utils import bytes_md5
|
||||
from mineru.utils.hash_utils import bytes_md5, str_sha256
|
||||
|
||||
|
||||
def _save_base64_image(b64_data_uri: str, image_writer, page_index: int):
|
||||
"""Persist a data-URI image via image_writer and return a relative path."""
|
||||
m = re.match(r'data:image/(\w+);base64,(.+)', b64_data_uri, re.DOTALL)
|
||||
if not m:
|
||||
logger.warning(f"Unrecognized image_base64 format in page {page_index}, skipping.")
|
||||
return None
|
||||
|
||||
fmt = m.group(1)
|
||||
ext = "jpg" if fmt == "jpeg" else fmt
|
||||
try:
|
||||
img_bytes = base64.b64decode(m.group(2))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode image_base64 on page {page_index}: {e}")
|
||||
return None
|
||||
|
||||
img_path = f"{str_sha256(b64_data_uri)}.{ext}"
|
||||
image_writer.write(img_path, img_bytes)
|
||||
return img_path
|
||||
|
||||
|
||||
def _replace_inline_base64_img_src(markup: str, image_writer, page_index: int) -> str:
|
||||
"""Replace inline base64 img src attributes with local relative paths."""
|
||||
if not markup or "base64," not in markup:
|
||||
return markup
|
||||
|
||||
def _replace_src(match, _writer=image_writer, _idx=page_index):
|
||||
img_path = _save_base64_image(match.group(1), _writer, _idx)
|
||||
if img_path:
|
||||
return f'src="{img_path}"'
|
||||
return match.group(0)
|
||||
|
||||
return re.sub(
|
||||
r'src="(data:image/[^"]+)"',
|
||||
_replace_src,
|
||||
markup,
|
||||
)
|
||||
|
||||
|
||||
def _replace_inline_table_images(preproc_blocks: list[dict], image_writer, page_index: int) -> None:
|
||||
"""Persist inline base64 images embedded inside table HTML."""
|
||||
if not image_writer:
|
||||
return
|
||||
|
||||
for block in preproc_blocks:
|
||||
if block.get("type") != BlockType.TABLE:
|
||||
continue
|
||||
|
||||
for sub_block in block.get("blocks", []):
|
||||
if sub_block.get("type") != BlockType.TABLE_BODY:
|
||||
continue
|
||||
|
||||
for line in sub_block.get("lines", []):
|
||||
for span in line.get("spans", []):
|
||||
if span.get("type") != ContentType.TABLE:
|
||||
continue
|
||||
span["html"] = _replace_inline_base64_img_src(
|
||||
span.get("html", ""),
|
||||
image_writer,
|
||||
page_index,
|
||||
)
|
||||
|
||||
|
||||
def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False):
|
||||
@@ -53,6 +117,8 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
|
||||
span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
|
||||
|
||||
"""构造page_info"""
|
||||
_replace_inline_table_images(preproc_blocks, image_writer, page_index)
|
||||
|
||||
page_info = make_page_info_dict(preproc_blocks, page_index, page_w, page_h, discarded_blocks)
|
||||
|
||||
return page_info
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import re
|
||||
from html import unescape
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from mineru.utils.char_utils import full_to_half_exclude_marks, is_hyphen_at_line_end
|
||||
@@ -165,7 +168,10 @@ def render_visual_block_segments(block, img_buket_path=''):
|
||||
if span['type'] != ContentType.TABLE:
|
||||
continue
|
||||
if span.get('html', ''):
|
||||
rendered_segments.append((span['html'], 'html_block'))
|
||||
rendered_segments.append((
|
||||
_format_embedded_html(span['html'], img_buket_path),
|
||||
'html_block',
|
||||
))
|
||||
elif span.get('image_path', ''):
|
||||
rendered_segments.append((f"", 'markdown_line'))
|
||||
return rendered_segments
|
||||
@@ -204,6 +210,38 @@ inline_right_delimiter = delimiters['inline']['right']
|
||||
CJK_LANGS = {'zh', 'ja', 'ko'}
|
||||
|
||||
|
||||
def _prefix_table_img_src(html, img_buket_path):
|
||||
"""Prefix non-data image sources in table HTML with img_buket_path."""
|
||||
if not html or not img_buket_path:
|
||||
return html
|
||||
|
||||
return re.sub(
|
||||
r'src="(?!data:)([^"]+)"',
|
||||
lambda match: f'src="{img_buket_path}/{match.group(1)}"',
|
||||
html,
|
||||
)
|
||||
|
||||
|
||||
def _replace_eq_tags_in_table_html(html):
|
||||
"""Replace <eq>...</eq> tags in table HTML with inline math delimiters."""
|
||||
if not html:
|
||||
return html
|
||||
|
||||
return re.sub(
|
||||
r'<eq>(.*?)</eq>',
|
||||
lambda match: (
|
||||
f" {inline_left_delimiter}{unescape(match.group(1))}{inline_right_delimiter} "
|
||||
),
|
||||
html,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def _format_embedded_html(html, img_buket_path):
|
||||
"""Normalize embedded table HTML for markdown/content outputs."""
|
||||
return _replace_eq_tags_in_table_html(_prefix_table_img_src(html, img_buket_path))
|
||||
|
||||
|
||||
def merge_para_with_text(para_block):
|
||||
if _is_fenced_code_block(para_block):
|
||||
code_text = _merge_para_text(
|
||||
@@ -419,7 +457,10 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_size)
|
||||
for span in line['spans']:
|
||||
if span['type'] == ContentType.TABLE:
|
||||
if span.get('html', ''):
|
||||
para_content[BlockType.TABLE_BODY] = f"{span['html']}"
|
||||
para_content[BlockType.TABLE_BODY] = _format_embedded_html(
|
||||
span['html'],
|
||||
img_buket_path,
|
||||
)
|
||||
|
||||
if span.get('image_path', ''):
|
||||
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
|
||||
|
||||
Reference in New Issue
Block a user