From 4cd501ccc704e7c07bb4ccbaedbd2be4addd32e4 Mon Sep 17 00:00:00 2001 From: myhloli Date: Fri, 20 Mar 2026 01:37:49 +0800 Subject: [PATCH] feat: implement thread safety for PDF processing and model initialization --- mineru/backend/hybrid/hybrid_analyze.py | 59 ++- mineru/backend/pipeline/model_init.py | 29 +- mineru/backend/pipeline/pipeline_analyze.py | 20 +- mineru/backend/vlm/vlm_analyze.py | 415 +++++++++++--------- mineru/cli/common.py | 65 +-- mineru/cli/fast_api.py | 31 +- mineru/utils/pdf_classify.py | 71 ++-- 7 files changed, 414 insertions(+), 276 deletions(-) diff --git a/mineru/backend/hybrid/hybrid_analyze.py b/mineru/backend/hybrid/hybrid_analyze.py index 674b5a89..8d96db41 100644 --- a/mineru/backend/hybrid/hybrid_analyze.py +++ b/mineru/backend/hybrid/hybrid_analyze.py @@ -18,7 +18,12 @@ from mineru.backend.hybrid.hybrid_model_output_to_middle_json import ( result_to_middle_json, ) from mineru.backend.pipeline.model_init import HybridModelSingleton -from mineru.backend.vlm.vlm_analyze import ModelSingleton +from mineru.backend.vlm.vlm_analyze import ( + ModelSingleton, + aio_predictor_execution_guard, + predictor_execution_guard, + _maybe_enable_serial_execution, +) from mineru.data.data_reader_writer import DataWriter from mineru.utils.config_reader import get_device, get_low_memory_window_size from mineru.utils.enum_class import ImageType, NotExtractType @@ -538,6 +543,7 @@ def doc_analyze( # 初始化预测器 if predictor is None: predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs) + predictor = _maybe_enable_serial_execution(predictor, backend) # 加载图像 load_images_start = time.time() @@ -556,14 +562,16 @@ def doc_analyze( infer_start = time.time() # VLM提取 if _vlm_ocr_enable: - model_list = predictor.batch_two_step_extract(images=images_pil_list) + with predictor_execution_guard(predictor): + model_list = predictor.batch_two_step_extract(images=images_pil_list) hybrid_pipeline_model = None else: batch_ratio = get_batch_ratio(device) - model_list = predictor.batch_two_step_extract( - images=images_pil_list, - not_extract_list=not_extract_list - ) + with predictor_execution_guard(predictor): + model_list = predictor.batch_two_step_extract( + images=images_pil_list, + not_extract_list=not_extract_list + ) model_list, hybrid_pipeline_model = _process_ocr_and_formulas( images_pil_list, model_list, @@ -604,6 +612,7 @@ def doc_analyze_low_memory( ): if predictor is None: predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs) + predictor = _maybe_enable_serial_execution(predictor, backend) device = get_device() _ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method) @@ -648,12 +657,14 @@ def doc_analyze_low_memory( f'({len(images_pil_list)} pages)' ) if _vlm_ocr_enable: - window_model_list = predictor.batch_two_step_extract(images=images_pil_list) + with predictor_execution_guard(predictor): + window_model_list = predictor.batch_two_step_extract(images=images_pil_list) else: - window_model_list = predictor.batch_two_step_extract( - images=images_pil_list, - not_extract_list=not_extract_list - ) + with predictor_execution_guard(predictor): + window_model_list = predictor.batch_two_step_extract( + images=images_pil_list, + not_extract_list=not_extract_list + ) window_model_list, hybrid_pipeline_model = _process_ocr_and_formulas( images_pil_list, window_model_list, @@ -715,6 +726,7 @@ async def aio_doc_analyze( # 初始化预测器 if predictor is None: predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs) + predictor = _maybe_enable_serial_execution(predictor, backend) # 加载图像 load_images_start = time.time() @@ -733,14 +745,16 @@ async def aio_doc_analyze( infer_start = time.time() # VLM提取 if _vlm_ocr_enable: - model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list) + async with aio_predictor_execution_guard(predictor): + model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list) hybrid_pipeline_model = None else: batch_ratio = get_batch_ratio(device) - model_list = await predictor.aio_batch_two_step_extract( - images=images_pil_list, - not_extract_list=not_extract_list - ) + async with aio_predictor_execution_guard(predictor): + model_list = await predictor.aio_batch_two_step_extract( + images=images_pil_list, + not_extract_list=not_extract_list + ) model_list, hybrid_pipeline_model = _process_ocr_and_formulas( images_pil_list, model_list, @@ -781,6 +795,7 @@ async def aio_doc_analyze_low_memory( ): if predictor is None: predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs) + predictor = _maybe_enable_serial_execution(predictor, backend) device = get_device() _ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method) @@ -825,12 +840,14 @@ async def aio_doc_analyze_low_memory( f'({len(images_pil_list)} pages)' ) if _vlm_ocr_enable: - window_model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list) + async with aio_predictor_execution_guard(predictor): + window_model_list = await predictor.aio_batch_two_step_extract(images=images_pil_list) else: - window_model_list = await predictor.aio_batch_two_step_extract( - images=images_pil_list, - not_extract_list=not_extract_list - ) + async with aio_predictor_execution_guard(predictor): + window_model_list = await predictor.aio_batch_two_step_extract( + images=images_pil_list, + not_extract_list=not_extract_list + ) window_model_list, hybrid_pipeline_model = _process_ocr_and_formulas( images_pil_list, window_model_list, diff --git a/mineru/backend/pipeline/model_init.py b/mineru/backend/pipeline/model_init.py index 634bd4ad..de170aca 100644 --- a/mineru/backend/pipeline/model_init.py +++ b/mineru/backend/pipeline/model_init.py @@ -1,4 +1,5 @@ import os +import threading import torch from loguru import logger @@ -112,10 +113,12 @@ def ocr_model_init(det_db_box_thresh=0.3, class AtomModelSingleton: _instance = None _models = {} + _lock = threading.RLock() def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super().__new__(cls) + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) return cls._instance def get_atom_model(self, atom_model_name: str, **kwargs): @@ -145,8 +148,9 @@ class AtomModelSingleton: else: key = atom_model_name - if key not in self._models: - self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs) + with self._lock: + if key not in self._models: + self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs) return self._models[key] def atom_model_init(model_name: str, **kwargs): @@ -258,10 +262,12 @@ class MineruPipelineModel: class HybridModelSingleton: _instance = None _models = {} + _lock = threading.RLock() def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super().__new__(cls) + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) return cls._instance def get_model( @@ -270,11 +276,12 @@ class HybridModelSingleton: formula_enable=None, ): key = (lang, formula_enable) - if key not in self._models: - self._models[key] = MineruHybridModel( - lang=lang, - formula_enable=formula_enable, - ) + with self._lock: + if key not in self._models: + self._models[key] = MineruHybridModel( + lang=lang, + formula_enable=formula_enable, + ) return self._models[key] def ocr_det_batch_setting(): diff --git a/mineru/backend/pipeline/pipeline_analyze.py b/mineru/backend/pipeline/pipeline_analyze.py index 72ef5772..2384b945 100644 --- a/mineru/backend/pipeline/pipeline_analyze.py +++ b/mineru/backend/pipeline/pipeline_analyze.py @@ -1,4 +1,5 @@ import os +import threading import time from typing import List, Tuple @@ -22,10 +23,12 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 class ModelSingleton: _instance = None _models = {} + _lock = threading.RLock() def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super().__new__(cls) + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) return cls._instance def get_model( @@ -35,12 +38,13 @@ class ModelSingleton: table_enable=None, ): key = (lang, formula_enable, table_enable) - if key not in self._models: - self._models[key] = custom_model_init( - lang=lang, - formula_enable=formula_enable, - table_enable=table_enable, - ) + with self._lock: + if key not in self._models: + self._models[key] = custom_model_init( + lang=lang, + formula_enable=formula_enable, + table_enable=table_enable, + ) return self._models[key] diff --git a/mineru/backend/vlm/vlm_analyze.py b/mineru/backend/vlm/vlm_analyze.py index c00bdb05..114a6848 100644 --- a/mineru/backend/vlm/vlm_analyze.py +++ b/mineru/backend/vlm/vlm_analyze.py @@ -1,7 +1,10 @@ # Copyright (c) Opendatalab. All rights reserved. +import asyncio import os import time import json +import threading +from contextlib import asynccontextmanager, contextmanager import pypdfium2 as pdfium from loguru import logger @@ -25,10 +28,12 @@ from packaging import version class ModelSingleton: _instance = None _models = {} + _lock = threading.RLock() def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super().__new__(cls) + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) return cls._instance def get_model( @@ -39,188 +44,232 @@ class ModelSingleton: **kwargs, ) -> MinerUClient: key = (backend, model_path, server_url) - if key not in self._models: - start_time = time.time() - model = None - processor = None - vllm_llm = None - lmdeploy_engine = None - vllm_async_llm = None - batch_size = kwargs.get("batch_size", 0) # for transformers backend only - max_concurrency = kwargs.get("max_concurrency", 100) # for http-client backend only - http_timeout = kwargs.get("http_timeout", 600) # for http-client backend only - server_headers = kwargs.get("server_headers", None) # for http-client backend only - max_retries = kwargs.get("max_retries", 3) # for http-client backend only - retry_backoff_factor = kwargs.get("retry_backoff_factor", 0.5) # for http-client backend only - # 从kwargs中移除这些参数,避免传递给不相关的初始化函数 - for param in ["batch_size", "max_concurrency", "http_timeout", "server_headers", "max_retries", "retry_backoff_factor"]: - if param in kwargs: - del kwargs[param] - if backend not in ["http-client"] and not model_path: - model_path = auto_download_and_get_model_root_path("/","vlm") - if backend == "transformers": - try: - from transformers import ( - AutoProcessor, - Qwen2VLForConditionalGeneration, - ) - from transformers import __version__ as transformers_version - except ImportError: - raise ImportError("Please install transformers to use the transformers backend.") - - if version.parse(transformers_version) >= version.parse("4.56.0"): - dtype_key = "dtype" - else: - dtype_key = "torch_dtype" - device = get_device() - model = Qwen2VLForConditionalGeneration.from_pretrained( - model_path, - device_map={"": device}, - **{dtype_key: "auto"}, # type: ignore - ) - processor = AutoProcessor.from_pretrained( - model_path, - use_fast=True, - ) - if batch_size == 0: - batch_size = set_default_batch_size() - elif backend == "mlx-engine": - mlx_supported = is_mac_os_version_supported() - if not mlx_supported: - raise EnvironmentError("mlx-engine backend is only supported on macOS 13.5+ with Apple Silicon.") - try: - from mlx_vlm import load as mlx_load - except ImportError: - raise ImportError("Please install mlx-vlm to use the mlx-engine backend.") - model, processor = mlx_load(model_path) - else: - if os.getenv('OMP_NUM_THREADS') is None: - os.environ["OMP_NUM_THREADS"] = "1" - - if backend == "vllm-engine": + with self._lock: + if key not in self._models: + start_time = time.time() + model = None + processor = None + vllm_llm = None + lmdeploy_engine = None + vllm_async_llm = None + batch_size = kwargs.get("batch_size", 0) # for transformers backend only + max_concurrency = kwargs.get("max_concurrency", 100) # for http-client backend only + http_timeout = kwargs.get("http_timeout", 600) # for http-client backend only + server_headers = kwargs.get("server_headers", None) # for http-client backend only + max_retries = kwargs.get("max_retries", 3) # for http-client backend only + retry_backoff_factor = kwargs.get("retry_backoff_factor", 0.5) # for http-client backend only + # 从kwargs中移除这些参数,避免传递给不相关的初始化函数 + for param in ["batch_size", "max_concurrency", "http_timeout", "server_headers", "max_retries", "retry_backoff_factor"]: + if param in kwargs: + del kwargs[param] + if backend not in ["http-client"] and not model_path: + model_path = auto_download_and_get_model_root_path("/","vlm") + if backend == "transformers": try: - import vllm + from transformers import ( + AutoProcessor, + Qwen2VLForConditionalGeneration, + ) + from transformers import __version__ as transformers_version except ImportError: - raise ImportError("Please install vllm to use the vllm-engine backend.") + raise ImportError("Please install transformers to use the transformers backend.") - kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="sync_engine") - - if "compilation_config" in kwargs: - if isinstance(kwargs["compilation_config"], str): - try: - kwargs["compilation_config"] = json.loads(kwargs["compilation_config"]) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}") - del kwargs["compilation_config"] - if "gpu_memory_utilization" not in kwargs: - kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization() - if "model" not in kwargs: - kwargs["model"] = model_path - if enable_custom_logits_processors() and ("logits_processors" not in kwargs): - from mineru_vl_utils import MinerULogitsProcessor - kwargs["logits_processors"] = [MinerULogitsProcessor] - # 使用kwargs为 vllm初始化参数 - vllm_llm = vllm.LLM(**kwargs) - elif backend == "vllm-async-engine": - try: - from vllm.engine.arg_utils import AsyncEngineArgs - from vllm.v1.engine.async_llm import AsyncLLM - from vllm.config import CompilationConfig - except ImportError: - raise ImportError("Please install vllm to use the vllm-async-engine backend.") - - kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="async_engine") - - if "compilation_config" in kwargs: - if isinstance(kwargs["compilation_config"], dict): - # 如果是字典,转换为 CompilationConfig 对象 - kwargs["compilation_config"] = CompilationConfig(**kwargs["compilation_config"]) - elif isinstance(kwargs["compilation_config"], str): - # 如果是 JSON 字符串,先解析再转换 - try: - config_dict = json.loads(kwargs["compilation_config"]) - kwargs["compilation_config"] = CompilationConfig(**config_dict) - except (json.JSONDecodeError, TypeError) as e: - logger.warning( - f"Failed to parse compilation_config: {kwargs['compilation_config']}, error: {e}") - del kwargs["compilation_config"] - if "gpu_memory_utilization" not in kwargs: - kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization() - if "model" not in kwargs: - kwargs["model"] = model_path - if enable_custom_logits_processors() and ("logits_processors" not in kwargs): - from mineru_vl_utils import MinerULogitsProcessor - kwargs["logits_processors"] = [MinerULogitsProcessor] - # 使用kwargs为 vllm初始化参数 - vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs)) - elif backend == "lmdeploy-engine": - try: - from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig - from lmdeploy.serve.vl_async_engine import VLAsyncEngine - except ImportError: - raise ImportError("Please install lmdeploy to use the lmdeploy-engine backend.") - if "cache_max_entry_count" not in kwargs: - kwargs["cache_max_entry_count"] = 0.5 - - device_type = os.getenv("MINERU_LMDEPLOY_DEVICE", "") - if device_type == "": - if "lmdeploy_device" in kwargs: - device_type = kwargs.pop("lmdeploy_device") - if device_type not in ["cuda", "ascend", "maca", "camb"]: - raise ValueError(f"Unsupported lmdeploy device type: {device_type}") - else: - device_type = "cuda" - lm_backend = os.getenv("MINERU_LMDEPLOY_BACKEND", "") - if lm_backend == "": - if "lmdeploy_backend" in kwargs: - lm_backend = kwargs.pop("lmdeploy_backend") - if lm_backend not in ["pytorch", "turbomind"]: - raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}") - else: - lm_backend = set_lmdeploy_backend(device_type) - logger.info(f"lmdeploy device is: {device_type}, lmdeploy backend is: {lm_backend}") - - if lm_backend == "pytorch": - kwargs["device_type"] = device_type - backend_config = PytorchEngineConfig(**kwargs) - elif lm_backend == "turbomind": - backend_config = TurbomindEngineConfig(**kwargs) + if version.parse(transformers_version) >= version.parse("4.56.0"): + dtype_key = "dtype" else: - raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}") - - log_level = 'ERROR' - from lmdeploy.utils import get_logger - lm_logger = get_logger('lmdeploy') - lm_logger.setLevel(log_level) - if os.getenv('TM_LOG_LEVEL') is None: - os.environ['TM_LOG_LEVEL'] = log_level - - lmdeploy_engine = VLAsyncEngine( + dtype_key = "torch_dtype" + device = get_device() + model = Qwen2VLForConditionalGeneration.from_pretrained( model_path, - backend=lm_backend, - backend_config=backend_config, + device_map={"": device}, + **{dtype_key: "auto"}, # type: ignore ) - self._models[key] = MinerUClient( - backend=backend, - model=model, - processor=processor, - lmdeploy_engine=lmdeploy_engine, - vllm_llm=vllm_llm, - vllm_async_llm=vllm_async_llm, - server_url=server_url, - batch_size=batch_size, - max_concurrency=max_concurrency, - http_timeout=http_timeout, - server_headers=server_headers, - max_retries=max_retries, - retry_backoff_factor=retry_backoff_factor, - ) - elapsed = round(time.time() - start_time, 2) - logger.info(f"get {backend} predictor cost: {elapsed}s") + processor = AutoProcessor.from_pretrained( + model_path, + use_fast=True, + ) + if batch_size == 0: + batch_size = set_default_batch_size() + elif backend == "mlx-engine": + mlx_supported = is_mac_os_version_supported() + if not mlx_supported: + raise EnvironmentError("mlx-engine backend is only supported on macOS 13.5+ with Apple Silicon.") + try: + from mlx_vlm import load as mlx_load + except ImportError: + raise ImportError("Please install mlx-vlm to use the mlx-engine backend.") + model, processor = mlx_load(model_path) + else: + if os.getenv('OMP_NUM_THREADS') is None: + os.environ["OMP_NUM_THREADS"] = "1" + + if backend == "vllm-engine": + try: + import vllm + except ImportError: + raise ImportError("Please install vllm to use the vllm-engine backend.") + + kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="sync_engine") + + if "compilation_config" in kwargs: + if isinstance(kwargs["compilation_config"], str): + try: + kwargs["compilation_config"] = json.loads(kwargs["compilation_config"]) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}") + del kwargs["compilation_config"] + if "gpu_memory_utilization" not in kwargs: + kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization() + if "model" not in kwargs: + kwargs["model"] = model_path + if enable_custom_logits_processors() and ("logits_processors" not in kwargs): + from mineru_vl_utils import MinerULogitsProcessor + kwargs["logits_processors"] = [MinerULogitsProcessor] + # 使用kwargs为 vllm初始化参数 + vllm_llm = vllm.LLM(**kwargs) + elif backend == "vllm-async-engine": + try: + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.v1.engine.async_llm import AsyncLLM + from vllm.config import CompilationConfig + except ImportError: + raise ImportError("Please install vllm to use the vllm-async-engine backend.") + + kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="async_engine") + + if "compilation_config" in kwargs: + if isinstance(kwargs["compilation_config"], dict): + # 如果是字典,转换为 CompilationConfig 对象 + kwargs["compilation_config"] = CompilationConfig(**kwargs["compilation_config"]) + elif isinstance(kwargs["compilation_config"], str): + # 如果是 JSON 字符串,先解析再转换 + try: + config_dict = json.loads(kwargs["compilation_config"]) + kwargs["compilation_config"] = CompilationConfig(**config_dict) + except (json.JSONDecodeError, TypeError) as e: + logger.warning( + f"Failed to parse compilation_config: {kwargs['compilation_config']}, error: {e}") + del kwargs["compilation_config"] + if "gpu_memory_utilization" not in kwargs: + kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization() + if "model" not in kwargs: + kwargs["model"] = model_path + if enable_custom_logits_processors() and ("logits_processors" not in kwargs): + from mineru_vl_utils import MinerULogitsProcessor + kwargs["logits_processors"] = [MinerULogitsProcessor] + # 使用kwargs为 vllm初始化参数 + vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs)) + elif backend == "lmdeploy-engine": + try: + from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig + from lmdeploy.serve.vl_async_engine import VLAsyncEngine + except ImportError: + raise ImportError("Please install lmdeploy to use the lmdeploy-engine backend.") + if "cache_max_entry_count" not in kwargs: + kwargs["cache_max_entry_count"] = 0.5 + + device_type = os.getenv("MINERU_LMDEPLOY_DEVICE", "") + if device_type == "": + if "lmdeploy_device" in kwargs: + device_type = kwargs.pop("lmdeploy_device") + if device_type not in ["cuda", "ascend", "maca", "camb"]: + raise ValueError(f"Unsupported lmdeploy device type: {device_type}") + else: + device_type = "cuda" + lm_backend = os.getenv("MINERU_LMDEPLOY_BACKEND", "") + if lm_backend == "": + if "lmdeploy_backend" in kwargs: + lm_backend = kwargs.pop("lmdeploy_backend") + if lm_backend not in ["pytorch", "turbomind"]: + raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}") + else: + lm_backend = set_lmdeploy_backend(device_type) + logger.info(f"lmdeploy device is: {device_type}, lmdeploy backend is: {lm_backend}") + + if lm_backend == "pytorch": + kwargs["device_type"] = device_type + backend_config = PytorchEngineConfig(**kwargs) + elif lm_backend == "turbomind": + backend_config = TurbomindEngineConfig(**kwargs) + else: + raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}") + + log_level = 'ERROR' + from lmdeploy.utils import get_logger + lm_logger = get_logger('lmdeploy') + lm_logger.setLevel(log_level) + if os.getenv('TM_LOG_LEVEL') is None: + os.environ['TM_LOG_LEVEL'] = log_level + + lmdeploy_engine = VLAsyncEngine( + model_path, + backend=lm_backend, + backend_config=backend_config, + ) + predictor = MinerUClient( + backend=backend, + model=model, + processor=processor, + lmdeploy_engine=lmdeploy_engine, + vllm_llm=vllm_llm, + vllm_async_llm=vllm_async_llm, + server_url=server_url, + batch_size=batch_size, + max_concurrency=max_concurrency, + http_timeout=http_timeout, + server_headers=server_headers, + max_retries=max_retries, + retry_backoff_factor=retry_backoff_factor, + ) + _maybe_enable_serial_execution(predictor, backend) + self._models[key] = predictor + elapsed = round(time.time() - start_time, 2) + logger.info(f"get {backend} predictor cost: {elapsed}s") return self._models[key] +def _predictor_uses_mlx(predictor: MinerUClient, backend: str | None = None) -> bool: + if backend == "mlx-engine": + return True + client = getattr(predictor, "client", None) + return type(client).__module__.endswith(".mlx_client") + + +def _maybe_enable_serial_execution( + predictor: MinerUClient, + backend: str | None = None, +) -> MinerUClient: + if _predictor_uses_mlx(predictor, backend) and not hasattr( + predictor, "_mineru_execution_lock" + ): + predictor._mineru_execution_lock = threading.Lock() + return predictor + + +@contextmanager +def predictor_execution_guard(predictor: MinerUClient): + lock = getattr(predictor, "_mineru_execution_lock", None) + if lock is None: + yield + return + with lock: + yield + + +@asynccontextmanager +async def aio_predictor_execution_guard(predictor: MinerUClient): + lock = getattr(predictor, "_mineru_execution_lock", None) + if lock is None: + yield + return + await asyncio.to_thread(lock.acquire) + try: + yield + finally: + lock.release() + + def _close_images(images_list): for image_dict in images_list or []: pil_img = image_dict.get("img_pil") @@ -242,6 +291,7 @@ def doc_analyze( ): if predictor is None: predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs) + predictor = _maybe_enable_serial_execution(predictor, backend) load_images_start = time.time() images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL) @@ -250,7 +300,8 @@ def doc_analyze( logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s") infer_start = time.time() - results = predictor.batch_two_step_extract(images=images_pil_list) + with predictor_execution_guard(predictor): + results = predictor.batch_two_step_extract(images=images_pil_list) infer_time = round(time.time() - infer_start, 2) logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s") @@ -269,6 +320,7 @@ def doc_analyze_low_memory( ): if predictor is None: predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs) + predictor = _maybe_enable_serial_execution(predictor, backend) pdf_doc = pdfium.PdfDocument(pdf_bytes) middle_json = init_middle_json() @@ -304,7 +356,8 @@ def doc_analyze_low_memory( f'pages {window_start + 1}-{window_end + 1}/{page_count} ' f'({len(images_pil_list)} pages)' ) - window_results = predictor.batch_two_step_extract(images=images_pil_list) + with predictor_execution_guard(predictor): + window_results = predictor.batch_two_step_extract(images=images_pil_list) results.extend(window_results) append_page_blocks_to_middle_json( middle_json, @@ -343,6 +396,7 @@ async def aio_doc_analyze( ): if predictor is None: predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs) + predictor = _maybe_enable_serial_execution(predictor, backend) load_images_start = time.time() images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL) @@ -351,7 +405,8 @@ async def aio_doc_analyze( logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s") infer_start = time.time() - results = await predictor.aio_batch_two_step_extract(images=images_pil_list) + async with aio_predictor_execution_guard(predictor): + results = await predictor.aio_batch_two_step_extract(images=images_pil_list) infer_time = round(time.time() - infer_start, 2) logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s") middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer) @@ -369,6 +424,7 @@ async def aio_doc_analyze_low_memory( ): if predictor is None: predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs) + predictor = _maybe_enable_serial_execution(predictor, backend) pdf_doc = pdfium.PdfDocument(pdf_bytes) middle_json = init_middle_json() @@ -404,7 +460,8 @@ async def aio_doc_analyze_low_memory( f'pages {window_start + 1}-{window_end + 1}/{page_count} ' f'({len(images_pil_list)} pages)' ) - window_results = await predictor.aio_batch_two_step_extract(images=images_pil_list) + async with aio_predictor_execution_guard(predictor): + window_results = await predictor.aio_batch_two_step_extract(images=images_pil_list) results.extend(window_results) append_page_blocks_to_middle_json( middle_json, diff --git a/mineru/cli/common.py b/mineru/cli/common.py index 68660b66..3c9e23f2 100644 --- a/mineru/cli/common.py +++ b/mineru/cli/common.py @@ -3,6 +3,7 @@ import io import json import os import copy +import threading from concurrent.futures import ThreadPoolExecutor from pathlib import Path @@ -37,6 +38,8 @@ office_suffixes = docx_suffixes + pptx_suffixes + xlsx_suffixes os.environ["TOKENIZERS_PARALLELISM"] = "false" +_pdf_rewrite_lock = threading.Lock() + def read_fn(path): if not isinstance(path, Path): path = Path(path) @@ -60,34 +63,46 @@ def prepare_env(output_dir, pdf_file_name, parse_method): def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page_id=None): - pdf = pdfium.PdfDocument(pdf_bytes) - output_pdf = pdfium.PdfDocument.new() - try: - end_page_id = get_end_page_id(end_page_id, len(pdf)) + # pypdfium2 document import/save is not thread-safe across concurrent FastAPI tasks. + with _pdf_rewrite_lock: + pdf = None + output_pdf = None + try: + pdf = pdfium.PdfDocument(pdf_bytes) + page_count = len(pdf) + end_page_id = get_end_page_id(end_page_id, page_count) - # 逐页导入,失败则跳过 - output_index = 0 - for page_index in range(start_page_id, end_page_id + 1): - try: - output_pdf.import_pages(pdf, pages=[page_index]) - output_index += 1 - except Exception as page_error: - output_pdf.del_page(output_index) - logger.warning(f"Failed to import page {page_index}: {page_error}, skipping this page.") - continue + # Avoid rewriting when the caller requests the whole document. + if start_page_id <= 0 and end_page_id >= page_count - 1: + return pdf_bytes - # 将新PDF保存到内存缓冲区 - output_buffer = io.BytesIO() - output_pdf.save(output_buffer) + output_pdf = pdfium.PdfDocument.new() - # 获取字节数据 - output_bytes = output_buffer.getvalue() - except Exception as e: - logger.warning(f"Error in converting PDF bytes: {e}, Using original PDF bytes.") - output_bytes = pdf_bytes - pdf.close() - output_pdf.close() - return output_bytes + # 逐页导入,失败则跳过 + output_index = 0 + for page_index in range(start_page_id, end_page_id + 1): + try: + output_pdf.import_pages(pdf, pages=[page_index]) + output_index += 1 + except Exception as page_error: + output_pdf.del_page(output_index) + logger.warning(f"Failed to import page {page_index}: {page_error}, skipping this page.") + continue + + # 将新PDF保存到内存缓冲区 + output_buffer = io.BytesIO() + output_pdf.save(output_buffer) + + # 获取字节数据 + return output_buffer.getvalue() + except Exception as e: + logger.warning(f"Error in converting PDF bytes: {e}, Using original PDF bytes.") + return pdf_bytes + finally: + if pdf is not None: + pdf.close() + if output_pdf is not None: + output_pdf.close() def _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id): diff --git a/mineru/cli/fast_api.py b/mineru/cli/fast_api.py index 98fe9f43..cfe35fac 100644 --- a/mineru/cli/fast_api.py +++ b/mineru/cli/fast_api.py @@ -5,6 +5,7 @@ import re import shutil import sys import tempfile +import threading import uuid import zipfile from contextlib import asynccontextmanager, suppress @@ -40,6 +41,7 @@ from mineru.cli.common import ( read_fn, ) from mineru.utils.cli_parser import arg_parse +from mineru.utils.config_reader import get_device from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path from mineru.version import __version__ @@ -59,6 +61,7 @@ ALLOWED_PARSE_METHODS = {"auto", "txt", "ocr"} # 并发控制器 _request_semaphore: Optional[asyncio.Semaphore] = None +_mps_parse_lock = threading.Lock() @dataclass @@ -164,10 +167,10 @@ def create_app(): global _request_semaphore try: max_concurrent_requests = int( - os.getenv("MINERU_API_MAX_CONCURRENT_REQUESTS", "0") + os.getenv("MINERU_API_MAX_CONCURRENT_REQUESTS", "3") ) except ValueError: - max_concurrent_requests = 0 + max_concurrent_requests = 3 if max_concurrent_requests > 0: _request_semaphore = asyncio.Semaphore(max_concurrent_requests) @@ -654,12 +657,32 @@ async def run_parse_job( ) if request_options.backend == "pipeline": - await asyncio.to_thread(do_parse, **parse_kwargs) + async with serialize_parse_job_if_needed(request_options.backend): + await asyncio.to_thread(do_parse, **parse_kwargs) else: - await aio_do_parse(**parse_kwargs) + async with serialize_parse_job_if_needed(request_options.backend): + await aio_do_parse(**parse_kwargs) return response_file_names +def should_serialize_parse_job(backend: str) -> bool: + if get_device() != "mps": + return False + return backend == "pipeline" or backend.startswith(("vlm-", "hybrid-")) + + +@asynccontextmanager +async def serialize_parse_job_if_needed(backend: str): + if not should_serialize_parse_job(backend): + yield + return + await asyncio.to_thread(_mps_parse_lock.acquire) + try: + yield + finally: + _mps_parse_lock.release() + + def create_task_output_dir(task_id: str) -> str: output_root = get_output_root() task_output_dir = output_root / task_id diff --git a/mineru/utils/pdf_classify.py b/mineru/utils/pdf_classify.py index b0468d2b..f6103c4c 100644 --- a/mineru/utils/pdf_classify.py +++ b/mineru/utils/pdf_classify.py @@ -1,5 +1,6 @@ # Copyright (c) Opendatalab. All rights reserved. import re +import threading from io import BytesIO import numpy as np import pypdfium2 as pdfium @@ -13,6 +14,8 @@ from pdfminer.pdfinterp import PDFPageInterpreter from pdfminer.layout import LAParams, LTImage, LTFigure from pdfminer.converter import PDFPageAggregator +_pdf_sample_extract_lock = threading.Lock() + def classify(pdf_bytes): """ @@ -27,6 +30,8 @@ def classify(pdf_bytes): # 从字节数据加载PDF sample_pdf_bytes = extract_pages(pdf_bytes) + if not sample_pdf_bytes: + return 'ocr' pdf = pdfium.PdfDocument(sample_pdf_bytes) try: # 获取PDF页数 @@ -187,40 +192,50 @@ def extract_pages(src_pdf_bytes: bytes) -> bytes: bytes: 提取页面后的PDF字节数据 """ - # 从字节数据加载PDF - pdf = pdfium.PdfDocument(src_pdf_bytes) + with _pdf_sample_extract_lock: + pdf = None + sample_docs = None + try: + # 从字节数据加载PDF + pdf = pdfium.PdfDocument(src_pdf_bytes) - # 获取PDF页数 - total_page = len(pdf) - if total_page == 0: - # 如果PDF没有页面,直接返回空文档 - logger.warning("PDF is empty, return empty document") - return b'' + # 获取PDF页数 + total_page = len(pdf) + if total_page == 0: + # 如果PDF没有页面,直接返回空文档 + logger.warning("PDF is empty, return empty document") + return b'' - # 选择最多10页 - select_page_cnt = min(10, total_page) + # 小文档直接复用原始字节,避免无意义的 PDF 重写。 + if total_page <= 10: + return src_pdf_bytes - # 从总页数中随机选择页面 - page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist() + # 选择最多10页 + select_page_cnt = min(10, total_page) - # 创建一个新的PDF文档 - sample_docs = pdfium.PdfDocument.new() + # 从总页数中随机选择页面 + page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist() - try: - # 将选择的页面导入新文档 - sample_docs.import_pages(pdf, page_indices) - pdf.close() + # 创建一个新的PDF文档 + sample_docs = pdfium.PdfDocument.new() - # 将新PDF保存到内存缓冲区 - output_buffer = BytesIO() - sample_docs.save(output_buffer) + # 将选择的页面导入新文档 + sample_docs.import_pages(pdf, page_indices) - # 获取字节数据 - return output_buffer.getvalue() - except Exception as e: - pdf.close() - logger.exception(e) - return b'' # 出错时返回空字节 + # 将新PDF保存到内存缓冲区 + output_buffer = BytesIO() + sample_docs.save(output_buffer) + + # 获取字节数据 + return output_buffer.getvalue() + except Exception as e: + logger.exception(e) + return src_pdf_bytes + finally: + if pdf is not None: + pdf.close() + if sample_docs is not None: + sample_docs.close() def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool: @@ -263,4 +278,4 @@ def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool: if __name__ == '__main__': with open('/Users/myhloli/pdf/luanma2x10.pdf', 'rb') as f: p_bytes = f.read() - logger.info(f"PDF分类结果: {classify(p_bytes)}") \ No newline at end of file + logger.info(f"PDF分类结果: {classify(p_bytes)}")