From 0e4c9aee0096d398511fceee5bb45c1b4bb88afb Mon Sep 17 00:00:00 2001 From: myhloli Date: Wed, 24 Dec 2025 14:00:13 +0800 Subject: [PATCH] refactor: enhance batch ratio calculation based on GPU compute capability --- mineru/backend/hybrid/hybrid_analyze.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mineru/backend/hybrid/hybrid_analyze.py b/mineru/backend/hybrid/hybrid_analyze.py index b756d5f3..85a2ef51 100644 --- a/mineru/backend/hybrid/hybrid_analyze.py +++ b/mineru/backend/hybrid/hybrid_analyze.py @@ -5,10 +5,17 @@ from collections import defaultdict import cv2 import numpy as np from loguru import logger +from packaging import version from mineru_vl_utils import MinerUClient from mineru_vl_utils.structs import BlockType from tqdm import tqdm +try: + import torch + import torch_npu +except ImportError: + pass + from mineru.backend.hybrid.hybrid_model_output_to_middle_json import result_to_middle_json from mineru.backend.pipeline.model_init import HybridModelSingleton from mineru.backend.vlm.vlm_analyze import ModelSingleton @@ -362,6 +369,15 @@ def get_batch_ratio(device): else: batch_ratio = 1 + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + # 正确计算Compute Capability + compute_capability = f"{major}.{minor}" + elif hasattr(torch, 'npu') and torch.npu.is_available(): + compute_capability = "8.0" + if version.parse(compute_capability) < version.parse("8.0"): + batch_ratio = max(1, batch_ratio // 2) + logger.info(f"hybrid batch ratio (auto, vram={gpu_memory}GB): {batch_ratio}") return batch_ratio