diff --git a/mineru/backend/vlm/utils.py b/mineru/backend/vlm/utils.py index 03c16e90..a5aac6e3 100644 --- a/mineru/backend/vlm/utils.py +++ b/mineru/backend/vlm/utils.py @@ -100,20 +100,4 @@ def set_default_batch_size() -> int: except Exception as e: logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1') batch_size = 1 - return batch_size - - -def set_compilation_config() -> dict: - device = get_device() - compilation_config = {} - if str(device).startswith('musa'): - try: - import torch - if torch.musa.is_available(): - compilation_config = { - "cudagraph_capture_sizes": [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 30], - "simple_cuda_graph": True - } - except Exception as e: - pass - return compilation_config \ No newline at end of file + return batch_size \ No newline at end of file diff --git a/mineru/backend/vlm/vlm_analyze.py b/mineru/backend/vlm/vlm_analyze.py index 7402a4bd..584d7fe1 100644 --- a/mineru/backend/vlm/vlm_analyze.py +++ b/mineru/backend/vlm/vlm_analyze.py @@ -6,7 +6,7 @@ import json from loguru import logger from .utils import enable_custom_logits_processors, set_default_gpu_memory_utilization, set_default_batch_size, \ - set_lmdeploy_backend, set_compilation_config + set_lmdeploy_backend from .model_output_to_middle_json import result_to_middle_json from ...data.data_reader_writer import DataWriter from mineru.utils.pdf_image_tools import load_images_from_pdf @@ -100,6 +100,17 @@ class ModelSingleton: import vllm except ImportError: raise ImportError("Please install vllm to use the vllm-engine backend.") + device = get_device() + if device.startswith("musa"): + import torch + if torch.musa.is_available(): + compilation_config = { + "cudagraph_capture_sizes": [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 30], + "simple_cuda_graph": True + } + block_size = 32 + kwargs["compilation_config"] = compilation_config + kwargs["block_size"] = block_size if "compilation_config" in kwargs: if isinstance(kwargs["compilation_config"], str): @@ -109,11 +120,6 @@ class ModelSingleton: logger.warning( f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}") del kwargs["compilation_config"] - else: - compilation_config = set_compilation_config() - if compilation_config: - kwargs["compilation_config"] = compilation_config - if "gpu_memory_utilization" not in kwargs: kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization() if "model" not in kwargs: @@ -129,6 +135,17 @@ class ModelSingleton: from vllm.v1.engine.async_llm import AsyncLLM except ImportError: raise ImportError("Please install vllm to use the vllm-async-engine backend.") + device = get_device() + if device.startswith("musa"): + import torch + if torch.musa.is_available(): + compilation_config = { + "cudagraph_capture_sizes": [1,2,3,4,5,6,7,8,10,12,14,16,18,20,24,28,30], + "simple_cuda_graph": True + } + block_size = 32 + kwargs["compilation_config"] = compilation_config + kwargs["block_size"] = block_size if "compilation_config" in kwargs: if isinstance(kwargs["compilation_config"], str): @@ -138,11 +155,6 @@ class ModelSingleton: logger.warning( f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}") del kwargs["compilation_config"] - else: - compilation_config = set_compilation_config() - if compilation_config: - kwargs["compilation_config"] = compilation_config - if "gpu_memory_utilization" not in kwargs: kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization() if "model" not in kwargs: