Compare commits

...

14 Commits

Author SHA1 Message Date
赵小蒙
cd8b2d2c78 修复import错误 2024-03-29 17:48:41 +08:00
赵小蒙
d35c49268d 修复import错误 2024-03-29 17:41:53 +08:00
赵小蒙
016cde3ece 修复init错误 2024-03-29 17:29:44 +08:00
赵小蒙
4b8dbd7cfb ocr_pdf_intermediate_dict_to_markdown_with_para支持mm和nlp双模式 2024-03-29 17:18:32 +08:00
赵小蒙
d6a5724b26 table_latex支持 2024-03-29 15:56:18 +08:00
赵小蒙
50a543ce0e s3配置信息路径更换 2024-03-29 14:57:12 +08:00
赵小蒙
575ca00e01 app.common依赖删除,pipeline_ocr重构 2024-03-29 14:04:57 +08:00
赵小蒙
7f0c734ff6 pipeline重构 2024-03-28 19:02:03 +08:00
赵小蒙
872cd73f4a pipeline重构 2024-03-28 17:02:16 +08:00
赵小蒙
7fcbae01fe demo重构 2024-03-28 17:01:53 +08:00
赵小蒙
752d620a0c Merge remote-tracking branch 'origin/master' 2024-03-28 14:45:11 +08:00
赵小蒙
fc10772503 ocr_construct_page_component 位置移动 2024-03-28 14:45:00 +08:00
liusilu
fd616c5778 Merge branch 'master' of https://github.com/myhloli/Magic-PDF 2024-03-28 13:35:22 +08:00
liusilu
acb9cbd6d2 add pdf tools 2024-03-28 13:35:12 +08:00
18 changed files with 959 additions and 435 deletions

1
.gitignore vendored
View File

@@ -34,3 +34,4 @@ tmp
ocr_demo
/app/common/__init__.py
/magic_pdf/config/__init__.py

View File

@@ -28,7 +28,9 @@ pip install -r requirements.txt
3.Run the main script
```sh
use demo/demo_test.py
use demo/text_demo.py
or
use demo/ocr_demo.py
```
### 版权说明

32
demo/demo_commons.py Normal file
View File

@@ -0,0 +1,32 @@
import json
from magic_pdf.spark.s3 import get_s3_config
from magic_pdf.libs.commons import join_path, read_file, json_dump_path
local_json_path = "Z:/format.json"
local_jsonl_path = "Z:/format.jsonl"
def get_json_from_local_or_s3(book_name=None):
if book_name is None:
with open(local_json_path, "r", encoding="utf-8") as json_file:
json_line = json_file.read()
json_object = json.loads(json_line)
else:
# error_log_path & json_dump_path
# 可配置从上述两个地址获取源json
json_path = join_path(json_dump_path, book_name + ".json")
s3_config = get_s3_config(json_path)
file_content = read_file(json_path, s3_config)
json_str = file_content.decode("utf-8")
# logger.info(json_str)
json_object = json.loads(json_str)
return json_object
def write_json_to_local(jso, book_name=None):
if book_name is None:
with open(local_json_path, "w", encoding="utf-8") as file:
file.write(json.dumps(jso, ensure_ascii=False))
else:
pass

View File

@@ -4,18 +4,14 @@ import os
from loguru import logger
from pathlib import Path
from app.common.s3 import get_s3_config
from demo.demo_test import get_json_from_local_or_s3
from magic_pdf.pipeline_ocr import ocr_parse_pdf_core
from magic_pdf.spark.s3 import get_s3_config
from demo.demo_commons import get_json_from_local_or_s3
from magic_pdf.dict2md.ocr_mkcontent import (
ocr_mk_mm_markdown_with_para,
ocr_mk_nlp_markdown,
ocr_mk_mm_markdown,
ocr_mk_mm_standard_format,
ocr_mk_mm_markdown_with_para_and_pagination,
make_standard_format_with_para
)
from magic_pdf.libs.commons import join_path
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
from magic_pdf.libs.commons import join_path, read_file
def save_markdown(markdown_text, input_filepath):
@@ -43,7 +39,8 @@ def ocr_local_parse(ocr_pdf_path, ocr_json_file_path):
ocr_pdf_model_info = read_json_file(ocr_json_file_path)
pth = Path(ocr_json_file_path)
book_name = pth.name
ocr_parse_core(book_name, ocr_pdf_path, ocr_pdf_model_info)
pdf_bytes = read_file(ocr_pdf_path, None)
ocr_parse_core(book_name, pdf_bytes, ocr_pdf_model_info)
except Exception as e:
logger.exception(e)
@@ -54,24 +51,19 @@ def ocr_online_parse(book_name, start_page_id=0, debug_mode=True):
# logger.info(json_object)
s3_pdf_path = json_object["file_location"]
s3_config = get_s3_config(s3_pdf_path)
pdf_bytes = read_file(s3_pdf_path, s3_config)
ocr_pdf_model_info = json_object.get("doc_layout_result")
ocr_parse_core(book_name, s3_pdf_path, ocr_pdf_model_info, s3_config=s3_config)
ocr_parse_core(book_name, pdf_bytes, ocr_pdf_model_info)
except Exception as e:
logger.exception(e)
def ocr_parse_core(book_name, ocr_pdf_path, ocr_pdf_model_info, start_page_id=0, s3_config=None):
def ocr_parse_core(book_name, pdf_bytes, ocr_pdf_model_info, start_page_id=0):
save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
save_path = join_path(save_tmp_path, "md")
save_path_with_bookname = os.path.join(save_path, book_name)
text_content_save_path = f"{save_path_with_bookname}/book.md"
pdf_info_dict = parse_pdf_by_ocr(
ocr_pdf_path,
s3_config,
ocr_pdf_model_info,
save_path,
book_name,
debug_mode=True)
pdf_info_dict, parse_time = ocr_parse_pdf_core(pdf_bytes, ocr_pdf_model_info, book_name, start_page_id=start_page_id, debug_mode=True)
parent_dir = os.path.dirname(text_content_save_path)
if not os.path.exists(parent_dir):

View File

@@ -5,7 +5,7 @@ from pathlib import Path
import click
from loguru import logger
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.commons import join_path, read_file
from magic_pdf.dict2md.mkcontent import mk_mm_markdown
from magic_pdf.pipeline import parse_pdf_by_model
@@ -21,9 +21,11 @@ def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path: str, pdf_model_p
text_content_save_path = f"{save_path}/{book_name}/book.md"
# metadata_save_path = f"{save_path}/{book_name}/metadata.json"
pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile)
try:
paras_dict = parse_pdf_by_model(
s3_pdf_path, s3_pdf_profile, pdf_model_path, save_path, book_name, pdf_model_profile, start_page_num, debug_mode=debug_mode
pdf_bytes, pdf_model_path, save_path, book_name, pdf_model_profile, start_page_num, debug_mode=debug_mode
)
parent_dir = os.path.dirname(text_content_save_path)
if not os.path.exists(parent_dir):

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import click
from demo.demo_commons import get_json_from_local_or_s3, write_json_to_local, local_jsonl_path, local_json_path
from magic_pdf.dict2md.mkcontent import mk_mm_markdown
from magic_pdf.pipeline import (
meta_scan,
@@ -13,38 +14,10 @@ from magic_pdf.pipeline import (
pdf_intermediate_dict_to_markdown,
save_tables_to_s3,
)
from magic_pdf.libs.commons import join_path, read_file, json_dump_path
from app.common.s3 import get_s3_config
from magic_pdf.libs.commons import join_path
from loguru import logger
local_json_path = "Z:/format.json"
local_jsonl_path = "Z:/format.jsonl"
def get_json_from_local_or_s3(book_name=None):
if book_name is None:
with open(local_json_path, "r", encoding="utf-8") as json_file:
json_line = json_file.read()
json_object = json.loads(json_line)
else:
# error_log_path & json_dump_path
# 可配置从上述两个地址获取源json
json_path = join_path(json_dump_path, book_name + ".json")
s3_config = get_s3_config(json_path)
file_content = read_file(json_path, s3_config)
json_str = file_content.decode("utf-8")
# logger.info(json_str)
json_object = json.loads(json_str)
return json_object
def write_json_to_local(jso, book_name=None):
if book_name is None:
with open(local_json_path, "w", encoding="utf-8") as file:
file.write(json.dumps(jso, ensure_ascii=False))
else:
pass
def demo_parse_pdf(book_name=None, start_page_id=0, debug_mode=True):

View File

@@ -166,45 +166,65 @@ def mk_mm_markdown_1(para_dict: dict):
return content_text
def __insert_after_para(text, image_path, content_list):
def __insert_after_para(text, type, element, content_list):
"""
在content_list中找到text将image_path作为一个新的node插入到text后面
"""
for i, c in enumerate(content_list):
content_type = c.get("type")
if content_type in UNI_FORMAT_TEXT_TYPE and text in c.get("text", ''):
img_node = {
"type": "image",
"img_path": image_path,
"img_alt":"",
"img_title":"",
"img_caption":""
}
content_list.insert(i+1, img_node)
if type == "image":
content_node = {
"type": "image",
"img_path": element.get("image_path"),
"img_alt": "",
"img_title": "",
"img_caption": "",
}
elif type == "table":
content_node = {
"type": "table",
"img_path": element.get("image_path"),
"table_latex": element.get("text"),
"table_title": "",
"table_caption": "",
"table_quality": element.get("quality"),
}
content_list.insert(i+1, content_node)
break
else:
logger.error(f"Can't find the location of image {image_path} in the markdown file, search target is {text}")
logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}")
def __insert_before_para(text, image_path, content_list):
def __insert_before_para(text, type, element, content_list):
"""
在content_list中找到text将image_path作为一个新的node插入到text前面
"""
for i, c in enumerate(content_list):
content_type = c.get("type")
if content_type in UNI_FORMAT_TEXT_TYPE and text in c.get("text", ''):
img_node = {
"type": "image",
"img_path": image_path,
"img_alt":"",
"img_title":"",
"img_caption":""
}
content_list.insert(i, img_node)
if type == "image":
content_node = {
"type": "image",
"img_path": element.get("image_path"),
"img_alt": "",
"img_title": "",
"img_caption": "",
}
elif type == "table":
content_node = {
"type": "table",
"img_path": element.get("image_path"),
"table_latex": element.get("text"),
"table_title": "",
"table_caption": "",
"table_quality": element.get("quality"),
}
content_list.insert(i, content_node)
break
else:
logger.error(f"Can't find the location of image {image_path} in the markdown file, search target is {text}")
logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}")
def mk_universal_format(para_dict: dict):
@@ -220,9 +240,11 @@ def mk_universal_format(para_dict: dict):
all_page_images = []
all_page_images.extend(page_info.get("images",[]))
all_page_images.extend(page_info.get("image_backup", []) )
all_page_images.extend(page_info.get("tables",[]))
all_page_images.extend(page_info.get("table_backup",[]) )
# all_page_images.extend(page_info.get("tables",[]))
# all_page_images.extend(page_info.get("table_backup",[]) )
all_page_tables = []
all_page_tables.extend(page_info.get("tables", []))
if not para_blocks or not pymu_raw_blocks: # 只有图片的拼接的场景
for img in all_page_images:
content_node = {
@@ -233,6 +255,16 @@ def mk_universal_format(para_dict: dict):
"img_caption":""
}
page_lst.append(content_node) # TODO 图片顺序
for table in all_page_tables:
content_node = {
"type": "table",
"img_path": table['image_path'],
"table_latex": table.get("text"),
"table_title": "",
"table_caption": "",
"table_quality": table.get("quality"),
}
page_lst.append(content_node) # TODO 图片顺序
else:
for block in para_blocks:
item = block["paras"]
@@ -266,56 +298,65 @@ def mk_universal_format(para_dict: dict):
"""插入图片"""
for img in all_page_images:
imgbox = img['bbox']
img_content = f"{img['image_path']}"
# 先看在哪个block内
for block in pymu_raw_blocks:
bbox = block['bbox']
if bbox[0]-1 <= imgbox[0] < bbox[2]+1 and bbox[1]-1 <= imgbox[1] < bbox[3]+1:# 确定在这个大的block内然后进入逐行比较距离
for l in block['lines']:
line_box = l['bbox']
if line_box[0]-1 <= imgbox[0] < line_box[2]+1 and line_box[1]-1 <= imgbox[1] < line_box[3]+1: # 在line内的插入line前面
line_txt = "".join([s['text'] for s in l['spans']])
__insert_before_para(line_txt, img_content, content_lst)
break
break
else:# 在行与行之间
# 找到图片x0,y0与line的x0,y0最近的line
min_distance = 100000
min_line = None
for l in block['lines']:
line_box = l['bbox']
distance = math.sqrt((line_box[0] - imgbox[0])**2 + (line_box[1] - imgbox[1])**2)
if distance < min_distance:
min_distance = distance
min_line = l
if min_line:
line_txt = "".join([s['text'] for s in min_line['spans']])
img_h = imgbox[3] - imgbox[1]
if min_distance<img_h: # 文字在图片前面
__insert_after_para(line_txt, img_content, content_lst)
else:
__insert_before_para(line_txt, img_content, content_lst)
break
else:
logger.error(f"Can't find the location of image {img['image_path']} in the markdown file #1")
else:# 应当在两个block之间
# 找到上方最近的block如果上方没有就找大下方最近的block
top_txt_block = find_top_nearest_text_bbox(pymu_raw_blocks, imgbox)
if top_txt_block:
line_txt = "".join([s['text'] for s in top_txt_block['lines'][-1]['spans']])
__insert_after_para(line_txt, img_content, content_lst)
else:
bottom_txt_block = find_bottom_nearest_text_bbox(pymu_raw_blocks, imgbox)
if bottom_txt_block:
line_txt = "".join([s['text'] for s in bottom_txt_block['lines'][0]['spans']])
__insert_before_para(line_txt, img_content, content_lst)
else: # TODO ,图片可能独占一列,这种情况上下是没有图片的
logger.error(f"Can't find the location of image {img['image_path']} in the markdown file #2")
insert_img_or_table("image", img, pymu_raw_blocks, content_lst)
"""插入表格"""
for table in all_page_tables:
insert_img_or_table("table", table, pymu_raw_blocks, content_lst)
# end for
return content_lst
def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
element_bbox = element['bbox']
# 先看在哪个block内
for block in pymu_raw_blocks:
bbox = block['bbox']
if bbox[0] - 1 <= element_bbox[0] < bbox[2] + 1 and bbox[1] - 1 <= element_bbox[1] < bbox[
3] + 1: # 确定在这个大的block内然后进入逐行比较距离
for l in block['lines']:
line_box = l['bbox']
if line_box[0] - 1 <= element_bbox[0] < line_box[2] + 1 and line_box[1] - 1 <= element_bbox[1] < line_box[
3] + 1: # 在line内的插入line前面
line_txt = "".join([s['text'] for s in l['spans']])
__insert_before_para(line_txt, type, element, content_lst)
break
break
else: # 在行与行之间
# 找到图片x0,y0与line的x0,y0最近的line
min_distance = 100000
min_line = None
for l in block['lines']:
line_box = l['bbox']
distance = math.sqrt((line_box[0] - element_bbox[0]) ** 2 + (line_box[1] - element_bbox[1]) ** 2)
if distance < min_distance:
min_distance = distance
min_line = l
if min_line:
line_txt = "".join([s['text'] for s in min_line['spans']])
img_h = element_bbox[3] - element_bbox[1]
if min_distance < img_h: # 文字在图片前面
__insert_after_para(line_txt, type, element, content_lst)
else:
__insert_before_para(line_txt, type, element, content_lst)
break
else:
logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file #1")
else: # 应当在两个block之间
# 找到上方最近的block如果上方没有就找大下方最近的block
top_txt_block = find_top_nearest_text_bbox(pymu_raw_blocks, element_bbox)
if top_txt_block:
line_txt = "".join([s['text'] for s in top_txt_block['lines'][-1]['spans']])
__insert_after_para(line_txt, type, element, content_lst)
else:
bottom_txt_block = find_bottom_nearest_text_bbox(pymu_raw_blocks, element_bbox)
if bottom_txt_block:
line_txt = "".join([s['text'] for s in bottom_txt_block['lines'][0]['spans']])
__insert_before_para(line_txt, type, element, content_lst)
else: # TODO ,图片可能独占一列,这种情况上下是没有图片的
logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file #2")
def mk_mm_markdown(content_list):
"""
基于同一格式的内容列表构造markdown含图片
@@ -348,6 +389,8 @@ def mk_nlp_markdown(content_list):
content_md.append(c.get("text"))
elif content_type == "equation":
content_md.append(f"$$\n{c.get('latex')}\n$$")
elif content_type == "table":
content_md.append(f"$$\n{c.get('table_latex')}\n$$")
elif content_type in UNI_FORMAT_TEXT_TYPE:
content_md.append(f"{'#'*int(content_type[1])} {c.get('text')}")
return "\n\n".join(content_md)

View File

@@ -53,7 +53,7 @@ from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.equations_replace import combine_chars_to_pymudict, remove_chars_in_text_blocks, replace_equations_in_textblock
from magic_pdf.pre_proc.pdf_pre_filter import pdf_filter
from magic_pdf.pre_proc.detect_footer_header_by_statistics import drop_footer_header
from magic_pdf.pre_proc.construct_paras import construct_page_component
from magic_pdf.pre_proc.construct_page_dict import construct_page_component
from magic_pdf.pre_proc.fix_image import combine_images, fix_image_vertical, fix_seperated_image, include_img_title
from magic_pdf.post_proc.pdf_post_filter import pdf_post_filter
from magic_pdf.pre_proc.remove_rotate_bbox import get_side_boundry, remove_rotate_side_textblock, remove_side_blank_block
@@ -71,8 +71,7 @@ paraMergeException_msg = ParaMergeException().message
def parse_pdf_by_model(
s3_pdf_path,
s3_pdf_profile,
pdf_bytes,
pdf_model_output,
save_path,
book_name,
@@ -83,7 +82,7 @@ def parse_pdf_by_model(
junk_img_bojids=[],
debug_mode=False,
):
pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile)
save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
md_bookname_save_path = ""
book_name = sanitize_filename(book_name)

View File

@@ -18,6 +18,7 @@ from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.libs.safe_filename import sanitize_filename
from magic_pdf.para.para_split import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component
from magic_pdf.pre_proc.detect_footer_by_model import parse_footers
from magic_pdf.pre_proc.detect_footnote import parse_footnotes_by_model
from magic_pdf.pre_proc.detect_header import parse_headers
@@ -33,32 +34,9 @@ from magic_pdf.pre_proc.ocr_span_list_modify import remove_spans_by_bboxes, remo
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox
def construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
images, tables, interline_equations, inline_equations,
dropped_text_block, dropped_image_block, dropped_table_block, dropped_equation_block,
need_remove_spans_bboxes_dict):
return_dict = {
'preproc_blocks': blocks,
'layout_bboxes': layout_bboxes,
'page_idx': page_id,
'page_size': [page_w, page_h],
'_layout_tree': layout_tree,
'images': images,
'tables': tables,
'interline_equations': interline_equations,
'inline_equations': inline_equations,
'droped_text_block': dropped_text_block,
'droped_image_block': dropped_image_block,
'droped_table_block': dropped_table_block,
'dropped_equation_block': dropped_equation_block,
'droped_bboxes': need_remove_spans_bboxes_dict,
}
return return_dict
def parse_pdf_by_ocr(
pdf_path,
s3_pdf_profile,
pdf_bytes,
pdf_model_output,
save_path,
book_name,
@@ -68,7 +46,7 @@ def parse_pdf_by_ocr(
end_page_id=None,
debug_mode=False,
):
pdf_bytes = read_file(pdf_path, s3_pdf_profile)
save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
book_name = sanitize_filename(book_name)
md_bookname_save_path = ""
@@ -254,7 +232,7 @@ def parse_pdf_by_ocr(
dropped_equation_block.append(span)
'''构造pdf_info_dict'''
page_info = construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
page_info = ocr_construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
images, tables, interline_equations, inline_equations,
dropped_text_block, dropped_image_block, dropped_table_block,
dropped_equation_block,

View File

@@ -75,7 +75,7 @@ from magic_pdf.pre_proc.equations_replace import (
)
from magic_pdf.pre_proc.pdf_pre_filter import pdf_filter
from magic_pdf.pre_proc.detect_footer_header_by_statistics import drop_footer_header
from magic_pdf.pre_proc.construct_paras import construct_page_component
from magic_pdf.pre_proc.construct_page_dict import construct_page_component
from magic_pdf.pre_proc.fix_image import (
combine_images,
fix_image_vertical,

View File

@@ -3,9 +3,6 @@ import sys
import time
from urllib.parse import quote
from magic_pdf.dict2md.ocr_mkcontent import ocr_mk_mm_markdown, ocr_mk_nlp_markdown_with_para, \
ocr_mk_mm_markdown_with_para_and_pagination, ocr_mk_mm_markdown_with_para, ocr_mk_mm_standard_format, \
make_standard_format_with_para
from magic_pdf.libs.commons import (
read_file,
join_path,
@@ -15,34 +12,19 @@ from magic_pdf.libs.commons import (
)
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.json_compressor import JsonCompressor
from magic_pdf.dict2md.mkcontent import mk_nlp_markdown, mk_universal_format
from magic_pdf.dict2md.mkcontent import mk_universal_format
from magic_pdf.pdf_parse_by_model import parse_pdf_by_model
from magic_pdf.filter.pdf_classify_by_type import classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
from loguru import logger
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
from magic_pdf.pdf_parse_for_train import parse_pdf_for_train
from magic_pdf.spark.base import exception_handler, get_data_source
from magic_pdf.train_utils.convert_to_train_format import convert_to_train_format
from app.common.s3 import get_s3_config, get_s3_client
from magic_pdf.spark.s3 import get_s3_config, get_s3_client
def get_data_type(jso: dict):
data_type = jso.get("data_type")
if data_type is None:
data_type = jso.get("file_type")
return data_type
def get_bookid(jso: dict):
book_id = jso.get("bookid")
if book_id is None:
book_id = jso.get("original_file_id")
return book_id
def meta_scan(jso: dict, doc_layout_check=True) -> dict:
s3_pdf_path = jso.get("file_location")
s3_config = get_s3_config(s3_pdf_path)
@@ -304,22 +286,13 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
# 开始正式逻辑
s3_pdf_path = jso.get("file_location")
s3_config = get_s3_config(s3_pdf_path)
pdf_bytes = read_file(s3_pdf_path, s3_config)
model_output_json_list = jso.get("doc_layout_result")
data_source = get_data_source(jso)
file_id = jso.get("file_id")
book_name = f"{data_source}/{file_id}"
# 1.23.22已修复
# if debug_mode:
# pass
# else:
# if book_name == "zlib/zlib_21929367":
# jso['need_drop'] = True
# jso['drop_reason'] = DropReason.SPECIAL_PDF
# return jso
junk_img_bojids = jso["pdf_meta"]["junk_img_bojids"]
# total_page = jso['pdf_meta']['total_page']
# 增加检测 max_svgs 数量的检测逻辑,如果 max_svgs 超过3000则drop
svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"]
@@ -327,9 +300,6 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
if max_svgs > 3000:
jso["need_drop"] = True
jso["drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
# elif total_page > 1000:
# jso['need_drop'] = True
# jso['drop_reason'] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES
else:
try:
save_path = s3_image_save_path
@@ -341,8 +311,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
file=sys.stderr,
)
pdf_info_dict = parse_pdf_by_model(
s3_pdf_path,
s3_config,
pdf_bytes,
model_output_json_list,
save_path,
book_name,
@@ -373,18 +342,6 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
return jso
"""
统一处理逻辑
1.先调用parse_pdf对文本类pdf进行处理
2.再调用ocr_dropped_parse_pdf,对之前drop的pdf进行处理
"""
def uni_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
jso = parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
jso = ocr_dropped_parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
return jso
def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> dict:
# 检测debug开关
if debug_mode:
@@ -465,242 +422,16 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d
return jso
# 专门用来跑被drop的pdf跑完之后需要把need_drop字段置为false
def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
if not jso.get("need_drop", False):
return jso
else:
jso = ocr_parse_pdf_core(
jso, start_page_id=start_page_id, debug_mode=debug_mode
)
jso["need_drop"] = False
return jso
"""
统一处理逻辑
1.先调用parse_pdf对文本类pdf进行处理
2.再调用ocr_dropped_parse_pdf,对之前drop的pdf进行处理
"""
def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
# 检测debug开关
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
return jso
jso = ocr_parse_pdf_core(jso, start_page_id=start_page_id, debug_mode=debug_mode)
return jso
def ocr_parse_pdf_core(jso: dict, start_page_id=0, debug_mode=False) -> dict:
s3_pdf_path = jso.get("file_location")
s3_config = get_s3_config(s3_pdf_path)
model_output_json_list = jso.get("doc_layout_result")
data_source = get_data_source(jso)
file_id = jso.get("file_id")
book_name = f"{data_source}/{file_id}"
try:
save_path = s3_image_save_path
image_s3_config = get_s3_config(save_path)
start_time = time.time() # 记录开始时间
# 先打印一下book_name和解析开始的时间
logger.info(
f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
file=sys.stderr,
)
pdf_info_dict = parse_pdf_by_ocr(
s3_pdf_path,
s3_config,
model_output_json_list,
save_path,
book_name,
pdf_model_profile=None,
image_s3_config=image_s3_config,
start_page_id=start_page_id,
debug_mode=debug_mode,
)
pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict)
jso["pdf_intermediate_dict"] = pdf_info_dict
end_time = time.time() # 记录完成时间
parse_time = int(end_time - start_time) # 计算执行时间
# 解析完成后打印一下book_name和耗时
logger.info(
f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}",
file=sys.stderr,
)
jso["parse_time"] = parse_time
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
markdown_content = ocr_mk_mm_markdown(pdf_intermediate_dict)
jso["content"] = markdown_content
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
file=sys.stderr,
)
# 把无用的信息清空
jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_markdown_with_para(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
# markdown_content = ocr_mk_mm_markdown_with_para(pdf_intermediate_dict)
markdown_content = ocr_mk_nlp_markdown_with_para(pdf_intermediate_dict)
jso["content"] = markdown_content
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
file=sys.stderr,
)
# 把无用的信息清空
jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_markdown_with_para_and_pagination(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
markdown_content = ocr_mk_mm_markdown_with_para_and_pagination(pdf_intermediate_dict)
jso["content"] = markdown_content
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
file=sys.stderr,
)
# 把无用的信息清空
# jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
# jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_markdown_with_para_for_qa(
jso: dict, debug_mode=False
) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
markdown_content = ocr_mk_mm_markdown_with_para(pdf_intermediate_dict)
jso["content_ocr"] = markdown_content
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
file=sys.stderr,
)
# 把无用的信息清空
jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
jso["mid_json_ocr"] = pdf_intermediate_dict
jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_standard_format(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
standard_format = ocr_mk_mm_standard_format(pdf_intermediate_dict)
jso["content_list"] = standard_format
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},content_list length is {len(standard_format)}",
file=sys.stderr,
)
# 把无用的信息清空
jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_standard_format_with_para(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
standard_format = make_standard_format_with_para(pdf_intermediate_dict)
jso["content_list"] = standard_format
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},content_list length is {len(standard_format)}",
file=sys.stderr,
)
# 把无用的信息清空
jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
# def uni_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
# jso = parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
# jso = ocr_dropped_parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
# return jso
if __name__ == "__main__":

259
magic_pdf/pipeline_ocr.py Normal file
View File

@@ -0,0 +1,259 @@
import sys
import time
from loguru import logger
from magic_pdf.dict2md.ocr_mkcontent import ocr_mk_mm_markdown, ocr_mk_nlp_markdown_with_para, \
ocr_mk_mm_markdown_with_para_and_pagination, ocr_mk_mm_markdown_with_para, ocr_mk_mm_standard_format, \
make_standard_format_with_para
from magic_pdf.libs.commons import join_path, s3_image_save_path, formatted_time
from magic_pdf.libs.json_compressor import JsonCompressor
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
from magic_pdf.spark.base import get_data_source, exception_handler, get_pdf_bytes, get_bookname
from magic_pdf.spark.s3 import get_s3_config
def ocr_pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
markdown_content = ocr_mk_mm_markdown(pdf_intermediate_dict)
jso["content"] = markdown_content
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
file=sys.stderr,
)
# 把无用的信息清空
jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_markdown_with_para(jso: dict, mode, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
if mode == "mm":
markdown_content = ocr_mk_mm_markdown_with_para(pdf_intermediate_dict)
elif mode == "nlp":
markdown_content = ocr_mk_nlp_markdown_with_para(pdf_intermediate_dict)
jso["content"] = markdown_content
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
file=sys.stderr,
)
# 把无用的信息清空
jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_markdown_with_para_and_pagination(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
markdown_content = ocr_mk_mm_markdown_with_para_and_pagination(pdf_intermediate_dict)
jso["content"] = markdown_content
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
file=sys.stderr,
)
# 把无用的信息清空
# jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
# jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_markdown_with_para_for_qa(
jso: dict, debug_mode=False
) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
markdown_content = ocr_mk_mm_markdown_with_para(pdf_intermediate_dict)
jso["content_ocr"] = markdown_content
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
file=sys.stderr,
)
# 把无用的信息清空
jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
jso["mid_json_ocr"] = pdf_intermediate_dict
jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_standard_format(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
standard_format = ocr_mk_mm_standard_format(pdf_intermediate_dict)
jso["content_list"] = standard_format
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},content_list length is {len(standard_format)}",
file=sys.stderr,
)
# 把无用的信息清空
jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_pdf_intermediate_dict_to_standard_format_with_para(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
return jso
try:
pdf_intermediate_dict = jso["pdf_intermediate_dict"]
# 将 pdf_intermediate_dict 解压
pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
standard_format = make_standard_format_with_para(pdf_intermediate_dict)
jso["content_list"] = standard_format
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']},content_list length is {len(standard_format)}",
file=sys.stderr,
)
# 把无用的信息清空
jso["doc_layout_result"] = ""
jso["pdf_intermediate_dict"] = ""
jso["pdf_meta"] = ""
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name, start_page_id=0, debug_mode=False):
save_path = s3_image_save_path
image_s3_config = get_s3_config(save_path)
start_time = time.time() # 记录开始时间
# 先打印一下book_name和解析开始的时间
logger.info(
f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
file=sys.stderr,
)
pdf_info_dict = parse_pdf_by_ocr(
pdf_bytes,
model_output_json_list,
save_path,
book_name,
pdf_model_profile=None,
image_s3_config=image_s3_config,
start_page_id=start_page_id,
debug_mode=debug_mode,
)
end_time = time.time() # 记录完成时间
parse_time = int(end_time - start_time) # 计算执行时间
# 解析完成后打印一下book_name和耗时
logger.info(
f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}",
file=sys.stderr,
)
return pdf_info_dict, parse_time
# 专门用来跑被drop的pdf跑完之后需要把need_drop字段置为false
def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
if not jso.get("need_drop", False):
return jso
else:
try:
pdf_bytes = get_pdf_bytes(jso)
model_output_json_list = jso.get("doc_layout_result")
book_name = get_bookname(jso)
pdf_info_dict, parse_time = ocr_parse_pdf_core(
pdf_bytes, model_output_json_list, book_name, start_page_id=start_page_id, debug_mode=debug_mode
)
jso["pdf_intermediate_dict"] = JsonCompressor.compress_json(pdf_info_dict)
jso["parse_time"] = parse_time
jso["need_drop"] = False
except Exception as e:
jso = exception_handler(jso, e)
return jso
def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
# 检测debug开关
if debug_mode:
pass
else: # 如果debug没开则检测是否有needdrop字段
if jso.get("need_drop", False):
return jso
try:
pdf_bytes = get_pdf_bytes(jso)
model_output_json_list = jso.get("doc_layout_result")
book_name = get_bookname(jso)
pdf_info_dict, parse_time = ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name,
start_page_id=start_page_id, debug_mode=debug_mode)
jso["pdf_intermediate_dict"] = JsonCompressor.compress_json(pdf_info_dict)
jso["parse_time"] = parse_time
except Exception as e:
jso = exception_handler(jso, e)
return jso

View File

@@ -28,3 +28,26 @@ def construct_page_component(page_id, image_info, table_info, text_blocks_prepr
return_dict['footnote_bboxes_tmp'] = footnote_bboxes_tmp
return return_dict
def ocr_construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
images, tables, interline_equations, inline_equations,
dropped_text_block, dropped_image_block, dropped_table_block, dropped_equation_block,
need_remove_spans_bboxes_dict):
return_dict = {
'preproc_blocks': blocks,
'layout_bboxes': layout_bboxes,
'page_idx': page_id,
'page_size': [page_w, page_h],
'_layout_tree': layout_tree,
'images': images,
'tables': tables,
'interline_equations': interline_equations,
'inline_equations': inline_equations,
'droped_text_block': dropped_text_block,
'droped_image_block': dropped_image_block,
'droped_table_block': dropped_table_block,
'dropped_equation_block': dropped_equation_block,
'droped_bboxes': need_remove_spans_bboxes_dict,
}
return return_dict

View File

View File

@@ -1,9 +1,11 @@
from loguru import logger
from magic_pdf.libs.commons import read_file
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.spark.s3 import get_s3_config
def get_data_source(jso: dict):
data_source = jso.get("data_source")
@@ -12,6 +14,20 @@ def get_data_source(jso: dict):
return data_source
def get_data_type(jso: dict):
data_type = jso.get("data_type")
if data_type is None:
data_type = jso.get("file_type")
return data_type
def get_bookid(jso: dict):
book_id = jso.get("bookid")
if book_id is None:
book_id = jso.get("original_file_id")
return book_id
def exception_handler(jso: dict, e):
logger.exception(e)
jso["need_drop"] = True
@@ -19,3 +35,16 @@ def exception_handler(jso: dict, e):
jso["exception"] = f"ERROR: {e}"
return jso
def get_bookname(jso: dict):
data_source = get_data_source(jso)
file_id = jso.get("file_id")
book_name = f"{data_source}/{file_id}"
return book_name
def get_pdf_bytes(jso: dict):
pdf_s3_path = jso.get("file_location")
s3_config = get_s3_config(pdf_s3_path)
pdf_bytes = read_file(pdf_s3_path, s3_config)
return pdf_bytes

View File

@@ -2,10 +2,14 @@
import boto3
from botocore.client import Config
from app.common import s3_buckets, s3_clusters, get_cluster_name, s3_users
import re
import random
from typing import List, Union
try:
from app.config import s3_buckets, s3_clusters, s3_users
from app.common.runtime import get_cluster_name
except ImportError:
from magic_pdf.config import s3_buckets, s3_clusters, get_cluster_name, s3_users
__re_s3_path = re.compile("^s3a?://([^/]+)(?:/(.*))?$")
def get_s3_config(path: Union[str, List[str]], outside=False):

View File

@@ -0,0 +1,456 @@
import json
import pandas as pd
import numpy as np
from nltk.translate.bleu_score import sentence_bleu
import argparse
from sklearn.metrics import classification_report
from collections import Counter
from sklearn import metrics
from pandas import isnull
def indicator_cal(json_standard,json_test):
json_standard = pd.DataFrame(json_standard)
json_test = pd.DataFrame(json_test)
'''数据集总体指标'''
a=json_test[['id','mid_json']]
b=json_standard[['id','mid_json','pass_label']]
a=a.drop_duplicates(subset='id',keep='first')
a.index=range(len(a))
b=b.drop_duplicates(subset='id',keep='first')
b.index=range(len(b))
outer_merge=pd.merge(a,b,on='id',how='outer')
outer_merge.columns=['id','standard_mid_json','test_mid_json','pass_label']
standard_exist=outer_merge.standard_mid_json.apply(lambda x: not isnull(x))
test_exist=outer_merge.test_mid_json.apply(lambda x: not isnull(x))
overall_report = {}
overall_report['accuracy']=metrics.accuracy_score(standard_exist,test_exist)
overall_report['precision']=metrics.precision_score(standard_exist,test_exist)
overall_report['recall']=metrics.recall_score(standard_exist,test_exist)
overall_report['f1_score']=metrics.f1_score(standard_exist,test_exist)
inner_merge=pd.merge(a,b,on='id',how='inner')
inner_merge.columns=['id','standard_mid_json','test_mid_json','pass_label']
json_standard = inner_merge['standard_mid_json']#check一下是否对齐
json_test = inner_merge['test_mid_json']
'''批量读取中间生成的json文件'''
test_inline_equations=[]
test_interline_equations=[]
test_inline_euqations_bboxs=[]
test_interline_equations_bboxs=[]
test_dropped_text_bboxes=[]
test_dropped_text_tag=[]
test_dropped_image_bboxes=[]
test_dropped_table_bboxes=[]
test_preproc_num=[]#阅读顺序
test_para_num=[]
test_para_text=[]
for i in json_test:
mid_json=pd.DataFrame(i)
mid_json=mid_json.iloc[:,:-1]
for j1 in mid_json.loc['inline_equations',:]:
page_in_text=[]
page_in_bbox=[]
for k1 in j1:
page_in_text.append(k1['latex_text'])
page_in_bbox.append(k1['bbox'])
test_inline_equations.append(page_in_text)
test_inline_euqations_bboxs.append(page_in_bbox)
for j2 in mid_json.loc['interline_equations',:]:
page_in_text=[]
page_in_bbox=[]
for k2 in j2:
page_in_text.append(k2['latex_text'])
page_in_bbox.append(k2['bbox'])
test_interline_equations.append(page_in_text)
test_interline_equations_bboxs.append(page_in_bbox)
for j3 in mid_json.loc['droped_text_block',:]:
page_in_bbox=[]
page_in_tag=[]
for k3 in j3:
page_in_bbox.append(k3['bbox'])
#如果k3中存在tag这个key
if 'tag' in k3.keys():
page_in_tag.append(k3['tag'])
else:
page_in_tag.append('None')
test_dropped_text_tag.append(page_in_tag)
test_dropped_text_bboxes.append(page_in_bbox)
for j4 in mid_json.loc['droped_image_block',:]:
test_dropped_image_bboxes.append(j4)
for j5 in mid_json.loc['droped_table_block',:]:
test_dropped_table_bboxes.append(j5)
for j6 in mid_json.loc['preproc_blocks',:]:
page_in=[]
for k6 in j6:
page_in.append(k6['number'])
test_preproc_num.append(page_in)
test_pdf_text=[]
for j7 in mid_json.loc['para_blocks',:]:
test_para_num.append(len(j7))
for k7 in j7:
test_pdf_text.append(k7['text'])
test_para_text.append(test_pdf_text)
standard_inline_equations=[]
standard_interline_equations=[]
standard_inline_euqations_bboxs=[]
standard_interline_equations_bboxs=[]
standard_dropped_text_bboxes=[]
standard_dropped_text_tag=[]
standard_dropped_image_bboxes=[]
standard_dropped_table_bboxes=[]
standard_preproc_num=[]#阅读顺序
standard_para_num=[]
standard_para_text=[]
for i in json_standard:
mid_json=pd.DataFrame(i)
mid_json=mid_json.iloc[:,:-1]
for j1 in mid_json.loc['inline_equations',:]:
page_in_text=[]
page_in_bbox=[]
for k1 in j1:
page_in_text.append(k1['latex_text'])
page_in_bbox.append(k1['bbox'])
standard_inline_equations.append(page_in_text)
standard_inline_euqations_bboxs.append(page_in_bbox)
for j2 in mid_json.loc['interline_equations',:]:
page_in_text=[]
page_in_bbox=[]
for k2 in j2:
page_in_text.append(k2['latex_text'])
page_in_bbox.append(k2['bbox'])
standard_interline_equations.append(page_in_text)
standard_interline_equations_bboxs.append(page_in_bbox)
for j3 in mid_json.loc['droped_text_block',:]:
page_in_bbox=[]
page_in_tag=[]
for k3 in j3:
page_in_bbox.append(k3['bbox'])
if 'tag' in k3.keys():
page_in_tag.append(k3['tag'])
else:
page_in_tag.append('None')
standard_dropped_text_bboxes.append(page_in_bbox)
standard_dropped_text_tag.append(page_in_tag)
for j4 in mid_json.loc['droped_image_block',:]:
standard_dropped_image_bboxes.append(j4)
for j5 in mid_json.loc['droped_table_block',:]:
standard_dropped_table_bboxes.append(j5)
for j6 in mid_json.loc['preproc_blocks',:]:
page_in=[]
for k6 in j6:
page_in.append(k6['number'])
standard_preproc_num.append(page_in)
standard_pdf_text=[]
for j7 in mid_json.loc['para_blocks',:]:
standard_para_num.append(len(j7))
for k7 in j7:
standard_pdf_text.append(k7['text'])
standard_para_text.append(standard_pdf_text)
"""
在计算指标之前最好先确认基本统计信息是否一致
"""
'''
计算pdf之间的总体编辑距离和bleu
这里只计算正例的pdf
'''
test_para_text=np.asarray(test_para_text, dtype = object)[inner_merge['pass_label']=='yes']
standard_para_text=np.asarray(standard_para_text, dtype = object)[inner_merge['pass_label']=='yes']
pdf_dis=[]
pdf_bleu=[]
for a,b in zip(test_para_text,standard_para_text):
a1=[ ''.join(i) for i in a]
b1=[ ''.join(i) for i in b]
pdf_dis.append(Levenshtein_Distance(a1,b1))
pdf_bleu.append(sentence_bleu([a1],b1))
overall_report['pdf间的平均编辑距离']=np.mean(pdf_dis)
overall_report['pdf间的平均bleu']=np.mean(pdf_bleu)
'''行内公式和行间公式的编辑距离和bleu'''
inline_equations_edit_bleu=equations_indicator(test_inline_euqations_bboxs,standard_inline_euqations_bboxs,test_inline_equations,standard_inline_equations)
interline_equations_edit_bleu=equations_indicator(test_interline_equations_bboxs,standard_interline_equations_bboxs,test_interline_equations,standard_interline_equations)
'''行内公式bbox匹配相关指标'''
inline_equations_bbox_report=bbox_match_indicator(test_inline_euqations_bboxs,standard_inline_euqations_bboxs)
'''行间公式bbox匹配相关指标'''
interline_equations_bbox_report=bbox_match_indicator(test_interline_equations_bboxs,standard_interline_equations_bboxs)
'''可以先检查page和bbox数量是否一致'''
'''dropped_text_block的bbox匹配相关指标'''
test_text_bbox=[]
standard_text_bbox=[]
test_tag=[]
standard_tag=[]
index=0
for a,b in zip(test_dropped_text_bboxes,standard_dropped_text_bboxes):
test_page_tag=[]
standard_page_tag=[]
test_page_bbox=[]
standard_page_bbox=[]
if len(a)==0 and len(b)==0:
pass
else:
for i in range(len(b)):
judge=0
standard_page_tag.append(standard_dropped_text_tag[index][i])
standard_page_bbox.append(1)
for j in range(len(a)):
if bbox_offset(b[i],a[j]):
judge=1
test_page_tag.append(test_dropped_text_tag[index][j])
test_page_bbox.append(1)
break
if judge==0:
test_page_tag.append('None')
test_page_bbox.append(0)
if len(test_dropped_text_tag[index])+test_page_tag.count('None')>len(standard_dropped_text_tag[index]):#有多删的情况出现
test_page_tag1=test_page_tag.copy()
if 'None' in test_page_tag:
test_page_tag1=test_page_tag1.remove('None')
else:
test_page_tag1=test_page_tag
diff=list((Counter(test_dropped_text_tag[index]) - Counter(test_page_tag1)).elements())
test_page_tag.extend(diff)
standard_page_tag.extend(['None']*len(diff))
test_page_bbox.extend([1]*len(diff))
standard_page_bbox.extend([0]*len(diff))
test_tag.extend(test_page_tag)
standard_tag.extend(standard_page_tag)
test_text_bbox.extend(test_page_bbox)
standard_text_bbox.extend(standard_page_bbox)
index+=1
text_block_report = {}
text_block_report['accuracy']=metrics.accuracy_score(standard_text_bbox,test_text_bbox)
text_block_report['precision']=metrics.precision_score(standard_text_bbox,test_text_bbox)
text_block_report['recall']=metrics.recall_score(standard_text_bbox,test_text_bbox)
text_block_report['f1_score']=metrics.f1_score(standard_text_bbox,test_text_bbox)
'''删除的text_block的tag的准确率,召回率和f1-score'''
text_block_tag_report = classification_report(y_true=standard_tag , y_pred=test_tag,output_dict=True)
del text_block_tag_report['None']
del text_block_tag_report["macro avg"]
del text_block_tag_report["weighted avg"]
'''dropped_image_block的bbox匹配相关指标'''
'''有数据格式不一致的问题'''
image_block_report=bbox_match_indicator(test_dropped_image_bboxes,standard_dropped_image_bboxes)
'''dropped_table_block的bbox匹配相关指标'''
table_block_report=bbox_match_indicator(test_dropped_table_bboxes,standard_dropped_table_bboxes)
'''阅读顺序编辑距离的均值'''
preproc_num_dis=[]
for a,b in zip(test_preproc_num,standard_preproc_num):
preproc_num_dis.append(Levenshtein_Distance(a,b))
preproc_num_edit=np.mean(preproc_num_dis)
'''分段准确率'''
test_para_num=np.array(test_para_num)
standard_para_num=np.array(standard_para_num)
acc_para=np.mean(test_para_num==standard_para_num)
output=pd.DataFrame()
output['总体指标']=[overall_report]
output['行内公式平均编辑距离']=[inline_equations_edit_bleu[0]]
output['行内公式平均bleu']=[inline_equations_edit_bleu[1]]
output['行间公式平均编辑距离']=[interline_equations_edit_bleu[0]]
output['行间公式平均bleu']=[interline_equations_edit_bleu[1]]
output['行内公式识别相关指标']=[inline_equations_bbox_report]
output['行间公式识别相关指标']=[interline_equations_bbox_report]
output['阅读顺序平均编辑距离']=[preproc_num_edit]
output['分段准确率']=[acc_para]
output['删除的text block的相关指标']=[text_block_report]
output['删除的image block的相关指标']=[image_block_report]
output['删除的table block的相关指标']=[table_block_report]
output['删除的text block的tag相关指标']=[text_block_tag_report]
return output
"""
计算编辑距离
"""
def Levenshtein_Distance(str1, str2):
matrix = [[ i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
for i in range(1, len(str1)+1):
for j in range(1, len(str2)+1):
if(str1[i-1] == str2[j-1]):
d = 0
else:
d = 1
matrix[i][j] = min(matrix[i-1][j]+1, matrix[i][j-1]+1, matrix[i-1][j-1]+d)
return matrix[len(str1)][len(str2)]
'''
计算bbox偏移量是否符合标准的函数
'''
def bbox_offset(b_t,b_s):
'''b_t是test_doc里的bbox,b_s是standard_doc里的bbox'''
x1_t,y1_t,x2_t,y2_t=b_t
x1_s,y1_s,x2_s,y2_s=b_s
x1=max(x1_t,x1_s)
x2=min(x2_t,x2_s)
y1=max(y1_t,y1_s)
y2=min(y2_t,y2_s)
area_overlap=(x2-x1)*(y2-y1)
area_t=(x2_t-x1_t)*(y2_t-y1_t)+(x2_s-x1_s)*(y2_s-y1_s)-area_overlap
if area_t-area_overlap==0 or area_overlap/(area_t-area_overlap)>0.95:
return True
else:
return False
'''bbox匹配和对齐函数输出相关指标'''
'''输入的是以page为单位的bbox列表'''
def bbox_match_indicator(test_bbox_list,standard_bbox_list):
test_bbox=[]
standard_bbox=[]
for a,b in zip(test_bbox_list,standard_bbox_list):
test_page_bbox=[]
standard_page_bbox=[]
if len(a)==0 and len(b)==0:
pass
else:
for i in b:
if len(i)!=4:
continue
else:
judge=0
standard_page_bbox.append(1)
for j in a:
if bbox_offset(i,j):
judge=1
test_page_bbox.append(1)
break
if judge==0:
test_page_bbox.append(0)
diff_num=len(a)+test_page_bbox.count(0)-len(b)
if diff_num>0:#有多删的情况出现
test_page_bbox.extend([1]*diff_num)
standard_page_bbox.extend([0]*diff_num)
test_bbox.extend(test_page_bbox)
standard_bbox.extend(standard_page_bbox)
block_report = {}
block_report['accuracy']=metrics.accuracy_score(standard_bbox,test_bbox)
block_report['precision']=metrics.precision_score(standard_bbox,test_bbox)
block_report['recall']=metrics.recall_score(standard_bbox,test_bbox)
block_report['f1_score']=metrics.f1_score(standard_bbox,test_bbox)
return block_report
'''公式编辑距离和bleu'''
def equations_indicator(test_euqations_bboxs,standard_euqations_bboxs,test_equations,standard_equations):
test_match_equations=[]
standard_match_equations=[]
index=0
for a,b in zip(test_euqations_bboxs,standard_euqations_bboxs):
if len(a)==0 and len(b)==0:
pass
else:
for i in range(len(b)):
for j in range(len(a)):
if bbox_offset(b[i],a[j]):
standard_match_equations.append(standard_equations[index][i])
test_match_equations.append(test_equations[index][j])
break
index+=1
dis=[]
bleu=[]
for a,b in zip(test_match_equations,standard_match_equations):
if len(a)==0 and len(b)==0:
continue
else:
if a==b:
dis.append(0)
bleu.append(1)
else:
dis.append(Levenshtein_Distance(a,b))
bleu.append(sentence_bleu([a],b))
equations_edit=np.mean(dis)
equations_bleu=np.mean(bleu)
return (equations_edit,equations_bleu)
parser = argparse.ArgumentParser()
parser.add_argument('--test', type=str)
parser.add_argument('--standard', type=str)
args = parser.parse_args()
pdf_json_test = args.test
pdf_json_standard = args.standard
if __name__ == '__main__':
pdf_json_test = [json.loads(line)
for line in open(pdf_json_test, 'r', encoding='utf-8')]
pdf_json_standard = [json.loads(line)
for line in open(pdf_json_standard, 'r', encoding='utf-8')]
overall_indicator=indicator_cal(pdf_json_standard,pdf_json_test)
'''计算的指标输出到overall_indicator_output.json中'''
overall_indicator.to_json('overall_indicator_output.json',orient='records',lines=True,force_ascii=False)

View File

@@ -3,7 +3,7 @@ import json
import os
from magic_pdf.libs.commons import fitz
from app.common.s3 import get_s3_config, get_s3_client
from magic_pdf.spark.s3 import get_s3_config, get_s3_client
from magic_pdf.libs.commons import join_path, json_dump_path, read_file, parse_bucket_key
from loguru import logger