diff --git a/mineru/cli/client.py b/mineru/cli/client.py index 75fd2821..ec70edde 100644 --- a/mineru/cli/client.py +++ b/mineru/cli/client.py @@ -8,13 +8,14 @@ import socket import subprocess import sys import tempfile +import threading import time import zipfile from concurrent.futures import Future, ProcessPoolExecutor from contextlib import ExitStack from dataclasses import dataclass from pathlib import Path, PurePosixPath -from typing import Awaitable, Callable, Optional +from typing import Awaitable, Callable, Optional, TextIO import click import httpx @@ -48,8 +49,6 @@ from .visualization import ( os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1" log_level = os.getenv("MINERU_LOG_LEVEL", "INFO").upper() -logger.remove() -logger.add(sys.stderr, level=log_level) HEALTH_ENDPOINT = "/health" TASKS_ENDPOINT = "/tasks" @@ -113,6 +112,150 @@ class TaskFailure: message: str +@dataclass +class LiveTaskStatusState: + task_index: int + task_id: str + status: str + frame_step: int = 0 + + +class LiveAwareStderrSink: + def __init__(self, stream: TextIO): + self.stream = stream + self.lock = threading.RLock() + self.renderer: LiveTaskStatusRenderer | None = None + + def set_renderer(self, renderer: "LiveTaskStatusRenderer | None") -> None: + with self.lock: + self.renderer = renderer + + def isatty(self) -> bool: + return bool(getattr(self.stream, "isatty", lambda: False)()) + + def write(self, message: str) -> None: + with self.lock: + renderer = self.renderer + if renderer is not None: + renderer.clear_locked() + self.stream.write(message) + self.stream.flush() + if renderer is not None: + renderer.render_locked() + + def flush(self) -> None: + self.stream.flush() + + def stop(self) -> None: + self.flush() + + +class LiveTaskStatusRenderer: + BAR_WIDTH = 12 + RUNNER_WIDTH = 4 + ACTIVE_STATUSES = {"pending", "processing"} + + def __init__(self, sink: LiveAwareStderrSink): + self.sink = sink + self._rendered_line_count = 0 + self._task_states: dict[str, LiveTaskStatusState] = {} + + def register_task(self, task: PlannedTask, task_id: str) -> None: + with self.sink.lock: + self._task_states[task_id] = LiveTaskStatusState( + task_index=task.index, + task_id=task_id, + status="pending", + ) + self.render_locked() + + def update_status(self, task_id: str, status: str) -> None: + with self.sink.lock: + state = self._task_states.get(task_id) + if state is None: + return + state.status = status + if status in self.ACTIVE_STATUSES: + state.frame_step += 1 + self.render_locked() + + def remove_task(self, task_id: str) -> None: + with self.sink.lock: + if self._task_states.pop(task_id, None) is None: + return + self.render_locked() + + def close(self) -> None: + with self.sink.lock: + self._task_states.clear() + self.clear_locked() + + def snapshot_lines(self) -> list[str]: + with self.sink.lock: + return self._build_render_lines_locked() + + def clear_locked(self) -> None: + if self._rendered_line_count <= 0: + return + + self.sink.stream.write(f"\x1b[{self._rendered_line_count}A\r") + for index in range(self._rendered_line_count): + self.sink.stream.write("\x1b[2K") + if index + 1 < self._rendered_line_count: + self.sink.stream.write("\x1b[1B\r") + if self._rendered_line_count > 1: + self.sink.stream.write(f"\x1b[{self._rendered_line_count - 1}A\r") + self.sink.stream.flush() + self._rendered_line_count = 0 + + def render_locked(self) -> None: + self.clear_locked() + lines = self._build_render_lines_locked() + if lines: + self.sink.stream.write("\n".join(lines)) + self.sink.stream.write("\n") + self.sink.stream.flush() + self._rendered_line_count = len(lines) + + def _build_render_lines_locked(self) -> list[str]: + states = sorted( + self._task_states.values(), + key=lambda state: (state.task_index, state.task_id), + ) + return [ + f"{self._build_bar(state.frame_step)} status={state.status} | task_id={state.task_id}" + for state in states + ] + + @classmethod + def _build_bar(cls, frame_step: int) -> str: + cells = [" "] * cls.BAR_WIDTH + runner_start = frame_step % cls.BAR_WIDTH + for offset in range(cls.RUNNER_WIDTH): + position = (runner_start + offset) % cls.BAR_WIDTH + cells[position] = "=" + head_position = (runner_start + cls.RUNNER_WIDTH - 1) % cls.BAR_WIDTH + cells[head_position] = ">" + return f"[{''.join(cells)}]" + + +def create_live_task_status_renderer( + api_url: Optional[str], +) -> Optional[LiveTaskStatusRenderer]: + if api_url is None or not _stderr_sink.isatty(): + _stderr_sink.set_renderer(None) + return None + + renderer = LiveTaskStatusRenderer(_stderr_sink) + _stderr_sink.set_renderer(renderer) + return renderer + + +_stderr_sink = LiveAwareStderrSink(sys.stderr) +logger.remove() +logger.add(_stderr_sink, level=log_level) + + def build_http_timeout() -> httpx.Timeout: return httpx.Timeout(connect=10, read=60, write=300, pool=30) @@ -724,6 +867,7 @@ async def wait_for_task_result( client: httpx.AsyncClient, submit_response: SubmitResponse, planned_task: PlannedTask, + live_renderer: Optional[LiveTaskStatusRenderer] = None, timeout_seconds: float = TASK_RESULT_TIMEOUT_SECONDS, ) -> None: deadline = asyncio.get_running_loop().time() + timeout_seconds @@ -738,6 +882,8 @@ async def wait_for_task_result( payload = response.json() status = payload.get("status") if status in {"pending", "processing"}: + if live_renderer is not None: + live_renderer.update_status(submit_response.task_id, status) await asyncio.sleep(TASK_STATUS_POLL_INTERVAL_SECONDS) continue if status == "completed": @@ -857,6 +1003,7 @@ async def run_planned_task( visualization_context: Optional[VisualizationContext], form_data: dict[str, str], output_dir: Path, + live_renderer: Optional[LiveTaskStatusRenderer] = None, ) -> None: logger.info(format_task_submission_message(planned_task, progress)) submit_response = await submit_task( @@ -865,11 +1012,18 @@ async def run_planned_task( planned_task=planned_task, form_data=form_data, ) - await wait_for_task_result( - client=client, - submit_response=submit_response, - planned_task=planned_task, - ) + if live_renderer is not None: + live_renderer.register_task(planned_task, submit_response.task_id) + try: + await wait_for_task_result( + client=client, + submit_response=submit_response, + planned_task=planned_task, + live_renderer=live_renderer, + ) + finally: + if live_renderer is not None: + live_renderer.remove_task(submit_response.task_id) zip_path = await download_result_zip( client=client, submit_response=submit_response, @@ -938,6 +1092,7 @@ async def run_orchestrated_cli( timeout = build_http_timeout() local_server: LocalAPIServer | None = None visualization_context: Optional[VisualizationContext] = None + live_renderer: Optional[LiveTaskStatusRenderer] = None async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as http_client: try: if api_url is None: @@ -950,6 +1105,7 @@ async def run_orchestrated_cli( http_client, normalize_base_url(api_url), ) + live_renderer = create_live_task_status_renderer(api_url) planned_tasks = plan_tasks( documents=documents, @@ -987,6 +1143,7 @@ async def run_orchestrated_cli( visualization_context=visualization_context, form_data=form_data, output_dir=output_dir, + live_renderer=live_renderer, ), ) if failures: @@ -1002,7 +1159,12 @@ async def run_orchestrated_cli( if local_server is not None: local_server.stop() finally: - await wait_for_visualization_jobs(visualization_context) + try: + await wait_for_visualization_jobs(visualization_context) + finally: + if live_renderer is not None: + live_renderer.close() + _stderr_sink.set_renderer(None) @click.command()