mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
Merge pull request #1788 from opendatalab/dev
refactor(magic_pdf): remove bfloat16 support checks and usage
This commit is contained in:
@@ -338,24 +338,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
|
||||
|
||||
def model_init(model_name: str):
|
||||
from transformers import LayoutLMv3ForTokenClassification
|
||||
device = get_device()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
if torch.cuda.is_bf16_supported():
|
||||
supports_bfloat16 = True
|
||||
else:
|
||||
supports_bfloat16 = False
|
||||
elif str(device).startswith("npu"):
|
||||
import torch_npu
|
||||
if torch_npu.npu.is_available():
|
||||
device = torch.device('npu')
|
||||
supports_bfloat16 = False
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
supports_bfloat16 = False
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
supports_bfloat16 = False
|
||||
device = torch.device(get_device())
|
||||
|
||||
if model_name == 'layoutreader':
|
||||
# 检测modelscope的缓存目录是否存在
|
||||
@@ -371,9 +354,6 @@ def model_init(model_name: str):
|
||||
model = LayoutLMv3ForTokenClassification.from_pretrained(
|
||||
'hantian/layoutreader'
|
||||
)
|
||||
# 检查设备是否支持 bfloat16
|
||||
if supports_bfloat16:
|
||||
model.bfloat16()
|
||||
model.to(device).eval()
|
||||
else:
|
||||
logger.error('model name not allow')
|
||||
|
||||
Reference in New Issue
Block a user