diff --git a/mineru/model/mfd/yolo_v8.py b/mineru/model/mfd/yolo_v8.py index 15ff7e8a..d5c34b8f 100644 --- a/mineru/model/mfd/yolo_v8.py +++ b/mineru/model/mfd/yolo_v8.py @@ -1,5 +1,7 @@ import os from typing import List, Union + +import torch from tqdm import tqdm from ultralytics import YOLO import numpy as np @@ -18,8 +20,8 @@ class YOLOv8MFDModel: conf: float = 0.25, iou: float = 0.45, ): - self.model = YOLO(weight).to(device) - self.device = device + self.device = torch.device(device) + self.model = YOLO(weight).to(self.device) self.imgsz = imgsz self.conf = conf self.iou = iou