mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 11:08:32 +07:00
feat: implement live task status rendering with thread-safe output handling
This commit is contained in:
@@ -8,13 +8,14 @@ import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import zipfile
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Awaitable, Callable, Optional
|
||||
from typing import Awaitable, Callable, Optional, TextIO
|
||||
|
||||
import click
|
||||
import httpx
|
||||
@@ -48,8 +49,6 @@ from .visualization import (
|
||||
|
||||
os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1"
|
||||
log_level = os.getenv("MINERU_LOG_LEVEL", "INFO").upper()
|
||||
logger.remove()
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
HEALTH_ENDPOINT = "/health"
|
||||
TASKS_ENDPOINT = "/tasks"
|
||||
@@ -113,6 +112,150 @@ class TaskFailure:
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveTaskStatusState:
|
||||
task_index: int
|
||||
task_id: str
|
||||
status: str
|
||||
frame_step: int = 0
|
||||
|
||||
|
||||
class LiveAwareStderrSink:
|
||||
def __init__(self, stream: TextIO):
|
||||
self.stream = stream
|
||||
self.lock = threading.RLock()
|
||||
self.renderer: LiveTaskStatusRenderer | None = None
|
||||
|
||||
def set_renderer(self, renderer: "LiveTaskStatusRenderer | None") -> None:
|
||||
with self.lock:
|
||||
self.renderer = renderer
|
||||
|
||||
def isatty(self) -> bool:
|
||||
return bool(getattr(self.stream, "isatty", lambda: False)())
|
||||
|
||||
def write(self, message: str) -> None:
|
||||
with self.lock:
|
||||
renderer = self.renderer
|
||||
if renderer is not None:
|
||||
renderer.clear_locked()
|
||||
self.stream.write(message)
|
||||
self.stream.flush()
|
||||
if renderer is not None:
|
||||
renderer.render_locked()
|
||||
|
||||
def flush(self) -> None:
|
||||
self.stream.flush()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.flush()
|
||||
|
||||
|
||||
class LiveTaskStatusRenderer:
|
||||
BAR_WIDTH = 12
|
||||
RUNNER_WIDTH = 4
|
||||
ACTIVE_STATUSES = {"pending", "processing"}
|
||||
|
||||
def __init__(self, sink: LiveAwareStderrSink):
|
||||
self.sink = sink
|
||||
self._rendered_line_count = 0
|
||||
self._task_states: dict[str, LiveTaskStatusState] = {}
|
||||
|
||||
def register_task(self, task: PlannedTask, task_id: str) -> None:
|
||||
with self.sink.lock:
|
||||
self._task_states[task_id] = LiveTaskStatusState(
|
||||
task_index=task.index,
|
||||
task_id=task_id,
|
||||
status="pending",
|
||||
)
|
||||
self.render_locked()
|
||||
|
||||
def update_status(self, task_id: str, status: str) -> None:
|
||||
with self.sink.lock:
|
||||
state = self._task_states.get(task_id)
|
||||
if state is None:
|
||||
return
|
||||
state.status = status
|
||||
if status in self.ACTIVE_STATUSES:
|
||||
state.frame_step += 1
|
||||
self.render_locked()
|
||||
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
with self.sink.lock:
|
||||
if self._task_states.pop(task_id, None) is None:
|
||||
return
|
||||
self.render_locked()
|
||||
|
||||
def close(self) -> None:
|
||||
with self.sink.lock:
|
||||
self._task_states.clear()
|
||||
self.clear_locked()
|
||||
|
||||
def snapshot_lines(self) -> list[str]:
|
||||
with self.sink.lock:
|
||||
return self._build_render_lines_locked()
|
||||
|
||||
def clear_locked(self) -> None:
|
||||
if self._rendered_line_count <= 0:
|
||||
return
|
||||
|
||||
self.sink.stream.write(f"\x1b[{self._rendered_line_count}A\r")
|
||||
for index in range(self._rendered_line_count):
|
||||
self.sink.stream.write("\x1b[2K")
|
||||
if index + 1 < self._rendered_line_count:
|
||||
self.sink.stream.write("\x1b[1B\r")
|
||||
if self._rendered_line_count > 1:
|
||||
self.sink.stream.write(f"\x1b[{self._rendered_line_count - 1}A\r")
|
||||
self.sink.stream.flush()
|
||||
self._rendered_line_count = 0
|
||||
|
||||
def render_locked(self) -> None:
|
||||
self.clear_locked()
|
||||
lines = self._build_render_lines_locked()
|
||||
if lines:
|
||||
self.sink.stream.write("\n".join(lines))
|
||||
self.sink.stream.write("\n")
|
||||
self.sink.stream.flush()
|
||||
self._rendered_line_count = len(lines)
|
||||
|
||||
def _build_render_lines_locked(self) -> list[str]:
|
||||
states = sorted(
|
||||
self._task_states.values(),
|
||||
key=lambda state: (state.task_index, state.task_id),
|
||||
)
|
||||
return [
|
||||
f"{self._build_bar(state.frame_step)} status={state.status} | task_id={state.task_id}"
|
||||
for state in states
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _build_bar(cls, frame_step: int) -> str:
|
||||
cells = [" "] * cls.BAR_WIDTH
|
||||
runner_start = frame_step % cls.BAR_WIDTH
|
||||
for offset in range(cls.RUNNER_WIDTH):
|
||||
position = (runner_start + offset) % cls.BAR_WIDTH
|
||||
cells[position] = "="
|
||||
head_position = (runner_start + cls.RUNNER_WIDTH - 1) % cls.BAR_WIDTH
|
||||
cells[head_position] = ">"
|
||||
return f"[{''.join(cells)}]"
|
||||
|
||||
|
||||
def create_live_task_status_renderer(
|
||||
api_url: Optional[str],
|
||||
) -> Optional[LiveTaskStatusRenderer]:
|
||||
if api_url is None or not _stderr_sink.isatty():
|
||||
_stderr_sink.set_renderer(None)
|
||||
return None
|
||||
|
||||
renderer = LiveTaskStatusRenderer(_stderr_sink)
|
||||
_stderr_sink.set_renderer(renderer)
|
||||
return renderer
|
||||
|
||||
|
||||
_stderr_sink = LiveAwareStderrSink(sys.stderr)
|
||||
logger.remove()
|
||||
logger.add(_stderr_sink, level=log_level)
|
||||
|
||||
|
||||
def build_http_timeout() -> httpx.Timeout:
|
||||
return httpx.Timeout(connect=10, read=60, write=300, pool=30)
|
||||
|
||||
@@ -724,6 +867,7 @@ async def wait_for_task_result(
|
||||
client: httpx.AsyncClient,
|
||||
submit_response: SubmitResponse,
|
||||
planned_task: PlannedTask,
|
||||
live_renderer: Optional[LiveTaskStatusRenderer] = None,
|
||||
timeout_seconds: float = TASK_RESULT_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
deadline = asyncio.get_running_loop().time() + timeout_seconds
|
||||
@@ -738,6 +882,8 @@ async def wait_for_task_result(
|
||||
payload = response.json()
|
||||
status = payload.get("status")
|
||||
if status in {"pending", "processing"}:
|
||||
if live_renderer is not None:
|
||||
live_renderer.update_status(submit_response.task_id, status)
|
||||
await asyncio.sleep(TASK_STATUS_POLL_INTERVAL_SECONDS)
|
||||
continue
|
||||
if status == "completed":
|
||||
@@ -857,6 +1003,7 @@ async def run_planned_task(
|
||||
visualization_context: Optional[VisualizationContext],
|
||||
form_data: dict[str, str],
|
||||
output_dir: Path,
|
||||
live_renderer: Optional[LiveTaskStatusRenderer] = None,
|
||||
) -> None:
|
||||
logger.info(format_task_submission_message(planned_task, progress))
|
||||
submit_response = await submit_task(
|
||||
@@ -865,11 +1012,18 @@ async def run_planned_task(
|
||||
planned_task=planned_task,
|
||||
form_data=form_data,
|
||||
)
|
||||
await wait_for_task_result(
|
||||
client=client,
|
||||
submit_response=submit_response,
|
||||
planned_task=planned_task,
|
||||
)
|
||||
if live_renderer is not None:
|
||||
live_renderer.register_task(planned_task, submit_response.task_id)
|
||||
try:
|
||||
await wait_for_task_result(
|
||||
client=client,
|
||||
submit_response=submit_response,
|
||||
planned_task=planned_task,
|
||||
live_renderer=live_renderer,
|
||||
)
|
||||
finally:
|
||||
if live_renderer is not None:
|
||||
live_renderer.remove_task(submit_response.task_id)
|
||||
zip_path = await download_result_zip(
|
||||
client=client,
|
||||
submit_response=submit_response,
|
||||
@@ -938,6 +1092,7 @@ async def run_orchestrated_cli(
|
||||
timeout = build_http_timeout()
|
||||
local_server: LocalAPIServer | None = None
|
||||
visualization_context: Optional[VisualizationContext] = None
|
||||
live_renderer: Optional[LiveTaskStatusRenderer] = None
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as http_client:
|
||||
try:
|
||||
if api_url is None:
|
||||
@@ -950,6 +1105,7 @@ async def run_orchestrated_cli(
|
||||
http_client,
|
||||
normalize_base_url(api_url),
|
||||
)
|
||||
live_renderer = create_live_task_status_renderer(api_url)
|
||||
|
||||
planned_tasks = plan_tasks(
|
||||
documents=documents,
|
||||
@@ -987,6 +1143,7 @@ async def run_orchestrated_cli(
|
||||
visualization_context=visualization_context,
|
||||
form_data=form_data,
|
||||
output_dir=output_dir,
|
||||
live_renderer=live_renderer,
|
||||
),
|
||||
)
|
||||
if failures:
|
||||
@@ -1002,7 +1159,12 @@ async def run_orchestrated_cli(
|
||||
if local_server is not None:
|
||||
local_server.stop()
|
||||
finally:
|
||||
await wait_for_visualization_jobs(visualization_context)
|
||||
try:
|
||||
await wait_for_visualization_jobs(visualization_context)
|
||||
finally:
|
||||
if live_renderer is not None:
|
||||
live_renderer.close()
|
||||
_stderr_sink.set_renderer(None)
|
||||
|
||||
|
||||
@click.command()
|
||||
|
||||
Reference in New Issue
Block a user