* add OCRReader, MathPixReader and ExcelReader * update test case for ocr reader * reformat * minor fix
336 lines
10 KiB
Python
336 lines
10 KiB
Python
import csv
|
|
from io import StringIO
|
|
from typing import List, Optional, Tuple
|
|
|
|
|
|
def check_col_conflicts(
|
|
col_a: List[str], col_b: List[str], thres: float = 0.15
|
|
) -> bool:
|
|
"""Check if 2 columns A and B has non-empty content in the same row
|
|
(to be used with merge_cols)
|
|
|
|
Args:
|
|
col_a: column A (list of str)
|
|
col_b: column B (list of str)
|
|
thres: percentage of overlapping allowed
|
|
Returns:
|
|
if number of overlapping greater than threshold
|
|
"""
|
|
num_rows = len([cell for cell in col_a if cell])
|
|
assert len(col_a) == len(col_b)
|
|
conflict_count = 0
|
|
for cell_a, cell_b in zip(col_a, col_b):
|
|
if cell_a and cell_b:
|
|
conflict_count += 1
|
|
return conflict_count > num_rows * thres
|
|
|
|
|
|
def merge_cols(col_a: List[str], col_b: List[str]) -> List[str]:
|
|
"""Merge column A and B if they do not have conflict rows
|
|
|
|
Args:
|
|
col_a: column A (list of str)
|
|
col_b: column B (list of str)
|
|
Returns:
|
|
merged column
|
|
"""
|
|
for r_id in range(len(col_a)):
|
|
if col_b[r_id]:
|
|
col_a[r_id] = col_a[r_id] + " " + col_b[r_id]
|
|
return col_a
|
|
|
|
|
|
def add_index_col(csv_rows: List[List[str]]) -> List[List[str]]:
|
|
"""Add index column as the first column of the table csv_rows
|
|
|
|
Args:
|
|
csv_rows: input table
|
|
Returns:
|
|
output table with index column
|
|
"""
|
|
new_csv_rows = [["row id"] + [""] * len(csv_rows[0])]
|
|
for r_id, row in enumerate(csv_rows):
|
|
new_csv_rows.append([str(r_id + 1)] + row)
|
|
return new_csv_rows
|
|
|
|
|
|
def compress_csv(csv_rows: List[List[str]]) -> List[List[str]]:
|
|
"""Compress table csv_rows by merging sparse columns (merge_cols)
|
|
|
|
Args:
|
|
csv_rows: input table
|
|
Returns:
|
|
output: compressed table
|
|
"""
|
|
csv_cols = [[r[c_id] for r in csv_rows] for c_id in range(len(csv_rows[0]))]
|
|
to_remove_col_ids = []
|
|
last_c_id = 0
|
|
for c_id in range(1, len(csv_cols)):
|
|
if not check_col_conflicts(csv_cols[last_c_id], csv_cols[c_id]):
|
|
to_remove_col_ids.append(c_id)
|
|
csv_cols[last_c_id] = merge_cols(csv_cols[last_c_id], csv_cols[c_id])
|
|
else:
|
|
last_c_id = c_id
|
|
|
|
csv_cols = [r for c_id, r in enumerate(csv_cols) if c_id not in to_remove_col_ids]
|
|
csv_rows = [[c[r_id] for c in csv_cols] for r_id in range(len(csv_cols[0]))]
|
|
return csv_rows
|
|
|
|
|
|
def _get_rect_iou(gt_box: List[tuple], pd_box: List[tuple], iou_type=0) -> int:
|
|
"""Intersection over union on layout rectangle
|
|
|
|
Args:
|
|
gt_box: List[tuple]
|
|
A list contains bounding box coordinates of ground truth
|
|
pd_box: List[tuple]
|
|
A list contains bounding box coordinates of prediction
|
|
iou_type: int
|
|
0: intersection / union, normal IOU
|
|
1: intersection / min(areas), useful when boxes are under/over-segmented
|
|
|
|
Input format: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
|
Annotation for each element in bbox:
|
|
(x1, y1) (x2, y1)
|
|
+-------+
|
|
| |
|
|
| |
|
|
+-------+
|
|
(x1, y2) (x2, y2)
|
|
|
|
Returns:
|
|
Intersection over union value
|
|
"""
|
|
|
|
assert iou_type in [0, 1], "Only support 0: origin iou, 1: intersection / min(area)"
|
|
|
|
# determine the (x, y)-coordinates of the intersection rectangle
|
|
# gt_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
|
# pd_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
|
x_left = max(gt_box[0][0], pd_box[0][0])
|
|
y_top = max(gt_box[0][1], pd_box[0][1])
|
|
x_right = min(gt_box[2][0], pd_box[2][0])
|
|
y_bottom = min(gt_box[2][1], pd_box[2][1])
|
|
|
|
# compute the area of intersection rectangle
|
|
interArea = max(0, x_right - x_left) * max(0, y_bottom - y_top)
|
|
|
|
# compute the area of both the prediction and ground-truth
|
|
# rectangles
|
|
gt_area = (gt_box[2][0] - gt_box[0][0]) * (gt_box[2][1] - gt_box[0][1])
|
|
pd_area = (pd_box[2][0] - pd_box[0][0]) * (pd_box[2][1] - pd_box[0][1])
|
|
|
|
# compute the intersection over union by taking the intersection
|
|
# area and dividing it by the sum of prediction + ground-truth
|
|
# areas - the interesection area
|
|
if iou_type == 0:
|
|
iou = interArea / float(gt_area + pd_area - interArea)
|
|
elif iou_type == 1:
|
|
iou = interArea / max(min(gt_area, pd_area), 1)
|
|
|
|
# return the intersection over union value
|
|
return iou
|
|
|
|
|
|
def get_table_from_ocr(ocr_list: List[dict], table_list: List[dict]):
|
|
"""Get list of text lines belong to table regions specified by table_list
|
|
|
|
Args:
|
|
ocr_list: list of OCR output in Casia format (Flax)
|
|
table_list: list of table output in Casia format (Flax)
|
|
|
|
Returns:
|
|
_type_: _description_
|
|
"""
|
|
table_texts = []
|
|
for table in table_list:
|
|
if table["type"] != "table":
|
|
continue
|
|
cur_table_texts = []
|
|
for ocr in ocr_list:
|
|
_iou = _get_rect_iou(table["location"], ocr["location"], iou_type=1)
|
|
if _iou > 0.8:
|
|
cur_table_texts.append(ocr["text"])
|
|
table_texts.append(cur_table_texts)
|
|
|
|
return table_texts
|
|
|
|
|
|
def make_markdown_table(array: List[List[str]]) -> str:
|
|
"""Convert table rows in list format to markdown string
|
|
|
|
Args:
|
|
Python list with rows of table as lists
|
|
First element as header.
|
|
Example Input:
|
|
[["Name", "Age", "Height"],
|
|
["Jake", 20, 5'10],
|
|
["Mary", 21, 5'7]]
|
|
Returns:
|
|
String to put into a .md file
|
|
"""
|
|
array = compress_csv(array)
|
|
array = add_index_col(array)
|
|
markdown = "\n" + str("| ")
|
|
|
|
for e in array[0]:
|
|
to_add = " " + str(e) + str(" |")
|
|
markdown += to_add
|
|
markdown += "\n"
|
|
|
|
markdown += "| "
|
|
for i in range(len(array[0])):
|
|
markdown += str("--- | ")
|
|
markdown += "\n"
|
|
|
|
for entry in array[1:]:
|
|
markdown += str("| ")
|
|
for e in entry:
|
|
to_add = str(e) + str(" | ")
|
|
markdown += to_add
|
|
markdown += "\n"
|
|
|
|
return markdown + "\n"
|
|
|
|
|
|
def parse_csv_string_to_list(csv_str: str) -> List[List[str]]:
|
|
"""Convert CSV string to list of rows
|
|
|
|
Args:
|
|
csv_str: input CSV string
|
|
|
|
Returns:
|
|
Output table in list format
|
|
"""
|
|
io = StringIO(csv_str)
|
|
csv_reader = csv.reader(io, delimiter=",")
|
|
rows = [row for row in csv_reader]
|
|
return rows
|
|
|
|
|
|
def format_cell(cell: str, length_limit: Optional[int] = None) -> str:
|
|
"""Format cell content by remove redundant character and enforce length limit
|
|
|
|
Args:
|
|
cell: input cell text
|
|
length_limit: limit of text length.
|
|
|
|
Returns:
|
|
new cell text
|
|
"""
|
|
cell = cell.replace("\n", " ")
|
|
if length_limit:
|
|
cell = cell[:length_limit]
|
|
return cell
|
|
|
|
|
|
def extract_tables_from_csv_string(
|
|
csv_content: str, table_texts: List[List[str]]
|
|
) -> Tuple[List[str], str]:
|
|
"""Extract list of table from FullOCR output
|
|
(csv_content) with the specified table_texts
|
|
|
|
Args:
|
|
csv_content: CSV output from FullOCR pipeline
|
|
table_texts: list of table texts extracted
|
|
from get_table_from_ocr()
|
|
|
|
Returns:
|
|
List of tables and non-text content
|
|
"""
|
|
rows = parse_csv_string_to_list(csv_content)
|
|
used_row_ids = []
|
|
table_csv_list = []
|
|
for table in table_texts:
|
|
cur_rows = []
|
|
for row_id, row in enumerate(rows):
|
|
scores = [
|
|
any(cell in cell_reference for cell in table)
|
|
for cell_reference in row
|
|
if cell_reference
|
|
]
|
|
score = sum(scores) / len(scores)
|
|
if score > 0.5 and row_id not in used_row_ids:
|
|
used_row_ids.append(row_id)
|
|
cur_rows.append([format_cell(cell) for cell in row])
|
|
if cur_rows:
|
|
table_csv_list.append(make_markdown_table(cur_rows))
|
|
else:
|
|
print("table not matched", table)
|
|
|
|
non_table_rows = [
|
|
row for row_id, row in enumerate(rows) if row_id not in used_row_ids
|
|
]
|
|
non_table_text = "\n".join(
|
|
" ".join(format_cell(cell) for cell in row) for row in non_table_rows
|
|
)
|
|
return table_csv_list, non_table_text
|
|
|
|
|
|
def strip_special_chars_markdown(text: str) -> str:
|
|
"""Strip special characters from input text in markdown table format"""
|
|
return text.replace("|", "").replace(":---:", "").replace("---", "")
|
|
|
|
|
|
def markdown_to_list(markdown_text: str, pad_to_max_col: Optional[bool] = True):
|
|
rows = []
|
|
lines = markdown_text.split("\n")
|
|
markdown_lines = [line.strip() for line in lines if " | " in line]
|
|
|
|
for row in markdown_lines:
|
|
tmp = row
|
|
# Get rid of leading and trailing '|'
|
|
if tmp.startswith("|"):
|
|
tmp = tmp[1:]
|
|
if tmp.endswith("|"):
|
|
tmp = tmp[:-1]
|
|
|
|
# Split line and ignore column whitespace
|
|
clean_line = tmp.split("|")
|
|
if not all(c == "" for c in clean_line):
|
|
# Append clean row data to rows variable
|
|
rows.append(clean_line)
|
|
|
|
# Get rid of syntactical sugar to indicate header (2nd row)
|
|
rows = [row for row in rows if "---" not in " ".join(row)]
|
|
max_cols = max(len(row) for row in rows)
|
|
if pad_to_max_col:
|
|
rows = [row + [""] * (max_cols - len(row)) for row in rows]
|
|
return rows
|
|
|
|
|
|
def parse_markdown_text_to_tables(text: str) -> Tuple[List[str], List[str]]:
|
|
"""Convert markdown text to list of non-table spans and table spans
|
|
|
|
Args:
|
|
text: input markdown text
|
|
|
|
Returns:
|
|
list of table spans and non-table spans
|
|
"""
|
|
# init empty tables and texts list
|
|
tables = []
|
|
texts = []
|
|
|
|
# split input by line break
|
|
lines = text.split("\n")
|
|
cur_table = []
|
|
cur_text: List[str] = []
|
|
for line in lines:
|
|
line = line.strip()
|
|
if line.startswith("|"):
|
|
if len(cur_text) > 0:
|
|
texts.append(cur_text)
|
|
cur_text = []
|
|
cur_table.append(line)
|
|
else:
|
|
# add new table to the list
|
|
if len(cur_table) > 0:
|
|
tables.append(cur_table)
|
|
cur_table = []
|
|
cur_text.append(line)
|
|
|
|
table_texts = ["\n".join(table) for table in tables]
|
|
non_table_texts = ["\n".join(text) for text in texts]
|
|
return table_texts, non_table_texts
|