feat: implement live task status rendering with thread-safe output handling

This commit is contained in:
myhloli
2026-03-26 11:19:40 +08:00
parent d4f1710e42
commit d2d1a35b32

View File

@@ -8,13 +8,14 @@ import socket
import subprocess
import sys
import tempfile
import threading
import time
import zipfile
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import ExitStack
from dataclasses import dataclass
from pathlib import Path, PurePosixPath
from typing import Awaitable, Callable, Optional
from typing import Awaitable, Callable, Optional, TextIO
import click
import httpx
@@ -48,8 +49,6 @@ from .visualization import (
os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1"
log_level = os.getenv("MINERU_LOG_LEVEL", "INFO").upper()
logger.remove()
logger.add(sys.stderr, level=log_level)
HEALTH_ENDPOINT = "/health"
TASKS_ENDPOINT = "/tasks"
@@ -113,6 +112,150 @@ class TaskFailure:
message: str
@dataclass
class LiveTaskStatusState:
task_index: int
task_id: str
status: str
frame_step: int = 0
class LiveAwareStderrSink:
def __init__(self, stream: TextIO):
self.stream = stream
self.lock = threading.RLock()
self.renderer: LiveTaskStatusRenderer | None = None
def set_renderer(self, renderer: "LiveTaskStatusRenderer | None") -> None:
with self.lock:
self.renderer = renderer
def isatty(self) -> bool:
return bool(getattr(self.stream, "isatty", lambda: False)())
def write(self, message: str) -> None:
with self.lock:
renderer = self.renderer
if renderer is not None:
renderer.clear_locked()
self.stream.write(message)
self.stream.flush()
if renderer is not None:
renderer.render_locked()
def flush(self) -> None:
self.stream.flush()
def stop(self) -> None:
self.flush()
class LiveTaskStatusRenderer:
BAR_WIDTH = 12
RUNNER_WIDTH = 4
ACTIVE_STATUSES = {"pending", "processing"}
def __init__(self, sink: LiveAwareStderrSink):
self.sink = sink
self._rendered_line_count = 0
self._task_states: dict[str, LiveTaskStatusState] = {}
def register_task(self, task: PlannedTask, task_id: str) -> None:
with self.sink.lock:
self._task_states[task_id] = LiveTaskStatusState(
task_index=task.index,
task_id=task_id,
status="pending",
)
self.render_locked()
def update_status(self, task_id: str, status: str) -> None:
with self.sink.lock:
state = self._task_states.get(task_id)
if state is None:
return
state.status = status
if status in self.ACTIVE_STATUSES:
state.frame_step += 1
self.render_locked()
def remove_task(self, task_id: str) -> None:
with self.sink.lock:
if self._task_states.pop(task_id, None) is None:
return
self.render_locked()
def close(self) -> None:
with self.sink.lock:
self._task_states.clear()
self.clear_locked()
def snapshot_lines(self) -> list[str]:
with self.sink.lock:
return self._build_render_lines_locked()
def clear_locked(self) -> None:
if self._rendered_line_count <= 0:
return
self.sink.stream.write(f"\x1b[{self._rendered_line_count}A\r")
for index in range(self._rendered_line_count):
self.sink.stream.write("\x1b[2K")
if index + 1 < self._rendered_line_count:
self.sink.stream.write("\x1b[1B\r")
if self._rendered_line_count > 1:
self.sink.stream.write(f"\x1b[{self._rendered_line_count - 1}A\r")
self.sink.stream.flush()
self._rendered_line_count = 0
def render_locked(self) -> None:
self.clear_locked()
lines = self._build_render_lines_locked()
if lines:
self.sink.stream.write("\n".join(lines))
self.sink.stream.write("\n")
self.sink.stream.flush()
self._rendered_line_count = len(lines)
def _build_render_lines_locked(self) -> list[str]:
states = sorted(
self._task_states.values(),
key=lambda state: (state.task_index, state.task_id),
)
return [
f"{self._build_bar(state.frame_step)} status={state.status} | task_id={state.task_id}"
for state in states
]
@classmethod
def _build_bar(cls, frame_step: int) -> str:
cells = [" "] * cls.BAR_WIDTH
runner_start = frame_step % cls.BAR_WIDTH
for offset in range(cls.RUNNER_WIDTH):
position = (runner_start + offset) % cls.BAR_WIDTH
cells[position] = "="
head_position = (runner_start + cls.RUNNER_WIDTH - 1) % cls.BAR_WIDTH
cells[head_position] = ">"
return f"[{''.join(cells)}]"
def create_live_task_status_renderer(
api_url: Optional[str],
) -> Optional[LiveTaskStatusRenderer]:
if api_url is None or not _stderr_sink.isatty():
_stderr_sink.set_renderer(None)
return None
renderer = LiveTaskStatusRenderer(_stderr_sink)
_stderr_sink.set_renderer(renderer)
return renderer
_stderr_sink = LiveAwareStderrSink(sys.stderr)
logger.remove()
logger.add(_stderr_sink, level=log_level)
def build_http_timeout() -> httpx.Timeout:
return httpx.Timeout(connect=10, read=60, write=300, pool=30)
@@ -724,6 +867,7 @@ async def wait_for_task_result(
client: httpx.AsyncClient,
submit_response: SubmitResponse,
planned_task: PlannedTask,
live_renderer: Optional[LiveTaskStatusRenderer] = None,
timeout_seconds: float = TASK_RESULT_TIMEOUT_SECONDS,
) -> None:
deadline = asyncio.get_running_loop().time() + timeout_seconds
@@ -738,6 +882,8 @@ async def wait_for_task_result(
payload = response.json()
status = payload.get("status")
if status in {"pending", "processing"}:
if live_renderer is not None:
live_renderer.update_status(submit_response.task_id, status)
await asyncio.sleep(TASK_STATUS_POLL_INTERVAL_SECONDS)
continue
if status == "completed":
@@ -857,6 +1003,7 @@ async def run_planned_task(
visualization_context: Optional[VisualizationContext],
form_data: dict[str, str],
output_dir: Path,
live_renderer: Optional[LiveTaskStatusRenderer] = None,
) -> None:
logger.info(format_task_submission_message(planned_task, progress))
submit_response = await submit_task(
@@ -865,11 +1012,18 @@ async def run_planned_task(
planned_task=planned_task,
form_data=form_data,
)
await wait_for_task_result(
client=client,
submit_response=submit_response,
planned_task=planned_task,
)
if live_renderer is not None:
live_renderer.register_task(planned_task, submit_response.task_id)
try:
await wait_for_task_result(
client=client,
submit_response=submit_response,
planned_task=planned_task,
live_renderer=live_renderer,
)
finally:
if live_renderer is not None:
live_renderer.remove_task(submit_response.task_id)
zip_path = await download_result_zip(
client=client,
submit_response=submit_response,
@@ -938,6 +1092,7 @@ async def run_orchestrated_cli(
timeout = build_http_timeout()
local_server: LocalAPIServer | None = None
visualization_context: Optional[VisualizationContext] = None
live_renderer: Optional[LiveTaskStatusRenderer] = None
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as http_client:
try:
if api_url is None:
@@ -950,6 +1105,7 @@ async def run_orchestrated_cli(
http_client,
normalize_base_url(api_url),
)
live_renderer = create_live_task_status_renderer(api_url)
planned_tasks = plan_tasks(
documents=documents,
@@ -987,6 +1143,7 @@ async def run_orchestrated_cli(
visualization_context=visualization_context,
form_data=form_data,
output_dir=output_dir,
live_renderer=live_renderer,
),
)
if failures:
@@ -1002,7 +1159,12 @@ async def run_orchestrated_cli(
if local_server is not None:
local_server.stop()
finally:
await wait_for_visualization_jobs(visualization_context)
try:
await wait_for_visualization_jobs(visualization_context)
finally:
if live_renderer is not None:
live_renderer.close()
_stderr_sink.set_renderer(None)
@click.command()