Refactor fast_api.py for logging and concurrency

This commit is contained in:
Xiaomeng Zhao
2026-01-29 16:39:50 +08:00
committed by GitHub
parent df66af3f97
commit c77edb27bc

View File

@@ -10,12 +10,16 @@ import zipfile
import shutil
from pathlib import Path
import glob
from fastapi import Depends, FastAPI, HTTPException, UploadFile, File, Form
from fastapi import Depends, FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse, FileResponse
from typing import List, Optional
from loguru import logger
from fastapi import BackgroundTasks
log_level = os.getenv("MINERU_LOG_LEVEL", "INFO").upper()
logger.remove() # 移除默认handler
logger.add(sys.stderr, level=log_level) # 添加新handler
from base64 import b64encode
from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes
@@ -23,11 +27,6 @@ from mineru.utils.cli_parser import arg_parse
from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path
from mineru.version import __version__
log_level = os.getenv("MINERU_LOG_LEVEL", "INFO").upper()
logger.remove() # 移除默认handler
logger.add(sys.stderr, level=log_level) # 添加新handler
# 并发控制器
_request_semaphore: Optional[asyncio.Semaphore] = None
@@ -35,7 +34,8 @@ _request_semaphore: Optional[asyncio.Semaphore] = None
# 并发控制依赖函数
async def limit_concurrency():
if _request_semaphore is not None:
if _request_semaphore.locked():
# 检查信号量是否已用尽,如果是则拒绝请求
if _request_semaphore._value == 0:
raise HTTPException(
status_code=503,
detail=f"Server is at maximum capacity: {os.getenv('MINERU_API_MAX_CONCURRENT_REQUESTS', 'unset')}. Please try again later.",
@@ -86,7 +86,7 @@ def sanitize_filename(filename: str) -> str:
移除路径遍历字符, 保留 Unicode 字母、数字、._-
禁止隐藏文件
"""
sanitized = re.sub(r"[/\\\.]{2,}|[/\\]", "", filename)
sanitized = re.sub(r"[/\\.]{2,}|[/\\]", "", filename)
sanitized = re.sub(r"[^\w.-]", "_", sanitized, flags=re.UNICODE)
if sanitized.startswith("."):
sanitized = "_" + sanitized[1:]
@@ -287,6 +287,10 @@ async def parse_pdf(
parse_dir = os.path.join(
unique_dir, pdf_name, f"hybrid_{parse_method}"
)
else:
# 未知 backend跳过此文件
logger.warning(f"Unknown backend type: {backend}, skipping {pdf_name}")
continue
if not os.path.exists(parse_dir):
continue
@@ -318,7 +322,7 @@ async def parse_pdf(
zf.write(
path,
arcname=os.path.join(
safe_pdf_name, os.path.basename(path)
safe_pdf_name, f"{safe_pdf_name}_model.json"
),
)
@@ -368,6 +372,10 @@ async def parse_pdf(
parse_dir = os.path.join(
unique_dir, pdf_name, f"hybrid_{parse_method}"
)
else:
# 未知 backend跳过此文件
logger.warning(f"Unknown backend type: {backend}, skipping {pdf_name}")
continue
if os.path.exists(parse_dir):
if return_md: