mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
feat: implement thread safety for PDF processing and model initialization
This commit is contained in:
@@ -18,7 +18,12 @@ from mineru.backend.hybrid.hybrid_model_output_to_middle_json import (
|
||||
result_to_middle_json,
|
||||
)
|
||||
from mineru.backend.pipeline.model_init import HybridModelSingleton
|
||||
from mineru.backend.vlm.vlm_analyze import ModelSingleton
|
||||
from mineru.backend.vlm.vlm_analyze import (
|
||||
ModelSingleton,
|
||||
aio_predictor_execution_guard,
|
||||
predictor_execution_guard,
|
||||
_maybe_enable_serial_execution,
|
||||
)
|
||||
from mineru.data.data_reader_writer import DataWriter
|
||||
from mineru.utils.config_reader import get_device, get_low_memory_window_size
|
||||
from mineru.utils.enum_class import ImageType, NotExtractType
|
||||
@@ -538,6 +543,7 @@ def doc_analyze(
|
||||
# 初始化预测器
|
||||
if predictor is None:
|
||||
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
||||
predictor = _maybe_enable_serial_execution(predictor, backend)
|
||||
|
||||
# 加载图像
|
||||
load_images_start = time.time()
|
||||
@@ -556,14 +562,16 @@ def doc_analyze(
|
||||
infer_start = time.time()
|
||||
# VLM提取
|
||||
if _vlm_ocr_enable:
|
||||
model_list = predictor.batch_two_step_extract(images=images_pil_list)
|
||||
with predictor_execution_guard(predictor):
|
||||
model_list = predictor.batch_two_step_extract(images=images_pil_list)
|
||||
hybrid_pipeline_model = None
|
||||
else:
|
||||
batch_ratio = get_batch_ratio(device)
|
||||
model_list = predictor.batch_two_step_extract(
|
||||
images=images_pil_list,
|
||||
not_extract_list=not_extract_list
|
||||
)
|
||||
with predictor_execution_guard(predictor):
|
||||
model_list = predictor.batch_two_step_extract(
|
||||
images=images_pil_list,
|
||||
not_extract_list=not_extract_list
|
||||
)
|
||||
model_list, hybrid_pipeline_model = _process_ocr_and_formulas(
|
||||
images_pil_list,
|
||||
model_list,
|
||||
@@ -604,6 +612,7 @@ def doc_analyze_low_memory(
|
||||
):
|
||||
if predictor is None:
|
||||
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
||||
predictor = _maybe_enable_serial_execution(predictor, backend)
|
||||
|
||||
device = get_device()
|
||||
_ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method)
|
||||
@@ -648,12 +657,14 @@ def doc_analyze_low_memory(
|
||||
f'({len(images_pil_list)} pages)'
|
||||
)
|
||||
if _vlm_ocr_enable:
|
||||
window_model_list = predictor.batch_two_step_extract(images=images_pil_list)
|
||||
with predictor_execution_guard(predictor):
|
||||
window_model_list = predictor.batch_two_step_extract(images=images_pil_list)
|
||||
else:
|
||||
window_model_list = predictor.batch_two_step_extract(
|
||||
images=images_pil_list,
|
||||
not_extract_list=not_extract_list
|
||||
)
|
||||
with predictor_execution_guard(predictor):
|
||||
window_model_list = predictor.batch_two_step_extract(
|
||||
images=images_pil_list,
|
||||
not_extract_list=not_extract_list
|
||||
)
|
||||
window_model_list, hybrid_pipeline_model = _process_ocr_and_formulas(
|
||||
images_pil_list,
|
||||
window_model_list,
|
||||
@@ -715,6 +726,7 @@ async def aio_doc_analyze(
|
||||
# 初始化预测器
|
||||
if predictor is None:
|
||||
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
||||
predictor = _maybe_enable_serial_execution(predictor, backend)
|
||||
|
||||
# 加载图像
|
||||
load_images_start = time.time()
|
||||
@@ -733,14 +745,16 @@ async def aio_doc_analyze(
|
||||
infer_start = time.time()
|
||||
# VLM提取
|
||||
if _vlm_ocr_enable:
|
||||
model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list)
|
||||
async with aio_predictor_execution_guard(predictor):
|
||||
model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list)
|
||||
hybrid_pipeline_model = None
|
||||
else:
|
||||
batch_ratio = get_batch_ratio(device)
|
||||
model_list = await predictor.aio_batch_two_step_extract(
|
||||
images=images_pil_list,
|
||||
not_extract_list=not_extract_list
|
||||
)
|
||||
async with aio_predictor_execution_guard(predictor):
|
||||
model_list = await predictor.aio_batch_two_step_extract(
|
||||
images=images_pil_list,
|
||||
not_extract_list=not_extract_list
|
||||
)
|
||||
model_list, hybrid_pipeline_model = _process_ocr_and_formulas(
|
||||
images_pil_list,
|
||||
model_list,
|
||||
@@ -781,6 +795,7 @@ async def aio_doc_analyze_low_memory(
|
||||
):
|
||||
if predictor is None:
|
||||
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
||||
predictor = _maybe_enable_serial_execution(predictor, backend)
|
||||
|
||||
device = get_device()
|
||||
_ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method)
|
||||
@@ -825,12 +840,14 @@ async def aio_doc_analyze_low_memory(
|
||||
f'({len(images_pil_list)} pages)'
|
||||
)
|
||||
if _vlm_ocr_enable:
|
||||
window_model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list)
|
||||
async with aio_predictor_execution_guard(predictor):
|
||||
window_model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list)
|
||||
else:
|
||||
window_model_list = await predictor.aio_batch_two_step_extract(
|
||||
images=images_pil_list,
|
||||
not_extract_list=not_extract_list
|
||||
)
|
||||
async with aio_predictor_execution_guard(predictor):
|
||||
window_model_list = await predictor.aio_batch_two_step_extract(
|
||||
images=images_pil_list,
|
||||
not_extract_list=not_extract_list
|
||||
)
|
||||
window_model_list, hybrid_pipeline_model = _process_ocr_and_formulas(
|
||||
images_pil_list,
|
||||
window_model_list,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import threading
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
@@ -112,10 +113,12 @@ def ocr_model_init(det_db_box_thresh=0.3,
|
||||
class AtomModelSingleton:
|
||||
_instance = None
|
||||
_models = {}
|
||||
_lock = threading.RLock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def get_atom_model(self, atom_model_name: str, **kwargs):
|
||||
@@ -145,8 +148,9 @@ class AtomModelSingleton:
|
||||
else:
|
||||
key = atom_model_name
|
||||
|
||||
if key not in self._models:
|
||||
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
||||
with self._lock:
|
||||
if key not in self._models:
|
||||
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
||||
return self._models[key]
|
||||
|
||||
def atom_model_init(model_name: str, **kwargs):
|
||||
@@ -258,10 +262,12 @@ class MineruPipelineModel:
|
||||
class HybridModelSingleton:
|
||||
_instance = None
|
||||
_models = {}
|
||||
_lock = threading.RLock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def get_model(
|
||||
@@ -270,11 +276,12 @@ class HybridModelSingleton:
|
||||
formula_enable=None,
|
||||
):
|
||||
key = (lang, formula_enable)
|
||||
if key not in self._models:
|
||||
self._models[key] = MineruHybridModel(
|
||||
lang=lang,
|
||||
formula_enable=formula_enable,
|
||||
)
|
||||
with self._lock:
|
||||
if key not in self._models:
|
||||
self._models[key] = MineruHybridModel(
|
||||
lang=lang,
|
||||
formula_enable=formula_enable,
|
||||
)
|
||||
return self._models[key]
|
||||
|
||||
def ocr_det_batch_setting():
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
@@ -22,10 +23,12 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
||||
class ModelSingleton:
|
||||
_instance = None
|
||||
_models = {}
|
||||
_lock = threading.RLock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def get_model(
|
||||
@@ -35,12 +38,13 @@ class ModelSingleton:
|
||||
table_enable=None,
|
||||
):
|
||||
key = (lang, formula_enable, table_enable)
|
||||
if key not in self._models:
|
||||
self._models[key] = custom_model_init(
|
||||
lang=lang,
|
||||
formula_enable=formula_enable,
|
||||
table_enable=table_enable,
|
||||
)
|
||||
with self._lock:
|
||||
if key not in self._models:
|
||||
self._models[key] = custom_model_init(
|
||||
lang=lang,
|
||||
formula_enable=formula_enable,
|
||||
table_enable=table_enable,
|
||||
)
|
||||
return self._models[key]
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# Copyright (c) Opendatalab. All rights reserved.
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import threading
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
|
||||
import pypdfium2 as pdfium
|
||||
from loguru import logger
|
||||
@@ -25,10 +28,12 @@ from packaging import version
|
||||
class ModelSingleton:
|
||||
_instance = None
|
||||
_models = {}
|
||||
_lock = threading.RLock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def get_model(
|
||||
@@ -39,188 +44,232 @@ class ModelSingleton:
|
||||
**kwargs,
|
||||
) -> MinerUClient:
|
||||
key = (backend, model_path, server_url)
|
||||
if key not in self._models:
|
||||
start_time = time.time()
|
||||
model = None
|
||||
processor = None
|
||||
vllm_llm = None
|
||||
lmdeploy_engine = None
|
||||
vllm_async_llm = None
|
||||
batch_size = kwargs.get("batch_size", 0) # for transformers backend only
|
||||
max_concurrency = kwargs.get("max_concurrency", 100) # for http-client backend only
|
||||
http_timeout = kwargs.get("http_timeout", 600) # for http-client backend only
|
||||
server_headers = kwargs.get("server_headers", None) # for http-client backend only
|
||||
max_retries = kwargs.get("max_retries", 3) # for http-client backend only
|
||||
retry_backoff_factor = kwargs.get("retry_backoff_factor", 0.5) # for http-client backend only
|
||||
# 从kwargs中移除这些参数,避免传递给不相关的初始化函数
|
||||
for param in ["batch_size", "max_concurrency", "http_timeout", "server_headers", "max_retries", "retry_backoff_factor"]:
|
||||
if param in kwargs:
|
||||
del kwargs[param]
|
||||
if backend not in ["http-client"] and not model_path:
|
||||
model_path = auto_download_and_get_model_root_path("/","vlm")
|
||||
if backend == "transformers":
|
||||
try:
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
from transformers import __version__ as transformers_version
|
||||
except ImportError:
|
||||
raise ImportError("Please install transformers to use the transformers backend.")
|
||||
|
||||
if version.parse(transformers_version) >= version.parse("4.56.0"):
|
||||
dtype_key = "dtype"
|
||||
else:
|
||||
dtype_key = "torch_dtype"
|
||||
device = get_device()
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map={"": device},
|
||||
**{dtype_key: "auto"}, # type: ignore
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_path,
|
||||
use_fast=True,
|
||||
)
|
||||
if batch_size == 0:
|
||||
batch_size = set_default_batch_size()
|
||||
elif backend == "mlx-engine":
|
||||
mlx_supported = is_mac_os_version_supported()
|
||||
if not mlx_supported:
|
||||
raise EnvironmentError("mlx-engine backend is only supported on macOS 13.5+ with Apple Silicon.")
|
||||
try:
|
||||
from mlx_vlm import load as mlx_load
|
||||
except ImportError:
|
||||
raise ImportError("Please install mlx-vlm to use the mlx-engine backend.")
|
||||
model, processor = mlx_load(model_path)
|
||||
else:
|
||||
if os.getenv('OMP_NUM_THREADS') is None:
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
|
||||
if backend == "vllm-engine":
|
||||
with self._lock:
|
||||
if key not in self._models:
|
||||
start_time = time.time()
|
||||
model = None
|
||||
processor = None
|
||||
vllm_llm = None
|
||||
lmdeploy_engine = None
|
||||
vllm_async_llm = None
|
||||
batch_size = kwargs.get("batch_size", 0) # for transformers backend only
|
||||
max_concurrency = kwargs.get("max_concurrency", 100) # for http-client backend only
|
||||
http_timeout = kwargs.get("http_timeout", 600) # for http-client backend only
|
||||
server_headers = kwargs.get("server_headers", None) # for http-client backend only
|
||||
max_retries = kwargs.get("max_retries", 3) # for http-client backend only
|
||||
retry_backoff_factor = kwargs.get("retry_backoff_factor", 0.5) # for http-client backend only
|
||||
# 从kwargs中移除这些参数,避免传递给不相关的初始化函数
|
||||
for param in ["batch_size", "max_concurrency", "http_timeout", "server_headers", "max_retries", "retry_backoff_factor"]:
|
||||
if param in kwargs:
|
||||
del kwargs[param]
|
||||
if backend not in ["http-client"] and not model_path:
|
||||
model_path = auto_download_and_get_model_root_path("/","vlm")
|
||||
if backend == "transformers":
|
||||
try:
|
||||
import vllm
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
from transformers import __version__ as transformers_version
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm to use the vllm-engine backend.")
|
||||
raise ImportError("Please install transformers to use the transformers backend.")
|
||||
|
||||
kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="sync_engine")
|
||||
|
||||
if "compilation_config" in kwargs:
|
||||
if isinstance(kwargs["compilation_config"], str):
|
||||
try:
|
||||
kwargs["compilation_config"] = json.loads(kwargs["compilation_config"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
|
||||
del kwargs["compilation_config"]
|
||||
if "gpu_memory_utilization" not in kwargs:
|
||||
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
|
||||
if "model" not in kwargs:
|
||||
kwargs["model"] = model_path
|
||||
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
|
||||
from mineru_vl_utils import MinerULogitsProcessor
|
||||
kwargs["logits_processors"] = [MinerULogitsProcessor]
|
||||
# 使用kwargs为 vllm初始化参数
|
||||
vllm_llm = vllm.LLM(**kwargs)
|
||||
elif backend == "vllm-async-engine":
|
||||
try:
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.config import CompilationConfig
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
|
||||
|
||||
kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="async_engine")
|
||||
|
||||
if "compilation_config" in kwargs:
|
||||
if isinstance(kwargs["compilation_config"], dict):
|
||||
# 如果是字典,转换为 CompilationConfig 对象
|
||||
kwargs["compilation_config"] = CompilationConfig(**kwargs["compilation_config"])
|
||||
elif isinstance(kwargs["compilation_config"], str):
|
||||
# 如果是 JSON 字符串,先解析再转换
|
||||
try:
|
||||
config_dict = json.loads(kwargs["compilation_config"])
|
||||
kwargs["compilation_config"] = CompilationConfig(**config_dict)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse compilation_config: {kwargs['compilation_config']}, error: {e}")
|
||||
del kwargs["compilation_config"]
|
||||
if "gpu_memory_utilization" not in kwargs:
|
||||
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
|
||||
if "model" not in kwargs:
|
||||
kwargs["model"] = model_path
|
||||
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
|
||||
from mineru_vl_utils import MinerULogitsProcessor
|
||||
kwargs["logits_processors"] = [MinerULogitsProcessor]
|
||||
# 使用kwargs为 vllm初始化参数
|
||||
vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
|
||||
elif backend == "lmdeploy-engine":
|
||||
try:
|
||||
from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig
|
||||
from lmdeploy.serve.vl_async_engine import VLAsyncEngine
|
||||
except ImportError:
|
||||
raise ImportError("Please install lmdeploy to use the lmdeploy-engine backend.")
|
||||
if "cache_max_entry_count" not in kwargs:
|
||||
kwargs["cache_max_entry_count"] = 0.5
|
||||
|
||||
device_type = os.getenv("MINERU_LMDEPLOY_DEVICE", "")
|
||||
if device_type == "":
|
||||
if "lmdeploy_device" in kwargs:
|
||||
device_type = kwargs.pop("lmdeploy_device")
|
||||
if device_type not in ["cuda", "ascend", "maca", "camb"]:
|
||||
raise ValueError(f"Unsupported lmdeploy device type: {device_type}")
|
||||
else:
|
||||
device_type = "cuda"
|
||||
lm_backend = os.getenv("MINERU_LMDEPLOY_BACKEND", "")
|
||||
if lm_backend == "":
|
||||
if "lmdeploy_backend" in kwargs:
|
||||
lm_backend = kwargs.pop("lmdeploy_backend")
|
||||
if lm_backend not in ["pytorch", "turbomind"]:
|
||||
raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
|
||||
else:
|
||||
lm_backend = set_lmdeploy_backend(device_type)
|
||||
logger.info(f"lmdeploy device is: {device_type}, lmdeploy backend is: {lm_backend}")
|
||||
|
||||
if lm_backend == "pytorch":
|
||||
kwargs["device_type"] = device_type
|
||||
backend_config = PytorchEngineConfig(**kwargs)
|
||||
elif lm_backend == "turbomind":
|
||||
backend_config = TurbomindEngineConfig(**kwargs)
|
||||
if version.parse(transformers_version) >= version.parse("4.56.0"):
|
||||
dtype_key = "dtype"
|
||||
else:
|
||||
raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
|
||||
|
||||
log_level = 'ERROR'
|
||||
from lmdeploy.utils import get_logger
|
||||
lm_logger = get_logger('lmdeploy')
|
||||
lm_logger.setLevel(log_level)
|
||||
if os.getenv('TM_LOG_LEVEL') is None:
|
||||
os.environ['TM_LOG_LEVEL'] = log_level
|
||||
|
||||
lmdeploy_engine = VLAsyncEngine(
|
||||
dtype_key = "torch_dtype"
|
||||
device = get_device()
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
backend=lm_backend,
|
||||
backend_config=backend_config,
|
||||
device_map={"": device},
|
||||
**{dtype_key: "auto"}, # type: ignore
|
||||
)
|
||||
self._models[key] = MinerUClient(
|
||||
backend=backend,
|
||||
model=model,
|
||||
processor=processor,
|
||||
lmdeploy_engine=lmdeploy_engine,
|
||||
vllm_llm=vllm_llm,
|
||||
vllm_async_llm=vllm_async_llm,
|
||||
server_url=server_url,
|
||||
batch_size=batch_size,
|
||||
max_concurrency=max_concurrency,
|
||||
http_timeout=http_timeout,
|
||||
server_headers=server_headers,
|
||||
max_retries=max_retries,
|
||||
retry_backoff_factor=retry_backoff_factor,
|
||||
)
|
||||
elapsed = round(time.time() - start_time, 2)
|
||||
logger.info(f"get {backend} predictor cost: {elapsed}s")
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_path,
|
||||
use_fast=True,
|
||||
)
|
||||
if batch_size == 0:
|
||||
batch_size = set_default_batch_size()
|
||||
elif backend == "mlx-engine":
|
||||
mlx_supported = is_mac_os_version_supported()
|
||||
if not mlx_supported:
|
||||
raise EnvironmentError("mlx-engine backend is only supported on macOS 13.5+ with Apple Silicon.")
|
||||
try:
|
||||
from mlx_vlm import load as mlx_load
|
||||
except ImportError:
|
||||
raise ImportError("Please install mlx-vlm to use the mlx-engine backend.")
|
||||
model, processor = mlx_load(model_path)
|
||||
else:
|
||||
if os.getenv('OMP_NUM_THREADS') is None:
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
|
||||
if backend == "vllm-engine":
|
||||
try:
|
||||
import vllm
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm to use the vllm-engine backend.")
|
||||
|
||||
kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="sync_engine")
|
||||
|
||||
if "compilation_config" in kwargs:
|
||||
if isinstance(kwargs["compilation_config"], str):
|
||||
try:
|
||||
kwargs["compilation_config"] = json.loads(kwargs["compilation_config"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
|
||||
del kwargs["compilation_config"]
|
||||
if "gpu_memory_utilization" not in kwargs:
|
||||
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
|
||||
if "model" not in kwargs:
|
||||
kwargs["model"] = model_path
|
||||
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
|
||||
from mineru_vl_utils import MinerULogitsProcessor
|
||||
kwargs["logits_processors"] = [MinerULogitsProcessor]
|
||||
# 使用kwargs为 vllm初始化参数
|
||||
vllm_llm = vllm.LLM(**kwargs)
|
||||
elif backend == "vllm-async-engine":
|
||||
try:
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.config import CompilationConfig
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
|
||||
|
||||
kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="async_engine")
|
||||
|
||||
if "compilation_config" in kwargs:
|
||||
if isinstance(kwargs["compilation_config"], dict):
|
||||
# 如果是字典,转换为 CompilationConfig 对象
|
||||
kwargs["compilation_config"] = CompilationConfig(**kwargs["compilation_config"])
|
||||
elif isinstance(kwargs["compilation_config"], str):
|
||||
# 如果是 JSON 字符串,先解析再转换
|
||||
try:
|
||||
config_dict = json.loads(kwargs["compilation_config"])
|
||||
kwargs["compilation_config"] = CompilationConfig(**config_dict)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse compilation_config: {kwargs['compilation_config']}, error: {e}")
|
||||
del kwargs["compilation_config"]
|
||||
if "gpu_memory_utilization" not in kwargs:
|
||||
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
|
||||
if "model" not in kwargs:
|
||||
kwargs["model"] = model_path
|
||||
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
|
||||
from mineru_vl_utils import MinerULogitsProcessor
|
||||
kwargs["logits_processors"] = [MinerULogitsProcessor]
|
||||
# 使用kwargs为 vllm初始化参数
|
||||
vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
|
||||
elif backend == "lmdeploy-engine":
|
||||
try:
|
||||
from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig
|
||||
from lmdeploy.serve.vl_async_engine import VLAsyncEngine
|
||||
except ImportError:
|
||||
raise ImportError("Please install lmdeploy to use the lmdeploy-engine backend.")
|
||||
if "cache_max_entry_count" not in kwargs:
|
||||
kwargs["cache_max_entry_count"] = 0.5
|
||||
|
||||
device_type = os.getenv("MINERU_LMDEPLOY_DEVICE", "")
|
||||
if device_type == "":
|
||||
if "lmdeploy_device" in kwargs:
|
||||
device_type = kwargs.pop("lmdeploy_device")
|
||||
if device_type not in ["cuda", "ascend", "maca", "camb"]:
|
||||
raise ValueError(f"Unsupported lmdeploy device type: {device_type}")
|
||||
else:
|
||||
device_type = "cuda"
|
||||
lm_backend = os.getenv("MINERU_LMDEPLOY_BACKEND", "")
|
||||
if lm_backend == "":
|
||||
if "lmdeploy_backend" in kwargs:
|
||||
lm_backend = kwargs.pop("lmdeploy_backend")
|
||||
if lm_backend not in ["pytorch", "turbomind"]:
|
||||
raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
|
||||
else:
|
||||
lm_backend = set_lmdeploy_backend(device_type)
|
||||
logger.info(f"lmdeploy device is: {device_type}, lmdeploy backend is: {lm_backend}")
|
||||
|
||||
if lm_backend == "pytorch":
|
||||
kwargs["device_type"] = device_type
|
||||
backend_config = PytorchEngineConfig(**kwargs)
|
||||
elif lm_backend == "turbomind":
|
||||
backend_config = TurbomindEngineConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
|
||||
|
||||
log_level = 'ERROR'
|
||||
from lmdeploy.utils import get_logger
|
||||
lm_logger = get_logger('lmdeploy')
|
||||
lm_logger.setLevel(log_level)
|
||||
if os.getenv('TM_LOG_LEVEL') is None:
|
||||
os.environ['TM_LOG_LEVEL'] = log_level
|
||||
|
||||
lmdeploy_engine = VLAsyncEngine(
|
||||
model_path,
|
||||
backend=lm_backend,
|
||||
backend_config=backend_config,
|
||||
)
|
||||
predictor = MinerUClient(
|
||||
backend=backend,
|
||||
model=model,
|
||||
processor=processor,
|
||||
lmdeploy_engine=lmdeploy_engine,
|
||||
vllm_llm=vllm_llm,
|
||||
vllm_async_llm=vllm_async_llm,
|
||||
server_url=server_url,
|
||||
batch_size=batch_size,
|
||||
max_concurrency=max_concurrency,
|
||||
http_timeout=http_timeout,
|
||||
server_headers=server_headers,
|
||||
max_retries=max_retries,
|
||||
retry_backoff_factor=retry_backoff_factor,
|
||||
)
|
||||
_maybe_enable_serial_execution(predictor, backend)
|
||||
self._models[key] = predictor
|
||||
elapsed = round(time.time() - start_time, 2)
|
||||
logger.info(f"get {backend} predictor cost: {elapsed}s")
|
||||
return self._models[key]
|
||||
|
||||
|
||||
def _predictor_uses_mlx(predictor: MinerUClient, backend: str | None = None) -> bool:
|
||||
if backend == "mlx-engine":
|
||||
return True
|
||||
client = getattr(predictor, "client", None)
|
||||
return type(client).__module__.endswith(".mlx_client")
|
||||
|
||||
|
||||
def _maybe_enable_serial_execution(
|
||||
predictor: MinerUClient,
|
||||
backend: str | None = None,
|
||||
) -> MinerUClient:
|
||||
if _predictor_uses_mlx(predictor, backend) and not hasattr(
|
||||
predictor, "_mineru_execution_lock"
|
||||
):
|
||||
predictor._mineru_execution_lock = threading.Lock()
|
||||
return predictor
|
||||
|
||||
|
||||
@contextmanager
|
||||
def predictor_execution_guard(predictor: MinerUClient):
|
||||
lock = getattr(predictor, "_mineru_execution_lock", None)
|
||||
if lock is None:
|
||||
yield
|
||||
return
|
||||
with lock:
|
||||
yield
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def aio_predictor_execution_guard(predictor: MinerUClient):
|
||||
lock = getattr(predictor, "_mineru_execution_lock", None)
|
||||
if lock is None:
|
||||
yield
|
||||
return
|
||||
await asyncio.to_thread(lock.acquire)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
def _close_images(images_list):
|
||||
for image_dict in images_list or []:
|
||||
pil_img = image_dict.get("img_pil")
|
||||
@@ -242,6 +291,7 @@ def doc_analyze(
|
||||
):
|
||||
if predictor is None:
|
||||
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
||||
predictor = _maybe_enable_serial_execution(predictor, backend)
|
||||
|
||||
load_images_start = time.time()
|
||||
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
|
||||
@@ -250,7 +300,8 @@ def doc_analyze(
|
||||
logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s")
|
||||
|
||||
infer_start = time.time()
|
||||
results = predictor.batch_two_step_extract(images=images_pil_list)
|
||||
with predictor_execution_guard(predictor):
|
||||
results = predictor.batch_two_step_extract(images=images_pil_list)
|
||||
infer_time = round(time.time() - infer_start, 2)
|
||||
logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
|
||||
|
||||
@@ -269,6 +320,7 @@ def doc_analyze_low_memory(
|
||||
):
|
||||
if predictor is None:
|
||||
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
||||
predictor = _maybe_enable_serial_execution(predictor, backend)
|
||||
|
||||
pdf_doc = pdfium.PdfDocument(pdf_bytes)
|
||||
middle_json = init_middle_json()
|
||||
@@ -304,7 +356,8 @@ def doc_analyze_low_memory(
|
||||
f'pages {window_start + 1}-{window_end + 1}/{page_count} '
|
||||
f'({len(images_pil_list)} pages)'
|
||||
)
|
||||
window_results = predictor.batch_two_step_extract(images=images_pil_list)
|
||||
with predictor_execution_guard(predictor):
|
||||
window_results = predictor.batch_two_step_extract(images=images_pil_list)
|
||||
results.extend(window_results)
|
||||
append_page_blocks_to_middle_json(
|
||||
middle_json,
|
||||
@@ -343,6 +396,7 @@ async def aio_doc_analyze(
|
||||
):
|
||||
if predictor is None:
|
||||
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
||||
predictor = _maybe_enable_serial_execution(predictor, backend)
|
||||
|
||||
load_images_start = time.time()
|
||||
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
|
||||
@@ -351,7 +405,8 @@ async def aio_doc_analyze(
|
||||
logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s")
|
||||
|
||||
infer_start = time.time()
|
||||
results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
|
||||
async with aio_predictor_execution_guard(predictor):
|
||||
results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
|
||||
infer_time = round(time.time() - infer_start, 2)
|
||||
logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
|
||||
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
|
||||
@@ -369,6 +424,7 @@ async def aio_doc_analyze_low_memory(
|
||||
):
|
||||
if predictor is None:
|
||||
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
||||
predictor = _maybe_enable_serial_execution(predictor, backend)
|
||||
|
||||
pdf_doc = pdfium.PdfDocument(pdf_bytes)
|
||||
middle_json = init_middle_json()
|
||||
@@ -404,7 +460,8 @@ async def aio_doc_analyze_low_memory(
|
||||
f'pages {window_start + 1}-{window_end + 1}/{page_count} '
|
||||
f'({len(images_pil_list)} pages)'
|
||||
)
|
||||
window_results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
|
||||
async with aio_predictor_execution_guard(predictor):
|
||||
window_results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
|
||||
results.extend(window_results)
|
||||
append_page_blocks_to_middle_json(
|
||||
middle_json,
|
||||
|
||||
@@ -3,6 +3,7 @@ import io
|
||||
import json
|
||||
import os
|
||||
import copy
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
@@ -37,6 +38,8 @@ office_suffixes = docx_suffixes + pptx_suffixes + xlsx_suffixes
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
_pdf_rewrite_lock = threading.Lock()
|
||||
|
||||
def read_fn(path):
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
@@ -60,34 +63,46 @@ def prepare_env(output_dir, pdf_file_name, parse_method):
|
||||
|
||||
|
||||
def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page_id=None):
|
||||
pdf = pdfium.PdfDocument(pdf_bytes)
|
||||
output_pdf = pdfium.PdfDocument.new()
|
||||
try:
|
||||
end_page_id = get_end_page_id(end_page_id, len(pdf))
|
||||
# pypdfium2 document import/save is not thread-safe across concurrent FastAPI tasks.
|
||||
with _pdf_rewrite_lock:
|
||||
pdf = None
|
||||
output_pdf = None
|
||||
try:
|
||||
pdf = pdfium.PdfDocument(pdf_bytes)
|
||||
page_count = len(pdf)
|
||||
end_page_id = get_end_page_id(end_page_id, page_count)
|
||||
|
||||
# 逐页导入,失败则跳过
|
||||
output_index = 0
|
||||
for page_index in range(start_page_id, end_page_id + 1):
|
||||
try:
|
||||
output_pdf.import_pages(pdf, pages=[page_index])
|
||||
output_index += 1
|
||||
except Exception as page_error:
|
||||
output_pdf.del_page(output_index)
|
||||
logger.warning(f"Failed to import page {page_index}: {page_error}, skipping this page.")
|
||||
continue
|
||||
# Avoid rewriting when the caller requests the whole document.
|
||||
if start_page_id <= 0 and end_page_id >= page_count - 1:
|
||||
return pdf_bytes
|
||||
|
||||
# 将新PDF保存到内存缓冲区
|
||||
output_buffer = io.BytesIO()
|
||||
output_pdf.save(output_buffer)
|
||||
output_pdf = pdfium.PdfDocument.new()
|
||||
|
||||
# 获取字节数据
|
||||
output_bytes = output_buffer.getvalue()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in converting PDF bytes: {e}, Using original PDF bytes.")
|
||||
output_bytes = pdf_bytes
|
||||
pdf.close()
|
||||
output_pdf.close()
|
||||
return output_bytes
|
||||
# 逐页导入,失败则跳过
|
||||
output_index = 0
|
||||
for page_index in range(start_page_id, end_page_id + 1):
|
||||
try:
|
||||
output_pdf.import_pages(pdf, pages=[page_index])
|
||||
output_index += 1
|
||||
except Exception as page_error:
|
||||
output_pdf.del_page(output_index)
|
||||
logger.warning(f"Failed to import page {page_index}: {page_error}, skipping this page.")
|
||||
continue
|
||||
|
||||
# 将新PDF保存到内存缓冲区
|
||||
output_buffer = io.BytesIO()
|
||||
output_pdf.save(output_buffer)
|
||||
|
||||
# 获取字节数据
|
||||
return output_buffer.getvalue()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in converting PDF bytes: {e}, Using original PDF bytes.")
|
||||
return pdf_bytes
|
||||
finally:
|
||||
if pdf is not None:
|
||||
pdf.close()
|
||||
if output_pdf is not None:
|
||||
output_pdf.close()
|
||||
|
||||
|
||||
def _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id):
|
||||
|
||||
@@ -5,6 +5,7 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
import zipfile
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
@@ -40,6 +41,7 @@ from mineru.cli.common import (
|
||||
read_fn,
|
||||
)
|
||||
from mineru.utils.cli_parser import arg_parse
|
||||
from mineru.utils.config_reader import get_device
|
||||
from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path
|
||||
from mineru.version import __version__
|
||||
|
||||
@@ -59,6 +61,7 @@ ALLOWED_PARSE_METHODS = {"auto", "txt", "ocr"}
|
||||
|
||||
# 并发控制器
|
||||
_request_semaphore: Optional[asyncio.Semaphore] = None
|
||||
_mps_parse_lock = threading.Lock()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -164,10 +167,10 @@ def create_app():
|
||||
global _request_semaphore
|
||||
try:
|
||||
max_concurrent_requests = int(
|
||||
os.getenv("MINERU_API_MAX_CONCURRENT_REQUESTS", "0")
|
||||
os.getenv("MINERU_API_MAX_CONCURRENT_REQUESTS", "3")
|
||||
)
|
||||
except ValueError:
|
||||
max_concurrent_requests = 0
|
||||
max_concurrent_requests = 3
|
||||
|
||||
if max_concurrent_requests > 0:
|
||||
_request_semaphore = asyncio.Semaphore(max_concurrent_requests)
|
||||
@@ -654,12 +657,32 @@ async def run_parse_job(
|
||||
)
|
||||
|
||||
if request_options.backend == "pipeline":
|
||||
await asyncio.to_thread(do_parse, **parse_kwargs)
|
||||
async with serialize_parse_job_if_needed(request_options.backend):
|
||||
await asyncio.to_thread(do_parse, **parse_kwargs)
|
||||
else:
|
||||
await aio_do_parse(**parse_kwargs)
|
||||
async with serialize_parse_job_if_needed(request_options.backend):
|
||||
await aio_do_parse(**parse_kwargs)
|
||||
return response_file_names
|
||||
|
||||
|
||||
def should_serialize_parse_job(backend: str) -> bool:
|
||||
if get_device() != "mps":
|
||||
return False
|
||||
return backend == "pipeline" or backend.startswith(("vlm-", "hybrid-"))
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def serialize_parse_job_if_needed(backend: str):
|
||||
if not should_serialize_parse_job(backend):
|
||||
yield
|
||||
return
|
||||
await asyncio.to_thread(_mps_parse_lock.acquire)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_mps_parse_lock.release()
|
||||
|
||||
|
||||
def create_task_output_dir(task_id: str) -> str:
|
||||
output_root = get_output_root()
|
||||
task_output_dir = output_root / task_id
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Opendatalab. All rights reserved.
|
||||
import re
|
||||
import threading
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import pypdfium2 as pdfium
|
||||
@@ -13,6 +14,8 @@ from pdfminer.pdfinterp import PDFPageInterpreter
|
||||
from pdfminer.layout import LAParams, LTImage, LTFigure
|
||||
from pdfminer.converter import PDFPageAggregator
|
||||
|
||||
_pdf_sample_extract_lock = threading.Lock()
|
||||
|
||||
|
||||
def classify(pdf_bytes):
|
||||
"""
|
||||
@@ -27,6 +30,8 @@ def classify(pdf_bytes):
|
||||
|
||||
# 从字节数据加载PDF
|
||||
sample_pdf_bytes = extract_pages(pdf_bytes)
|
||||
if not sample_pdf_bytes:
|
||||
return 'ocr'
|
||||
pdf = pdfium.PdfDocument(sample_pdf_bytes)
|
||||
try:
|
||||
# 获取PDF页数
|
||||
@@ -187,40 +192,50 @@ def extract_pages(src_pdf_bytes: bytes) -> bytes:
|
||||
bytes: 提取页面后的PDF字节数据
|
||||
"""
|
||||
|
||||
# 从字节数据加载PDF
|
||||
pdf = pdfium.PdfDocument(src_pdf_bytes)
|
||||
with _pdf_sample_extract_lock:
|
||||
pdf = None
|
||||
sample_docs = None
|
||||
try:
|
||||
# 从字节数据加载PDF
|
||||
pdf = pdfium.PdfDocument(src_pdf_bytes)
|
||||
|
||||
# 获取PDF页数
|
||||
total_page = len(pdf)
|
||||
if total_page == 0:
|
||||
# 如果PDF没有页面,直接返回空文档
|
||||
logger.warning("PDF is empty, return empty document")
|
||||
return b''
|
||||
# 获取PDF页数
|
||||
total_page = len(pdf)
|
||||
if total_page == 0:
|
||||
# 如果PDF没有页面,直接返回空文档
|
||||
logger.warning("PDF is empty, return empty document")
|
||||
return b''
|
||||
|
||||
# 选择最多10页
|
||||
select_page_cnt = min(10, total_page)
|
||||
# 小文档直接复用原始字节,避免无意义的 PDF 重写。
|
||||
if total_page <= 10:
|
||||
return src_pdf_bytes
|
||||
|
||||
# 从总页数中随机选择页面
|
||||
page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist()
|
||||
# 选择最多10页
|
||||
select_page_cnt = min(10, total_page)
|
||||
|
||||
# 创建一个新的PDF文档
|
||||
sample_docs = pdfium.PdfDocument.new()
|
||||
# 从总页数中随机选择页面
|
||||
page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist()
|
||||
|
||||
try:
|
||||
# 将选择的页面导入新文档
|
||||
sample_docs.import_pages(pdf, page_indices)
|
||||
pdf.close()
|
||||
# 创建一个新的PDF文档
|
||||
sample_docs = pdfium.PdfDocument.new()
|
||||
|
||||
# 将新PDF保存到内存缓冲区
|
||||
output_buffer = BytesIO()
|
||||
sample_docs.save(output_buffer)
|
||||
# 将选择的页面导入新文档
|
||||
sample_docs.import_pages(pdf, page_indices)
|
||||
|
||||
# 获取字节数据
|
||||
return output_buffer.getvalue()
|
||||
except Exception as e:
|
||||
pdf.close()
|
||||
logger.exception(e)
|
||||
return b'' # 出错时返回空字节
|
||||
# 将新PDF保存到内存缓冲区
|
||||
output_buffer = BytesIO()
|
||||
sample_docs.save(output_buffer)
|
||||
|
||||
# 获取字节数据
|
||||
return output_buffer.getvalue()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return src_pdf_bytes
|
||||
finally:
|
||||
if pdf is not None:
|
||||
pdf.close()
|
||||
if sample_docs is not None:
|
||||
sample_docs.close()
|
||||
|
||||
|
||||
def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool:
|
||||
@@ -263,4 +278,4 @@ def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool:
|
||||
if __name__ == '__main__':
|
||||
with open('/Users/myhloli/pdf/luanma2x10.pdf', 'rb') as f:
|
||||
p_bytes = f.read()
|
||||
logger.info(f"PDF分类结果: {classify(p_bytes)}")
|
||||
logger.info(f"PDF分类结果: {classify(p_bytes)}")
|
||||
|
||||
Reference in New Issue
Block a user