Add new OCRReader with PDF+OCR text merging (#66)
This change speeds up OCR extraction by allowing bypassing OCR for texts that are irrelevant (not in table). --------- Co-authored-by: Nguyen Trung Duc (john) <trungduc1992@gmail.com>
This commit is contained in:
parent
d79b3744cb
commit
4704e2c11a
|
@ -7,78 +7,81 @@ from llama_index.readers.base import BaseReader
|
||||||
|
|
||||||
from kotaemon.documents import Document
|
from kotaemon.documents import Document
|
||||||
|
|
||||||
from .utils.table import (
|
from .utils.pdf_ocr import parse_ocr_output, read_pdf_unstructured
|
||||||
extract_tables_from_csv_string,
|
from .utils.table import strip_special_chars_markdown
|
||||||
get_table_from_ocr,
|
|
||||||
strip_special_chars_markdown,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_OCR_ENDPOINT = "http://127.0.0.1:8000/v2/ai/infer/"
|
DEFAULT_OCR_ENDPOINT = "http://127.0.0.1:8000/v2/ai/infer/"
|
||||||
|
|
||||||
|
|
||||||
class OCRReader(BaseReader):
|
class OCRReader(BaseReader):
|
||||||
def __init__(self, endpoint: str = DEFAULT_OCR_ENDPOINT):
|
def __init__(self, endpoint: str = DEFAULT_OCR_ENDPOINT, use_ocr=True):
|
||||||
"""Init the OCR reader with OCR endpoint (FullOCR pipeline)
|
"""Init the OCR reader with OCR endpoint (FullOCR pipeline)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
endpoint: URL to FullOCR endpoint. Defaults to OCR_ENDPOINT.
|
endpoint: URL to FullOCR endpoint. Defaults to OCR_ENDPOINT.
|
||||||
|
use_ocr: whether to use OCR to read text
|
||||||
|
(e.g: from images, tables) in the PDF
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ocr_endpoint = endpoint
|
self.ocr_endpoint = endpoint
|
||||||
|
self.use_ocr = use_ocr
|
||||||
|
|
||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
file: Path,
|
file_path: Path,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
"""Load data using OCR reader
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (Path): Path to PDF file
|
||||||
|
debug_path (Path): Path to store debug image output
|
||||||
|
artifact_path (Path): Path to OCR endpoints artifacts directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: list of documents extracted from the PDF file
|
||||||
|
"""
|
||||||
# create input params for the requests
|
# create input params for the requests
|
||||||
content = open(file, "rb")
|
content = open(file_path, "rb")
|
||||||
files = {"input": content}
|
files = {"input": content}
|
||||||
data = {"job_id": uuid4()}
|
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
|
||||||
|
|
||||||
# init list of output documents
|
debug_path = kwargs.pop("debug_path", None)
|
||||||
documents = []
|
artifact_path = kwargs.pop("artifact_path", None)
|
||||||
all_table_csv_list = []
|
|
||||||
all_non_table_texts = []
|
|
||||||
|
|
||||||
# call the API from FullOCR endpoint
|
# call the API from FullOCR endpoint
|
||||||
if "response_content" in kwargs:
|
if "response_content" in kwargs:
|
||||||
# overriding response content if specified
|
# overriding response content if specified
|
||||||
results = kwargs["response_content"]
|
ocr_results = kwargs["response_content"]
|
||||||
else:
|
else:
|
||||||
# call original API
|
# call original API
|
||||||
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
|
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
|
||||||
results = resp.json()["result"]
|
ocr_results = resp.json()["result"]
|
||||||
|
|
||||||
for _id, each in enumerate(results):
|
# read PDF through normal reader (unstructured)
|
||||||
csv_content = each["csv_string"]
|
pdf_page_items = read_pdf_unstructured(file_path)
|
||||||
table = each["json"]["table"]
|
# merge PDF text output with OCR output
|
||||||
ocr = each["json"]["ocr"]
|
tables, texts = parse_ocr_output(
|
||||||
|
ocr_results,
|
||||||
# using helper function to extract list of table texts from FullOCR output
|
pdf_page_items,
|
||||||
table_texts = get_table_from_ocr(ocr, table)
|
debug_path=debug_path,
|
||||||
# extract the formatted CSV table from specified text
|
artifact_path=artifact_path,
|
||||||
csv_list, non_table_text = extract_tables_from_csv_string(
|
)
|
||||||
csv_content, table_texts
|
|
||||||
)
|
|
||||||
all_table_csv_list.extend([(csv, _id) for csv in csv_list])
|
|
||||||
all_non_table_texts.append((non_table_text, _id))
|
|
||||||
|
|
||||||
# create output Document with metadata from table
|
# create output Document with metadata from table
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
text=strip_special_chars_markdown(csv),
|
text=strip_special_chars_markdown(table_text),
|
||||||
metadata={
|
metadata={
|
||||||
"table_origin": csv,
|
"table_origin": table_text,
|
||||||
"type": "table",
|
"type": "table",
|
||||||
"page_label": page_id + 1,
|
"page_label": page_id + 1,
|
||||||
"source": file.name,
|
"source": file_path.name,
|
||||||
},
|
},
|
||||||
metadata_template="",
|
metadata_template="",
|
||||||
metadata_seperator="",
|
metadata_seperator="",
|
||||||
)
|
)
|
||||||
for csv, page_id in all_table_csv_list
|
for page_id, table_text in tables
|
||||||
]
|
]
|
||||||
# create Document from non-table text
|
# create Document from non-table text
|
||||||
documents.extend(
|
documents.extend(
|
||||||
|
@ -87,10 +90,10 @@ class OCRReader(BaseReader):
|
||||||
text=non_table_text,
|
text=non_table_text,
|
||||||
metadata={
|
metadata={
|
||||||
"page_label": page_id + 1,
|
"page_label": page_id + 1,
|
||||||
"source": file.name,
|
"source": file_path.name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for non_table_text, page_id in all_non_table_texts
|
for page_id, non_table_text in texts
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
144
knowledgehub/loaders/utils/box.py
Normal file
144
knowledgehub/loaders/utils/box.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_to_points(box: List[int]):
|
||||||
|
"""Convert bounding box to list of points"""
|
||||||
|
x1, y1, x2, y2 = box
|
||||||
|
return [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||||
|
|
||||||
|
|
||||||
|
def points_to_bbox(points: List[Tuple[int, int]]):
|
||||||
|
"""Convert list of points to bounding box"""
|
||||||
|
all_x = [p[0] for p in points]
|
||||||
|
all_y = [p[1] for p in points]
|
||||||
|
return [min(all_x), min(all_y), max(all_x), max(all_y)]
|
||||||
|
|
||||||
|
|
||||||
|
def scale_points(points: List[Tuple[int, int]], scale_factor: float = 1.0):
|
||||||
|
"""Scale points by a scale factor"""
|
||||||
|
return [(int(pos[0] * scale_factor), int(pos[1] * scale_factor)) for pos in points]
|
||||||
|
|
||||||
|
|
||||||
|
def union_points(points: List[Tuple[int, int]]):
|
||||||
|
"""Return union bounding box of list of points"""
|
||||||
|
all_x = [p[0] for p in points]
|
||||||
|
all_y = [p[1] for p in points]
|
||||||
|
bbox = (min(all_x), min(all_y), max(all_x), max(all_y))
|
||||||
|
return bbox
|
||||||
|
|
||||||
|
|
||||||
|
def scale_box(box: List[int], scale_factor: float = 1.0):
|
||||||
|
"""Scale box by a scale factor"""
|
||||||
|
return [int(pos * scale_factor) for pos in box]
|
||||||
|
|
||||||
|
|
||||||
|
def box_h(box: List[int]):
|
||||||
|
"Return box height"
|
||||||
|
return box[3] - box[1]
|
||||||
|
|
||||||
|
|
||||||
|
def box_w(box: List[int]):
|
||||||
|
"Return box width"
|
||||||
|
return box[2] - box[0]
|
||||||
|
|
||||||
|
|
||||||
|
def box_area(box: List[int]):
|
||||||
|
"Return box area"
|
||||||
|
x1, y1, x2, y2 = box
|
||||||
|
return (x2 - x1) * (y2 - y1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rect_iou(gt_box: List[tuple], pd_box: List[tuple], iou_type=0) -> int:
|
||||||
|
"""Intersection over union on layout rectangle
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gt_box: List[tuple]
|
||||||
|
A list contains bounding box coordinates of ground truth
|
||||||
|
pd_box: List[tuple]
|
||||||
|
A list contains bounding box coordinates of prediction
|
||||||
|
iou_type: int
|
||||||
|
0: intersection / union, normal IOU
|
||||||
|
1: intersection / min(areas), useful when boxes are under/over-segmented
|
||||||
|
|
||||||
|
Input format: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||||
|
Annotation for each element in bbox:
|
||||||
|
(x1, y1) (x2, y1)
|
||||||
|
+-------+
|
||||||
|
| |
|
||||||
|
| |
|
||||||
|
+-------+
|
||||||
|
(x1, y2) (x2, y2)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Intersection over union value
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert iou_type in [0, 1], "Only support 0: origin iou, 1: intersection / min(area)"
|
||||||
|
|
||||||
|
# determine the (x, y)-coordinates of the intersection rectangle
|
||||||
|
# gt_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||||
|
# pd_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||||
|
x_left = max(gt_box[0][0], pd_box[0][0])
|
||||||
|
y_top = max(gt_box[0][1], pd_box[0][1])
|
||||||
|
x_right = min(gt_box[2][0], pd_box[2][0])
|
||||||
|
y_bottom = min(gt_box[2][1], pd_box[2][1])
|
||||||
|
|
||||||
|
# compute the area of intersection rectangle
|
||||||
|
interArea = max(0, x_right - x_left) * max(0, y_bottom - y_top)
|
||||||
|
|
||||||
|
# compute the area of both the prediction and ground-truth
|
||||||
|
# rectangles
|
||||||
|
gt_area = (gt_box[2][0] - gt_box[0][0]) * (gt_box[2][1] - gt_box[0][1])
|
||||||
|
pd_area = (pd_box[2][0] - pd_box[0][0]) * (pd_box[2][1] - pd_box[0][1])
|
||||||
|
|
||||||
|
# compute the intersection over union by taking the intersection
|
||||||
|
# area and dividing it by the sum of prediction + ground-truth
|
||||||
|
# areas - the interesection area
|
||||||
|
if iou_type == 0:
|
||||||
|
iou = interArea / float(gt_area + pd_area - interArea)
|
||||||
|
elif iou_type == 1:
|
||||||
|
iou = interArea / max(min(gt_area, pd_area), 1)
|
||||||
|
|
||||||
|
# return the intersection over union value
|
||||||
|
return iou
|
||||||
|
|
||||||
|
|
||||||
|
def sort_funsd_reading_order(lines: List[dict], box_key_name: str = "box"):
|
||||||
|
"""Sort cell list to create the right reading order using their locations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lines: list of cells to sort
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of cell lists in the right reading order that contain
|
||||||
|
no key or start with a key and contain no other key
|
||||||
|
"""
|
||||||
|
sorted_list = []
|
||||||
|
|
||||||
|
if len(lines) == 0:
|
||||||
|
return lines
|
||||||
|
|
||||||
|
while len(lines) > 1:
|
||||||
|
topleft_line = lines[0]
|
||||||
|
for line in lines[1:]:
|
||||||
|
topleft_line_pos = topleft_line[box_key_name]
|
||||||
|
topleft_line_center_y = (topleft_line_pos[1] + topleft_line_pos[3]) / 2
|
||||||
|
x1, y1, x2, y2 = line[box_key_name]
|
||||||
|
box_center_x = (x1 + x2) / 2
|
||||||
|
box_center_y = (y1 + y2) / 2
|
||||||
|
cell_h = y2 - y1
|
||||||
|
if box_center_y <= topleft_line_center_y - cell_h / 2:
|
||||||
|
topleft_line = line
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
box_center_x < topleft_line_pos[2]
|
||||||
|
and box_center_y < topleft_line_pos[3]
|
||||||
|
):
|
||||||
|
topleft_line = line
|
||||||
|
continue
|
||||||
|
sorted_list.append(topleft_line)
|
||||||
|
lines.remove(topleft_line)
|
||||||
|
|
||||||
|
sorted_list.append(lines[0])
|
||||||
|
|
||||||
|
return sorted_list
|
295
knowledgehub/loaders/utils/pdf_ocr.py
Normal file
295
knowledgehub/loaders/utils/pdf_ocr.py
Normal file
|
@ -0,0 +1,295 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from .box import (
|
||||||
|
bbox_to_points,
|
||||||
|
box_area,
|
||||||
|
box_h,
|
||||||
|
box_w,
|
||||||
|
get_rect_iou,
|
||||||
|
points_to_bbox,
|
||||||
|
scale_box,
|
||||||
|
scale_points,
|
||||||
|
sort_funsd_reading_order,
|
||||||
|
union_points,
|
||||||
|
)
|
||||||
|
from .table import table_cells_to_markdown
|
||||||
|
|
||||||
|
IOU_THRES = 0.5
|
||||||
|
PADDING_THRES = 1.1
|
||||||
|
|
||||||
|
|
||||||
|
def read_pdf_unstructured(input_path: Union[Path, str]):
|
||||||
|
"""Convert PDF from specified path to list of text items with
|
||||||
|
location information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: path to input file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict page_number: list of text boxes
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from unstructured.partition.auto import partition
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install unstructured PDF reader \
|
||||||
|
`pip install unstructured[pdf]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
page_items = defaultdict(list)
|
||||||
|
items = partition(input_path)
|
||||||
|
for item in items:
|
||||||
|
page_number = item.metadata.page_number
|
||||||
|
bbox = points_to_bbox(item.metadata.coordinates.points)
|
||||||
|
coord_system = item.metadata.coordinates.system
|
||||||
|
max_w, max_h = coord_system.width, coord_system.height
|
||||||
|
page_items[page_number - 1].append(
|
||||||
|
{
|
||||||
|
"text": item.text,
|
||||||
|
"box": bbox,
|
||||||
|
"location": bbox_to_points(bbox),
|
||||||
|
"page_shape": (max_w, max_h),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return page_items
|
||||||
|
|
||||||
|
|
||||||
|
def merge_ocr_and_pdf_texts(
|
||||||
|
ocr_list: List[dict], pdf_text_list: List[dict], debug_info=None
|
||||||
|
):
|
||||||
|
"""Merge PDF and OCR text using IOU overlaping location
|
||||||
|
Args:
|
||||||
|
ocr_list: List of OCR items {"text", "box", "location"}
|
||||||
|
pdf_text_list: List of PDF items {"text", "box", "location"}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined list of PDF text and non-overlap OCR text
|
||||||
|
"""
|
||||||
|
not_matched_ocr = []
|
||||||
|
|
||||||
|
# check for debug info
|
||||||
|
if debug_info is not None:
|
||||||
|
cv2, debug_im = debug_info
|
||||||
|
|
||||||
|
for ocr_item in ocr_list:
|
||||||
|
matched = False
|
||||||
|
for pdf_item in pdf_text_list:
|
||||||
|
if (
|
||||||
|
get_rect_iou(ocr_item["location"], pdf_item["location"], iou_type=1)
|
||||||
|
> IOU_THRES
|
||||||
|
):
|
||||||
|
matched = True
|
||||||
|
break
|
||||||
|
|
||||||
|
color = (255, 0, 0)
|
||||||
|
if not matched:
|
||||||
|
ocr_item["matched"] = False
|
||||||
|
not_matched_ocr.append(ocr_item)
|
||||||
|
color = (0, 255, 255)
|
||||||
|
|
||||||
|
if debug_info is not None:
|
||||||
|
cv2.rectangle(
|
||||||
|
debug_im,
|
||||||
|
ocr_item["location"][0],
|
||||||
|
ocr_item["location"][2],
|
||||||
|
color=color,
|
||||||
|
thickness=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if debug_info is not None:
|
||||||
|
for pdf_item in pdf_text_list:
|
||||||
|
cv2.rectangle(
|
||||||
|
debug_im,
|
||||||
|
pdf_item["location"][0],
|
||||||
|
pdf_item["location"][2],
|
||||||
|
color=(0, 255, 0),
|
||||||
|
thickness=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pdf_text_list + not_matched_ocr
|
||||||
|
|
||||||
|
|
||||||
|
def merge_table_cell_and_ocr(
|
||||||
|
table_list: List[dict], ocr_list: List[dict], pdf_list: List[dict], debug_info=None
|
||||||
|
):
|
||||||
|
"""Merge table items with OCR text using IOU overlaping location
|
||||||
|
Args:
|
||||||
|
table_list: List of table items
|
||||||
|
"type": ("table", "cell", "text"), "text", "box", "location"}
|
||||||
|
ocr_list: List of OCR items {"text", "box", "location"}
|
||||||
|
pdf_list: List of PDF items {"text", "box", "location"}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
all_table_cells: List of tables, each of table is reprented
|
||||||
|
by list of cells with combined text from OCR
|
||||||
|
not_matched_items: List of PDF text which is not overlapped by table region
|
||||||
|
"""
|
||||||
|
# check for debug info
|
||||||
|
if debug_info is not None:
|
||||||
|
cv2, debug_im = debug_info
|
||||||
|
|
||||||
|
cell_list = [item for item in table_list if item["type"] == "cell"]
|
||||||
|
table_list = [item for item in table_list if item["type"] == "table"]
|
||||||
|
|
||||||
|
# sort table by area
|
||||||
|
table_list = sorted(table_list, key=lambda item: box_area(item["bbox"]))
|
||||||
|
|
||||||
|
all_tables = []
|
||||||
|
matched_pdf_ids = []
|
||||||
|
matched_cell_ids = []
|
||||||
|
|
||||||
|
for table in table_list:
|
||||||
|
if debug_info is not None:
|
||||||
|
cv2.rectangle(
|
||||||
|
debug_im,
|
||||||
|
table["location"][0],
|
||||||
|
table["location"][2],
|
||||||
|
color=[0, 0, 255],
|
||||||
|
thickness=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_table_cells = []
|
||||||
|
for cell_id, cell in enumerate(cell_list):
|
||||||
|
if cell_id in matched_cell_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if get_rect_iou(
|
||||||
|
table["location"], cell["location"], iou_type=1
|
||||||
|
) > IOU_THRES and box_area(table["bbox"]) > box_area(cell["bbox"]):
|
||||||
|
color = [128, 0, 128]
|
||||||
|
# cell matched to table
|
||||||
|
for item_list, item_type in [(pdf_list, "pdf"), (ocr_list, "ocr")]:
|
||||||
|
cell["ocr"] = []
|
||||||
|
for item_id, item in enumerate(item_list):
|
||||||
|
if item_type == "pdf" and item_id in matched_pdf_ids:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
get_rect_iou(item["location"], cell["location"], iou_type=1)
|
||||||
|
> IOU_THRES
|
||||||
|
):
|
||||||
|
cell["ocr"].append(item)
|
||||||
|
if item_type == "pdf":
|
||||||
|
matched_pdf_ids.append(item_id)
|
||||||
|
|
||||||
|
if len(cell["ocr"]) > 0:
|
||||||
|
# check if union of matched ocr does
|
||||||
|
# not extend over cell boundary,
|
||||||
|
# if True, continue to use OCR_list to match
|
||||||
|
all_box_points_in_cell = []
|
||||||
|
for item in cell["ocr"]:
|
||||||
|
all_box_points_in_cell.extend(item["location"])
|
||||||
|
union_box = union_points(all_box_points_in_cell)
|
||||||
|
cell_okay = (
|
||||||
|
box_h(union_box) <= box_h(cell["bbox"]) * PADDING_THRES
|
||||||
|
and box_w(union_box) <= box_w(cell["bbox"]) * PADDING_THRES
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cell_okay = False
|
||||||
|
|
||||||
|
if cell_okay:
|
||||||
|
if item_type == "pdf":
|
||||||
|
color = [255, 0, 255]
|
||||||
|
break
|
||||||
|
|
||||||
|
if debug_info is not None:
|
||||||
|
cv2.rectangle(
|
||||||
|
debug_im,
|
||||||
|
cell["location"][0],
|
||||||
|
cell["location"][2],
|
||||||
|
color=color,
|
||||||
|
thickness=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
matched_cell_ids.append(cell_id)
|
||||||
|
cur_table_cells.append(cell)
|
||||||
|
|
||||||
|
all_tables.append(cur_table_cells)
|
||||||
|
|
||||||
|
not_matched_items = [
|
||||||
|
item for _id, item in enumerate(pdf_list) if _id not in matched_pdf_ids
|
||||||
|
]
|
||||||
|
if debug_info is not None:
|
||||||
|
for item in not_matched_items:
|
||||||
|
cv2.rectangle(
|
||||||
|
debug_im,
|
||||||
|
item["location"][0],
|
||||||
|
item["location"][2],
|
||||||
|
color=[128, 128, 128],
|
||||||
|
thickness=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_tables, not_matched_items
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ocr_output(
|
||||||
|
ocr_page_items: List[dict],
|
||||||
|
pdf_page_items: Dict[int, List[dict]],
|
||||||
|
artifact_path: Optional[str] = None,
|
||||||
|
debug_path: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Main function to combine OCR output and PDF text to
|
||||||
|
form list of table / non-table regions
|
||||||
|
Args:
|
||||||
|
ocr_page_items: List of OCR items by page
|
||||||
|
pdf_page_items: Dict of PDF texts (page number as key)
|
||||||
|
debug_path: If specified, use OpenCV to plot debug image and save to debug_path
|
||||||
|
"""
|
||||||
|
all_tables = []
|
||||||
|
all_texts = []
|
||||||
|
|
||||||
|
for page_id, page in enumerate(ocr_page_items):
|
||||||
|
ocr_list = page["json"]["ocr"]
|
||||||
|
table_list = page["json"]["table"]
|
||||||
|
page_shape = page["image_shape"]
|
||||||
|
pdf_item_list = pdf_page_items[page_id]
|
||||||
|
|
||||||
|
# create bbox additional information
|
||||||
|
for item in ocr_list:
|
||||||
|
item["box"] = points_to_bbox(item["location"])
|
||||||
|
|
||||||
|
# re-scale pdf items according to new image size
|
||||||
|
for item in pdf_item_list:
|
||||||
|
scale_factor = page_shape[0] / item["page_shape"][0]
|
||||||
|
item["box"] = scale_box(item["box"], scale_factor=scale_factor)
|
||||||
|
item["location"] = scale_points(item["location"], scale_factor=scale_factor)
|
||||||
|
|
||||||
|
# if using debug mode, openCV must be installed
|
||||||
|
if debug_path and artifact_path is not None:
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install openCV first to use OCRReader debug mode"
|
||||||
|
)
|
||||||
|
image_path = Path(artifact_path) / page["image"]
|
||||||
|
image = cv2.imread(str(image_path))
|
||||||
|
debug_info = (cv2, image)
|
||||||
|
else:
|
||||||
|
debug_info = None
|
||||||
|
|
||||||
|
new_pdf_list = merge_ocr_and_pdf_texts(
|
||||||
|
ocr_list, pdf_item_list, debug_info=debug_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# sort by reading order
|
||||||
|
ocr_list = sort_funsd_reading_order(ocr_list)
|
||||||
|
new_pdf_list = sort_funsd_reading_order(new_pdf_list)
|
||||||
|
|
||||||
|
all_table_cells, non_table_text_list = merge_table_cell_and_ocr(
|
||||||
|
table_list, ocr_list, new_pdf_list, debug_info=debug_info
|
||||||
|
)
|
||||||
|
|
||||||
|
table_texts = [table_cells_to_markdown(cells) for cells in all_table_cells]
|
||||||
|
all_tables.extend([(page_id, text) for text in table_texts])
|
||||||
|
all_texts.append(
|
||||||
|
(page_id, " ".join(item["text"] for item in non_table_text_list))
|
||||||
|
)
|
||||||
|
|
||||||
|
# export debug image to debug_path
|
||||||
|
if debug_path:
|
||||||
|
cv2.imwrite(str(Path(debug_path) / "page_{}.png".format(page_id)), image)
|
||||||
|
|
||||||
|
return all_tables, all_texts
|
|
@ -2,6 +2,8 @@ import csv
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from .box import get_rect_iou
|
||||||
|
|
||||||
|
|
||||||
def check_col_conflicts(
|
def check_col_conflicts(
|
||||||
col_a: List[str], col_b: List[str], thres: float = 0.15
|
col_a: List[str], col_b: List[str], thres: float = 0.15
|
||||||
|
@ -77,61 +79,6 @@ def compress_csv(csv_rows: List[List[str]]) -> List[List[str]]:
|
||||||
return csv_rows
|
return csv_rows
|
||||||
|
|
||||||
|
|
||||||
def _get_rect_iou(gt_box: List[tuple], pd_box: List[tuple], iou_type=0) -> int:
|
|
||||||
"""Intersection over union on layout rectangle
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gt_box: List[tuple]
|
|
||||||
A list contains bounding box coordinates of ground truth
|
|
||||||
pd_box: List[tuple]
|
|
||||||
A list contains bounding box coordinates of prediction
|
|
||||||
iou_type: int
|
|
||||||
0: intersection / union, normal IOU
|
|
||||||
1: intersection / min(areas), useful when boxes are under/over-segmented
|
|
||||||
|
|
||||||
Input format: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
|
||||||
Annotation for each element in bbox:
|
|
||||||
(x1, y1) (x2, y1)
|
|
||||||
+-------+
|
|
||||||
| |
|
|
||||||
| |
|
|
||||||
+-------+
|
|
||||||
(x1, y2) (x2, y2)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Intersection over union value
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert iou_type in [0, 1], "Only support 0: origin iou, 1: intersection / min(area)"
|
|
||||||
|
|
||||||
# determine the (x, y)-coordinates of the intersection rectangle
|
|
||||||
# gt_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
|
||||||
# pd_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
|
||||||
x_left = max(gt_box[0][0], pd_box[0][0])
|
|
||||||
y_top = max(gt_box[0][1], pd_box[0][1])
|
|
||||||
x_right = min(gt_box[2][0], pd_box[2][0])
|
|
||||||
y_bottom = min(gt_box[2][1], pd_box[2][1])
|
|
||||||
|
|
||||||
# compute the area of intersection rectangle
|
|
||||||
interArea = max(0, x_right - x_left) * max(0, y_bottom - y_top)
|
|
||||||
|
|
||||||
# compute the area of both the prediction and ground-truth
|
|
||||||
# rectangles
|
|
||||||
gt_area = (gt_box[2][0] - gt_box[0][0]) * (gt_box[2][1] - gt_box[0][1])
|
|
||||||
pd_area = (pd_box[2][0] - pd_box[0][0]) * (pd_box[2][1] - pd_box[0][1])
|
|
||||||
|
|
||||||
# compute the intersection over union by taking the intersection
|
|
||||||
# area and dividing it by the sum of prediction + ground-truth
|
|
||||||
# areas - the interesection area
|
|
||||||
if iou_type == 0:
|
|
||||||
iou = interArea / float(gt_area + pd_area - interArea)
|
|
||||||
elif iou_type == 1:
|
|
||||||
iou = interArea / max(min(gt_area, pd_area), 1)
|
|
||||||
|
|
||||||
# return the intersection over union value
|
|
||||||
return iou
|
|
||||||
|
|
||||||
|
|
||||||
def get_table_from_ocr(ocr_list: List[dict], table_list: List[dict]):
|
def get_table_from_ocr(ocr_list: List[dict], table_list: List[dict]):
|
||||||
"""Get list of text lines belong to table regions specified by table_list
|
"""Get list of text lines belong to table regions specified by table_list
|
||||||
|
|
||||||
|
@ -148,7 +95,7 @@ def get_table_from_ocr(ocr_list: List[dict], table_list: List[dict]):
|
||||||
continue
|
continue
|
||||||
cur_table_texts = []
|
cur_table_texts = []
|
||||||
for ocr in ocr_list:
|
for ocr in ocr_list:
|
||||||
_iou = _get_rect_iou(table["location"], ocr["location"], iou_type=1)
|
_iou = get_rect_iou(table["location"], ocr["location"], iou_type=1)
|
||||||
if _iou > 0.8:
|
if _iou > 0.8:
|
||||||
cur_table_texts.append(ocr["text"])
|
cur_table_texts.append(ocr["text"])
|
||||||
table_texts.append(cur_table_texts)
|
table_texts.append(cur_table_texts)
|
||||||
|
@ -272,33 +219,6 @@ def strip_special_chars_markdown(text: str) -> str:
|
||||||
return text.replace("|", "").replace(":---:", "").replace("---", "")
|
return text.replace("|", "").replace(":---:", "").replace("---", "")
|
||||||
|
|
||||||
|
|
||||||
def markdown_to_list(markdown_text: str, pad_to_max_col: Optional[bool] = True):
|
|
||||||
rows = []
|
|
||||||
lines = markdown_text.split("\n")
|
|
||||||
markdown_lines = [line.strip() for line in lines if " | " in line]
|
|
||||||
|
|
||||||
for row in markdown_lines:
|
|
||||||
tmp = row
|
|
||||||
# Get rid of leading and trailing '|'
|
|
||||||
if tmp.startswith("|"):
|
|
||||||
tmp = tmp[1:]
|
|
||||||
if tmp.endswith("|"):
|
|
||||||
tmp = tmp[:-1]
|
|
||||||
|
|
||||||
# Split line and ignore column whitespace
|
|
||||||
clean_line = tmp.split("|")
|
|
||||||
if not all(c == "" for c in clean_line):
|
|
||||||
# Append clean row data to rows variable
|
|
||||||
rows.append(clean_line)
|
|
||||||
|
|
||||||
# Get rid of syntactical sugar to indicate header (2nd row)
|
|
||||||
rows = [row for row in rows if "---" not in " ".join(row)]
|
|
||||||
max_cols = max(len(row) for row in rows)
|
|
||||||
if pad_to_max_col:
|
|
||||||
rows = [row + [""] * (max_cols - len(row)) for row in rows]
|
|
||||||
return rows
|
|
||||||
|
|
||||||
|
|
||||||
def parse_markdown_text_to_tables(text: str) -> Tuple[List[str], List[str]]:
|
def parse_markdown_text_to_tables(text: str) -> Tuple[List[str], List[str]]:
|
||||||
"""Convert markdown text to list of non-table spans and table spans
|
"""Convert markdown text to list of non-table spans and table spans
|
||||||
|
|
||||||
|
@ -333,3 +253,36 @@ def parse_markdown_text_to_tables(text: str) -> Tuple[List[str], List[str]]:
|
||||||
table_texts = ["\n".join(table) for table in tables]
|
table_texts = ["\n".join(table) for table in tables]
|
||||||
non_table_texts = ["\n".join(text) for text in texts]
|
non_table_texts = ["\n".join(text) for text in texts]
|
||||||
return table_texts, non_table_texts
|
return table_texts, non_table_texts
|
||||||
|
|
||||||
|
|
||||||
|
def table_cells_to_markdown(cells: List[dict]):
|
||||||
|
"""Convert list of cells with attached text to Markdown table"""
|
||||||
|
|
||||||
|
if len(cells) == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
all_row_ids = []
|
||||||
|
all_col_ids = []
|
||||||
|
for cell in cells:
|
||||||
|
all_row_ids.extend(cell["rows"])
|
||||||
|
all_col_ids.extend(cell["columns"])
|
||||||
|
|
||||||
|
num_rows, num_cols = max(all_row_ids) + 1, max(all_col_ids) + 1
|
||||||
|
table_rows = [["" for c in range(num_cols)] for r in range(num_rows)]
|
||||||
|
|
||||||
|
# start filling in the grid
|
||||||
|
for cell in cells:
|
||||||
|
cell_text = " ".join(item["text"] for item in cell["ocr"])
|
||||||
|
start_row_id, end_row_id = cell["rows"]
|
||||||
|
start_col_id, end_col_id = cell["columns"]
|
||||||
|
span_cell = end_row_id != start_row_id or end_col_id != start_col_id
|
||||||
|
|
||||||
|
# do not repeat long text in span cell to prevent context length issue
|
||||||
|
if span_cell and len(cell_text.replace(" ", "")) < 20 and start_row_id > 0:
|
||||||
|
for row in range(start_row_id, end_row_id + 1):
|
||||||
|
for col in range(start_col_id, end_col_id + 1):
|
||||||
|
table_rows[row][col] += cell_text + " "
|
||||||
|
else:
|
||||||
|
table_rows[start_row_id][start_col_id] += cell_text + " "
|
||||||
|
|
||||||
|
return make_markdown_table(table_rows)
|
||||||
|
|
|
@ -70,9 +70,11 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||||
embedding=self.embedding,
|
embedding=self.embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
text_splitter: SimpleNodeParser = SimpleNodeParser.withx(
|
@Node.auto(depends_on=["chunk_size", "chunk_overlap"])
|
||||||
chunk_size=1024, chunk_overlap=256
|
def text_splitter(self) -> SimpleNodeParser:
|
||||||
)
|
return SimpleNodeParser(
|
||||||
|
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from theflow import Node, Param
|
from theflow import Node
|
||||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
|
@ -43,8 +43,8 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||||
request_timeout=60,
|
request_timeout=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_store: Param[InMemoryVectorStore] = Param(_(InMemoryVectorStore))
|
vector_store: _[InMemoryVectorStore] = _(InMemoryVectorStore)
|
||||||
doc_store: Param[InMemoryDocumentStore] = Param(_(InMemoryDocumentStore))
|
doc_store: _[InMemoryDocumentStore] = _(InMemoryDocumentStore)
|
||||||
|
|
||||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
|
|
BIN
tests/resources/7810d908b0ff4ce381dcab873196d133.jpg
Normal file
BIN
tests/resources/7810d908b0ff4ce381dcab873196d133.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 322 KiB |
File diff suppressed because one or more lines are too long
BIN
tests/resources/table.pdf
Normal file
BIN
tests/resources/table.pdf
Normal file
Binary file not shown.
|
@ -5,7 +5,7 @@ import pytest
|
||||||
|
|
||||||
from kotaemon.loaders import MathpixPDFReader, OCRReader, PandasExcelReader
|
from kotaemon.loaders import MathpixPDFReader, OCRReader, PandasExcelReader
|
||||||
|
|
||||||
input_file = Path(__file__).parent / "resources" / "dummy.pdf"
|
input_file = Path(__file__).parent / "resources" / "table.pdf"
|
||||||
input_file_excel = Path(__file__).parent / "resources" / "dummy.xlsx"
|
input_file_excel = Path(__file__).parent / "resources" / "dummy.xlsx"
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ def test_ocr_reader(fullocr_output):
|
||||||
reader = OCRReader()
|
reader = OCRReader()
|
||||||
documents = reader.load_data(input_file, response_content=fullocr_output)
|
documents = reader.load_data(input_file, response_content=fullocr_output)
|
||||||
table_docs = [doc for doc in documents if doc.metadata.get("type", "") == "table"]
|
table_docs = [doc for doc in documents if doc.metadata.get("type", "") == "table"]
|
||||||
assert len(table_docs) == 4
|
assert len(table_docs) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_mathpix_reader(mathpix_output):
|
def test_mathpix_reader(mathpix_output):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user