diff --git a/mineru/backend/vlm/utils.py b/mineru/backend/vlm/utils.py index 1be77c0d..a5aac6e3 100644 --- a/mineru/backend/vlm/utils.py +++ b/mineru/backend/vlm/utils.py @@ -20,6 +20,8 @@ def enable_custom_logits_processors() -> bool: compute_capability = "8.0" elif hasattr(torch, 'gcu') and torch.gcu.is_available(): compute_capability = "8.0" + elif hasattr(torch, 'musa') and torch.musa.is_available(): + compute_capability = "8.0" else: logger.info("CUDA not available, disabling custom_logits_processors") return False diff --git a/mineru/utils/block_sort.py b/mineru/utils/block_sort.py index b419f265..ab6e5f62 100644 --- a/mineru/utils/block_sort.py +++ b/mineru/utils/block_sort.py @@ -189,6 +189,12 @@ def model_init(model_name: str): elif device_name.startswith("gcu"): if torch.gcu.is_bf16_supported(): bf_16_support = True + elif device_name.startswith("musa"): + if torch.musa.is_bf16_supported(): + bf_16_support = True + elif device_name.startswith("npu"): + if torch.npu.is_bf16_supported(): + bf_16_support = True if model_name == 'layoutreader': # 检测modelscope的缓存目录是否存在 diff --git a/mineru/utils/config_reader.py b/mineru/utils/config_reader.py index 7f73dc1a..22f975db 100644 --- a/mineru/utils/config_reader.py +++ b/mineru/utils/config_reader.py @@ -90,7 +90,11 @@ def get_device(): if torch.gcu.is_available(): return "gcu" except Exception as e: - pass + try: + if torch.musa.is_available(): + return "musa" + except Exception as e: + pass return "cpu" diff --git a/mineru/utils/model_utils.py b/mineru/utils/model_utils.py index 63c67d97..33b14fe6 100644 --- a/mineru/utils/model_utils.py +++ b/mineru/utils/model_utils.py @@ -426,6 +426,9 @@ def clean_memory(device='cuda'): elif str(device).startswith("gcu"): if torch.gcu.is_available(): torch.gcu.empty_cache() + elif str(device).startswith("musa"): + if torch.musa.is_available(): + torch.musa.empty_cache() gc.collect() @@ -464,5 +467,8 @@ def get_vram(device) -> int: elif str(device).startswith("gcu"): if torch.gcu.is_available(): total_memory = round(torch.gcu.get_device_properties(device).total_memory / (1024 ** 3)) # 转为 GB + elif str(device).startswith("musa"): + if torch.musa.is_available(): + total_memory = round(torch.musa.get_device_properties(device).total_memory / (1024 ** 3)) # 转为 GB return total_memory