Feat/add multimodal loader (#5)
* Add Adobe reader as the multimodal loader * Allow FullQAPipeline to reasoning on figures * fix: move the adobe import to avoid ImportError, notify users whenever they run the AdobeReader --------- Co-authored-by: cin-albert <albert@cinnamon.is>
This commit is contained in:
parent
a3bf728400
commit
e67a25c0bd
|
@ -7,6 +7,7 @@ from kotaemon.base import BaseComponent, Document, Param
|
||||||
from kotaemon.indices.extractors import BaseDocParser
|
from kotaemon.indices.extractors import BaseDocParser
|
||||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||||
from kotaemon.loaders import (
|
from kotaemon.loaders import (
|
||||||
|
AdobeReader,
|
||||||
DirectoryReader,
|
DirectoryReader,
|
||||||
MathpixPDFReader,
|
MathpixPDFReader,
|
||||||
OCRReader,
|
OCRReader,
|
||||||
|
@ -41,7 +42,7 @@ class DocumentIngestor(BaseComponent):
|
||||||
The default file extractors are stored in `KH_DEFAULT_FILE_EXTRACTORS`
|
The default file extractors are stored in `KH_DEFAULT_FILE_EXTRACTORS`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
|
pdf_mode: str = "normal" # "normal", "mathpix", "ocr", "multimodal"
|
||||||
doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: [])
|
doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: [])
|
||||||
text_splitter: BaseSplitter = TokenSplitter.withx(
|
text_splitter: BaseSplitter = TokenSplitter.withx(
|
||||||
chunk_size=1024,
|
chunk_size=1024,
|
||||||
|
@ -61,6 +62,8 @@ class DocumentIngestor(BaseComponent):
|
||||||
pass # use default loader of llama-index which is pypdf
|
pass # use default loader of llama-index which is pypdf
|
||||||
elif self.pdf_mode == "ocr":
|
elif self.pdf_mode == "ocr":
|
||||||
file_extractors[".pdf"] = OCRReader()
|
file_extractors[".pdf"] = OCRReader()
|
||||||
|
elif self.pdf_mode == "multimodal":
|
||||||
|
file_extractors[".pdf"] = AdobeReader()
|
||||||
else:
|
else:
|
||||||
file_extractors[".pdf"] = MathpixPDFReader()
|
file_extractors[".pdf"] = MathpixPDFReader()
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from .adobe_loader import AdobeReader
|
||||||
from .base import AutoReader, BaseReader
|
from .base import AutoReader, BaseReader
|
||||||
from .composite_loader import DirectoryReader
|
from .composite_loader import DirectoryReader
|
||||||
from .docx_loader import DocxReader
|
from .docx_loader import DocxReader
|
||||||
|
@ -17,4 +18,5 @@ __all__ = [
|
||||||
"UnstructuredReader",
|
"UnstructuredReader",
|
||||||
"DocxReader",
|
"DocxReader",
|
||||||
"HtmlReader",
|
"HtmlReader",
|
||||||
|
"AdobeReader",
|
||||||
]
|
]
|
||||||
|
|
187
libs/kotaemon/kotaemon/loaders/adobe_loader.py
Normal file
187
libs/kotaemon/kotaemon/loaders/adobe_loader.py
Normal file
|
@ -0,0 +1,187 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from decouple import config
|
||||||
|
from llama_index.readers.base import BaseReader
|
||||||
|
|
||||||
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
from .utils.adobe import (
|
||||||
|
generate_figure_captions,
|
||||||
|
load_json,
|
||||||
|
parse_figure_paths,
|
||||||
|
parse_table_paths,
|
||||||
|
request_adobe_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_VLM_ENDPOINT = (
|
||||||
|
"{0}openai/deployments/{1}/chat/completions?api-version={2}".format(
|
||||||
|
config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||||
|
"gpt-4-vision",
|
||||||
|
config("OPENAI_API_VERSION", default=""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AdobeReader(BaseReader):
|
||||||
|
"""Read PDF using the Adobe's PDF Services.
|
||||||
|
Be able to extract text, table, and figure with high accuracy
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>> from kotaemon.loaders import AdobeReader
|
||||||
|
>> reader = AdobeReader()
|
||||||
|
>> documents = reader.load_data("path/to/pdf")
|
||||||
|
```
|
||||||
|
Args:
|
||||||
|
endpoint: URL to the Vision Language Model endpoint. If not provided,
|
||||||
|
will use the default `kotaemon.loaders.adobe_loader.DEFAULT_VLM_ENDPOINT`
|
||||||
|
|
||||||
|
max_figures_to_caption: an int decides how many figured will be captioned.
|
||||||
|
The rest will be ignored (are indexed without captions).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vlm_endpoint: Optional[str] = None,
|
||||||
|
max_figures_to_caption: int = 100,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Init params"""
|
||||||
|
super().__init__(*args)
|
||||||
|
self.table_regex = r"/Table(\[\d+\])?$"
|
||||||
|
self.figure_regex = r"/Figure(\[\d+\])?$"
|
||||||
|
self.vlm_endpoint = vlm_endpoint or DEFAULT_VLM_ENDPOINT
|
||||||
|
self.max_figures_to_caption = max_figures_to_caption
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self, file: Path, extra_info: Optional[Dict] = None, **kwargs
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Load data by calling to the Adobe's API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (Path): Path to the PDF file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: list of documents extracted from the PDF file,
|
||||||
|
includes 3 types: text, table, and image
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
filename = file.name
|
||||||
|
filepath = str(Path(file).resolve())
|
||||||
|
output_path = request_adobe_service(file_path=str(file), output_path="")
|
||||||
|
results_path = os.path.join(output_path, "structuredData.json")
|
||||||
|
|
||||||
|
if not os.path.exists(results_path):
|
||||||
|
logger.exception("Fail to parse the document.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
data = load_json(results_path)
|
||||||
|
|
||||||
|
texts = defaultdict(list)
|
||||||
|
tables = []
|
||||||
|
figures = []
|
||||||
|
|
||||||
|
elements = data["elements"]
|
||||||
|
for item_id, item in enumerate(elements):
|
||||||
|
page_number = item.get("Page", -1) + 1
|
||||||
|
item_path = item["Path"]
|
||||||
|
item_text = item.get("Text", "")
|
||||||
|
|
||||||
|
file_paths = [
|
||||||
|
Path(output_path) / path for path in item.get("filePaths", [])
|
||||||
|
]
|
||||||
|
prev_item = elements[item_id - 1]
|
||||||
|
title = prev_item.get("Text", "")
|
||||||
|
|
||||||
|
if re.search(self.table_regex, item_path):
|
||||||
|
table_content = parse_table_paths(file_paths)
|
||||||
|
if not table_content:
|
||||||
|
continue
|
||||||
|
table_caption = (
|
||||||
|
table_content.replace("|", "").replace("---", "")
|
||||||
|
+ f"\n(Table in Page {page_number}. {title})"
|
||||||
|
)
|
||||||
|
tables.append((page_number, table_content, table_caption))
|
||||||
|
|
||||||
|
elif re.search(self.figure_regex, item_path):
|
||||||
|
figure_caption = (
|
||||||
|
item_text + f"\n(Figure in Page {page_number}. {title})"
|
||||||
|
)
|
||||||
|
figure_content = parse_figure_paths(file_paths)
|
||||||
|
if not figure_content:
|
||||||
|
continue
|
||||||
|
figures.append([page_number, figure_content, figure_caption])
|
||||||
|
|
||||||
|
else:
|
||||||
|
if item_text and "Table" not in item_path and "Figure" not in item_path:
|
||||||
|
texts[page_number].append(item_text)
|
||||||
|
|
||||||
|
# get figure caption using GPT-4V
|
||||||
|
figure_captions = generate_figure_captions(
|
||||||
|
self.vlm_endpoint,
|
||||||
|
[item[1] for item in figures],
|
||||||
|
self.max_figures_to_caption,
|
||||||
|
)
|
||||||
|
for item, caption in zip(figures, figure_captions):
|
||||||
|
# update figure caption
|
||||||
|
item[2] += " " + caption
|
||||||
|
|
||||||
|
# Wrap elements with Document
|
||||||
|
documents = []
|
||||||
|
|
||||||
|
# join plain text elements
|
||||||
|
for page_number, txts in texts.items():
|
||||||
|
documents.append(
|
||||||
|
Document(
|
||||||
|
text="\n".join(txts),
|
||||||
|
metadata={
|
||||||
|
"page_label": page_number,
|
||||||
|
"file_name": filename,
|
||||||
|
"file_path": filepath,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# table elements
|
||||||
|
for page_number, table_content, table_caption in tables:
|
||||||
|
documents.append(
|
||||||
|
Document(
|
||||||
|
text=table_caption,
|
||||||
|
metadata={
|
||||||
|
"table_origin": table_content,
|
||||||
|
"type": "table",
|
||||||
|
"page_label": page_number,
|
||||||
|
"file_name": filename,
|
||||||
|
"file_path": filepath,
|
||||||
|
},
|
||||||
|
metadata_template="",
|
||||||
|
metadata_seperator="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# figure elements
|
||||||
|
for page_number, figure_content, figure_caption in figures:
|
||||||
|
documents.append(
|
||||||
|
Document(
|
||||||
|
text=figure_caption,
|
||||||
|
metadata={
|
||||||
|
"image_origin": figure_content,
|
||||||
|
"type": "image",
|
||||||
|
"page_label": page_number,
|
||||||
|
"file_name": filename,
|
||||||
|
"file_path": filepath,
|
||||||
|
},
|
||||||
|
metadata_template="",
|
||||||
|
metadata_seperator="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return documents
|
248
libs/kotaemon/kotaemon/loaders/utils/adobe.py
Normal file
248
libs/kotaemon/kotaemon/loaders/utils/adobe.py
Normal file
|
@ -0,0 +1,248 @@
|
||||||
|
# need pip install pdfservices-sdk==2.3.0
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from decouple import config
|
||||||
|
|
||||||
|
from kotaemon.loaders.utils.gpt4v import generate_gpt4v
|
||||||
|
|
||||||
|
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
||||||
|
|
||||||
|
|
||||||
|
def request_adobe_service(file_path: str, output_path: str = "") -> str:
|
||||||
|
"""Main function to call the adobe service, and unzip the results.
|
||||||
|
Args:
|
||||||
|
file_path (str): path to the pdf file
|
||||||
|
output_path (str): path to store the results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output_path (str): path to the results
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from adobe.pdfservices.operation.auth.credentials import Credentials
|
||||||
|
from adobe.pdfservices.operation.exception.exceptions import (
|
||||||
|
SdkException,
|
||||||
|
ServiceApiException,
|
||||||
|
ServiceUsageException,
|
||||||
|
)
|
||||||
|
from adobe.pdfservices.operation.execution_context import ExecutionContext
|
||||||
|
from adobe.pdfservices.operation.io.file_ref import FileRef
|
||||||
|
from adobe.pdfservices.operation.pdfops.extract_pdf_operation import (
|
||||||
|
ExtractPDFOperation,
|
||||||
|
)
|
||||||
|
from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_element_type import ( # noqa: E501
|
||||||
|
ExtractElementType,
|
||||||
|
)
|
||||||
|
from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_pdf_options import ( # noqa: E501
|
||||||
|
ExtractPDFOptions,
|
||||||
|
)
|
||||||
|
from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_renditions_element_type import ( # noqa: E501
|
||||||
|
ExtractRenditionsElementType,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"pdfservices-sdk is not installed. "
|
||||||
|
"Please install it by running `pip install pdfservices-sdk"
|
||||||
|
"@git+https://github.com/niallcm/pdfservices-python-sdk.git"
|
||||||
|
"@bump-and-unfreeze-requirements`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not output_path:
|
||||||
|
output_path = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initial setup, create credentials instance.
|
||||||
|
credentials = (
|
||||||
|
Credentials.service_principal_credentials_builder()
|
||||||
|
.with_client_id(config("PDF_SERVICES_CLIENT_ID", default=""))
|
||||||
|
.with_client_secret(config("PDF_SERVICES_CLIENT_SECRET", default=""))
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an ExecutionContext using credentials
|
||||||
|
# and create a new operation instance.
|
||||||
|
execution_context = ExecutionContext.create(credentials)
|
||||||
|
extract_pdf_operation = ExtractPDFOperation.create_new()
|
||||||
|
|
||||||
|
# Set operation input from a source file.
|
||||||
|
source = FileRef.create_from_local_file(file_path)
|
||||||
|
extract_pdf_operation.set_input(source)
|
||||||
|
|
||||||
|
# Build ExtractPDF options and set them into the operation
|
||||||
|
extract_pdf_options: ExtractPDFOptions = (
|
||||||
|
ExtractPDFOptions.builder()
|
||||||
|
.with_elements_to_extract(
|
||||||
|
[ExtractElementType.TEXT, ExtractElementType.TABLES]
|
||||||
|
)
|
||||||
|
.with_elements_to_extract_renditions(
|
||||||
|
[
|
||||||
|
ExtractRenditionsElementType.TABLES,
|
||||||
|
ExtractRenditionsElementType.FIGURES,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
extract_pdf_operation.set_options(extract_pdf_options)
|
||||||
|
|
||||||
|
# Execute the operation.
|
||||||
|
result: FileRef = extract_pdf_operation.execute(execution_context)
|
||||||
|
|
||||||
|
# Save the result to the specified location.
|
||||||
|
zip_file_path = os.path.join(
|
||||||
|
output_path, "ExtractTextTableWithFigureTableRendition.zip"
|
||||||
|
)
|
||||||
|
result.save_as(zip_file_path)
|
||||||
|
# Open the ZIP file
|
||||||
|
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||||||
|
# Extract all contents to the destination folder
|
||||||
|
zip_ref.extractall(output_path)
|
||||||
|
except (ServiceApiException, ServiceUsageException, SdkException):
|
||||||
|
logging.exception("Exception encountered while executing operation")
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def make_markdown_table(table_as_list: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Convert table from python list representation to markdown format.
|
||||||
|
The input list consists of rows of tables, the first row is the header.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_as_list: list of table rows
|
||||||
|
Example: [["Name", "Age", "Height"],
|
||||||
|
["Jake", 20, 5'10],
|
||||||
|
["Mary", 21, 5'7]]
|
||||||
|
Returns:
|
||||||
|
markdown representation of the table
|
||||||
|
"""
|
||||||
|
markdown = "\n" + str("| ")
|
||||||
|
|
||||||
|
for e in table_as_list[0]:
|
||||||
|
to_add = " " + str(e) + str(" |")
|
||||||
|
markdown += to_add
|
||||||
|
markdown += "\n"
|
||||||
|
|
||||||
|
markdown += "| "
|
||||||
|
for i in range(len(table_as_list[0])):
|
||||||
|
markdown += str("--- | ")
|
||||||
|
markdown += "\n"
|
||||||
|
|
||||||
|
for entry in table_as_list[1:]:
|
||||||
|
markdown += str("| ")
|
||||||
|
for e in entry:
|
||||||
|
to_add = str(e) + str(" | ")
|
||||||
|
markdown += to_add
|
||||||
|
markdown += "\n"
|
||||||
|
|
||||||
|
return markdown + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(input_path: Union[str | Path]) -> dict:
|
||||||
|
"""Load json file"""
|
||||||
|
with open(input_path, "r") as fi:
|
||||||
|
data = json.load(fi)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def load_excel(input_path: Union[str | Path]) -> str:
|
||||||
|
"""Load excel file and convert to markdown"""
|
||||||
|
|
||||||
|
df = pd.read_excel(input_path).fillna("")
|
||||||
|
# Convert dataframe to a list of rows
|
||||||
|
row_list = [df.columns.values.tolist()] + df.values.tolist()
|
||||||
|
|
||||||
|
for item_id, item in enumerate(row_list[0]):
|
||||||
|
if "Unnamed" in item:
|
||||||
|
row_list[0][item_id] = ""
|
||||||
|
|
||||||
|
for row in row_list:
|
||||||
|
for item_id, item in enumerate(row):
|
||||||
|
row[item_id] = str(item).replace("_x000D_", " ").replace("\n", " ").strip()
|
||||||
|
|
||||||
|
markdown_str = make_markdown_table(row_list)
|
||||||
|
return markdown_str
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image_base64(image_path: Union[str | Path]) -> Union[bytes, str]:
|
||||||
|
"""Convert image to base64"""
|
||||||
|
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_table_paths(file_paths: List[Path]) -> str:
|
||||||
|
"""Read the table stored in an excel file given the file path"""
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
for path in file_paths:
|
||||||
|
if path.suffix == ".xlsx":
|
||||||
|
content = load_excel(path)
|
||||||
|
break
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def parse_figure_paths(file_paths: List[Path]) -> Union[bytes, str]:
|
||||||
|
"""Read and convert an image to base64 given the image path"""
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
for path in file_paths:
|
||||||
|
if path.suffix == ".png":
|
||||||
|
base64_image = encode_image_base64(path)
|
||||||
|
content = f"data:image/png;base64,{base64_image}" # type: ignore
|
||||||
|
break
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def generate_single_figure_caption(vlm_endpoint: str, figure: str) -> str:
|
||||||
|
"""Summarize a single figure using GPT-4V"""
|
||||||
|
if figure:
|
||||||
|
output = generate_gpt4v(
|
||||||
|
endpoint=vlm_endpoint,
|
||||||
|
prompt="Provide a short 2 sentence summary of this image?",
|
||||||
|
images=figure,
|
||||||
|
)
|
||||||
|
if "sorry" in output.lower():
|
||||||
|
output = ""
|
||||||
|
else:
|
||||||
|
output = ""
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def generate_figure_captions(
|
||||||
|
vlm_endpoint: str, figures: List, max_figures_to_process: int
|
||||||
|
) -> List:
|
||||||
|
"""Summarize several figures using GPT-4V.
|
||||||
|
Args:
|
||||||
|
vlm_endpoint (str): endpoint to the vision language model service
|
||||||
|
figures (List): list of base64 images
|
||||||
|
max_figures_to_process (int): the maximum number of figures will be summarized,
|
||||||
|
the rest are ignored.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results (List[str]): list of all figure captions and empty strings for
|
||||||
|
ignored figures.
|
||||||
|
"""
|
||||||
|
to_gen_figures = figures[:max_figures_to_process]
|
||||||
|
other_figures = figures[max_figures_to_process:]
|
||||||
|
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(
|
||||||
|
lambda: generate_single_figure_caption(vlm_endpoint, figure)
|
||||||
|
)
|
||||||
|
for figure in to_gen_figures
|
||||||
|
]
|
||||||
|
|
||||||
|
results = [future.result() for future in futures]
|
||||||
|
return results + [""] * len(other_figures)
|
96
libs/kotaemon/kotaemon/loaders/utils/gpt4v.py
Normal file
96
libs/kotaemon/kotaemon/loaders/utils/gpt4v.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
import json
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from decouple import config
|
||||||
|
|
||||||
|
|
||||||
|
def generate_gpt4v(
|
||||||
|
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
|
||||||
|
) -> str:
|
||||||
|
# OpenAI API Key
|
||||||
|
api_key = config("AZURE_OPENAI_API_KEY", default="")
|
||||||
|
headers = {"Content-Type": "application/json", "api-key": api_key}
|
||||||
|
|
||||||
|
if isinstance(images, str):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image},
|
||||||
|
}
|
||||||
|
for image in images
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(endpoint, headers=headers, json=payload)
|
||||||
|
output = response.json()
|
||||||
|
output = output["choices"][0]["message"]["content"]
|
||||||
|
except Exception:
|
||||||
|
output = ""
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def stream_gpt4v(
|
||||||
|
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
|
||||||
|
) -> Any:
|
||||||
|
# OpenAI API Key
|
||||||
|
api_key = config("AZURE_OPENAI_API_KEY", default="")
|
||||||
|
headers = {"Content-Type": "application/json", "api-key": api_key}
|
||||||
|
|
||||||
|
if isinstance(images, str):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image},
|
||||||
|
}
|
||||||
|
for image in images
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
response = requests.post(endpoint, headers=headers, json=payload, stream=True)
|
||||||
|
assert response.status_code == 200, str(response.content)
|
||||||
|
output = ""
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
if line.startswith(b"\xef\xbb\xbf"):
|
||||||
|
line = line[9:]
|
||||||
|
else:
|
||||||
|
line = line[6:]
|
||||||
|
try:
|
||||||
|
if line == "[DONE]":
|
||||||
|
break
|
||||||
|
line = json.loads(line.decode("utf-8"))
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
if len(line["choices"]):
|
||||||
|
output += line["choices"][0]["delta"].get("content", "")
|
||||||
|
yield line["choices"][0]["delta"].get("content", "")
|
||||||
|
except Exception:
|
||||||
|
output = ""
|
||||||
|
return output
|
|
@ -60,6 +60,7 @@ adv = [
|
||||||
"cohere",
|
"cohere",
|
||||||
"elasticsearch",
|
"elasticsearch",
|
||||||
"llama-cpp-python",
|
"llama-cpp-python",
|
||||||
|
"pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"ipython",
|
"ipython",
|
||||||
|
@ -69,6 +70,7 @@ dev = [
|
||||||
"flake8",
|
"flake8",
|
||||||
"sphinx",
|
"sphinx",
|
||||||
"coverage",
|
"coverage",
|
||||||
|
"python-decouple"
|
||||||
]
|
]
|
||||||
all = ["kotaemon[adv,dev]"]
|
all = ["kotaemon[adv,dev]"]
|
||||||
|
|
||||||
|
|
21
libs/kotaemon/tests/_test_multimodal_reader.py
Normal file
21
libs/kotaemon/tests/_test_multimodal_reader.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# TODO: This test is broken and should be rewritten
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from kotaemon.loaders import AdobeReader
|
||||||
|
|
||||||
|
# from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
|
input_file = Path(__file__).parent / "resources" / "multimodal.pdf"
|
||||||
|
|
||||||
|
# load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def test_adobe_reader():
|
||||||
|
reader = AdobeReader()
|
||||||
|
documents = reader.load_data(input_file)
|
||||||
|
table_docs = [doc for doc in documents if doc.metadata.get("type", "") == "table"]
|
||||||
|
assert len(table_docs) == 2
|
||||||
|
|
||||||
|
figure_docs = [doc for doc in documents if doc.metadata.get("type", "") == "image"]
|
||||||
|
assert len(figure_docs) == 2
|
BIN
libs/kotaemon/tests/resources/multimodal.pdf
Normal file
BIN
libs/kotaemon/tests/resources/multimodal.pdf
Normal file
Binary file not shown.
|
@ -124,6 +124,11 @@ if config("LOCAL_MODEL", default=""):
|
||||||
|
|
||||||
|
|
||||||
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
|
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
|
||||||
|
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
|
||||||
|
config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||||
|
config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4-vision"),
|
||||||
|
config("OPENAI_API_VERSION", default=""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SETTINGS_APP = {
|
SETTINGS_APP = {
|
||||||
|
|
|
@ -378,6 +378,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||||
("PDF text parser", "normal"),
|
("PDF text parser", "normal"),
|
||||||
("Mathpix", "mathpix"),
|
("Mathpix", "mathpix"),
|
||||||
("Advanced ocr", "ocr"),
|
("Advanced ocr", "ocr"),
|
||||||
|
("Multimodal parser", "multimodal"),
|
||||||
],
|
],
|
||||||
"component": "dropdown",
|
"component": "dropdown",
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import html
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from ktem.components import llms
|
from ktem.components import llms
|
||||||
from ktem.reasoning.base import BaseReasoning
|
from ktem.reasoning.base import BaseReasoning
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
from kotaemon.base import (
|
from kotaemon.base import (
|
||||||
BaseComponent,
|
BaseComponent,
|
||||||
|
@ -18,9 +21,15 @@ from kotaemon.base import (
|
||||||
from kotaemon.indices.qa.citation import CitationPipeline
|
from kotaemon.indices.qa.citation import CitationPipeline
|
||||||
from kotaemon.indices.splitters import TokenSplitter
|
from kotaemon.indices.splitters import TokenSplitter
|
||||||
from kotaemon.llms import ChatLLM, PromptTemplate
|
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||||
|
from kotaemon.loaders.utils.gpt4v import stream_gpt4v
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
EVIDENCE_MODE_TEXT = 0
|
||||||
|
EVIDENCE_MODE_TABLE = 1
|
||||||
|
EVIDENCE_MODE_CHATBOT = 2
|
||||||
|
EVIDENCE_MODE_FIGURE = 3
|
||||||
|
|
||||||
|
|
||||||
class PrepareEvidencePipeline(BaseComponent):
|
class PrepareEvidencePipeline(BaseComponent):
|
||||||
"""Prepare the evidence text from the list of retrieved documents
|
"""Prepare the evidence text from the list of retrieved documents
|
||||||
|
@ -46,7 +55,7 @@ class PrepareEvidencePipeline(BaseComponent):
|
||||||
def run(self, docs: list[RetrievedDocument]) -> Document:
|
def run(self, docs: list[RetrievedDocument]) -> Document:
|
||||||
evidence = ""
|
evidence = ""
|
||||||
table_found = 0
|
table_found = 0
|
||||||
evidence_mode = 0
|
evidence_mode = EVIDENCE_MODE_TEXT
|
||||||
|
|
||||||
for _id, retrieved_item in enumerate(docs):
|
for _id, retrieved_item in enumerate(docs):
|
||||||
retrieved_content = ""
|
retrieved_content = ""
|
||||||
|
@ -55,7 +64,7 @@ class PrepareEvidencePipeline(BaseComponent):
|
||||||
if page:
|
if page:
|
||||||
source += f" (Page {page})"
|
source += f" (Page {page})"
|
||||||
if retrieved_item.metadata.get("type", "") == "table":
|
if retrieved_item.metadata.get("type", "") == "table":
|
||||||
evidence_mode = 1 # table
|
evidence_mode = EVIDENCE_MODE_TABLE
|
||||||
if table_found < 5:
|
if table_found < 5:
|
||||||
retrieved_content = retrieved_item.metadata.get("table_origin", "")
|
retrieved_content = retrieved_item.metadata.get("table_origin", "")
|
||||||
if retrieved_content not in evidence:
|
if retrieved_content not in evidence:
|
||||||
|
@ -66,13 +75,23 @@ class PrepareEvidencePipeline(BaseComponent):
|
||||||
+ "\n<br>"
|
+ "\n<br>"
|
||||||
)
|
)
|
||||||
elif retrieved_item.metadata.get("type", "") == "chatbot":
|
elif retrieved_item.metadata.get("type", "") == "chatbot":
|
||||||
evidence_mode = 2 # chatbot
|
evidence_mode = EVIDENCE_MODE_CHATBOT
|
||||||
retrieved_content = retrieved_item.metadata["window"]
|
retrieved_content = retrieved_item.metadata["window"]
|
||||||
evidence += (
|
evidence += (
|
||||||
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
|
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
|
||||||
+ retrieved_content
|
+ retrieved_content
|
||||||
+ "\n<br>"
|
+ "\n<br>"
|
||||||
)
|
)
|
||||||
|
elif retrieved_item.metadata.get("type", "") == "image":
|
||||||
|
evidence_mode = EVIDENCE_MODE_FIGURE
|
||||||
|
retrieved_content = retrieved_item.metadata.get("image_origin", "")
|
||||||
|
retrieved_caption = html.escape(retrieved_item.get_content())
|
||||||
|
evidence += (
|
||||||
|
f"<br><b>Figure from {source}</b>\n"
|
||||||
|
+ f"<img width='85%' src='{retrieved_content}' "
|
||||||
|
+ f"alt='{retrieved_caption}'/>"
|
||||||
|
+ "\n<br>"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if "window" in retrieved_item.metadata:
|
if "window" in retrieved_item.metadata:
|
||||||
retrieved_content = retrieved_item.metadata["window"]
|
retrieved_content = retrieved_item.metadata["window"]
|
||||||
|
@ -90,6 +109,7 @@ class PrepareEvidencePipeline(BaseComponent):
|
||||||
print(retrieved_item.metadata)
|
print(retrieved_item.metadata)
|
||||||
print("Score", retrieved_item.metadata.get("relevance_score", None))
|
print("Score", retrieved_item.metadata.get("relevance_score", None))
|
||||||
|
|
||||||
|
if evidence_mode != EVIDENCE_MODE_FIGURE:
|
||||||
# trim context by trim_len
|
# trim context by trim_len
|
||||||
print("len (original)", len(evidence))
|
print("len (original)", len(evidence))
|
||||||
if evidence:
|
if evidence:
|
||||||
|
@ -134,6 +154,16 @@ DEFAULT_QA_CHATBOT_PROMPT = (
|
||||||
"Answer:"
|
"Answer:"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DEFAULT_QA_FIGURE_PROMPT = (
|
||||||
|
"Use the given context: texts, tables, and figures below to answer the question. "
|
||||||
|
"If you don't know the answer, just say that you don't know. "
|
||||||
|
"Give answer in {lang}.\n\n"
|
||||||
|
"Context: \n"
|
||||||
|
"{context}\n"
|
||||||
|
"Question: {question}\n"
|
||||||
|
"Answer: "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AnswerWithContextPipeline(BaseComponent):
|
class AnswerWithContextPipeline(BaseComponent):
|
||||||
"""Answer the question based on the evidence
|
"""Answer the question based on the evidence
|
||||||
|
@ -151,6 +181,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
|
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
|
||||||
|
vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT
|
||||||
citation_pipeline: CitationPipeline = Node(
|
citation_pipeline: CitationPipeline = Node(
|
||||||
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
|
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
|
||||||
)
|
)
|
||||||
|
@ -158,6 +189,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
||||||
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
||||||
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
|
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
|
||||||
|
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
|
||||||
|
|
||||||
enable_citation: bool = False
|
enable_citation: bool = False
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
|
@ -188,13 +220,25 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
(determined by retrieval pipeline)
|
(determined by retrieval pipeline)
|
||||||
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
||||||
"""
|
"""
|
||||||
if evidence_mode == 0:
|
if evidence_mode == EVIDENCE_MODE_TEXT:
|
||||||
prompt_template = PromptTemplate(self.qa_template)
|
prompt_template = PromptTemplate(self.qa_template)
|
||||||
elif evidence_mode == 1:
|
elif evidence_mode == EVIDENCE_MODE_TABLE:
|
||||||
prompt_template = PromptTemplate(self.qa_table_template)
|
prompt_template = PromptTemplate(self.qa_table_template)
|
||||||
|
elif evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
|
prompt_template = PromptTemplate(self.qa_figure_template)
|
||||||
else:
|
else:
|
||||||
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
||||||
|
|
||||||
|
images = []
|
||||||
|
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
|
# isolate image from evidence
|
||||||
|
evidence, images = self.extract_evidence_images(evidence)
|
||||||
|
prompt = prompt_template.populate(
|
||||||
|
context=evidence,
|
||||||
|
question=question,
|
||||||
|
lang=self.lang,
|
||||||
|
)
|
||||||
|
else:
|
||||||
prompt = prompt_template.populate(
|
prompt = prompt_template.populate(
|
||||||
context=evidence,
|
context=evidence,
|
||||||
question=question,
|
question=question,
|
||||||
|
@ -208,12 +252,18 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
)
|
)
|
||||||
print("Citation task created")
|
print("Citation task created")
|
||||||
|
|
||||||
|
output = ""
|
||||||
|
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||||
|
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
|
||||||
|
output += text
|
||||||
|
self.report_output({"output": text})
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
else:
|
||||||
messages = []
|
messages = []
|
||||||
if self.system_prompt:
|
if self.system_prompt:
|
||||||
messages.append(SystemMessage(content=self.system_prompt))
|
messages.append(SystemMessage(content=self.system_prompt))
|
||||||
messages.append(HumanMessage(content=prompt))
|
messages.append(HumanMessage(content=prompt))
|
||||||
|
|
||||||
output = ""
|
|
||||||
try:
|
try:
|
||||||
# try streaming first
|
# try streaming first
|
||||||
print("Trying LLM streaming")
|
print("Trying LLM streaming")
|
||||||
|
@ -237,6 +287,13 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
def extract_evidence_images(self, evidence: str):
|
||||||
|
"""Util function to extract and isolate images from context/evidence"""
|
||||||
|
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
|
||||||
|
matches = re.findall(image_pattern, evidence)
|
||||||
|
context = re.sub(image_pattern, "", evidence)
|
||||||
|
return context, matches
|
||||||
|
|
||||||
|
|
||||||
class FullQAPipeline(BaseReasoning):
|
class FullQAPipeline(BaseReasoning):
|
||||||
"""Question answering pipeline. Handle from question to answer"""
|
"""Question answering pipeline. Handle from question to answer"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user