diff --git a/mineru/model/mfr/unimernet/Unimernet.py b/mineru/model/mfr/unimernet/Unimernet.py index 5dc5fd5e..5f8b9513 100644 --- a/mineru/model/mfr/unimernet/Unimernet.py +++ b/mineru/model/mfr/unimernet/Unimernet.py @@ -23,12 +23,12 @@ class MathDataset(Dataset): class UnimernetModel(object): def __init__(self, weight_dir, _device_="cpu"): from .unimernet_hf import UnimernetModel - if _device_.startswith("mps") or _device_.startswith("npu"): + if _device_.startswith("mps") or _device_.startswith("npu") or _device_.startswith("musa"): self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager") else: self.model = UnimernetModel.from_pretrained(weight_dir) - self.device = _device_ - self.model.to(_device_) + self.device = torch.device(_device_) + self.model.to(self.device) if not _device_.startswith("cpu"): self.model = self.model.to(dtype=torch.float16) self.model.eval()