From 575ca00e01e8ab0b09af98f553d45c20e4e71caa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E5=B0=8F=E8=92=99?= Date: Fri, 29 Mar 2024 14:04:57 +0800 Subject: [PATCH] =?UTF-8?q?app.common=E4=BE=9D=E8=B5=96=E5=88=A0=E9=99=A4?= =?UTF-8?q?=EF=BC=8Cpipeline=5Focr=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + demo/demo_commons.py | 2 +- demo/ocr_demo.py | 14 +--- magic_pdf/pipeline.py | 114 ++------------------------ magic_pdf/pipeline_ocr.py | 79 +++++++++++++++++- magic_pdf/spark/base.py | 31 ++++++- {app/common => magic_pdf/spark}/s3.py | 5 +- tests/test_commons.py | 2 +- 8 files changed, 123 insertions(+), 125 deletions(-) rename {app/common => magic_pdf/spark}/s3.py (93%) diff --git a/.gitignore b/.gitignore index 97329d1f..c29ab947 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,4 @@ tmp ocr_demo /app/common/__init__.py +/magic_pdf/spark/__init__.py diff --git a/demo/demo_commons.py b/demo/demo_commons.py index bdc80bfd..ffadd80e 100644 --- a/demo/demo_commons.py +++ b/demo/demo_commons.py @@ -1,6 +1,6 @@ import json -from app.common.s3 import get_s3_config +from magic_pdf.spark.s3 import get_s3_config from magic_pdf.libs.commons import join_path, read_file, json_dump_path diff --git a/demo/ocr_demo.py b/demo/ocr_demo.py index c21362a7..5c34dd3e 100644 --- a/demo/ocr_demo.py +++ b/demo/ocr_demo.py @@ -4,14 +4,11 @@ import os from loguru import logger from pathlib import Path -from app.common.s3 import get_s3_config +from magic_pdf.pipeline_ocr import ocr_parse_pdf_core +from magic_pdf.spark.s3 import get_s3_config from demo.demo_commons import get_json_from_local_or_s3 from magic_pdf.dict2md.ocr_mkcontent import ( ocr_mk_mm_markdown_with_para, - ocr_mk_nlp_markdown, - ocr_mk_mm_markdown, - ocr_mk_mm_standard_format, - ocr_mk_mm_markdown_with_para_and_pagination, make_standard_format_with_para ) from magic_pdf.libs.commons import join_path, read_file @@ -67,12 +64,7 @@ def ocr_parse_core(book_name, pdf_bytes, ocr_pdf_model_info, start_page_id=0): save_path = join_path(save_tmp_path, "md") save_path_with_bookname = os.path.join(save_path, book_name) text_content_save_path = f"{save_path_with_bookname}/book.md" - pdf_info_dict = parse_pdf_by_ocr( - pdf_bytes, - ocr_pdf_model_info, - save_path, - book_name, - debug_mode=True) + pdf_info_dict, parse_time = ocr_parse_pdf_core(pdf_bytes, ocr_pdf_model_info, book_name, start_page_id=start_page_id, debug_mode=True) parent_dir = os.path.dirname(text_content_save_path) if not os.path.exists(parent_dir): diff --git a/magic_pdf/pipeline.py b/magic_pdf/pipeline.py index 911a14d6..e1cdec39 100644 --- a/magic_pdf/pipeline.py +++ b/magic_pdf/pipeline.py @@ -3,9 +3,6 @@ import sys import time from urllib.parse import quote -from magic_pdf.dict2md.ocr_mkcontent import ocr_mk_mm_markdown, ocr_mk_nlp_markdown_with_para, \ - ocr_mk_mm_markdown_with_para_and_pagination, ocr_mk_mm_markdown_with_para, ocr_mk_mm_standard_format, \ - make_standard_format_with_para from magic_pdf.libs.commons import ( read_file, join_path, @@ -15,34 +12,19 @@ from magic_pdf.libs.commons import ( ) from magic_pdf.libs.drop_reason import DropReason from magic_pdf.libs.json_compressor import JsonCompressor -from magic_pdf.dict2md.mkcontent import mk_nlp_markdown, mk_universal_format +from magic_pdf.dict2md.mkcontent import mk_universal_format from magic_pdf.pdf_parse_by_model import parse_pdf_by_model from magic_pdf.filter.pdf_classify_by_type import classify from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan from loguru import logger -from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr from magic_pdf.pdf_parse_for_train import parse_pdf_for_train -from magic_pdf.spark.base import exception_handler, get_data_source +from magic_pdf.spark.base import exception_handler, get_data_source, get_bookname, get_pdf_bytes from magic_pdf.train_utils.convert_to_train_format import convert_to_train_format -from app.common.s3 import get_s3_config, get_s3_client +from magic_pdf.spark.s3 import get_s3_config, get_s3_client -def get_data_type(jso: dict): - data_type = jso.get("data_type") - if data_type is None: - data_type = jso.get("file_type") - return data_type - - -def get_bookid(jso: dict): - book_id = jso.get("bookid") - if book_id is None: - book_id = jso.get("original_file_id") - return book_id - - def meta_scan(jso: dict, doc_layout_check=True) -> dict: s3_pdf_path = jso.get("file_location") s3_config = get_s3_config(s3_pdf_path) @@ -310,17 +292,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: file_id = jso.get("file_id") book_name = f"{data_source}/{file_id}" - # 1.23.22已修复 - # if debug_mode: - # pass - # else: - # if book_name == "zlib/zlib_21929367": - # jso['need_drop'] = True - # jso['drop_reason'] = DropReason.SPECIAL_PDF - # return jso - junk_img_bojids = jso["pdf_meta"]["junk_img_bojids"] - # total_page = jso['pdf_meta']['total_page'] # 增加检测 max_svgs 数量的检测逻辑,如果 max_svgs 超过3000则drop svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"] @@ -328,9 +300,6 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: if max_svgs > 3000: jso["need_drop"] = True jso["drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS - # elif total_page > 1000: - # jso['need_drop'] = True - # jso['drop_reason'] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES else: try: save_path = s3_image_save_path @@ -459,79 +428,10 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d 2.再调用ocr_dropped_parse_pdf,对之前drop的pdf进行处理 """ - -def uni_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: - jso = parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode) - jso = ocr_dropped_parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode) - return jso - - -# 专门用来跑被drop的pdf,跑完之后需要把need_drop字段置为false -def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: - if not jso.get("need_drop", False): - return jso - else: - jso = ocr_parse_pdf_core( - jso, start_page_id=start_page_id, debug_mode=debug_mode - ) - jso["need_drop"] = False - return jso - - -def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: - # 检测debug开关 - if debug_mode: - pass - else: # 如果debug没开,则检测是否有needdrop字段 - if jso.get("need_drop", False): - return jso - - jso = ocr_parse_pdf_core(jso, start_page_id=start_page_id, debug_mode=debug_mode) - return jso - - -def ocr_parse_pdf_core(jso: dict, start_page_id=0, debug_mode=False) -> dict: - - s3_pdf_path = jso.get("file_location") - s3_config = get_s3_config(s3_pdf_path) - pdf_bytes = read_file(s3_pdf_path, s3_config) - - model_output_json_list = jso.get("doc_layout_result") - data_source = get_data_source(jso) - file_id = jso.get("file_id") - book_name = f"{data_source}/{file_id}" - try: - save_path = s3_image_save_path - image_s3_config = get_s3_config(save_path) - start_time = time.time() # 记录开始时间 - # 先打印一下book_name和解析开始的时间 - logger.info( - f"book_name is:{book_name},start_time is:{formatted_time(start_time)}", - file=sys.stderr, - ) - pdf_info_dict = parse_pdf_by_ocr( - pdf_bytes, - model_output_json_list, - save_path, - book_name, - pdf_model_profile=None, - image_s3_config=image_s3_config, - start_page_id=start_page_id, - debug_mode=debug_mode, - ) - pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict) - jso["pdf_intermediate_dict"] = pdf_info_dict - end_time = time.time() # 记录完成时间 - parse_time = int(end_time - start_time) # 计算执行时间 - # 解析完成后打印一下book_name和耗时 - logger.info( - f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}", - file=sys.stderr, - ) - jso["parse_time"] = parse_time - except Exception as e: - jso = exception_handler(jso, e) - return jso +# def uni_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: +# jso = parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode) +# jso = ocr_dropped_parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode) +# return jso if __name__ == "__main__": diff --git a/magic_pdf/pipeline_ocr.py b/magic_pdf/pipeline_ocr.py index 6b6c7a41..124a0ccd 100644 --- a/magic_pdf/pipeline_ocr.py +++ b/magic_pdf/pipeline_ocr.py @@ -1,13 +1,16 @@ import sys +import time from loguru import logger from magic_pdf.dict2md.ocr_mkcontent import ocr_mk_mm_markdown, ocr_mk_nlp_markdown_with_para, \ ocr_mk_mm_markdown_with_para_and_pagination, ocr_mk_mm_markdown_with_para, ocr_mk_mm_standard_format, \ make_standard_format_with_para -from magic_pdf.libs.commons import join_path +from magic_pdf.libs.commons import join_path, s3_image_save_path, formatted_time from magic_pdf.libs.json_compressor import JsonCompressor -from magic_pdf.spark.base import get_data_source, exception_handler +from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr +from magic_pdf.spark.base import get_data_source, exception_handler, get_pdf_bytes, get_bookname +from magic_pdf.spark.s3 import get_s3_config def ocr_pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict: @@ -179,4 +182,74 @@ def ocr_pdf_intermediate_dict_to_standard_format_with_para(jso: dict, debug_mode jso["pdf_meta"] = "" except Exception as e: jso = exception_handler(jso, e) - return jso \ No newline at end of file + return jso + + +def ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name, start_page_id=0, debug_mode=False): + save_path = s3_image_save_path + image_s3_config = get_s3_config(save_path) + start_time = time.time() # 记录开始时间 + # 先打印一下book_name和解析开始的时间 + logger.info( + f"book_name is:{book_name},start_time is:{formatted_time(start_time)}", + file=sys.stderr, + ) + pdf_info_dict = parse_pdf_by_ocr( + pdf_bytes, + model_output_json_list, + save_path, + book_name, + pdf_model_profile=None, + image_s3_config=image_s3_config, + start_page_id=start_page_id, + debug_mode=debug_mode, + ) + end_time = time.time() # 记录完成时间 + parse_time = int(end_time - start_time) # 计算执行时间 + # 解析完成后打印一下book_name和耗时 + logger.info( + f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}", + file=sys.stderr, + ) + + return pdf_info_dict, parse_time + + +# 专门用来跑被drop的pdf,跑完之后需要把need_drop字段置为false +def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: + if not jso.get("need_drop", False): + return jso + else: + try: + pdf_bytes = get_pdf_bytes(jso) + model_output_json_list = jso.get("doc_layout_result") + book_name = get_bookname(jso) + pdf_info_dict, parse_time = ocr_parse_pdf_core( + pdf_bytes, model_output_json_list, book_name, start_page_id=start_page_id, debug_mode=debug_mode + ) + jso["pdf_intermediate_dict"] = JsonCompressor.compress_json(pdf_info_dict) + jso["parse_time"] = parse_time + jso["need_drop"] = False + except Exception as e: + jso = exception_handler(jso, e) + return jso + + +def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: + # 检测debug开关 + if debug_mode: + pass + else: # 如果debug没开,则检测是否有needdrop字段 + if jso.get("need_drop", False): + return jso + try: + pdf_bytes = get_pdf_bytes(jso) + model_output_json_list = jso.get("doc_layout_result") + book_name = get_bookname(jso) + pdf_info_dict, parse_time = ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name, + start_page_id=start_page_id, debug_mode=debug_mode) + jso["pdf_intermediate_dict"] = JsonCompressor.compress_json(pdf_info_dict) + jso["parse_time"] = parse_time + except Exception as e: + jso = exception_handler(jso, e) + return jso diff --git a/magic_pdf/spark/base.py b/magic_pdf/spark/base.py index c08d1a9f..43b786a1 100644 --- a/magic_pdf/spark/base.py +++ b/magic_pdf/spark/base.py @@ -1,9 +1,11 @@ - from loguru import logger +from magic_pdf.libs.commons import read_file from magic_pdf.libs.drop_reason import DropReason +from magic_pdf.spark.s3 import get_s3_config + def get_data_source(jso: dict): data_source = jso.get("data_source") @@ -12,6 +14,20 @@ def get_data_source(jso: dict): return data_source +def get_data_type(jso: dict): + data_type = jso.get("data_type") + if data_type is None: + data_type = jso.get("file_type") + return data_type + + +def get_bookid(jso: dict): + book_id = jso.get("bookid") + if book_id is None: + book_id = jso.get("original_file_id") + return book_id + + def exception_handler(jso: dict, e): logger.exception(e) jso["need_drop"] = True @@ -19,3 +35,16 @@ def exception_handler(jso: dict, e): jso["exception"] = f"ERROR: {e}" return jso + +def get_bookname(jso: dict): + data_source = get_data_source(jso) + file_id = jso.get("file_id") + book_name = f"{data_source}/{file_id}" + return book_name + + +def get_pdf_bytes(jso: dict): + pdf_s3_path = jso.get("file_location") + s3_config = get_s3_config(pdf_s3_path) + pdf_bytes = read_file(pdf_s3_path, s3_config) + return pdf_bytes \ No newline at end of file diff --git a/app/common/s3.py b/magic_pdf/spark/s3.py similarity index 93% rename from app/common/s3.py rename to magic_pdf/spark/s3.py index 329ee75c..aa4a9391 100644 --- a/app/common/s3.py +++ b/magic_pdf/spark/s3.py @@ -2,10 +2,13 @@ import boto3 from botocore.client import Config -from app.common import s3_buckets, s3_clusters, get_cluster_name, s3_users import re import random from typing import List, Union +try: + from app.common import s3_buckets, s3_clusters, get_cluster_name, s3_users +except ImportError: + from magic_pdf.spark import s3_buckets, s3_clusters, get_cluster_name, s3_users __re_s3_path = re.compile("^s3a?://([^/]+)(?:/(.*))?$") def get_s3_config(path: Union[str, List[str]], outside=False): diff --git a/tests/test_commons.py b/tests/test_commons.py index 0732b0a4..e41576e1 100644 --- a/tests/test_commons.py +++ b/tests/test_commons.py @@ -3,7 +3,7 @@ import json import os from magic_pdf.libs.commons import fitz -from app.common.s3 import get_s3_config, get_s3_client +from magic_pdf.spark.s3 import get_s3_config, get_s3_client from magic_pdf.libs.commons import join_path, json_dump_path, read_file, parse_bucket_key from loguru import logger