diff --git a/mineru/backend/pipeline/batch_analyze.py b/mineru/backend/pipeline/batch_analyze.py index d0f87819..89e7264c 100644 --- a/mineru/backend/pipeline/batch_analyze.py +++ b/mineru/backend/pipeline/batch_analyze.py @@ -1,3 +1,4 @@ +import base64 import html import cv2 @@ -107,6 +108,197 @@ class BatchAnalyze: layout_res[:] = [item for item in layout_res if keep_item(item)] + @staticmethod + def _bbox_center(bbox: list[float]) -> tuple[float, float]: + return (float(bbox[0] + bbox[2]) / 2.0, float(bbox[1] + bbox[3]) / 2.0) + + @staticmethod + def _is_point_in_bbox(point: tuple[float, float], bbox: list[float]) -> bool: + x, y = point + return bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3] + + @staticmethod + def _bbox_intersection(bbox1: list[float], bbox2: list[float]) -> list[float] | None: + x0 = max(float(bbox1[0]), float(bbox2[0])) + y0 = max(float(bbox1[1]), float(bbox2[1])) + x1 = min(float(bbox1[2]), float(bbox2[2])) + y1 = min(float(bbox1[3]), float(bbox2[3])) + if x1 <= x0 or y1 <= y0: + return None + return [x0, y0, x1, y1] + + @classmethod + def _bbox_intersection_area(cls, bbox1: list[float], bbox2: list[float]) -> float: + overlap_bbox = cls._bbox_intersection(bbox1, bbox2) + if overlap_bbox is None: + return 0.0 + return float(overlap_bbox[2] - overlap_bbox[0]) * float(overlap_bbox[3] - overlap_bbox[1]) + + @staticmethod + def _bbox_to_relative_bbox(bbox: list[float], base_bbox: list[float]) -> list[float]: + return [ + float(bbox[0]) - float(base_bbox[0]), + float(bbox[1]) - float(base_bbox[1]), + float(bbox[2]) - float(base_bbox[0]), + float(bbox[3]) - float(base_bbox[1]), + ] + + @staticmethod + def _bbox_to_quad(bbox: list[float]) -> np.ndarray: + x0, y0, x1, y1 = bbox + return np.asarray([[x0, y0], [x1, y0], [x1, y1], [x0, y1]], dtype=np.float32) + + @staticmethod + def _encode_table_inline_image(np_img: np.ndarray, bbox: list[float]) -> str: + image_h, image_w = np_img.shape[:2] + image_bbox = normalize_to_int_bbox(bbox, image_size=(image_h, image_w)) + if image_bbox is None: + return "" + + x0, y0, x1, y1 = image_bbox + if x1 <= x0 or y1 <= y0: + return "" + + crop_rgb = np_img[y0:y1, x0:x1] + if crop_rgb.size == 0: + return "" + + crop_bgr = cv2.cvtColor(crop_rgb, cv2.COLOR_RGB2BGR) + success, encoded = cv2.imencode(".jpg", crop_bgr) + if not success: + return "" + + b64_str = base64.b64encode(encoded.tobytes()).decode("ascii") + return f"data:image/jpg;base64,{b64_str}" + + @staticmethod + def _get_virtual_image_bbox(bbox: list[float], box_size: float = 10.0) -> list[float]: + center_x, center_y = BatchAnalyze._bbox_center(bbox) + half_size = box_size / 2.0 + return [ + center_x - half_size, + center_y - half_size, + center_x + half_size, + center_y + half_size, + ] + + @staticmethod + def _table_supports_inline_objects(table_res_dict: dict) -> bool: + return str(table_res_dict.get("rotate_label", "0")) == "0" + + @staticmethod + def _sort_table_ocr_result(ocr_result: list[list]) -> None: + if not ocr_result: + return + + sorted_result = sorted( + ocr_result, + key=lambda item: (float(np.asarray(item[0])[0][1]), float(np.asarray(item[0])[0][0])), + ) + + for i in range(len(sorted_result) - 1): + for j in range(i, -1, -1): + cur_box = np.asarray(sorted_result[j][0], dtype=np.float32) + next_box = np.asarray(sorted_result[j + 1][0], dtype=np.float32) + if ( + abs(float(next_box[0][1]) - float(cur_box[0][1])) < 10 + and float(next_box[0][0]) < float(cur_box[0][0]) + ): + sorted_result[j], sorted_result[j + 1] = sorted_result[j + 1], sorted_result[j] + else: + break + + ocr_result[:] = sorted_result + + @classmethod + def _extract_table_inline_objects( + cls, + layout_res: list[dict], + np_img: np.ndarray, + formula_enable: bool, + ) -> dict[int, list[dict]]: + image_h, image_w = np_img.shape[:2] + image_size = (image_h, image_w) + + tables = [] + for res in layout_res: + if res.get("label") != "table": + continue + table_bbox = normalize_to_int_bbox(res.get("bbox"), image_size=image_size) + if table_bbox is None: + continue + tables.append((res, table_bbox)) + + if not tables: + return {} + + table_inline_objects = {id(table_res): [] for table_res, _ in tables} + remove_ids = set() + candidate_labels = {"image"} + if formula_enable: + candidate_labels.update({"inline_formula", "display_formula"}) + + for layout_item in layout_res: + label = layout_item.get("label") + if label not in candidate_labels: + continue + + item_bbox = normalize_to_int_bbox(layout_item.get("bbox"), image_size=image_size) + if item_bbox is None: + continue + + item_center = cls._bbox_center(item_bbox) + matched_tables = [] + for table_res, table_bbox in tables: + if not cls._is_point_in_bbox(item_center, table_bbox): + continue + overlap_area = cls._bbox_intersection_area(item_bbox, table_bbox) + matched_tables.append((overlap_area, table_res, table_bbox)) + + if not matched_tables: + continue + + matched_tables.sort(key=lambda item: item[0], reverse=True) + _, table_res, table_bbox = matched_tables[0] + overlap_bbox = cls._bbox_intersection(item_bbox, table_bbox) + if overlap_bbox is None: + continue + + rel_overlap_bbox = cls._bbox_to_relative_bbox(overlap_bbox, table_bbox) + score = float(layout_item.get("score", 1.0)) + + if label == "image": + image_src = cls._encode_table_inline_image(np_img, item_bbox) + if not image_src: + continue + content = f'' + token_bbox = cls._get_virtual_image_bbox(rel_overlap_bbox) + kind = "image" + else: + latex = layout_item.get("latex", "") + if not latex: + continue + content = f"{html.escape(latex)}" + token_bbox = rel_overlap_bbox + kind = "formula" + + table_inline_objects[id(table_res)].append( + { + "kind": kind, + "page_bbox": item_bbox, + "table_rel_mask_bbox": rel_overlap_bbox, + "table_token_bbox": token_bbox, + "content": content, + "score": score, + } + ) + remove_ids.add(id(layout_item)) + + if remove_ids: + layout_res[:] = [item for item in layout_res if id(item) not in remove_ids] + + return table_inline_objects + def __call__(self, images_with_extra_info: list) -> list: if len(images_with_extra_info) == 0: @@ -167,6 +359,15 @@ class BatchAnalyze: _, ocr_enable, _lang = images_with_extra_info[index] layout_res = images_layout_res[index] np_img = np_images[index] + table_inline_objects = ( + self._extract_table_inline_objects( + layout_res, + np_img, + formula_enable=self.formula_enable, + ) + if self.table_enable + else {} + ) ocr_res_list, table_res_list, single_page_mfdetrec_res = ( get_res_list_from_layout_res(layout_res) @@ -191,11 +392,17 @@ class BatchAnalyze: wireless_table_img = get_crop_table_img(scale = 1) wired_table_img = get_crop_table_img(scale = 10/3) + table_page_bbox = normalize_to_int_bbox( + table_res.get("bbox"), + image_size=np_img.shape[:2], + ) or [0, 0, 0, 0] table_res_list_all_page.append({'table_res':table_res, 'lang':_lang, 'table_img':wireless_table_img, 'wired_table_img':wired_table_img, + 'table_page_bbox':table_page_bbox, + 'table_inline_objects':table_inline_objects.get(id(table_res), []), }) # 表格识别 table recognition @@ -243,7 +450,30 @@ class BatchAnalyze: tqdm(table_res_list_all_page, desc="Table-ocr det") ): bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR) - ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0] + table_inline_objects = ( + table_res_dict.get("table_inline_objects", []) + if self._table_supports_inline_objects(table_res_dict) + else [] + ) + inline_mask_boxes = [ + {"bbox": inline_object["table_rel_mask_bbox"]} + for inline_object in table_inline_objects + ] + formula_mask_boxes = [ + {"bbox": inline_object["table_rel_mask_bbox"]} + for inline_object in table_inline_objects + if inline_object["kind"] == "formula" + ] + det_image = ( + self._apply_mask_boxes_to_image(bgr_image, inline_mask_boxes) + if inline_mask_boxes + else bgr_image + ) + ocr_result = det_ocr_engine.ocr(det_image, rec=False)[0] + if ocr_result and formula_mask_boxes: + ocr_result = update_det_boxes(ocr_result, formula_mask_boxes) + if ocr_result: + ocr_result = sorted_boxes(ocr_result) # 构造需要 OCR 识别的图片字典,包括cropped_img, dt_box, table_id,并按照语言进行分组 for dt_box in ocr_result: rec_img_lang_group[table_res_dict["lang"]].append( @@ -281,6 +511,26 @@ class BatchAnalyze: ] # 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型 + for table_res_dict in table_res_list_all_page: + if not self._table_supports_inline_objects(table_res_dict): + continue + + table_inline_objects = table_res_dict.get("table_inline_objects", []) + if not table_inline_objects: + continue + + table_ocr_result = table_res_dict.setdefault("ocr_result", []) + for inline_object in table_inline_objects: + table_ocr_result.append( + [ + self._bbox_to_quad(inline_object["table_token_bbox"]), + inline_object["content"], + inline_object["score"], + ] + ) + + self._sort_table_ocr_result(table_ocr_result) + wireless_table_model = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.WirelessTable, ) diff --git a/mineru/backend/pipeline/model_json_to_middle_json.py b/mineru/backend/pipeline/model_json_to_middle_json.py index e2c01586..9460425b 100644 --- a/mineru/backend/pipeline/model_json_to_middle_json.py +++ b/mineru/backend/pipeline/model_json_to_middle_json.py @@ -1,6 +1,8 @@ # Copyright (c) Opendatalab. All rights reserved. +import base64 import copy import os +import re import time from loguru import logger @@ -18,7 +20,69 @@ from mineru.utils.model_utils import clean_memory from mineru.backend.pipeline.pipeline_magic_model import MagicModel from mineru.utils.ocr_utils import OcrConfidence, rotate_vertical_crop_if_needed from mineru.version import __version__ -from mineru.utils.hash_utils import bytes_md5 +from mineru.utils.hash_utils import bytes_md5, str_sha256 + + +def _save_base64_image(b64_data_uri: str, image_writer, page_index: int): + """Persist a data-URI image via image_writer and return a relative path.""" + m = re.match(r'data:image/(\w+);base64,(.+)', b64_data_uri, re.DOTALL) + if not m: + logger.warning(f"Unrecognized image_base64 format in page {page_index}, skipping.") + return None + + fmt = m.group(1) + ext = "jpg" if fmt == "jpeg" else fmt + try: + img_bytes = base64.b64decode(m.group(2)) + except Exception as e: + logger.warning(f"Failed to decode image_base64 on page {page_index}: {e}") + return None + + img_path = f"{str_sha256(b64_data_uri)}.{ext}" + image_writer.write(img_path, img_bytes) + return img_path + + +def _replace_inline_base64_img_src(markup: str, image_writer, page_index: int) -> str: + """Replace inline base64 img src attributes with local relative paths.""" + if not markup or "base64," not in markup: + return markup + + def _replace_src(match, _writer=image_writer, _idx=page_index): + img_path = _save_base64_image(match.group(1), _writer, _idx) + if img_path: + return f'src="{img_path}"' + return match.group(0) + + return re.sub( + r'src="(data:image/[^"]+)"', + _replace_src, + markup, + ) + + +def _replace_inline_table_images(preproc_blocks: list[dict], image_writer, page_index: int) -> None: + """Persist inline base64 images embedded inside table HTML.""" + if not image_writer: + return + + for block in preproc_blocks: + if block.get("type") != BlockType.TABLE: + continue + + for sub_block in block.get("blocks", []): + if sub_block.get("type") != BlockType.TABLE_BODY: + continue + + for line in sub_block.get("lines", []): + for span in line.get("spans", []): + if span.get("type") != ContentType.TABLE: + continue + span["html"] = _replace_inline_base64_img_src( + span.get("html", ""), + image_writer, + page_index, + ) def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False): @@ -53,6 +117,8 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale) """构造page_info""" + _replace_inline_table_images(preproc_blocks, image_writer, page_index) + page_info = make_page_info_dict(preproc_blocks, page_index, page_w, page_h, discarded_blocks) return page_info diff --git a/mineru/backend/pipeline/pipeline_middle_json_mkcontent.py b/mineru/backend/pipeline/pipeline_middle_json_mkcontent.py index dc464a35..0fe70ff4 100644 --- a/mineru/backend/pipeline/pipeline_middle_json_mkcontent.py +++ b/mineru/backend/pipeline/pipeline_middle_json_mkcontent.py @@ -1,3 +1,6 @@ +import re +from html import unescape + from loguru import logger from mineru.utils.char_utils import full_to_half_exclude_marks, is_hyphen_at_line_end @@ -165,7 +168,10 @@ def render_visual_block_segments(block, img_buket_path=''): if span['type'] != ContentType.TABLE: continue if span.get('html', ''): - rendered_segments.append((span['html'], 'html_block')) + rendered_segments.append(( + _format_embedded_html(span['html'], img_buket_path), + 'html_block', + )) elif span.get('image_path', ''): rendered_segments.append((f"![]({img_buket_path}/{span['image_path']})", 'markdown_line')) return rendered_segments @@ -204,6 +210,38 @@ inline_right_delimiter = delimiters['inline']['right'] CJK_LANGS = {'zh', 'ja', 'ko'} +def _prefix_table_img_src(html, img_buket_path): + """Prefix non-data image sources in table HTML with img_buket_path.""" + if not html or not img_buket_path: + return html + + return re.sub( + r'src="(?!data:)([^"]+)"', + lambda match: f'src="{img_buket_path}/{match.group(1)}"', + html, + ) + + +def _replace_eq_tags_in_table_html(html): + """Replace ... tags in table HTML with inline math delimiters.""" + if not html: + return html + + return re.sub( + r'(.*?)', + lambda match: ( + f" {inline_left_delimiter}{unescape(match.group(1))}{inline_right_delimiter} " + ), + html, + flags=re.DOTALL, + ) + + +def _format_embedded_html(html, img_buket_path): + """Normalize embedded table HTML for markdown/content outputs.""" + return _replace_eq_tags_in_table_html(_prefix_table_img_src(html, img_buket_path)) + + def merge_para_with_text(para_block): if _is_fenced_code_block(para_block): code_text = _merge_para_text( @@ -419,7 +457,10 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_size) for span in line['spans']: if span['type'] == ContentType.TABLE: if span.get('html', ''): - para_content[BlockType.TABLE_BODY] = f"{span['html']}" + para_content[BlockType.TABLE_BODY] = _format_embedded_html( + span['html'], + img_buket_path, + ) if span.get('image_path', ''): para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"