mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 19:18:34 +07:00
188 lines
5.8 KiB
Python
188 lines
5.8 KiB
Python
# Copyright (c) Opendatalab. All rights reserved.
|
|
import json
|
|
import os
|
|
from loguru import logger
|
|
|
|
try:
|
|
import torch
|
|
import torch_npu
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
# 定义配置文件名常量
|
|
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
|
|
|
|
|
|
def read_config():
|
|
if os.path.isabs(CONFIG_FILE_NAME):
|
|
config_file = CONFIG_FILE_NAME
|
|
else:
|
|
home_dir = os.path.expanduser('~')
|
|
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
|
|
|
|
if not os.path.exists(config_file):
|
|
# logger.warning(f'{config_file} not found, using default configuration')
|
|
return None
|
|
else:
|
|
with open(config_file, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
return config
|
|
|
|
|
|
def get_s3_config(bucket_name: str):
|
|
"""~/magic-pdf.json 读出来."""
|
|
config = read_config()
|
|
|
|
bucket_info = config.get('bucket_info')
|
|
if bucket_name not in bucket_info:
|
|
access_key, secret_key, storage_endpoint = bucket_info['[default]']
|
|
else:
|
|
access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
|
|
|
|
if access_key is None or secret_key is None or storage_endpoint is None:
|
|
raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
|
|
|
|
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
|
|
|
|
return access_key, secret_key, storage_endpoint
|
|
|
|
|
|
def get_s3_config_dict(path: str):
|
|
access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
|
|
return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
|
|
|
|
|
|
def get_bucket_name(path):
|
|
bucket, key = parse_bucket_key(path)
|
|
return bucket
|
|
|
|
|
|
def parse_bucket_key(s3_full_path: str):
|
|
"""
|
|
输入 s3://bucket/path/to/my/file.txt
|
|
输出 bucket, path/to/my/file.txt
|
|
"""
|
|
s3_full_path = s3_full_path.strip()
|
|
if s3_full_path.startswith("s3://"):
|
|
s3_full_path = s3_full_path[5:]
|
|
if s3_full_path.startswith("/"):
|
|
s3_full_path = s3_full_path[1:]
|
|
bucket, key = s3_full_path.split("/", 1)
|
|
return bucket, key
|
|
|
|
|
|
def get_device():
|
|
device_mode = os.getenv('MINERU_DEVICE_MODE', None)
|
|
if device_mode is not None:
|
|
return device_mode
|
|
else:
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
elif torch.backends.mps.is_available():
|
|
return "mps"
|
|
else:
|
|
try:
|
|
if torch_npu.npu.is_available():
|
|
return "npu"
|
|
except Exception as e:
|
|
try:
|
|
if torch.gcu.is_available():
|
|
return "gcu"
|
|
except Exception as e:
|
|
try:
|
|
if torch.musa.is_available():
|
|
return "musa"
|
|
except Exception as e:
|
|
try:
|
|
if torch.mlu.is_available():
|
|
return "mlu"
|
|
except Exception as e:
|
|
try:
|
|
if torch.sdaa.is_available():
|
|
return "sdaa"
|
|
except Exception as e:
|
|
pass
|
|
|
|
return "cpu"
|
|
|
|
|
|
def get_formula_enable(formula_enable):
|
|
formula_enable_env = os.getenv('MINERU_FORMULA_ENABLE')
|
|
formula_enable = formula_enable if formula_enable_env is None else formula_enable_env.lower() == 'true'
|
|
return formula_enable
|
|
|
|
|
|
def get_table_enable(table_enable):
|
|
table_enable_env = os.getenv('MINERU_TABLE_ENABLE')
|
|
table_enable = table_enable if table_enable_env is None else table_enable_env.lower() == 'true'
|
|
return table_enable
|
|
|
|
|
|
def get_ocr_det_mask_inline_formula_enable(enable):
|
|
enable_env = os.getenv('MINERU_OCR_DET_MASK_INLINE_FORMULA_ENABLE')
|
|
enable = enable if enable_env is None else enable_env.lower() == 'true'
|
|
return enable
|
|
|
|
|
|
def get_processing_window_size(default: int = 64) -> int:
|
|
value = os.getenv('MINERU_PROCESSING_WINDOW_SIZE')
|
|
if value is None:
|
|
return default
|
|
try:
|
|
window_size = int(value)
|
|
except ValueError:
|
|
logger.warning(
|
|
f"Invalid MINERU_PROCESSING_WINDOW_SIZE value: {value}, use default {default}"
|
|
)
|
|
return default
|
|
return max(1, window_size)
|
|
|
|
|
|
def get_max_concurrent_requests(default: int = 3) -> int:
|
|
value = os.getenv('MINERU_API_MAX_CONCURRENT_REQUESTS')
|
|
if value is None:
|
|
return default
|
|
try:
|
|
max_concurrent_requests = int(value)
|
|
except ValueError:
|
|
logger.warning(
|
|
f"Invalid MINERU_API_MAX_CONCURRENT_REQUESTS value: {value}, use default {default}"
|
|
)
|
|
return default
|
|
return max(0, max_concurrent_requests)
|
|
|
|
|
|
def get_latex_delimiter_config():
|
|
config = read_config()
|
|
if config is None:
|
|
return None
|
|
latex_delimiter_config = config.get('latex-delimiter-config', None)
|
|
if latex_delimiter_config is None:
|
|
# logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
|
|
return None
|
|
else:
|
|
return latex_delimiter_config
|
|
|
|
|
|
def get_llm_aided_config():
|
|
config = read_config()
|
|
if config is None:
|
|
return None
|
|
llm_aided_config = config.get('llm-aided-config', None)
|
|
if llm_aided_config is None:
|
|
# logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
|
|
return None
|
|
else:
|
|
return llm_aided_config
|
|
|
|
|
|
def get_local_models_dir():
|
|
config = read_config()
|
|
if config is None:
|
|
return None
|
|
models_dir = config.get('models-dir')
|
|
if models_dir is None:
|
|
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use None as default")
|
|
return models_dir
|