mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
feat: add support for MUSA devices in Unimernet model initialization
This commit is contained in:
@@ -23,12 +23,12 @@ class MathDataset(Dataset):
|
||||
class UnimernetModel(object):
|
||||
def __init__(self, weight_dir, _device_="cpu"):
|
||||
from .unimernet_hf import UnimernetModel
|
||||
if _device_.startswith("mps") or _device_.startswith("npu"):
|
||||
if _device_.startswith("mps") or _device_.startswith("npu") or _device_.startswith("musa"):
|
||||
self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
|
||||
else:
|
||||
self.model = UnimernetModel.from_pretrained(weight_dir)
|
||||
self.device = _device_
|
||||
self.model.to(_device_)
|
||||
self.device = torch.device(_device_)
|
||||
self.model.to(self.device)
|
||||
if not _device_.startswith("cpu"):
|
||||
self.model = self.model.to(dtype=torch.float16)
|
||||
self.model.eval()
|
||||
|
||||
Reference in New Issue
Block a user