refactor: enhance batch ratio calculation based on GPU compute capability

This commit is contained in:
myhloli
2025-12-24 14:00:13 +08:00
parent 41d5b4843a
commit 0e4c9aee00

View File

@@ -5,10 +5,17 @@ from collections import defaultdict
import cv2 import cv2
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from packaging import version
from mineru_vl_utils import MinerUClient from mineru_vl_utils import MinerUClient
from mineru_vl_utils.structs import BlockType from mineru_vl_utils.structs import BlockType
from tqdm import tqdm from tqdm import tqdm
try:
import torch
import torch_npu
except ImportError:
pass
from mineru.backend.hybrid.hybrid_model_output_to_middle_json import result_to_middle_json from mineru.backend.hybrid.hybrid_model_output_to_middle_json import result_to_middle_json
from mineru.backend.pipeline.model_init import HybridModelSingleton from mineru.backend.pipeline.model_init import HybridModelSingleton
from mineru.backend.vlm.vlm_analyze import ModelSingleton from mineru.backend.vlm.vlm_analyze import ModelSingleton
@@ -362,6 +369,15 @@ def get_batch_ratio(device):
else: else:
batch_ratio = 1 batch_ratio = 1
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
# 正确计算Compute Capability
compute_capability = f"{major}.{minor}"
elif hasattr(torch, 'npu') and torch.npu.is_available():
compute_capability = "8.0"
if version.parse(compute_capability) < version.parse("8.0"):
batch_ratio = max(1, batch_ratio // 2)
logger.info(f"hybrid batch ratio (auto, vram={gpu_memory}GB): {batch_ratio}") logger.info(f"hybrid batch ratio (auto, vram={gpu_memory}GB): {batch_ratio}")
return batch_ratio return batch_ratio