From 447ffcd32f425872ed36fcbb1acc818f9f8033b8 Mon Sep 17 00:00:00 2001 From: myhloli Date: Tue, 23 Dec 2025 17:00:02 +0800 Subject: [PATCH] refactor: implement dynamic batch ratio based on GPU memory and environment variable --- mineru/backend/hybrid/hybrid_analyze.py | 50 ++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/mineru/backend/hybrid/hybrid_analyze.py b/mineru/backend/hybrid/hybrid_analyze.py index 0a605c1c..014283d0 100644 --- a/mineru/backend/hybrid/hybrid_analyze.py +++ b/mineru/backend/hybrid/hybrid_analyze.py @@ -4,6 +4,7 @@ from collections import defaultdict import cv2 import numpy as np +from loguru import logger from mineru_vl_utils import MinerUClient from mineru_vl_utils.structs import BlockType from tqdm import tqdm @@ -12,8 +13,9 @@ from mineru.backend.hybrid.hybrid_model_output_to_middle_json import result_to_m from mineru.backend.pipeline.model_init import HybridModelSingleton from mineru.backend.vlm.vlm_analyze import ModelSingleton from mineru.data.data_reader_writer import DataWriter +from mineru.utils.config_reader import get_device from mineru.utils.enum_class import ImageType, NotExtractType -from mineru.utils.model_utils import crop_img +from mineru.utils.model_utils import crop_img, get_vram, clean_memory from mineru.utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, sorted_boxes, merge_det_boxes, \ update_det_boxes, OcrConfidence from mineru.utils.pdf_classify import classify @@ -203,7 +205,7 @@ def _process_ocr_and_formulas( inline_formula_enable, _ocr_enable, _vlm_ocr_enable, - batch_radio: int = 2, + batch_radio: int = 1, ): """处理OCR和公式识别""" @@ -317,6 +319,36 @@ def _normalize_bbox( normalize_poly_to_bbox(ocr_res, page_width, page_height) +def get_batch_ratio(device): + """ + 根据显存大小或环境变量获取 batch ratio + """ + # 1. 优先尝试从环境变量获取 + env_val = os.getenv("MINERU_HYBRID_BATCH_RATIO") + if env_val: + try: + batch_ratio = int(env_val) + logger.info(f"hybrid batch ratio (from env): {batch_ratio}") + return batch_ratio + except ValueError as e: + logger.warning(f"Invalid MINERU_HYBRID_BATCH_RATIO value: {env_val}, switching to auto mode. Error: {e}") + + # 2. 根据显存自动推断 + gpu_memory = get_vram(device) + + if gpu_memory >= 24: + batch_ratio = 8 + elif gpu_memory >= 16: + batch_ratio = 4 + elif gpu_memory >= 8: + batch_ratio = 2 + else: + batch_ratio = 1 + + logger.info(f"hybrid batch ratio (auto, vram={gpu_memory}GB): {batch_ratio}") + return batch_ratio + + def doc_analyze( pdf_bytes, image_writer: DataWriter | None, @@ -335,6 +367,9 @@ def doc_analyze( images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL) images_pil_list = [image_dict["img_pil"] for image_dict in images_list] + device = get_device() + batch_ratio = get_batch_ratio(device) + _ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method) _vlm_ocr_enable = False if _ocr_enable and language in ["ch", "en"] and inline_formula_enable: @@ -350,6 +385,7 @@ def doc_analyze( inline_formula_enable, _ocr_enable, _vlm_ocr_enable, + batch_radio=batch_ratio, ) _normalize_bbox( @@ -369,6 +405,9 @@ def doc_analyze( _vlm_ocr_enable, hybrid_pipeline_model, ) + + clean_memory(device) + return middle_json, results @@ -390,6 +429,9 @@ async def aio_doc_analyze( images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL) images_pil_list = [image_dict["img_pil"] for image_dict in images_list] + device = get_device() + batch_ratio = get_batch_ratio(device) + _ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method) _vlm_ocr_enable = False if _ocr_enable and language in ["ch", "en"] and inline_formula_enable: @@ -405,6 +447,7 @@ async def aio_doc_analyze( inline_formula_enable, _ocr_enable, _vlm_ocr_enable, + batch_radio=batch_ratio, ) _normalize_bbox( @@ -424,5 +467,8 @@ async def aio_doc_analyze( _vlm_ocr_enable, hybrid_pipeline_model, ) + + clean_memory(device) + return middle_json, results