diff --git a/mineru/model/vlm/vllm_server.py b/mineru/model/vlm/vllm_server.py index 4ba82110..68952525 100644 --- a/mineru/model/vlm/vllm_server.py +++ b/mineru/model/vlm/vllm_server.py @@ -2,6 +2,7 @@ import os import sys from mineru.backend.vlm.utils import set_default_gpu_memory_utilization, enable_custom_logits_processors +from mineru.utils.config_reader import get_device from mineru.utils.models_download_utils import auto_download_and_get_model_root_path from vllm.entrypoints.cli.main import main as vllm_main @@ -13,6 +14,8 @@ def main(): has_port_arg = False has_gpu_memory_utilization_arg = False has_logits_processors_arg = False + has_block_size_arg = False + has_compilation_config = False model_path = None model_arg_indices = [] @@ -24,6 +27,10 @@ def main(): has_gpu_memory_utilization_arg = True if arg == "--logits-processors" or arg.startswith("--logits-processors="): has_logits_processors_arg = True + if arg == "--block-size" or arg.startswith("--block-size="): + has_block_size_arg = True + if arg == "--compilation-config" or arg.startswith("--compilation-config="): + has_compilation_config = True if arg == "--model": if i + 1 < len(args): model_path = args[i + 1] @@ -49,6 +56,14 @@ def main(): model_path = auto_download_and_get_model_root_path("/", "vlm") if (not has_logits_processors_arg) and custom_logits_processors: args.extend(["--logits-processors", "mineru_vl_utils:MinerULogitsProcessor"]) + device = get_device() + if device.startswith("musa"): + import torch + if torch.musa.is_available(): + if not has_block_size_arg: + args.extend(["--block-size", "32"]) + if not has_compilation_config: + args.extend(["--compilation-config", '{"cudagraph_capture_sizes": [1,2,3,4,5,6,7,8,10,12,14,16,18,20,24,28,30], "simple_cuda_graph": true}']) # 重构参数,将模型路径作为位置参数 sys.argv = [sys.argv[0]] + ["serve", model_path] + args