Files
MinerU/mineru/backend/vlm/vlm_analyze.py

473 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) Opendatalab. All rights reserved.
import asyncio
import os
import time
import json
import threading
from contextlib import asynccontextmanager, contextmanager
import pypdfium2 as pdfium
from loguru import logger
from tqdm import tqdm
from .utils import enable_custom_logits_processors, set_default_gpu_memory_utilization, set_default_batch_size, \
set_lmdeploy_backend, mod_kwargs_by_device_type
from .model_output_to_middle_json import (
append_page_blocks_to_middle_json,
finalize_middle_json,
init_middle_json,
)
from mineru.backend.utils import exclude_progress_bar_idle_time
from ...data.data_reader_writer import DataWriter
from mineru.utils.pdf_image_tools import load_images_from_pdf_doc
from ...utils.check_sys_env import is_mac_os_version_supported
from ...utils.config_reader import get_device, get_processing_window_size
from ...utils.enum_class import ImageType
from ...utils.pdfium_guard import (
close_pdfium_document,
get_pdfium_document_page_count,
open_pdfium_document,
)
from ...utils.models_download_utils import auto_download_and_get_model_root_path
from mineru_vl_utils import MinerUClient
from packaging import version
class ModelSingleton:
_instance = None
_models = {}
_lock = threading.RLock()
def __new__(cls, *args, **kwargs):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
self,
backend: str,
model_path: str | None,
server_url: str | None,
**kwargs,
) -> MinerUClient:
key = (backend, model_path, server_url)
with self._lock:
if key not in self._models:
start_time = time.time()
model = None
processor = None
vllm_llm = None
lmdeploy_engine = None
vllm_async_llm = None
batch_size = kwargs.get("batch_size", 0) # for transformers backend only
max_concurrency = kwargs.get("max_concurrency", 100) # for http-client backend only
http_timeout = kwargs.get("http_timeout", 600) # for http-client backend only
server_headers = kwargs.get("server_headers", None) # for http-client backend only
max_retries = kwargs.get("max_retries", 3) # for http-client backend only
retry_backoff_factor = kwargs.get("retry_backoff_factor", 0.5) # for http-client backend only
# 从kwargs中移除这些参数避免传递给不相关的初始化函数
for param in ["batch_size", "max_concurrency", "http_timeout", "server_headers", "max_retries", "retry_backoff_factor"]:
if param in kwargs:
del kwargs[param]
if backend not in ["http-client"] and not model_path:
model_path = auto_download_and_get_model_root_path("/","vlm")
if backend == "transformers":
try:
from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
)
from transformers import __version__ as transformers_version
except ImportError:
raise ImportError("Please install transformers to use the transformers backend.")
if version.parse(transformers_version) >= version.parse("4.56.0"):
dtype_key = "dtype"
else:
dtype_key = "torch_dtype"
device = get_device()
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map={"": device},
**{dtype_key: "auto"}, # type: ignore
)
processor = AutoProcessor.from_pretrained(
model_path,
use_fast=True,
)
if batch_size == 0:
batch_size = set_default_batch_size()
elif backend == "mlx-engine":
mlx_supported = is_mac_os_version_supported()
if not mlx_supported:
raise EnvironmentError("mlx-engine backend is only supported on macOS 13.5+ with Apple Silicon.")
try:
from mlx_vlm import load as mlx_load
except ImportError:
raise ImportError("Please install mlx-vlm to use the mlx-engine backend.")
model, processor = mlx_load(model_path)
else:
if os.getenv('OMP_NUM_THREADS') is None:
os.environ["OMP_NUM_THREADS"] = "1"
if backend == "vllm-engine":
try:
import vllm
except ImportError:
raise ImportError("Please install vllm to use the vllm-engine backend.")
kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="sync_engine")
if "compilation_config" in kwargs:
if isinstance(kwargs["compilation_config"], str):
try:
kwargs["compilation_config"] = json.loads(kwargs["compilation_config"])
except json.JSONDecodeError:
logger.warning(
f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
del kwargs["compilation_config"]
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
if "model" not in kwargs:
kwargs["model"] = model_path
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
from mineru_vl_utils import MinerULogitsProcessor
kwargs["logits_processors"] = [MinerULogitsProcessor]
# 使用kwargs为 vllm初始化参数
vllm_llm = vllm.LLM(**kwargs)
elif backend == "vllm-async-engine":
try:
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.config import CompilationConfig
except ImportError:
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
kwargs = mod_kwargs_by_device_type(kwargs, vllm_mode="async_engine")
if "compilation_config" in kwargs:
if isinstance(kwargs["compilation_config"], dict):
# 如果是字典,转换为 CompilationConfig 对象
kwargs["compilation_config"] = CompilationConfig(**kwargs["compilation_config"])
elif isinstance(kwargs["compilation_config"], str):
# 如果是 JSON 字符串,先解析再转换
try:
config_dict = json.loads(kwargs["compilation_config"])
kwargs["compilation_config"] = CompilationConfig(**config_dict)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(
f"Failed to parse compilation_config: {kwargs['compilation_config']}, error: {e}")
del kwargs["compilation_config"]
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
if "model" not in kwargs:
kwargs["model"] = model_path
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
from mineru_vl_utils import MinerULogitsProcessor
kwargs["logits_processors"] = [MinerULogitsProcessor]
# 使用kwargs为 vllm初始化参数
vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
elif backend == "lmdeploy-engine":
try:
from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.serve.vl_async_engine import VLAsyncEngine
except ImportError:
raise ImportError("Please install lmdeploy to use the lmdeploy-engine backend.")
if "cache_max_entry_count" not in kwargs:
kwargs["cache_max_entry_count"] = 0.5
device_type = os.getenv("MINERU_LMDEPLOY_DEVICE", "")
if device_type == "":
if "lmdeploy_device" in kwargs:
device_type = kwargs.pop("lmdeploy_device")
if device_type not in ["cuda", "ascend", "maca", "camb"]:
raise ValueError(f"Unsupported lmdeploy device type: {device_type}")
else:
device_type = "cuda"
lm_backend = os.getenv("MINERU_LMDEPLOY_BACKEND", "")
if lm_backend == "":
if "lmdeploy_backend" in kwargs:
lm_backend = kwargs.pop("lmdeploy_backend")
if lm_backend not in ["pytorch", "turbomind"]:
raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
else:
lm_backend = set_lmdeploy_backend(device_type)
logger.info(f"lmdeploy device is: {device_type}, lmdeploy backend is: {lm_backend}")
if lm_backend == "pytorch":
kwargs["device_type"] = device_type
backend_config = PytorchEngineConfig(**kwargs)
elif lm_backend == "turbomind":
backend_config = TurbomindEngineConfig(**kwargs)
else:
raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
log_level = 'ERROR'
from lmdeploy.utils import get_logger
lm_logger = get_logger('lmdeploy')
lm_logger.setLevel(log_level)
if os.getenv('TM_LOG_LEVEL') is None:
os.environ['TM_LOG_LEVEL'] = log_level
lmdeploy_engine = VLAsyncEngine(
model_path,
backend=lm_backend,
backend_config=backend_config,
)
predictor = MinerUClient(
backend=backend,
model=model,
processor=processor,
lmdeploy_engine=lmdeploy_engine,
vllm_llm=vllm_llm,
vllm_async_llm=vllm_async_llm,
server_url=server_url,
batch_size=batch_size,
max_concurrency=max_concurrency,
http_timeout=http_timeout,
server_headers=server_headers,
max_retries=max_retries,
retry_backoff_factor=retry_backoff_factor,
)
_maybe_enable_serial_execution(predictor, backend)
self._models[key] = predictor
elapsed = round(time.time() - start_time, 2)
logger.info(f"get {backend} predictor cost: {elapsed}s")
return self._models[key]
def _predictor_uses_mlx(predictor: MinerUClient, backend: str | None = None) -> bool:
if backend == "mlx-engine":
return True
client = getattr(predictor, "client", None)
return type(client).__module__.endswith(".mlx_client")
def _maybe_enable_serial_execution(
predictor: MinerUClient,
backend: str | None = None,
) -> MinerUClient:
if _predictor_uses_mlx(predictor, backend) and not hasattr(
predictor, "_mineru_execution_lock"
):
predictor._mineru_execution_lock = threading.Lock()
return predictor
@contextmanager
def predictor_execution_guard(predictor: MinerUClient):
lock = getattr(predictor, "_mineru_execution_lock", None)
if lock is None:
yield
return
with lock:
yield
@asynccontextmanager
async def aio_predictor_execution_guard(predictor: MinerUClient):
lock = getattr(predictor, "_mineru_execution_lock", None)
if lock is None:
yield
return
await asyncio.to_thread(lock.acquire)
try:
yield
finally:
lock.release()
def _close_images(images_list):
for image_dict in images_list or []:
pil_img = image_dict.get("img_pil")
if pil_img is not None:
try:
pil_img.close()
except Exception:
pass
def doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
predictor: MinerUClient | None = None,
backend="transformers",
model_path: str | None = None,
server_url: str | None = None,
**kwargs,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
predictor = _maybe_enable_serial_execution(predictor, backend)
pdf_doc = open_pdfium_document(pdfium.PdfDocument, pdf_bytes)
middle_json = init_middle_json()
results = []
doc_closed = False
try:
page_count = get_pdfium_document_page_count(pdf_doc)
configured_window_size = get_processing_window_size(default=64)
effective_window_size = min(page_count, configured_window_size) if page_count else 0
total_windows = (
(page_count + effective_window_size - 1) // effective_window_size
if effective_window_size
else 0
)
logger.info(
f'VLM processing-window run. page_count={page_count}, '
f'window_size={configured_window_size}, total_windows={total_windows}'
)
infer_start = time.time()
progress_bar = None
last_append_end_time = None
try:
for window_index, window_start in enumerate(range(0, page_count, effective_window_size or 1)):
window_end = min(page_count - 1, window_start + effective_window_size - 1)
images_list = load_images_from_pdf_doc(
pdf_doc,
start_page_id=window_start,
end_page_id=window_end,
image_type=ImageType.PIL,
)
try:
images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
logger.info(
f'VLM processing window {window_index + 1}/{total_windows}: '
f'pages {window_start + 1}-{window_end + 1}/{page_count} '
f'({len(images_pil_list)} pages)'
)
with predictor_execution_guard(predictor):
window_results = predictor.batch_two_step_extract(images=images_pil_list)
results.extend(window_results)
if progress_bar is None:
progress_bar = tqdm(total=page_count, desc="Processing pages")
else:
exclude_progress_bar_idle_time(
progress_bar,
last_append_end_time,
now=time.time(),
)
append_page_blocks_to_middle_json(
middle_json,
window_results,
images_list,
pdf_doc,
image_writer,
page_start_index=window_start,
progress_bar=progress_bar,
)
last_append_end_time = time.time()
finally:
_close_images(images_list)
finally:
if progress_bar is not None:
progress_bar.close()
infer_time = round(time.time() - infer_start, 2)
if infer_time > 0 and page_count > 0:
logger.debug(
f"processing-window infer finished, cost: {infer_time}, "
f"speed: {round(len(results) / infer_time, 3)} page/s"
)
finalize_middle_json(middle_json["pdf_info"])
close_pdfium_document(pdf_doc)
doc_closed = True
return middle_json, results
finally:
if not doc_closed:
close_pdfium_document(pdf_doc)
async def aio_doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
predictor: MinerUClient | None = None,
backend="transformers",
model_path: str | None = None,
server_url: str | None = None,
**kwargs,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
predictor = _maybe_enable_serial_execution(predictor, backend)
pdf_doc = open_pdfium_document(pdfium.PdfDocument, pdf_bytes)
middle_json = init_middle_json()
results = []
doc_closed = False
try:
page_count = get_pdfium_document_page_count(pdf_doc)
configured_window_size = get_processing_window_size(default=64)
effective_window_size = min(page_count, configured_window_size) if page_count else 0
total_windows = (
(page_count + effective_window_size - 1) // effective_window_size
if effective_window_size
else 0
)
logger.info(
f'VLM processing-window run. page_count={page_count}, '
f'window_size={configured_window_size}, total_windows={total_windows}'
)
infer_start = time.time()
progress_bar = None
last_append_end_time = None
try:
for window_index, window_start in enumerate(range(0, page_count, effective_window_size or 1)):
window_end = min(page_count - 1, window_start + effective_window_size - 1)
images_list = load_images_from_pdf_doc(
pdf_doc,
start_page_id=window_start,
end_page_id=window_end,
image_type=ImageType.PIL,
)
try:
images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
logger.info(
f'VLM processing window {window_index + 1}/{total_windows}: '
f'pages {window_start + 1}-{window_end + 1}/{page_count} '
f'({len(images_pil_list)} pages)'
)
async with aio_predictor_execution_guard(predictor):
window_results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
results.extend(window_results)
if progress_bar is None:
progress_bar = tqdm(total=page_count, desc="Processing pages")
else:
exclude_progress_bar_idle_time(
progress_bar,
last_append_end_time,
now=time.time(),
)
append_page_blocks_to_middle_json(
middle_json,
window_results,
images_list,
pdf_doc,
image_writer,
page_start_index=window_start,
progress_bar=progress_bar,
)
last_append_end_time = time.time()
finally:
_close_images(images_list)
finally:
if progress_bar is not None:
progress_bar.close()
infer_time = round(time.time() - infer_start, 2)
if infer_time > 0 and page_count > 0:
logger.debug(
f"processing-window infer finished, cost: {infer_time}, "
f"speed: {round(len(results) / infer_time, 3)} page/s"
)
finalize_middle_json(middle_json["pdf_info"])
close_pdfium_document(pdf_doc)
doc_closed = True
return middle_json, results
finally:
if not doc_closed:
close_pdfium_document(pdf_doc)