diff --git a/mineru/backend/vlm/utils.py b/mineru/backend/vlm/utils.py index a5aac6e3..162adc67 100644 --- a/mineru/backend/vlm/utils.py +++ b/mineru/backend/vlm/utils.py @@ -100,4 +100,19 @@ def set_default_batch_size() -> int: except Exception as e: logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1') batch_size = 1 - return batch_size \ No newline at end of file + return batch_size + + +def set_compilation_config() -> dict: + device = get_device() + compilation_config = {} + if str(device).startswith('musa'): + try: + import torch + if torch.musa.is_available(): + compilation_config = { + "simple_cuda_graph": true + } + except Exception as e: + pass + return compilation_config \ No newline at end of file diff --git a/mineru/backend/vlm/vlm_analyze.py b/mineru/backend/vlm/vlm_analyze.py index ca7373d4..2a9e6eb5 100644 --- a/mineru/backend/vlm/vlm_analyze.py +++ b/mineru/backend/vlm/vlm_analyze.py @@ -99,6 +99,20 @@ class ModelSingleton: import vllm except ImportError: raise ImportError("Please install vllm to use the vllm-engine backend.") + + if "compilation_config" in kwargs: + if isinstance(kwargs["compilation_config"], str): + try: + kwargs["compilation_config"] = json.loads(kwargs["compilation_config"]) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}") + del kwargs["compilation_config"] + else: + compilation_config = set_compilation_config() + if compilation_config: + kwargs["compilation_config"] = compilation_config + if "gpu_memory_utilization" not in kwargs: kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization() if "model" not in kwargs: @@ -114,6 +128,20 @@ class ModelSingleton: from vllm.v1.engine.async_llm import AsyncLLM except ImportError: raise ImportError("Please install vllm to use the vllm-async-engine backend.") + + if "compilation_config" in kwargs: + if isinstance(kwargs["compilation_config"], str): + try: + kwargs["compilation_config"] = json.loads(kwargs["compilation_config"]) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}") + del kwargs["compilation_config"] + else: + compilation_config = set_compilation_config() + if compilation_config: + kwargs["compilation_config"] = compilation_config + if "gpu_memory_utilization" not in kwargs: kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization() if "model" not in kwargs: