feat: add support for MUSA devices in Unimernet model initialization

This commit is contained in:
myhloli
2026-01-22 16:14:52 +08:00
parent 6a75b39940
commit ffecb89e33

View File

@@ -23,12 +23,12 @@ class MathDataset(Dataset):
class UnimernetModel(object):
def __init__(self, weight_dir, _device_="cpu"):
from .unimernet_hf import UnimernetModel
if _device_.startswith("mps") or _device_.startswith("npu"):
if _device_.startswith("mps") or _device_.startswith("npu") or _device_.startswith("musa"):
self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
else:
self.model = UnimernetModel.from_pretrained(weight_dir)
self.device = _device_
self.model.to(_device_)
self.device = torch.device(_device_)
self.model.to(self.device)
if not _device_.startswith("cpu"):
self.model = self.model.to(dtype=torch.float16)
self.model.eval()