mirror of
https://github.com/opendatalab/MinerU.git
synced 2026-03-27 02:58:54 +07:00
app.common依赖删除,pipeline_ocr重构
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -34,3 +34,4 @@ tmp
|
||||
ocr_demo
|
||||
|
||||
/app/common/__init__.py
|
||||
/magic_pdf/spark/__init__.py
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
|
||||
from app.common.s3 import get_s3_config
|
||||
from magic_pdf.spark.s3 import get_s3_config
|
||||
from magic_pdf.libs.commons import join_path, read_file, json_dump_path
|
||||
|
||||
|
||||
|
||||
@@ -4,14 +4,11 @@ import os
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
from app.common.s3 import get_s3_config
|
||||
from magic_pdf.pipeline_ocr import ocr_parse_pdf_core
|
||||
from magic_pdf.spark.s3 import get_s3_config
|
||||
from demo.demo_commons import get_json_from_local_or_s3
|
||||
from magic_pdf.dict2md.ocr_mkcontent import (
|
||||
ocr_mk_mm_markdown_with_para,
|
||||
ocr_mk_nlp_markdown,
|
||||
ocr_mk_mm_markdown,
|
||||
ocr_mk_mm_standard_format,
|
||||
ocr_mk_mm_markdown_with_para_and_pagination,
|
||||
make_standard_format_with_para
|
||||
)
|
||||
from magic_pdf.libs.commons import join_path, read_file
|
||||
@@ -67,12 +64,7 @@ def ocr_parse_core(book_name, pdf_bytes, ocr_pdf_model_info, start_page_id=0):
|
||||
save_path = join_path(save_tmp_path, "md")
|
||||
save_path_with_bookname = os.path.join(save_path, book_name)
|
||||
text_content_save_path = f"{save_path_with_bookname}/book.md"
|
||||
pdf_info_dict = parse_pdf_by_ocr(
|
||||
pdf_bytes,
|
||||
ocr_pdf_model_info,
|
||||
save_path,
|
||||
book_name,
|
||||
debug_mode=True)
|
||||
pdf_info_dict, parse_time = ocr_parse_pdf_core(pdf_bytes, ocr_pdf_model_info, book_name, start_page_id=start_page_id, debug_mode=True)
|
||||
|
||||
parent_dir = os.path.dirname(text_content_save_path)
|
||||
if not os.path.exists(parent_dir):
|
||||
|
||||
@@ -3,9 +3,6 @@ import sys
|
||||
import time
|
||||
from urllib.parse import quote
|
||||
|
||||
from magic_pdf.dict2md.ocr_mkcontent import ocr_mk_mm_markdown, ocr_mk_nlp_markdown_with_para, \
|
||||
ocr_mk_mm_markdown_with_para_and_pagination, ocr_mk_mm_markdown_with_para, ocr_mk_mm_standard_format, \
|
||||
make_standard_format_with_para
|
||||
from magic_pdf.libs.commons import (
|
||||
read_file,
|
||||
join_path,
|
||||
@@ -15,34 +12,19 @@ from magic_pdf.libs.commons import (
|
||||
)
|
||||
from magic_pdf.libs.drop_reason import DropReason
|
||||
from magic_pdf.libs.json_compressor import JsonCompressor
|
||||
from magic_pdf.dict2md.mkcontent import mk_nlp_markdown, mk_universal_format
|
||||
from magic_pdf.dict2md.mkcontent import mk_universal_format
|
||||
from magic_pdf.pdf_parse_by_model import parse_pdf_by_model
|
||||
from magic_pdf.filter.pdf_classify_by_type import classify
|
||||
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
|
||||
from loguru import logger
|
||||
|
||||
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
|
||||
from magic_pdf.pdf_parse_for_train import parse_pdf_for_train
|
||||
from magic_pdf.spark.base import exception_handler, get_data_source
|
||||
from magic_pdf.spark.base import exception_handler, get_data_source, get_bookname, get_pdf_bytes
|
||||
from magic_pdf.train_utils.convert_to_train_format import convert_to_train_format
|
||||
from app.common.s3 import get_s3_config, get_s3_client
|
||||
from magic_pdf.spark.s3 import get_s3_config, get_s3_client
|
||||
|
||||
|
||||
|
||||
def get_data_type(jso: dict):
|
||||
data_type = jso.get("data_type")
|
||||
if data_type is None:
|
||||
data_type = jso.get("file_type")
|
||||
return data_type
|
||||
|
||||
|
||||
def get_bookid(jso: dict):
|
||||
book_id = jso.get("bookid")
|
||||
if book_id is None:
|
||||
book_id = jso.get("original_file_id")
|
||||
return book_id
|
||||
|
||||
|
||||
def meta_scan(jso: dict, doc_layout_check=True) -> dict:
|
||||
s3_pdf_path = jso.get("file_location")
|
||||
s3_config = get_s3_config(s3_pdf_path)
|
||||
@@ -310,17 +292,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
|
||||
file_id = jso.get("file_id")
|
||||
book_name = f"{data_source}/{file_id}"
|
||||
|
||||
# 1.23.22已修复
|
||||
# if debug_mode:
|
||||
# pass
|
||||
# else:
|
||||
# if book_name == "zlib/zlib_21929367":
|
||||
# jso['need_drop'] = True
|
||||
# jso['drop_reason'] = DropReason.SPECIAL_PDF
|
||||
# return jso
|
||||
|
||||
junk_img_bojids = jso["pdf_meta"]["junk_img_bojids"]
|
||||
# total_page = jso['pdf_meta']['total_page']
|
||||
|
||||
# 增加检测 max_svgs 数量的检测逻辑,如果 max_svgs 超过3000则drop
|
||||
svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"]
|
||||
@@ -328,9 +300,6 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
|
||||
if max_svgs > 3000:
|
||||
jso["need_drop"] = True
|
||||
jso["drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
|
||||
# elif total_page > 1000:
|
||||
# jso['need_drop'] = True
|
||||
# jso['drop_reason'] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES
|
||||
else:
|
||||
try:
|
||||
save_path = s3_image_save_path
|
||||
@@ -459,79 +428,10 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d
|
||||
2.再调用ocr_dropped_parse_pdf,对之前drop的pdf进行处理
|
||||
"""
|
||||
|
||||
|
||||
def uni_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
|
||||
jso = parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
|
||||
jso = ocr_dropped_parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
|
||||
return jso
|
||||
|
||||
|
||||
# 专门用来跑被drop的pdf,跑完之后需要把need_drop字段置为false
|
||||
def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
|
||||
if not jso.get("need_drop", False):
|
||||
return jso
|
||||
else:
|
||||
jso = ocr_parse_pdf_core(
|
||||
jso, start_page_id=start_page_id, debug_mode=debug_mode
|
||||
)
|
||||
jso["need_drop"] = False
|
||||
return jso
|
||||
|
||||
|
||||
def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
|
||||
# 检测debug开关
|
||||
if debug_mode:
|
||||
pass
|
||||
else: # 如果debug没开,则检测是否有needdrop字段
|
||||
if jso.get("need_drop", False):
|
||||
return jso
|
||||
|
||||
jso = ocr_parse_pdf_core(jso, start_page_id=start_page_id, debug_mode=debug_mode)
|
||||
return jso
|
||||
|
||||
|
||||
def ocr_parse_pdf_core(jso: dict, start_page_id=0, debug_mode=False) -> dict:
|
||||
|
||||
s3_pdf_path = jso.get("file_location")
|
||||
s3_config = get_s3_config(s3_pdf_path)
|
||||
pdf_bytes = read_file(s3_pdf_path, s3_config)
|
||||
|
||||
model_output_json_list = jso.get("doc_layout_result")
|
||||
data_source = get_data_source(jso)
|
||||
file_id = jso.get("file_id")
|
||||
book_name = f"{data_source}/{file_id}"
|
||||
try:
|
||||
save_path = s3_image_save_path
|
||||
image_s3_config = get_s3_config(save_path)
|
||||
start_time = time.time() # 记录开始时间
|
||||
# 先打印一下book_name和解析开始的时间
|
||||
logger.info(
|
||||
f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
pdf_info_dict = parse_pdf_by_ocr(
|
||||
pdf_bytes,
|
||||
model_output_json_list,
|
||||
save_path,
|
||||
book_name,
|
||||
pdf_model_profile=None,
|
||||
image_s3_config=image_s3_config,
|
||||
start_page_id=start_page_id,
|
||||
debug_mode=debug_mode,
|
||||
)
|
||||
pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict)
|
||||
jso["pdf_intermediate_dict"] = pdf_info_dict
|
||||
end_time = time.time() # 记录完成时间
|
||||
parse_time = int(end_time - start_time) # 计算执行时间
|
||||
# 解析完成后打印一下book_name和耗时
|
||||
logger.info(
|
||||
f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
jso["parse_time"] = parse_time
|
||||
except Exception as e:
|
||||
jso = exception_handler(jso, e)
|
||||
return jso
|
||||
# def uni_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
|
||||
# jso = parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
|
||||
# jso = ocr_dropped_parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
|
||||
# return jso
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from magic_pdf.dict2md.ocr_mkcontent import ocr_mk_mm_markdown, ocr_mk_nlp_markdown_with_para, \
|
||||
ocr_mk_mm_markdown_with_para_and_pagination, ocr_mk_mm_markdown_with_para, ocr_mk_mm_standard_format, \
|
||||
make_standard_format_with_para
|
||||
from magic_pdf.libs.commons import join_path
|
||||
from magic_pdf.libs.commons import join_path, s3_image_save_path, formatted_time
|
||||
from magic_pdf.libs.json_compressor import JsonCompressor
|
||||
from magic_pdf.spark.base import get_data_source, exception_handler
|
||||
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
|
||||
from magic_pdf.spark.base import get_data_source, exception_handler, get_pdf_bytes, get_bookname
|
||||
from magic_pdf.spark.s3 import get_s3_config
|
||||
|
||||
|
||||
def ocr_pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
|
||||
@@ -179,4 +182,74 @@ def ocr_pdf_intermediate_dict_to_standard_format_with_para(jso: dict, debug_mode
|
||||
jso["pdf_meta"] = ""
|
||||
except Exception as e:
|
||||
jso = exception_handler(jso, e)
|
||||
return jso
|
||||
return jso
|
||||
|
||||
|
||||
def ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name, start_page_id=0, debug_mode=False):
|
||||
save_path = s3_image_save_path
|
||||
image_s3_config = get_s3_config(save_path)
|
||||
start_time = time.time() # 记录开始时间
|
||||
# 先打印一下book_name和解析开始的时间
|
||||
logger.info(
|
||||
f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
pdf_info_dict = parse_pdf_by_ocr(
|
||||
pdf_bytes,
|
||||
model_output_json_list,
|
||||
save_path,
|
||||
book_name,
|
||||
pdf_model_profile=None,
|
||||
image_s3_config=image_s3_config,
|
||||
start_page_id=start_page_id,
|
||||
debug_mode=debug_mode,
|
||||
)
|
||||
end_time = time.time() # 记录完成时间
|
||||
parse_time = int(end_time - start_time) # 计算执行时间
|
||||
# 解析完成后打印一下book_name和耗时
|
||||
logger.info(
|
||||
f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
return pdf_info_dict, parse_time
|
||||
|
||||
|
||||
# 专门用来跑被drop的pdf,跑完之后需要把need_drop字段置为false
|
||||
def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
|
||||
if not jso.get("need_drop", False):
|
||||
return jso
|
||||
else:
|
||||
try:
|
||||
pdf_bytes = get_pdf_bytes(jso)
|
||||
model_output_json_list = jso.get("doc_layout_result")
|
||||
book_name = get_bookname(jso)
|
||||
pdf_info_dict, parse_time = ocr_parse_pdf_core(
|
||||
pdf_bytes, model_output_json_list, book_name, start_page_id=start_page_id, debug_mode=debug_mode
|
||||
)
|
||||
jso["pdf_intermediate_dict"] = JsonCompressor.compress_json(pdf_info_dict)
|
||||
jso["parse_time"] = parse_time
|
||||
jso["need_drop"] = False
|
||||
except Exception as e:
|
||||
jso = exception_handler(jso, e)
|
||||
return jso
|
||||
|
||||
|
||||
def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
|
||||
# 检测debug开关
|
||||
if debug_mode:
|
||||
pass
|
||||
else: # 如果debug没开,则检测是否有needdrop字段
|
||||
if jso.get("need_drop", False):
|
||||
return jso
|
||||
try:
|
||||
pdf_bytes = get_pdf_bytes(jso)
|
||||
model_output_json_list = jso.get("doc_layout_result")
|
||||
book_name = get_bookname(jso)
|
||||
pdf_info_dict, parse_time = ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name,
|
||||
start_page_id=start_page_id, debug_mode=debug_mode)
|
||||
jso["pdf_intermediate_dict"] = JsonCompressor.compress_json(pdf_info_dict)
|
||||
jso["parse_time"] = parse_time
|
||||
except Exception as e:
|
||||
jso = exception_handler(jso, e)
|
||||
return jso
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from magic_pdf.libs.commons import read_file
|
||||
from magic_pdf.libs.drop_reason import DropReason
|
||||
|
||||
from magic_pdf.spark.s3 import get_s3_config
|
||||
|
||||
|
||||
def get_data_source(jso: dict):
|
||||
data_source = jso.get("data_source")
|
||||
@@ -12,6 +14,20 @@ def get_data_source(jso: dict):
|
||||
return data_source
|
||||
|
||||
|
||||
def get_data_type(jso: dict):
|
||||
data_type = jso.get("data_type")
|
||||
if data_type is None:
|
||||
data_type = jso.get("file_type")
|
||||
return data_type
|
||||
|
||||
|
||||
def get_bookid(jso: dict):
|
||||
book_id = jso.get("bookid")
|
||||
if book_id is None:
|
||||
book_id = jso.get("original_file_id")
|
||||
return book_id
|
||||
|
||||
|
||||
def exception_handler(jso: dict, e):
|
||||
logger.exception(e)
|
||||
jso["need_drop"] = True
|
||||
@@ -19,3 +35,16 @@ def exception_handler(jso: dict, e):
|
||||
jso["exception"] = f"ERROR: {e}"
|
||||
return jso
|
||||
|
||||
|
||||
def get_bookname(jso: dict):
|
||||
data_source = get_data_source(jso)
|
||||
file_id = jso.get("file_id")
|
||||
book_name = f"{data_source}/{file_id}"
|
||||
return book_name
|
||||
|
||||
|
||||
def get_pdf_bytes(jso: dict):
|
||||
pdf_s3_path = jso.get("file_location")
|
||||
s3_config = get_s3_config(pdf_s3_path)
|
||||
pdf_bytes = read_file(pdf_s3_path, s3_config)
|
||||
return pdf_bytes
|
||||
@@ -2,10 +2,13 @@
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
|
||||
from app.common import s3_buckets, s3_clusters, get_cluster_name, s3_users
|
||||
import re
|
||||
import random
|
||||
from typing import List, Union
|
||||
try:
|
||||
from app.common import s3_buckets, s3_clusters, get_cluster_name, s3_users
|
||||
except ImportError:
|
||||
from magic_pdf.spark import s3_buckets, s3_clusters, get_cluster_name, s3_users
|
||||
|
||||
__re_s3_path = re.compile("^s3a?://([^/]+)(?:/(.*))?$")
|
||||
def get_s3_config(path: Union[str, List[str]], outside=False):
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import os
|
||||
from magic_pdf.libs.commons import fitz
|
||||
|
||||
from app.common.s3 import get_s3_config, get_s3_client
|
||||
from magic_pdf.spark.s3 import get_s3_config, get_s3_client
|
||||
from magic_pdf.libs.commons import join_path, json_dump_path, read_file, parse_bucket_key
|
||||
from loguru import logger
|
||||
|
||||
|
||||
Reference in New Issue
Block a user