Files
MinerU/mineru/cli/api_client.py

526 lines
17 KiB
Python

import asyncio
import atexit
import json
import mimetypes
import os
import socket
import subprocess
import sys
import tempfile
import threading
import time
import zipfile
from contextlib import ExitStack
from dataclasses import dataclass
from pathlib import Path, PurePosixPath
from typing import Callable, Optional, Sequence
import click
import httpx
from loguru import logger
from mineru.cli.api_protocol import (
API_PROTOCOL_VERSION,
DEFAULT_MAX_CONCURRENT_REQUESTS,
get_max_concurrent_requests as read_max_concurrent_requests,
)
HEALTH_ENDPOINT = "/health"
TASKS_ENDPOINT = "/tasks"
TASK_STATUS_POLL_INTERVAL_SECONDS = 1.0
TASK_RESULT_TIMEOUT_SECONDS = 3600
LOCAL_API_STARTUP_TIMEOUT_SECONDS = 30
LOCAL_API_CLEANUP_RETRIES = 8
LOCAL_API_CLEANUP_RETRY_INTERVAL_SECONDS = 0.25
@dataclass(frozen=True)
class UploadAsset:
path: Path
upload_name: str
@dataclass(frozen=True)
class ServerHealth:
base_url: str
max_concurrent_requests: int
processing_window_size: int
@dataclass(frozen=True)
class SubmitResponse:
task_id: str
status_url: str
result_url: str
file_names: tuple[str, ...] = ()
queued_ahead: int | None = None
@dataclass(frozen=True)
class TaskStatusSnapshot:
status: str
queued_ahead: int | None = None
class LocalAPIServer:
def __init__(self, extra_cli_args: Sequence[str] = ()):
self.temp_dir = tempfile.TemporaryDirectory(prefix="mineru-api-client-")
self.temp_root = Path(self.temp_dir.name)
self.output_root = self.temp_root / "output"
self.base_url: str | None = None
self.process: subprocess.Popen[bytes] | None = None
self._atexit_registered = False
self.extra_cli_args = tuple(extra_cli_args)
def start(self) -> str:
if self.process is not None:
raise RuntimeError("Local API server is already running")
port = find_free_port()
self.base_url = f"http://127.0.0.1:{port}"
env = os.environ.copy()
env["MINERU_API_OUTPUT_ROOT"] = str(self.output_root)
env["MINERU_API_MAX_CONCURRENT_REQUESTS"] = str(
read_max_concurrent_requests(default=DEFAULT_MAX_CONCURRENT_REQUESTS)
)
env["MINERU_API_DISABLE_ACCESS_LOG"] = "1"
self.output_root.mkdir(parents=True, exist_ok=True)
command = [
sys.executable,
"-m",
"mineru.cli.fast_api",
"--host",
"127.0.0.1",
"--port",
str(port),
*self.extra_cli_args,
]
self.process = subprocess.Popen(
command,
cwd=os.getcwd(),
env=env,
)
if not self._atexit_registered:
atexit.register(self.stop)
self._atexit_registered = True
return self.base_url
def stop(self) -> None:
process = self.process
self.process = None
try:
if process is not None and process.poll() is None:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait(timeout=5)
finally:
if self._atexit_registered:
try:
atexit.unregister(self.stop)
except Exception:
pass
self._atexit_registered = False
self._cleanup_temp_dir()
def _cleanup_temp_dir(self) -> None:
last_error: Exception | None = None
for attempt in range(LOCAL_API_CLEANUP_RETRIES):
try:
self.temp_dir.cleanup()
return
except FileNotFoundError:
return
except Exception as exc:
last_error = exc
if attempt + 1 < LOCAL_API_CLEANUP_RETRIES:
time.sleep(LOCAL_API_CLEANUP_RETRY_INTERVAL_SECONDS)
if last_error is not None:
logger.warning(
"Failed to clean up temporary MinerU API directory {}: {}. "
"You can remove it manually after processes release any open handles.",
self.temp_root,
last_error,
)
class ReusableLocalAPIServer:
def __init__(self, extra_cli_args: Sequence[str] = ()):
self._lock = threading.Lock()
self._server: LocalAPIServer | None = None
self._extra_cli_args = tuple(extra_cli_args)
def configure(self, extra_cli_args: Sequence[str]) -> None:
with self._lock:
self._extra_cli_args = tuple(extra_cli_args)
server = self._server
if server is None:
return
if server.process is not None and server.process.poll() is None:
return
self._server = None
def ensure_started(self) -> tuple[LocalAPIServer, bool]:
with self._lock:
server = self._server
if server is not None and server.process is not None and server.process.poll() is None:
return server, False
if server is not None:
server.stop()
server = LocalAPIServer(extra_cli_args=self._extra_cli_args)
server.start()
self._server = server
return server, True
def stop(self) -> None:
with self._lock:
server = self._server
self._server = None
if server is not None:
server.stop()
def build_http_timeout() -> httpx.Timeout:
return httpx.Timeout(connect=10, read=60, write=300, pool=30)
def find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
sock.listen(1)
return int(sock.getsockname()[1])
def normalize_base_url(url: str) -> str:
return url.rstrip("/")
def resolve_effective_max_concurrent_requests(
local_max: int,
server_max: int,
) -> int:
if local_max > 0 and server_max > 0:
return min(local_max, server_max)
if local_max > 0:
return local_max
if server_max > 0:
return server_max
return 0
def response_detail(response: httpx.Response) -> str:
try:
payload = response.json()
except Exception:
text = response.text.strip()
return text or response.reason_phrase
if isinstance(payload, dict):
detail = payload.get("detail")
if isinstance(detail, str):
return detail
error = payload.get("error")
if isinstance(error, str):
return error
message = payload.get("message")
if isinstance(message, str):
return message
return json.dumps(payload, ensure_ascii=False)
def validate_server_health_payload(payload: dict, base_url: str) -> ServerHealth:
status = payload.get("status")
if status != "healthy":
raise click.ClickException(
f"MinerU API at {base_url} is not healthy: {json.dumps(payload, ensure_ascii=False)}"
)
protocol_version = payload.get("protocol_version")
if protocol_version != API_PROTOCOL_VERSION:
raise click.ClickException(
f"MinerU API at {base_url} returned protocol_version={protocol_version}, "
f"expected {API_PROTOCOL_VERSION}"
)
max_concurrent_requests = payload.get("max_concurrent_requests")
processing_window_size = payload.get("processing_window_size")
if not isinstance(max_concurrent_requests, int):
raise click.ClickException(
f"MinerU API at {base_url} did not return a valid max_concurrent_requests"
)
if not isinstance(processing_window_size, int):
raise click.ClickException(
f"MinerU API at {base_url} did not return a valid processing_window_size"
)
return ServerHealth(
base_url=base_url,
max_concurrent_requests=max_concurrent_requests,
processing_window_size=max(1, processing_window_size),
)
async def fetch_server_health(
client: httpx.AsyncClient,
base_url: str,
) -> ServerHealth:
response = await client.get(f"{base_url}{HEALTH_ENDPOINT}")
if response.status_code != 200:
raise click.ClickException(
f"Failed to query MinerU API health from {base_url}: "
f"{response.status_code} {response_detail(response)}"
)
return validate_server_health_payload(response.json(), base_url)
async def wait_for_local_api_ready(
client: httpx.AsyncClient,
local_server: LocalAPIServer,
timeout_seconds: float = LOCAL_API_STARTUP_TIMEOUT_SECONDS,
) -> ServerHealth:
assert local_server.base_url is not None
deadline = asyncio.get_running_loop().time() + timeout_seconds
last_error: str | None = None
while asyncio.get_running_loop().time() < deadline:
process = local_server.process
if process is not None and process.poll() is not None:
raise click.ClickException(
"Local mineru-api exited before becoming healthy."
)
try:
return await fetch_server_health(client, local_server.base_url)
except click.ClickException as exc:
last_error = str(exc)
except httpx.HTTPError as exc:
last_error = str(exc)
await asyncio.sleep(TASK_STATUS_POLL_INTERVAL_SECONDS)
message = "Timed out waiting for local mineru-api to become healthy."
if last_error:
message = f"{message} {last_error}"
raise click.ClickException(message)
def build_parse_request_form_data(
lang_list: Sequence[str],
backend: str,
parse_method: str,
formula_enable: bool,
table_enable: bool,
server_url: Optional[str],
start_page_id: int,
end_page_id: Optional[int],
*,
return_md: bool,
return_middle_json: bool,
return_model_output: bool,
return_content_list: bool,
return_images: bool,
response_format_zip: bool,
return_original_file: bool,
) -> dict[str, str | list[str]]:
effective_lang_list = list(lang_list) or ["ch"]
data: dict[str, str | list[str]] = {
"lang_list": effective_lang_list,
"backend": backend,
"parse_method": parse_method,
"formula_enable": str(formula_enable).lower(),
"table_enable": str(table_enable).lower(),
"return_md": str(return_md).lower(),
"return_middle_json": str(return_middle_json).lower(),
"return_model_output": str(return_model_output).lower(),
"return_content_list": str(return_content_list).lower(),
"return_images": str(return_images).lower(),
"response_format_zip": str(response_format_zip).lower(),
"return_original_file": str(return_original_file).lower(),
"start_page_id": str(start_page_id),
"end_page_id": str(99999 if end_page_id is None else end_page_id),
}
if server_url:
data["server_url"] = server_url
return data
async def submit_parse_task(
base_url: str,
upload_assets: Sequence[UploadAsset],
form_data: dict[str, str | list[str]],
) -> SubmitResponse:
return await asyncio.to_thread(
submit_parse_task_sync,
base_url,
upload_assets,
form_data,
)
def submit_parse_task_sync(
base_url: str,
upload_assets: Sequence[UploadAsset],
form_data: dict[str, str | list[str]],
) -> SubmitResponse:
with httpx.Client(timeout=build_http_timeout(), follow_redirects=True) as sync_client:
with ExitStack() as stack:
files = []
for upload_asset in upload_assets:
mime_type = (
mimetypes.guess_type(upload_asset.upload_name)[0]
or "application/octet-stream"
)
file_handle = stack.enter_context(open(upload_asset.path, "rb"))
files.append(
(
"files",
(
upload_asset.upload_name,
file_handle,
mime_type,
),
)
)
response = sync_client.post(
f"{base_url}{TASKS_ENDPOINT}",
data=form_data,
files=files,
)
if response.status_code != 202:
raise click.ClickException(
f"Failed to submit parsing task: "
f"{response.status_code} {response_detail(response)}"
)
payload = response.json()
task_id = payload.get("task_id")
status_url = payload.get("status_url")
result_url = payload.get("result_url")
file_names = payload.get("file_names")
queued_ahead = payload.get("queued_ahead")
if (
not isinstance(task_id, str)
or not isinstance(status_url, str)
or not isinstance(result_url, str)
):
raise click.ClickException("MinerU API returned an invalid task payload")
normalized_file_names: tuple[str, ...] = ()
if isinstance(file_names, list) and all(isinstance(name, str) for name in file_names):
normalized_file_names = tuple(file_names)
if not isinstance(queued_ahead, int):
queued_ahead = None
return SubmitResponse(
task_id=task_id,
status_url=status_url,
result_url=result_url,
file_names=normalized_file_names,
queued_ahead=queued_ahead,
)
async def wait_for_task_result(
client: httpx.AsyncClient,
submit_response: SubmitResponse,
task_label: str,
*,
status_callback: Optional[Callable[[str], None]] = None,
status_snapshot_callback: Optional[Callable[[TaskStatusSnapshot], None]] = None,
timeout_seconds: float = TASK_RESULT_TIMEOUT_SECONDS,
) -> None:
deadline = asyncio.get_running_loop().time() + timeout_seconds
while asyncio.get_running_loop().time() < deadline:
response = await client.get(submit_response.status_url)
if response.status_code != 200:
raise click.ClickException(
f"Failed to query task status for {task_label}: "
f"{response.status_code} {response_detail(response)}"
)
payload = response.json()
status = payload.get("status")
if status in {"pending", "processing"}:
queued_ahead = payload.get("queued_ahead")
if not isinstance(queued_ahead, int):
queued_ahead = None
if status_snapshot_callback is not None:
status_snapshot_callback(
TaskStatusSnapshot(
status=status,
queued_ahead=queued_ahead,
)
)
if status_callback is not None:
status_callback(status)
await asyncio.sleep(TASK_STATUS_POLL_INTERVAL_SECONDS)
continue
if status == "completed":
return
raise click.ClickException(
f"Task {submit_response.task_id} failed for {task_label}: "
f"{json.dumps(payload, ensure_ascii=False)}"
)
raise click.ClickException(
f"Timed out waiting for result of task {submit_response.task_id} "
f"for {task_label}"
)
async def download_result_zip(
client: httpx.AsyncClient,
submit_response: SubmitResponse,
task_label: str,
) -> Path:
response = await client.get(submit_response.result_url)
if response.status_code != 200:
raise click.ClickException(
f"Failed to download result ZIP for task {submit_response.task_id}: "
f"{response.status_code} {response_detail(response)}"
)
content_type = response.headers.get("content-type", "")
if "application/zip" not in content_type:
raise click.ClickException(
f"Expected a ZIP result for {task_label}, "
f"got content-type={content_type or 'unknown'}"
)
zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="mineru_cli_result_")
os.close(zip_fd)
Path(zip_path).write_bytes(response.content)
return Path(zip_path)
def safe_extract_zip(zip_path: Path, output_dir: Path) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
output_root = output_dir.resolve()
with zipfile.ZipFile(zip_path, "r") as zip_file:
for member in zip_file.infolist():
member_path = PurePosixPath(member.filename)
if member_path.is_absolute() or ".." in member_path.parts:
raise click.ClickException(
f"Refusing to extract unsafe ZIP entry: {member.filename}"
)
target_path = (output_root / Path(*member_path.parts)).resolve()
if target_path != output_root and output_root not in target_path.parents:
raise click.ClickException(
f"Refusing to extract unsafe ZIP entry: {member.filename}"
)
if member.is_dir():
target_path.mkdir(parents=True, exist_ok=True)
continue
target_path.parent.mkdir(parents=True, exist_ok=True)
with zip_file.open(member, "r") as source, open(target_path, "wb") as handle:
handle.write(source.read())