修复codereview建议

This commit is contained in:
Magic_yuan
2025-10-17 13:04:49 +08:00
parent 08a89aeca1
commit e7d8bf097a
4 changed files with 137 additions and 62 deletions

View File

@@ -532,11 +532,13 @@ python start_all.py --cleanup-old-files-days 0
| 指标 | v1.x | v2.0 | 提升 |
|-----|------|------|-----|
| 任务响应延迟 | 5-10秒 (调度器触发) | 0.5秒 (Worker主动拉取) | **10-20倍** |
| 任务响应延迟<sup>※</sup> | 5-10秒 (调度器轮询) | 0.5秒 (Worker主动拉取) | **10-20倍** |
| 并发安全性 | 基础锁机制 | 原子操作 + 状态检查 | **可靠性提升** |
| 多GPU效率 | 有时会出现显存冲突 | 完全隔离,无冲突 | **稳定性提升** |
| 系统开销 | 调度器持续运行 | 可选监控(5分钟) | **资源节省** |
※ 任务响应延迟指任务添加到被 Worker 开始处理的时间间隔。v1.x 主要受调度器轮询间隔影响,非测量端到端处理时间。实际端到端响应时间还包括任务类型和系统负载所有因子。
## 📝 核心依赖
```txt

View File

@@ -10,6 +10,8 @@ import json
import sys
import time
import threading
import signal
import atexit
from pathlib import Path
import litserve as ls
from loguru import logger
@@ -87,7 +89,7 @@ class MinerUWorkerAPI(ls.LitAPI):
os.environ['CUDA_VISIBLE_DEVICES'] = device_id
# 设置为 cuda:0因为对进程来说只能看到一张卡逻辑ID变为0
os.environ['MINERU_DEVICE_MODE'] = 'cuda:0'
device_mode = 'cuda:0'
device_mode = os.environ['MINERU_DEVICE_MODE']
logger.info(f"🔒 CUDA_VISIBLE_DEVICES={device_id} (Physical GPU {device_id} → Logical GPU 0)")
else:
# 配置 MinerU 环境
@@ -126,6 +128,26 @@ class MinerUWorkerAPI(ls.LitAPI):
self.worker_thread.start()
logger.info(f"🔄 Worker loop started (poll_interval={self.poll_interval}s)")
def teardown(self):
"""
优雅关闭 Worker
设置 running 标志为 False等待 worker 线程完成当前任务后退出。
这避免了守护线程可能导致的任务处理不完整或数据库操作不一致问题。
"""
if self.enable_worker_loop and self.worker_thread and self.worker_thread.is_alive():
logger.info(f"🛑 Shutting down worker {self.worker_id}...")
self.running = False
# 等待线程完成当前任务(最多等待 poll_interval * 2 秒)
timeout = self.poll_interval * 2
self.worker_thread.join(timeout=timeout)
if self.worker_thread.is_alive():
logger.warning(f"⚠️ Worker thread did not stop within {timeout}s, forcing exit")
else:
logger.info(f"✅ Worker {self.worker_id} shut down gracefully")
def _worker_loop(self):
"""
Worker 主循环:持续拉取并处理任务
@@ -309,7 +331,7 @@ class MinerUWorkerAPI(ls.LitAPI):
try:
clean_memory()
except Exception as e:
logger.debug(f"Memory cleanup: {e}")
logger.debug(f"Memory cleanup failed for task {task_id}: {e}")
def _parse_with_markitdown(self, file_path: Path, file_name: str,
output_path: Path):
@@ -450,6 +472,24 @@ def start_litserve_workers(
timeout=False, # 不设置超时
)
# 注册优雅关闭处理器
def graceful_shutdown(signum=None, frame=None):
"""处理关闭信号,优雅地停止 worker"""
logger.info("🛑 Received shutdown signal, gracefully stopping workers...")
# 注意LitServe 会为每个设备创建多个 worker 实例
# 这里的 api 只是模板,实际的 worker 实例由 LitServe 管理
# teardown 会在每个 worker 进程中被调用
if hasattr(api, 'teardown'):
api.teardown()
sys.exit(0)
# 注册信号处理器Ctrl+C 等)
signal.signal(signal.SIGINT, graceful_shutdown)
signal.signal(signal.SIGTERM, graceful_shutdown)
# 注册 atexit 处理器(正常退出时调用)
atexit.register(lambda: api.teardown() if hasattr(api, 'teardown') else None)
logger.info(f"✅ LitServe worker pool initialized")
logger.info(f"📡 Listening on: http://0.0.0.0:{port}/predict")
if enable_worker_loop:

View File

@@ -103,12 +103,13 @@ class TaskDB:
''', (task_id, file_name, file_path, backend, json.dumps(options or {}), priority))
return task_id
def get_next_task(self, worker_id: str) -> Optional[Dict]:
def get_next_task(self, worker_id: str, max_retries: int = 3) -> Optional[Dict]:
"""
获取下一个待处理任务(原子操作,防止并发冲突)
Args:
worker_id: Worker ID
max_retries: 当任务被其他 worker 抢走时的最大重试次数默认3次
Returns:
task: 任务字典,如果没有任务返回 None
@@ -117,39 +118,97 @@ class TaskDB:
1. 使用 BEGIN IMMEDIATE 立即获取写锁
2. UPDATE 时检查 status = 'pending' 防止重复拉取
3. 检查 rowcount 确保更新成功
4. 如果任务被抢走,立即重试而不是返回 None避免不必要的等待
"""
with self.get_cursor() as cursor:
# 使用事务确保原子性
cursor.execute('BEGIN IMMEDIATE')
# 按优先级和创建时间获取任务
cursor.execute('''
SELECT * FROM tasks
WHERE status = 'pending'
ORDER BY priority DESC, created_at ASC
LIMIT 1
''')
task = cursor.fetchone()
if task:
# 立即标记为 processing并确保状态仍是 pending
for attempt in range(max_retries):
with self.get_cursor() as cursor:
# 使用事务确保原子性
cursor.execute('BEGIN IMMEDIATE')
# 按优先级和创建时间获取任务
cursor.execute('''
UPDATE tasks
SET status = 'processing',
started_at = CURRENT_TIMESTAMP,
worker_id = ?
WHERE task_id = ? AND status = 'pending'
''', (worker_id, task['task_id']))
SELECT * FROM tasks
WHERE status = 'pending'
ORDER BY priority DESC, created_at ASC
LIMIT 1
''')
# 检查是否更新成功(防止被其他 worker 抢走)
if cursor.rowcount == 0:
# 任务被其他进程抢走了,返回 None
# 调用方会在下一次循环中重新获取
task = cursor.fetchone()
if task:
# 立即标记为 processing并确保状态仍是 pending
cursor.execute('''
UPDATE tasks
SET status = 'processing',
started_at = CURRENT_TIMESTAMP,
worker_id = ?
WHERE task_id = ? AND status = 'pending'
''', (worker_id, task['task_id']))
# 检查是否更新成功(防止被其他 worker 抢走)
if cursor.rowcount == 0:
# 任务被其他进程抢走了,立即重试
# 因为队列中可能还有其他待处理任务
continue
return dict(task)
else:
# 队列中没有待处理任务,返回 None
return None
return dict(task)
return None
# 重试次数用尽,仍未获取到任务(高并发场景)
return None
def _build_update_clauses(self, status: str, result_path: str = None,
error_message: str = None, worker_id: str = None,
task_id: str = None):
"""
构建 UPDATE 和 WHERE 子句的辅助方法
Args:
status: 新状态
result_path: 结果路径(可选)
error_message: 错误信息(可选)
worker_id: Worker ID可选
task_id: 任务ID可选
Returns:
tuple: (update_clauses, update_params, where_clauses, where_params)
"""
update_clauses = ['status = ?']
update_params = [status]
where_clauses = []
where_params = []
# 添加 task_id 条件(如果提供)
if task_id:
where_clauses.append('task_id = ?')
where_params.append(task_id)
# 处理 completed 状态
if status == 'completed':
update_clauses.append('completed_at = CURRENT_TIMESTAMP')
if result_path:
update_clauses.append('result_path = ?')
update_params.append(result_path)
# 只更新正在处理的任务
where_clauses.append("status = 'processing'")
if worker_id:
where_clauses.append('worker_id = ?')
where_params.append(worker_id)
# 处理 failed 状态
elif status == 'failed':
update_clauses.append('completed_at = CURRENT_TIMESTAMP')
if error_message:
update_clauses.append('error_message = ?')
update_params.append(error_message)
# 只更新正在处理的任务
where_clauses.append("status = 'processing'")
if worker_id:
where_clauses.append('worker_id = ?')
where_params.append(worker_id)
return update_clauses, update_params, where_clauses, where_params
def update_task_status(self, task_id: str, status: str,
result_path: str = None, error_message: str = None,
@@ -173,35 +232,9 @@ class TaskDB:
3. 返回 False 表示任务被其他进程修改了
"""
with self.get_cursor() as cursor:
# 分离 UPDATE 和 WHERE 的参数,确保顺序正确
update_clauses = ['status = ?']
update_params = [status]
where_clauses = ['task_id = ?']
where_params = [task_id]
# 处理 completed 状态
if status == 'completed':
update_clauses.append('completed_at = CURRENT_TIMESTAMP')
if result_path:
update_clauses.append('result_path = ?')
update_params.append(result_path)
# 只更新正在处理的任务
where_clauses.append("status = 'processing'")
if worker_id:
where_clauses.append('worker_id = ?')
where_params.append(worker_id)
# 处理 failed 状态
elif status == 'failed':
update_clauses.append('completed_at = CURRENT_TIMESTAMP')
if error_message:
update_clauses.append('error_message = ?')
update_params.append(error_message)
# 只更新正在处理的任务
where_clauses.append("status = 'processing'")
if worker_id:
where_clauses.append('worker_id = ?')
where_params.append(worker_id)
# 使用辅助方法构建 UPDATE 和 WHERE 子句
update_clauses, update_params, where_clauses, where_params = \
self._build_update_clauses(status, result_path, error_message, worker_id, task_id)
# 合并参数:先 UPDATE 部分,再 WHERE 部分
all_params = update_params + where_params

View File

@@ -151,7 +151,7 @@ class TaskScheduler:
# 4. 定期清理旧任务文件和记录
cleanup_counter += 1
# 每24小时清理一次假设 monitor_interval = 300s
# 每24小时清理一次基于当前监控间隔计算
cleanup_interval_cycles = (24 * 3600) / self.monitor_interval
if cleanup_counter >= cleanup_interval_cycles:
cleanup_counter = 0