feat: implement thread safety for PDF processing and model initialization

This commit is contained in:
myhloli
2026-03-20 01:37:49 +08:00
parent 69f4f83b34
commit 4cd501ccc7
7 changed files with 414 additions and 276 deletions

View File

@@ -18,7 +18,12 @@ from mineru.backend.hybrid.hybrid_model_output_to_middle_json import (
result_to_middle_json,
)
from mineru.backend.pipeline.model_init import HybridModelSingleton
from mineru.backend.vlm.vlm_analyze import ModelSingleton
from mineru.backend.vlm.vlm_analyze import (
ModelSingleton,
aio_predictor_execution_guard,
predictor_execution_guard,
_maybe_enable_serial_execution,
)
from mineru.data.data_reader_writer import DataWriter
from mineru.utils.config_reader import get_device, get_low_memory_window_size
from mineru.utils.enum_class import ImageType, NotExtractType
@@ -538,6 +543,7 @@ def doc_analyze(
# 初始化预测器
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
predictor = _maybe_enable_serial_execution(predictor, backend)
# 加载图像
load_images_start = time.time()
@@ -556,14 +562,16 @@ def doc_analyze(
infer_start = time.time()
# VLM提取
if _vlm_ocr_enable:
model_list = predictor.batch_two_step_extract(images=images_pil_list)
with predictor_execution_guard(predictor):
model_list = predictor.batch_two_step_extract(images=images_pil_list)
hybrid_pipeline_model = None
else:
batch_ratio = get_batch_ratio(device)
model_list = predictor.batch_two_step_extract(
images=images_pil_list,
not_extract_list=not_extract_list
)
with predictor_execution_guard(predictor):
model_list = predictor.batch_two_step_extract(
images=images_pil_list,
not_extract_list=not_extract_list
)
model_list, hybrid_pipeline_model = _process_ocr_and_formulas(
images_pil_list,
model_list,
@@ -604,6 +612,7 @@ def doc_analyze_low_memory(
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
predictor = _maybe_enable_serial_execution(predictor, backend)
device = get_device()
_ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method)
@@ -648,12 +657,14 @@ def doc_analyze_low_memory(
f'({len(images_pil_list)} pages)'
)
if _vlm_ocr_enable:
window_model_list = predictor.batch_two_step_extract(images=images_pil_list)
with predictor_execution_guard(predictor):
window_model_list = predictor.batch_two_step_extract(images=images_pil_list)
else:
window_model_list = predictor.batch_two_step_extract(
images=images_pil_list,
not_extract_list=not_extract_list
)
with predictor_execution_guard(predictor):
window_model_list = predictor.batch_two_step_extract(
images=images_pil_list,
not_extract_list=not_extract_list
)
window_model_list, hybrid_pipeline_model = _process_ocr_and_formulas(
images_pil_list,
window_model_list,
@@ -715,6 +726,7 @@ async def aio_doc_analyze(
# 初始化预测器
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
predictor = _maybe_enable_serial_execution(predictor, backend)
# 加载图像
load_images_start = time.time()
@@ -733,14 +745,16 @@ async def aio_doc_analyze(
infer_start = time.time()
# VLM提取
if _vlm_ocr_enable:
model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list)
async with aio_predictor_execution_guard(predictor):
model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list)
hybrid_pipeline_model = None
else:
batch_ratio = get_batch_ratio(device)
model_list = await predictor.aio_batch_two_step_extract(
images=images_pil_list,
not_extract_list=not_extract_list
)
async with aio_predictor_execution_guard(predictor):
model_list = await predictor.aio_batch_two_step_extract(
images=images_pil_list,
not_extract_list=not_extract_list
)
model_list, hybrid_pipeline_model = _process_ocr_and_formulas(
images_pil_list,
model_list,
@@ -781,6 +795,7 @@ async def aio_doc_analyze_low_memory(
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
predictor = _maybe_enable_serial_execution(predictor, backend)
device = get_device()
_ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method)
@@ -825,12 +840,14 @@ async def aio_doc_analyze_low_memory(
f'({len(images_pil_list)} pages)'
)
if _vlm_ocr_enable:
window_model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list)
async with aio_predictor_execution_guard(predictor):
window_model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list)
else:
window_model_list = await predictor.aio_batch_two_step_extract(
images=images_pil_list,
not_extract_list=not_extract_list
)
async with aio_predictor_execution_guard(predictor):
window_model_list = await predictor.aio_batch_two_step_extract(
images=images_pil_list,
not_extract_list=not_extract_list
)
window_model_list, hybrid_pipeline_model = _process_ocr_and_formulas(
images_pil_list,
window_model_list,

View File

@@ -1,4 +1,5 @@
import os
import threading
import torch
from loguru import logger
@@ -112,10 +113,12 @@ def ocr_model_init(det_db_box_thresh=0.3,
class AtomModelSingleton:
_instance = None
_models = {}
_lock = threading.RLock()
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
@@ -145,8 +148,9 @@ class AtomModelSingleton:
else:
key = atom_model_name
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
with self._lock:
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key]
def atom_model_init(model_name: str, **kwargs):
@@ -258,10 +262,12 @@ class MineruPipelineModel:
class HybridModelSingleton:
_instance = None
_models = {}
_lock = threading.RLock()
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
@@ -270,11 +276,12 @@ class HybridModelSingleton:
formula_enable=None,
):
key = (lang, formula_enable)
if key not in self._models:
self._models[key] = MineruHybridModel(
lang=lang,
formula_enable=formula_enable,
)
with self._lock:
if key not in self._models:
self._models[key] = MineruHybridModel(
lang=lang,
formula_enable=formula_enable,
)
return self._models[key]
def ocr_det_batch_setting():

View File

@@ -1,4 +1,5 @@
import os
import threading
import time
from typing import List, Tuple
@@ -22,10 +23,12 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
class ModelSingleton:
_instance = None
_models = {}
_lock = threading.RLock()
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
@@ -35,12 +38,13 @@ class ModelSingleton:
table_enable=None,
):
key = (lang, formula_enable, table_enable)
if key not in self._models:
self._models[key] = custom_model_init(
lang=lang,
formula_enable=formula_enable,
table_enable=table_enable,
)
with self._lock:
if key not in self._models:
self._models[key] = custom_model_init(
lang=lang,
formula_enable=formula_enable,
table_enable=table_enable,
)
return self._models[key]

View File

@@ -1,7 +1,10 @@
# Copyright (c) Opendatalab. All rights reserved.
import asyncio
import os
import time
import json
import threading
from contextlib import asynccontextmanager, contextmanager
import pypdfium2 as pdfium
from loguru import logger
@@ -25,10 +28,12 @@ from packaging import version
class ModelSingleton:
_instance = None
_models = {}
_lock = threading.RLock()
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
@@ -39,188 +44,232 @@ class ModelSingleton:
**kwargs,
) -> MinerUClient:
key = (backend, model_path, server_url)
if key not in self._models:
start_time = time.time()
model = None
processor = None
vllm_llm = None
lmdeploy_engine = None
vllm_async_llm = None
batch_size = kwargs.get("batch_size", 0) # for transformers backend only
max_concurrency = kwargs.get("max_concurrency", 100) # for http-client backend only
http_timeout = kwargs.get("http_timeout", 600) # for http-client backend only
server_headers = kwargs.get("server_headers", None) # for http-client backend only
max_retries = kwargs.get("max_retries", 3) # for http-client backend only
retry_backoff_factor = kwargs.get("retry_backoff_factor", 0.5) # for http-client backend only
# 从kwargs中移除这些参数避免传递给不相关的初始化函数
for param in ["batch_size", "max_concurrency", "http_timeout", "server_headers", "max_retries", "retry_backoff_factor"]:
if param in kwargs:
del kwargs[param]
if backend not in ["http-client"] and not model_path:
model_path = auto_download_and_get_model_root_path("/","vlm")
if backend == "transformers":
try:
from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
)
from transformers import __version__ as transformers_version
except ImportError:
raise ImportError("Please install transformers to use the transformers backend.")
if version.parse(transformers_version) >= version.parse("4.56.0"):
dtype_key = "dtype"
else:
dtype_key = "torch_dtype"
device = get_device()
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map={"": device},
**{dtype_key: "auto"}, # type: ignore
)
processor = AutoProcessor.from_pretrained(
model_path,
use_fast=True,
)
if batch_size == 0:
batch_size = set_default_batch_size()
elif backend == "mlx-engine":
mlx_supported = is_mac_os_version_supported()
if not mlx_supported:
raise EnvironmentError("mlx-engine backend is only supported on macOS 13.5+ with Apple Silicon.")
try:
from mlx_vlm import load as mlx_load
except ImportError:
raise ImportError("Please install mlx-vlm to use the mlx-engine backend.")
model, processor = mlx_load(model_path)
else:
if os.getenv('OMP_NUM_THREADS') is None:
os.environ["OMP_NUM_THREADS"] = "1"
if backend == "vllm-engine":
with self._lock:
if key not in self._models:
start_time = time.time()
model = None
processor = None
vllm_llm = None
lmdeploy_engine = None
vllm_async_llm = None
batch_size = kwargs.get("batch_size", 0) # for transformers backend only
max_concurrency = kwargs.get("max_concurrency", 100) # for http-client backend only
http_timeout = kwargs.get("http_timeout", 600) # for http-client backend only
server_headers = kwargs.get("server_headers", None) # for http-client backend only
max_retries = kwargs.get("max_retries", 3) # for http-client backend only
retry_backoff_factor = kwargs.get("retry_backoff_factor", 0.5) # for http-client backend only
# 从kwargs中移除这些参数避免传递给不相关的初始化函数
for param in ["batch_size", "max_concurrency", "http_timeout", "server_headers", "max_retries", "retry_backoff_factor"]:
if param in kwargs:
del kwargs[param]
if backend not in ["http-client"] and not model_path:
model_path = auto_download_and_get_model_root_path("/","vlm")
if backend == "transformers":
try:
import vllm
from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
)
from transformers import __version__ as transformers_version
except ImportError:
raise ImportError("Please install vllm to use the vllm-engine backend.")
raise ImportError("Please install transformers to use the transformers backend.")
kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="sync_engine")
if "compilation_config" in kwargs:
if isinstance(kwargs["compilation_config"], str):
try:
kwargs["compilation_config"] = json.loads(kwargs["compilation_config"])
except json.JSONDecodeError:
logger.warning(
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
del kwargs["compilation_config"]
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
if "model" not in kwargs:
kwargs["model"] = model_path
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
from mineru_vl_utils import MinerULogitsProcessor
kwargs["logits_processors"] = [MinerULogitsProcessor]
# 使用kwargs为 vllm初始化参数
vllm_llm = vllm.LLM(**kwargs)
elif backend == "vllm-async-engine":
try:
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.config import CompilationConfig
except ImportError:
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="async_engine")
if "compilation_config" in kwargs:
if isinstance(kwargs["compilation_config"], dict):
# 如果是字典,转换为 CompilationConfig 对象
kwargs["compilation_config"] = CompilationConfig(**kwargs["compilation_config"])
elif isinstance(kwargs["compilation_config"], str):
# 如果是 JSON 字符串,先解析再转换
try:
config_dict = json.loads(kwargs["compilation_config"])
kwargs["compilation_config"] = CompilationConfig(**config_dict)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(
f"Failed to parse compilation_config: {kwargs['compilation_config']}, error: {e}")
del kwargs["compilation_config"]
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
if "model" not in kwargs:
kwargs["model"] = model_path
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
from mineru_vl_utils import MinerULogitsProcessor
kwargs["logits_processors"] = [MinerULogitsProcessor]
# 使用kwargs为 vllm初始化参数
vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
elif backend == "lmdeploy-engine":
try:
from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.serve.vl_async_engine import VLAsyncEngine
except ImportError:
raise ImportError("Please install lmdeploy to use the lmdeploy-engine backend.")
if "cache_max_entry_count" not in kwargs:
kwargs["cache_max_entry_count"] = 0.5
device_type = os.getenv("MINERU_LMDEPLOY_DEVICE", "")
if device_type == "":
if "lmdeploy_device" in kwargs:
device_type = kwargs.pop("lmdeploy_device")
if device_type not in ["cuda", "ascend", "maca", "camb"]:
raise ValueError(f"Unsupported lmdeploy device type: {device_type}")
else:
device_type = "cuda"
lm_backend = os.getenv("MINERU_LMDEPLOY_BACKEND", "")
if lm_backend == "":
if "lmdeploy_backend" in kwargs:
lm_backend = kwargs.pop("lmdeploy_backend")
if lm_backend not in ["pytorch", "turbomind"]:
raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
else:
lm_backend = set_lmdeploy_backend(device_type)
logger.info(f"lmdeploy device is: {device_type}, lmdeploy backend is: {lm_backend}")
if lm_backend == "pytorch":
kwargs["device_type"] = device_type
backend_config = PytorchEngineConfig(**kwargs)
elif lm_backend == "turbomind":
backend_config = TurbomindEngineConfig(**kwargs)
if version.parse(transformers_version) >= version.parse("4.56.0"):
dtype_key = "dtype"
else:
raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
log_level = 'ERROR'
from lmdeploy.utils import get_logger
lm_logger = get_logger('lmdeploy')
lm_logger.setLevel(log_level)
if os.getenv('TM_LOG_LEVEL') is None:
os.environ['TM_LOG_LEVEL'] = log_level
lmdeploy_engine = VLAsyncEngine(
dtype_key = "torch_dtype"
device = get_device()
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
backend=lm_backend,
backend_config=backend_config,
device_map={"": device},
**{dtype_key: "auto"}, # type: ignore
)
self._models[key] = MinerUClient(
backend=backend,
model=model,
processor=processor,
lmdeploy_engine=lmdeploy_engine,
vllm_llm=vllm_llm,
vllm_async_llm=vllm_async_llm,
server_url=server_url,
batch_size=batch_size,
max_concurrency=max_concurrency,
http_timeout=http_timeout,
server_headers=server_headers,
max_retries=max_retries,
retry_backoff_factor=retry_backoff_factor,
)
elapsed = round(time.time() - start_time, 2)
logger.info(f"get {backend} predictor cost: {elapsed}s")
processor = AutoProcessor.from_pretrained(
model_path,
use_fast=True,
)
if batch_size == 0:
batch_size = set_default_batch_size()
elif backend == "mlx-engine":
mlx_supported = is_mac_os_version_supported()
if not mlx_supported:
raise EnvironmentError("mlx-engine backend is only supported on macOS 13.5+ with Apple Silicon.")
try:
from mlx_vlm import load as mlx_load
except ImportError:
raise ImportError("Please install mlx-vlm to use the mlx-engine backend.")
model, processor = mlx_load(model_path)
else:
if os.getenv('OMP_NUM_THREADS') is None:
os.environ["OMP_NUM_THREADS"] = "1"
if backend == "vllm-engine":
try:
import vllm
except ImportError:
raise ImportError("Please install vllm to use the vllm-engine backend.")
kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="sync_engine")
if "compilation_config" in kwargs:
if isinstance(kwargs["compilation_config"], str):
try:
kwargs["compilation_config"] = json.loads(kwargs["compilation_config"])
except json.JSONDecodeError:
logger.warning(
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
del kwargs["compilation_config"]
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
if "model" not in kwargs:
kwargs["model"] = model_path
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
from mineru_vl_utils import MinerULogitsProcessor
kwargs["logits_processors"] = [MinerULogitsProcessor]
# 使用kwargs为 vllm初始化参数
vllm_llm = vllm.LLM(**kwargs)
elif backend == "vllm-async-engine":
try:
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.config import CompilationConfig
except ImportError:
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="async_engine")
if "compilation_config" in kwargs:
if isinstance(kwargs["compilation_config"], dict):
# 如果是字典,转换为 CompilationConfig 对象
kwargs["compilation_config"] = CompilationConfig(**kwargs["compilation_config"])
elif isinstance(kwargs["compilation_config"], str):
# 如果是 JSON 字符串,先解析再转换
try:
config_dict = json.loads(kwargs["compilation_config"])
kwargs["compilation_config"] = CompilationConfig(**config_dict)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(
f"Failed to parse compilation_config: {kwargs['compilation_config']}, error: {e}")
del kwargs["compilation_config"]
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
if "model" not in kwargs:
kwargs["model"] = model_path
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
from mineru_vl_utils import MinerULogitsProcessor
kwargs["logits_processors"] = [MinerULogitsProcessor]
# 使用kwargs为 vllm初始化参数
vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
elif backend == "lmdeploy-engine":
try:
from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.serve.vl_async_engine import VLAsyncEngine
except ImportError:
raise ImportError("Please install lmdeploy to use the lmdeploy-engine backend.")
if "cache_max_entry_count" not in kwargs:
kwargs["cache_max_entry_count"] = 0.5
device_type = os.getenv("MINERU_LMDEPLOY_DEVICE", "")
if device_type == "":
if "lmdeploy_device" in kwargs:
device_type = kwargs.pop("lmdeploy_device")
if device_type not in ["cuda", "ascend", "maca", "camb"]:
raise ValueError(f"Unsupported lmdeploy device type: {device_type}")
else:
device_type = "cuda"
lm_backend = os.getenv("MINERU_LMDEPLOY_BACKEND", "")
if lm_backend == "":
if "lmdeploy_backend" in kwargs:
lm_backend = kwargs.pop("lmdeploy_backend")
if lm_backend not in ["pytorch", "turbomind"]:
raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
else:
lm_backend = set_lmdeploy_backend(device_type)
logger.info(f"lmdeploy device is: {device_type}, lmdeploy backend is: {lm_backend}")
if lm_backend == "pytorch":
kwargs["device_type"] = device_type
backend_config = PytorchEngineConfig(**kwargs)
elif lm_backend == "turbomind":
backend_config = TurbomindEngineConfig(**kwargs)
else:
raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
log_level = 'ERROR'
from lmdeploy.utils import get_logger
lm_logger = get_logger('lmdeploy')
lm_logger.setLevel(log_level)
if os.getenv('TM_LOG_LEVEL') is None:
os.environ['TM_LOG_LEVEL'] = log_level
lmdeploy_engine = VLAsyncEngine(
model_path,
backend=lm_backend,
backend_config=backend_config,
)
predictor = MinerUClient(
backend=backend,
model=model,
processor=processor,
lmdeploy_engine=lmdeploy_engine,
vllm_llm=vllm_llm,
vllm_async_llm=vllm_async_llm,
server_url=server_url,
batch_size=batch_size,
max_concurrency=max_concurrency,
http_timeout=http_timeout,
server_headers=server_headers,
max_retries=max_retries,
retry_backoff_factor=retry_backoff_factor,
)
_maybe_enable_serial_execution(predictor, backend)
self._models[key] = predictor
elapsed = round(time.time() - start_time, 2)
logger.info(f"get {backend} predictor cost: {elapsed}s")
return self._models[key]
def _predictor_uses_mlx(predictor: MinerUClient, backend: str | None = None) -> bool:
if backend == "mlx-engine":
return True
client = getattr(predictor, "client", None)
return type(client).__module__.endswith(".mlx_client")
def _maybe_enable_serial_execution(
predictor: MinerUClient,
backend: str | None = None,
) -> MinerUClient:
if _predictor_uses_mlx(predictor, backend) and not hasattr(
predictor, "_mineru_execution_lock"
):
predictor._mineru_execution_lock = threading.Lock()
return predictor
@contextmanager
def predictor_execution_guard(predictor: MinerUClient):
lock = getattr(predictor, "_mineru_execution_lock", None)
if lock is None:
yield
return
with lock:
yield
@asynccontextmanager
async def aio_predictor_execution_guard(predictor: MinerUClient):
lock = getattr(predictor, "_mineru_execution_lock", None)
if lock is None:
yield
return
await asyncio.to_thread(lock.acquire)
try:
yield
finally:
lock.release()
def _close_images(images_list):
for image_dict in images_list or []:
pil_img = image_dict.get("img_pil")
@@ -242,6 +291,7 @@ def doc_analyze(
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
predictor = _maybe_enable_serial_execution(predictor, backend)
load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
@@ -250,7 +300,8 @@ def doc_analyze(
logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s")
infer_start = time.time()
results = predictor.batch_two_step_extract(images=images_pil_list)
with predictor_execution_guard(predictor):
results = predictor.batch_two_step_extract(images=images_pil_list)
infer_time = round(time.time() - infer_start, 2)
logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
@@ -269,6 +320,7 @@ def doc_analyze_low_memory(
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
predictor = _maybe_enable_serial_execution(predictor, backend)
pdf_doc = pdfium.PdfDocument(pdf_bytes)
middle_json = init_middle_json()
@@ -304,7 +356,8 @@ def doc_analyze_low_memory(
f'pages {window_start + 1}-{window_end + 1}/{page_count} '
f'({len(images_pil_list)} pages)'
)
window_results = predictor.batch_two_step_extract(images=images_pil_list)
with predictor_execution_guard(predictor):
window_results = predictor.batch_two_step_extract(images=images_pil_list)
results.extend(window_results)
append_page_blocks_to_middle_json(
middle_json,
@@ -343,6 +396,7 @@ async def aio_doc_analyze(
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
predictor = _maybe_enable_serial_execution(predictor, backend)
load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
@@ -351,7 +405,8 @@ async def aio_doc_analyze(
logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s")
infer_start = time.time()
results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
async with aio_predictor_execution_guard(predictor):
results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
infer_time = round(time.time() - infer_start, 2)
logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
@@ -369,6 +424,7 @@ async def aio_doc_analyze_low_memory(
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
predictor = _maybe_enable_serial_execution(predictor, backend)
pdf_doc = pdfium.PdfDocument(pdf_bytes)
middle_json = init_middle_json()
@@ -404,7 +460,8 @@ async def aio_doc_analyze_low_memory(
f'pages {window_start + 1}-{window_end + 1}/{page_count} '
f'({len(images_pil_list)} pages)'
)
window_results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
async with aio_predictor_execution_guard(predictor):
window_results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
results.extend(window_results)
append_page_blocks_to_middle_json(
middle_json,

View File

@@ -3,6 +3,7 @@ import io
import json
import os
import copy
import threading
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
@@ -37,6 +38,8 @@ office_suffixes = docx_suffixes + pptx_suffixes + xlsx_suffixes
os.environ["TOKENIZERS_PARALLELISM"] = "false"
_pdf_rewrite_lock = threading.Lock()
def read_fn(path):
if not isinstance(path, Path):
path = Path(path)
@@ -60,34 +63,46 @@ def prepare_env(output_dir, pdf_file_name, parse_method):
def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page_id=None):
pdf = pdfium.PdfDocument(pdf_bytes)
output_pdf = pdfium.PdfDocument.new()
try:
end_page_id = get_end_page_id(end_page_id, len(pdf))
# pypdfium2 document import/save is not thread-safe across concurrent FastAPI tasks.
with _pdf_rewrite_lock:
pdf = None
output_pdf = None
try:
pdf = pdfium.PdfDocument(pdf_bytes)
page_count = len(pdf)
end_page_id = get_end_page_id(end_page_id, page_count)
# 逐页导入,失败则跳过
output_index = 0
for page_index in range(start_page_id, end_page_id + 1):
try:
output_pdf.import_pages(pdf, pages=[page_index])
output_index += 1
except Exception as page_error:
output_pdf.del_page(output_index)
logger.warning(f"Failed to import page {page_index}: {page_error}, skipping this page.")
continue
# Avoid rewriting when the caller requests the whole document.
if start_page_id <= 0 and end_page_id >= page_count - 1:
return pdf_bytes
# 将新PDF保存到内存缓冲区
output_buffer = io.BytesIO()
output_pdf.save(output_buffer)
output_pdf = pdfium.PdfDocument.new()
# 获取字节数据
output_bytes = output_buffer.getvalue()
except Exception as e:
logger.warning(f"Error in converting PDF bytes: {e}, Using original PDF bytes.")
output_bytes = pdf_bytes
pdf.close()
output_pdf.close()
return output_bytes
# 逐页导入,失败则跳过
output_index = 0
for page_index in range(start_page_id, end_page_id + 1):
try:
output_pdf.import_pages(pdf, pages=[page_index])
output_index += 1
except Exception as page_error:
output_pdf.del_page(output_index)
logger.warning(f"Failed to import page {page_index}: {page_error}, skipping this page.")
continue
# 将新PDF保存到内存缓冲区
output_buffer = io.BytesIO()
output_pdf.save(output_buffer)
# 获取字节数据
return output_buffer.getvalue()
except Exception as e:
logger.warning(f"Error in converting PDF bytes: {e}, Using original PDF bytes.")
return pdf_bytes
finally:
if pdf is not None:
pdf.close()
if output_pdf is not None:
output_pdf.close()
def _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id):

View File

@@ -5,6 +5,7 @@ import re
import shutil
import sys
import tempfile
import threading
import uuid
import zipfile
from contextlib import asynccontextmanager, suppress
@@ -40,6 +41,7 @@ from mineru.cli.common import (
read_fn,
)
from mineru.utils.cli_parser import arg_parse
from mineru.utils.config_reader import get_device
from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path
from mineru.version import __version__
@@ -59,6 +61,7 @@ ALLOWED_PARSE_METHODS = {"auto", "txt", "ocr"}
# 并发控制器
_request_semaphore: Optional[asyncio.Semaphore] = None
_mps_parse_lock = threading.Lock()
@dataclass
@@ -164,10 +167,10 @@ def create_app():
global _request_semaphore
try:
max_concurrent_requests = int(
os.getenv("MINERU_API_MAX_CONCURRENT_REQUESTS", "0")
os.getenv("MINERU_API_MAX_CONCURRENT_REQUESTS", "3")
)
except ValueError:
max_concurrent_requests = 0
max_concurrent_requests = 3
if max_concurrent_requests > 0:
_request_semaphore = asyncio.Semaphore(max_concurrent_requests)
@@ -654,12 +657,32 @@ async def run_parse_job(
)
if request_options.backend == "pipeline":
await asyncio.to_thread(do_parse, **parse_kwargs)
async with serialize_parse_job_if_needed(request_options.backend):
await asyncio.to_thread(do_parse, **parse_kwargs)
else:
await aio_do_parse(**parse_kwargs)
async with serialize_parse_job_if_needed(request_options.backend):
await aio_do_parse(**parse_kwargs)
return response_file_names
def should_serialize_parse_job(backend: str) -> bool:
if get_device() != "mps":
return False
return backend == "pipeline" or backend.startswith(("vlm-", "hybrid-"))
@asynccontextmanager
async def serialize_parse_job_if_needed(backend: str):
if not should_serialize_parse_job(backend):
yield
return
await asyncio.to_thread(_mps_parse_lock.acquire)
try:
yield
finally:
_mps_parse_lock.release()
def create_task_output_dir(task_id: str) -> str:
output_root = get_output_root()
task_output_dir = output_root / task_id

View File

@@ -1,5 +1,6 @@
# Copyright (c) Opendatalab. All rights reserved.
import re
import threading
from io import BytesIO
import numpy as np
import pypdfium2 as pdfium
@@ -13,6 +14,8 @@ from pdfminer.pdfinterp import PDFPageInterpreter
from pdfminer.layout import LAParams, LTImage, LTFigure
from pdfminer.converter import PDFPageAggregator
_pdf_sample_extract_lock = threading.Lock()
def classify(pdf_bytes):
"""
@@ -27,6 +30,8 @@ def classify(pdf_bytes):
# 从字节数据加载PDF
sample_pdf_bytes = extract_pages(pdf_bytes)
if not sample_pdf_bytes:
return 'ocr'
pdf = pdfium.PdfDocument(sample_pdf_bytes)
try:
# 获取PDF页数
@@ -187,40 +192,50 @@ def extract_pages(src_pdf_bytes: bytes) -> bytes:
bytes: 提取页面后的PDF字节数据
"""
# 从字节数据加载PDF
pdf = pdfium.PdfDocument(src_pdf_bytes)
with _pdf_sample_extract_lock:
pdf = None
sample_docs = None
try:
# 从字节数据加载PDF
pdf = pdfium.PdfDocument(src_pdf_bytes)
# 获取PDF页数
total_page = len(pdf)
if total_page == 0:
# 如果PDF没有页面直接返回空文档
logger.warning("PDF is empty, return empty document")
return b''
# 获取PDF页数
total_page = len(pdf)
if total_page == 0:
# 如果PDF没有页面直接返回空文档
logger.warning("PDF is empty, return empty document")
return b''
# 选择最多10页
select_page_cnt = min(10, total_page)
# 小文档直接复用原始字节,避免无意义的 PDF 重写。
if total_page <= 10:
return src_pdf_bytes
# 从总页数中随机选择页面
page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist()
# 选择最多10页
select_page_cnt = min(10, total_page)
# 创建一个新的PDF文档
sample_docs = pdfium.PdfDocument.new()
# 从总页数中随机选择页面
page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist()
try:
# 将选择的页面导入新文档
sample_docs.import_pages(pdf, page_indices)
pdf.close()
# 创建一个新的PDF文档
sample_docs = pdfium.PdfDocument.new()
# 将新PDF保存到内存缓冲区
output_buffer = BytesIO()
sample_docs.save(output_buffer)
# 将选择的页面导入新文档
sample_docs.import_pages(pdf, page_indices)
# 获取字节数据
return output_buffer.getvalue()
except Exception as e:
pdf.close()
logger.exception(e)
return b'' # 出错时返回空字节
# 将新PDF保存到内存缓冲区
output_buffer = BytesIO()
sample_docs.save(output_buffer)
# 获取字节数据
return output_buffer.getvalue()
except Exception as e:
logger.exception(e)
return src_pdf_bytes
finally:
if pdf is not None:
pdf.close()
if sample_docs is not None:
sample_docs.close()
def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool:
@@ -263,4 +278,4 @@ def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool:
if __name__ == '__main__':
with open('/Users/myhloli/pdf/luanma2x10.pdf', 'rb') as f:
p_bytes = f.read()
logger.info(f"PDF分类结果: {classify(p_bytes)}")
logger.info(f"PDF分类结果: {classify(p_bytes)}")