mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
# Copyright (c) Opendatalab. All rights reserved.
|
|
import json
|
|
import os
|
|
|
|
import torch
|
|
from loguru import logger
|
|
|
|
# 定义配置文件名常量
|
|
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
|
|
|
|
|
|
def read_config():
|
|
if os.path.isabs(CONFIG_FILE_NAME):
|
|
config_file = CONFIG_FILE_NAME
|
|
else:
|
|
home_dir = os.path.expanduser('~')
|
|
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
|
|
|
|
if not os.path.exists(config_file):
|
|
logger.warning(f'{config_file} not found, using default configuration')
|
|
return None
|
|
else:
|
|
with open(config_file, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
return config
|
|
|
|
|
|
def get_s3_config(bucket_name: str):
|
|
"""~/magic-pdf.json 读出来."""
|
|
config = read_config()
|
|
|
|
bucket_info = config.get('bucket_info')
|
|
if bucket_name not in bucket_info:
|
|
access_key, secret_key, storage_endpoint = bucket_info['[default]']
|
|
else:
|
|
access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
|
|
|
|
if access_key is None or secret_key is None or storage_endpoint is None:
|
|
raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
|
|
|
|
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
|
|
|
|
return access_key, secret_key, storage_endpoint
|
|
|
|
|
|
def get_s3_config_dict(path: str):
|
|
access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
|
|
return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
|
|
|
|
|
|
def get_bucket_name(path):
|
|
bucket, key = parse_bucket_key(path)
|
|
return bucket
|
|
|
|
|
|
def parse_bucket_key(s3_full_path: str):
|
|
"""
|
|
输入 s3://bucket/path/to/my/file.txt
|
|
输出 bucket, path/to/my/file.txt
|
|
"""
|
|
s3_full_path = s3_full_path.strip()
|
|
if s3_full_path.startswith("s3://"):
|
|
s3_full_path = s3_full_path[5:]
|
|
if s3_full_path.startswith("/"):
|
|
s3_full_path = s3_full_path[1:]
|
|
bucket, key = s3_full_path.split("/", 1)
|
|
return bucket, key
|
|
|
|
|
|
def get_device():
|
|
device_mode = os.getenv('MINERU_DEVICE_MODE', None)
|
|
if device_mode is not None:
|
|
return device_mode
|
|
else:
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
if torch.backends.mps.is_available():
|
|
return "mps"
|
|
return "cpu"
|
|
|
|
|
|
def get_table_recog_config():
|
|
table_enable = os.getenv('MINERU_TABLE_ENABLE', None)
|
|
if table_enable is not None:
|
|
return json.loads(f'{{"enable": {table_enable}}}')
|
|
else:
|
|
logger.warning(f"not found 'MINERU_TABLE_ENABLE' in environment variable, use 'true' as default.")
|
|
return json.loads(f'{{"enable": true}}')
|
|
|
|
|
|
def get_formula_config():
|
|
formula_enable = os.getenv('MINERU_FORMULA_ENABLE', None)
|
|
if formula_enable is not None:
|
|
return json.loads(f'{{"enable": {formula_enable}}}')
|
|
else:
|
|
logger.warning(f"not found 'MINERU_FORMULA_ENABLE' in environment variable, use 'true' as default.")
|
|
return json.loads(f'{{"enable": true}}')
|
|
|
|
|
|
def get_latex_delimiter_config():
|
|
config = read_config()
|
|
latex_delimiter_config = config.get('latex-delimiter-config')
|
|
if latex_delimiter_config is None:
|
|
logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
|
|
return None
|
|
else:
|
|
return latex_delimiter_config
|
|
|
|
|
|
def get_llm_aided_config():
|
|
config = read_config()
|
|
llm_aided_config = config.get('llm-aided-config')
|
|
if llm_aided_config is None:
|
|
logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
|
|
return None
|
|
else:
|
|
return llm_aided_config
|
|
|
|
|
|
def get_local_models_dir():
|
|
config = read_config()
|
|
models_dir = config.get('models-dir')
|
|
if models_dir is None:
|
|
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use None as default")
|
|
return models_dir |