Feat/add multimodal loader (#5)
* Add Adobe reader as the multimodal loader * Allow FullQAPipeline to reasoning on figures * fix: move the adobe import to avoid ImportError, notify users whenever they run the AdobeReader --------- Co-authored-by: cin-albert <albert@cinnamon.is>
This commit is contained in:
parent
a3bf728400
commit
e67a25c0bd
|
@ -7,6 +7,7 @@ from kotaemon.base import BaseComponent, Document, Param
|
|||
from kotaemon.indices.extractors import BaseDocParser
|
||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||
from kotaemon.loaders import (
|
||||
AdobeReader,
|
||||
DirectoryReader,
|
||||
MathpixPDFReader,
|
||||
OCRReader,
|
||||
|
@ -41,7 +42,7 @@ class DocumentIngestor(BaseComponent):
|
|||
The default file extractors are stored in `KH_DEFAULT_FILE_EXTRACTORS`
|
||||
"""
|
||||
|
||||
pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
|
||||
pdf_mode: str = "normal" # "normal", "mathpix", "ocr", "multimodal"
|
||||
doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: [])
|
||||
text_splitter: BaseSplitter = TokenSplitter.withx(
|
||||
chunk_size=1024,
|
||||
|
@ -61,6 +62,8 @@ class DocumentIngestor(BaseComponent):
|
|||
pass # use default loader of llama-index which is pypdf
|
||||
elif self.pdf_mode == "ocr":
|
||||
file_extractors[".pdf"] = OCRReader()
|
||||
elif self.pdf_mode == "multimodal":
|
||||
file_extractors[".pdf"] = AdobeReader()
|
||||
else:
|
||||
file_extractors[".pdf"] = MathpixPDFReader()
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from .adobe_loader import AdobeReader
|
||||
from .base import AutoReader, BaseReader
|
||||
from .composite_loader import DirectoryReader
|
||||
from .docx_loader import DocxReader
|
||||
|
@ -17,4 +18,5 @@ __all__ = [
|
|||
"UnstructuredReader",
|
||||
"DocxReader",
|
||||
"HtmlReader",
|
||||
"AdobeReader",
|
||||
]
|
||||
|
|
187
libs/kotaemon/kotaemon/loaders/adobe_loader.py
Normal file
187
libs/kotaemon/kotaemon/loaders/adobe_loader.py
Normal file
|
@ -0,0 +1,187 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from decouple import config
|
||||
from llama_index.readers.base import BaseReader
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .utils.adobe import (
|
||||
generate_figure_captions,
|
||||
load_json,
|
||||
parse_figure_paths,
|
||||
parse_table_paths,
|
||||
request_adobe_service,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_VLM_ENDPOINT = (
|
||||
"{0}openai/deployments/{1}/chat/completions?api-version={2}".format(
|
||||
config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
"gpt-4-vision",
|
||||
config("OPENAI_API_VERSION", default=""),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AdobeReader(BaseReader):
|
||||
"""Read PDF using the Adobe's PDF Services.
|
||||
Be able to extract text, table, and figure with high accuracy
|
||||
|
||||
Example:
|
||||
```python
|
||||
>> from kotaemon.loaders import AdobeReader
|
||||
>> reader = AdobeReader()
|
||||
>> documents = reader.load_data("path/to/pdf")
|
||||
```
|
||||
Args:
|
||||
endpoint: URL to the Vision Language Model endpoint. If not provided,
|
||||
will use the default `kotaemon.loaders.adobe_loader.DEFAULT_VLM_ENDPOINT`
|
||||
|
||||
max_figures_to_caption: an int decides how many figured will be captioned.
|
||||
The rest will be ignored (are indexed without captions).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vlm_endpoint: Optional[str] = None,
|
||||
max_figures_to_caption: int = 100,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params"""
|
||||
super().__init__(*args)
|
||||
self.table_regex = r"/Table(\[\d+\])?$"
|
||||
self.figure_regex = r"/Figure(\[\d+\])?$"
|
||||
self.vlm_endpoint = vlm_endpoint or DEFAULT_VLM_ENDPOINT
|
||||
self.max_figures_to_caption = max_figures_to_caption
|
||||
|
||||
def load_data(
|
||||
self, file: Path, extra_info: Optional[Dict] = None, **kwargs
|
||||
) -> List[Document]:
|
||||
"""Load data by calling to the Adobe's API
|
||||
|
||||
Args:
|
||||
file (Path): Path to the PDF file
|
||||
|
||||
Returns:
|
||||
List[Document]: list of documents extracted from the PDF file,
|
||||
includes 3 types: text, table, and image
|
||||
|
||||
"""
|
||||
|
||||
filename = file.name
|
||||
filepath = str(Path(file).resolve())
|
||||
output_path = request_adobe_service(file_path=str(file), output_path="")
|
||||
results_path = os.path.join(output_path, "structuredData.json")
|
||||
|
||||
if not os.path.exists(results_path):
|
||||
logger.exception("Fail to parse the document.")
|
||||
return []
|
||||
|
||||
data = load_json(results_path)
|
||||
|
||||
texts = defaultdict(list)
|
||||
tables = []
|
||||
figures = []
|
||||
|
||||
elements = data["elements"]
|
||||
for item_id, item in enumerate(elements):
|
||||
page_number = item.get("Page", -1) + 1
|
||||
item_path = item["Path"]
|
||||
item_text = item.get("Text", "")
|
||||
|
||||
file_paths = [
|
||||
Path(output_path) / path for path in item.get("filePaths", [])
|
||||
]
|
||||
prev_item = elements[item_id - 1]
|
||||
title = prev_item.get("Text", "")
|
||||
|
||||
if re.search(self.table_regex, item_path):
|
||||
table_content = parse_table_paths(file_paths)
|
||||
if not table_content:
|
||||
continue
|
||||
table_caption = (
|
||||
table_content.replace("|", "").replace("---", "")
|
||||
+ f"\n(Table in Page {page_number}. {title})"
|
||||
)
|
||||
tables.append((page_number, table_content, table_caption))
|
||||
|
||||
elif re.search(self.figure_regex, item_path):
|
||||
figure_caption = (
|
||||
item_text + f"\n(Figure in Page {page_number}. {title})"
|
||||
)
|
||||
figure_content = parse_figure_paths(file_paths)
|
||||
if not figure_content:
|
||||
continue
|
||||
figures.append([page_number, figure_content, figure_caption])
|
||||
|
||||
else:
|
||||
if item_text and "Table" not in item_path and "Figure" not in item_path:
|
||||
texts[page_number].append(item_text)
|
||||
|
||||
# get figure caption using GPT-4V
|
||||
figure_captions = generate_figure_captions(
|
||||
self.vlm_endpoint,
|
||||
[item[1] for item in figures],
|
||||
self.max_figures_to_caption,
|
||||
)
|
||||
for item, caption in zip(figures, figure_captions):
|
||||
# update figure caption
|
||||
item[2] += " " + caption
|
||||
|
||||
# Wrap elements with Document
|
||||
documents = []
|
||||
|
||||
# join plain text elements
|
||||
for page_number, txts in texts.items():
|
||||
documents.append(
|
||||
Document(
|
||||
text="\n".join(txts),
|
||||
metadata={
|
||||
"page_label": page_number,
|
||||
"file_name": filename,
|
||||
"file_path": filepath,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# table elements
|
||||
for page_number, table_content, table_caption in tables:
|
||||
documents.append(
|
||||
Document(
|
||||
text=table_caption,
|
||||
metadata={
|
||||
"table_origin": table_content,
|
||||
"type": "table",
|
||||
"page_label": page_number,
|
||||
"file_name": filename,
|
||||
"file_path": filepath,
|
||||
},
|
||||
metadata_template="",
|
||||
metadata_seperator="",
|
||||
)
|
||||
)
|
||||
|
||||
# figure elements
|
||||
for page_number, figure_content, figure_caption in figures:
|
||||
documents.append(
|
||||
Document(
|
||||
text=figure_caption,
|
||||
metadata={
|
||||
"image_origin": figure_content,
|
||||
"type": "image",
|
||||
"page_label": page_number,
|
||||
"file_name": filename,
|
||||
"file_path": filepath,
|
||||
},
|
||||
metadata_template="",
|
||||
metadata_seperator="",
|
||||
)
|
||||
)
|
||||
return documents
|
248
libs/kotaemon/kotaemon/loaders/utils/adobe.py
Normal file
248
libs/kotaemon/kotaemon/loaders/utils/adobe.py
Normal file
|
@ -0,0 +1,248 @@
|
|||
# need pip install pdfservices-sdk==2.3.0
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import pandas as pd
|
||||
from decouple import config
|
||||
|
||||
from kotaemon.loaders.utils.gpt4v import generate_gpt4v
|
||||
|
||||
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
||||
|
||||
|
||||
def request_adobe_service(file_path: str, output_path: str = "") -> str:
|
||||
"""Main function to call the adobe service, and unzip the results.
|
||||
Args:
|
||||
file_path (str): path to the pdf file
|
||||
output_path (str): path to store the results
|
||||
|
||||
Returns:
|
||||
output_path (str): path to the results
|
||||
|
||||
"""
|
||||
try:
|
||||
from adobe.pdfservices.operation.auth.credentials import Credentials
|
||||
from adobe.pdfservices.operation.exception.exceptions import (
|
||||
SdkException,
|
||||
ServiceApiException,
|
||||
ServiceUsageException,
|
||||
)
|
||||
from adobe.pdfservices.operation.execution_context import ExecutionContext
|
||||
from adobe.pdfservices.operation.io.file_ref import FileRef
|
||||
from adobe.pdfservices.operation.pdfops.extract_pdf_operation import (
|
||||
ExtractPDFOperation,
|
||||
)
|
||||
from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_element_type import ( # noqa: E501
|
||||
ExtractElementType,
|
||||
)
|
||||
from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_pdf_options import ( # noqa: E501
|
||||
ExtractPDFOptions,
|
||||
)
|
||||
from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_renditions_element_type import ( # noqa: E501
|
||||
ExtractRenditionsElementType,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pdfservices-sdk is not installed. "
|
||||
"Please install it by running `pip install pdfservices-sdk"
|
||||
"@git+https://github.com/niallcm/pdfservices-python-sdk.git"
|
||||
"@bump-and-unfreeze-requirements`"
|
||||
)
|
||||
|
||||
if not output_path:
|
||||
output_path = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
# Initial setup, create credentials instance.
|
||||
credentials = (
|
||||
Credentials.service_principal_credentials_builder()
|
||||
.with_client_id(config("PDF_SERVICES_CLIENT_ID", default=""))
|
||||
.with_client_secret(config("PDF_SERVICES_CLIENT_SECRET", default=""))
|
||||
.build()
|
||||
)
|
||||
|
||||
# Create an ExecutionContext using credentials
|
||||
# and create a new operation instance.
|
||||
execution_context = ExecutionContext.create(credentials)
|
||||
extract_pdf_operation = ExtractPDFOperation.create_new()
|
||||
|
||||
# Set operation input from a source file.
|
||||
source = FileRef.create_from_local_file(file_path)
|
||||
extract_pdf_operation.set_input(source)
|
||||
|
||||
# Build ExtractPDF options and set them into the operation
|
||||
extract_pdf_options: ExtractPDFOptions = (
|
||||
ExtractPDFOptions.builder()
|
||||
.with_elements_to_extract(
|
||||
[ExtractElementType.TEXT, ExtractElementType.TABLES]
|
||||
)
|
||||
.with_elements_to_extract_renditions(
|
||||
[
|
||||
ExtractRenditionsElementType.TABLES,
|
||||
ExtractRenditionsElementType.FIGURES,
|
||||
]
|
||||
)
|
||||
.build()
|
||||
)
|
||||
extract_pdf_operation.set_options(extract_pdf_options)
|
||||
|
||||
# Execute the operation.
|
||||
result: FileRef = extract_pdf_operation.execute(execution_context)
|
||||
|
||||
# Save the result to the specified location.
|
||||
zip_file_path = os.path.join(
|
||||
output_path, "ExtractTextTableWithFigureTableRendition.zip"
|
||||
)
|
||||
result.save_as(zip_file_path)
|
||||
# Open the ZIP file
|
||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||||
# Extract all contents to the destination folder
|
||||
zip_ref.extractall(output_path)
|
||||
except (ServiceApiException, ServiceUsageException, SdkException):
|
||||
logging.exception("Exception encountered while executing operation")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def make_markdown_table(table_as_list: List[str]) -> str:
|
||||
"""
|
||||
Convert table from python list representation to markdown format.
|
||||
The input list consists of rows of tables, the first row is the header.
|
||||
|
||||
Args:
|
||||
table_as_list: list of table rows
|
||||
Example: [["Name", "Age", "Height"],
|
||||
["Jake", 20, 5'10],
|
||||
["Mary", 21, 5'7]]
|
||||
Returns:
|
||||
markdown representation of the table
|
||||
"""
|
||||
markdown = "\n" + str("| ")
|
||||
|
||||
for e in table_as_list[0]:
|
||||
to_add = " " + str(e) + str(" |")
|
||||
markdown += to_add
|
||||
markdown += "\n"
|
||||
|
||||
markdown += "| "
|
||||
for i in range(len(table_as_list[0])):
|
||||
markdown += str("--- | ")
|
||||
markdown += "\n"
|
||||
|
||||
for entry in table_as_list[1:]:
|
||||
markdown += str("| ")
|
||||
for e in entry:
|
||||
to_add = str(e) + str(" | ")
|
||||
markdown += to_add
|
||||
markdown += "\n"
|
||||
|
||||
return markdown + "\n"
|
||||
|
||||
|
||||
def load_json(input_path: Union[str | Path]) -> dict:
|
||||
"""Load json file"""
|
||||
with open(input_path, "r") as fi:
|
||||
data = json.load(fi)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def load_excel(input_path: Union[str | Path]) -> str:
|
||||
"""Load excel file and convert to markdown"""
|
||||
|
||||
df = pd.read_excel(input_path).fillna("")
|
||||
# Convert dataframe to a list of rows
|
||||
row_list = [df.columns.values.tolist()] + df.values.tolist()
|
||||
|
||||
for item_id, item in enumerate(row_list[0]):
|
||||
if "Unnamed" in item:
|
||||
row_list[0][item_id] = ""
|
||||
|
||||
for row in row_list:
|
||||
for item_id, item in enumerate(row):
|
||||
row[item_id] = str(item).replace("_x000D_", " ").replace("\n", " ").strip()
|
||||
|
||||
markdown_str = make_markdown_table(row_list)
|
||||
return markdown_str
|
||||
|
||||
|
||||
def encode_image_base64(image_path: Union[str | Path]) -> Union[bytes, str]:
|
||||
"""Convert image to base64"""
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
def parse_table_paths(file_paths: List[Path]) -> str:
|
||||
"""Read the table stored in an excel file given the file path"""
|
||||
|
||||
content = ""
|
||||
for path in file_paths:
|
||||
if path.suffix == ".xlsx":
|
||||
content = load_excel(path)
|
||||
break
|
||||
return content
|
||||
|
||||
|
||||
def parse_figure_paths(file_paths: List[Path]) -> Union[bytes, str]:
|
||||
"""Read and convert an image to base64 given the image path"""
|
||||
|
||||
content = ""
|
||||
for path in file_paths:
|
||||
if path.suffix == ".png":
|
||||
base64_image = encode_image_base64(path)
|
||||
content = f"data:image/png;base64,{base64_image}" # type: ignore
|
||||
break
|
||||
return content
|
||||
|
||||
|
||||
def generate_single_figure_caption(vlm_endpoint: str, figure: str) -> str:
|
||||
"""Summarize a single figure using GPT-4V"""
|
||||
if figure:
|
||||
output = generate_gpt4v(
|
||||
endpoint=vlm_endpoint,
|
||||
prompt="Provide a short 2 sentence summary of this image?",
|
||||
images=figure,
|
||||
)
|
||||
if "sorry" in output.lower():
|
||||
output = ""
|
||||
else:
|
||||
output = ""
|
||||
return output
|
||||
|
||||
|
||||
def generate_figure_captions(
|
||||
vlm_endpoint: str, figures: List, max_figures_to_process: int
|
||||
) -> List:
|
||||
"""Summarize several figures using GPT-4V.
|
||||
Args:
|
||||
vlm_endpoint (str): endpoint to the vision language model service
|
||||
figures (List): list of base64 images
|
||||
max_figures_to_process (int): the maximum number of figures will be summarized,
|
||||
the rest are ignored.
|
||||
|
||||
Returns:
|
||||
results (List[str]): list of all figure captions and empty strings for
|
||||
ignored figures.
|
||||
"""
|
||||
to_gen_figures = figures[:max_figures_to_process]
|
||||
other_figures = figures[max_figures_to_process:]
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
lambda: generate_single_figure_caption(vlm_endpoint, figure)
|
||||
)
|
||||
for figure in to_gen_figures
|
||||
]
|
||||
|
||||
results = [future.result() for future in futures]
|
||||
return results + [""] * len(other_figures)
|
96
libs/kotaemon/kotaemon/loaders/utils/gpt4v.py
Normal file
96
libs/kotaemon/kotaemon/loaders/utils/gpt4v.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import json
|
||||
from typing import Any, List
|
||||
|
||||
import requests
|
||||
from decouple import config
|
||||
|
||||
|
||||
def generate_gpt4v(
|
||||
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
|
||||
) -> str:
|
||||
# OpenAI API Key
|
||||
api_key = config("AZURE_OPENAI_API_KEY", default="")
|
||||
headers = {"Content-Type": "application/json", "api-key": api_key}
|
||||
|
||||
if isinstance(images, str):
|
||||
images = [images]
|
||||
|
||||
payload = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image},
|
||||
}
|
||||
for image in images
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(endpoint, headers=headers, json=payload)
|
||||
output = response.json()
|
||||
output = output["choices"][0]["message"]["content"]
|
||||
except Exception:
|
||||
output = ""
|
||||
return output
|
||||
|
||||
|
||||
def stream_gpt4v(
|
||||
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
|
||||
) -> Any:
|
||||
# OpenAI API Key
|
||||
api_key = config("AZURE_OPENAI_API_KEY", default="")
|
||||
headers = {"Content-Type": "application/json", "api-key": api_key}
|
||||
|
||||
if isinstance(images, str):
|
||||
images = [images]
|
||||
|
||||
payload = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image},
|
||||
}
|
||||
for image in images
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
}
|
||||
try:
|
||||
response = requests.post(endpoint, headers=headers, json=payload, stream=True)
|
||||
assert response.status_code == 200, str(response.content)
|
||||
output = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
if line.startswith(b"\xef\xbb\xbf"):
|
||||
line = line[9:]
|
||||
else:
|
||||
line = line[6:]
|
||||
try:
|
||||
if line == "[DONE]":
|
||||
break
|
||||
line = json.loads(line.decode("utf-8"))
|
||||
except Exception:
|
||||
break
|
||||
if len(line["choices"]):
|
||||
output += line["choices"][0]["delta"].get("content", "")
|
||||
yield line["choices"][0]["delta"].get("content", "")
|
||||
except Exception:
|
||||
output = ""
|
||||
return output
|
|
@ -60,6 +60,7 @@ adv = [
|
|||
"cohere",
|
||||
"elasticsearch",
|
||||
"llama-cpp-python",
|
||||
"pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements",
|
||||
]
|
||||
dev = [
|
||||
"ipython",
|
||||
|
@ -69,6 +70,7 @@ dev = [
|
|||
"flake8",
|
||||
"sphinx",
|
||||
"coverage",
|
||||
"python-decouple"
|
||||
]
|
||||
all = ["kotaemon[adv,dev]"]
|
||||
|
||||
|
|
21
libs/kotaemon/tests/_test_multimodal_reader.py
Normal file
21
libs/kotaemon/tests/_test_multimodal_reader.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# TODO: This test is broken and should be rewritten
|
||||
from pathlib import Path
|
||||
|
||||
from kotaemon.loaders import AdobeReader
|
||||
|
||||
# from dotenv import load_dotenv
|
||||
|
||||
|
||||
input_file = Path(__file__).parent / "resources" / "multimodal.pdf"
|
||||
|
||||
# load_dotenv()
|
||||
|
||||
|
||||
def test_adobe_reader():
|
||||
reader = AdobeReader()
|
||||
documents = reader.load_data(input_file)
|
||||
table_docs = [doc for doc in documents if doc.metadata.get("type", "") == "table"]
|
||||
assert len(table_docs) == 2
|
||||
|
||||
figure_docs = [doc for doc in documents if doc.metadata.get("type", "") == "image"]
|
||||
assert len(figure_docs) == 2
|
BIN
libs/kotaemon/tests/resources/multimodal.pdf
Normal file
BIN
libs/kotaemon/tests/resources/multimodal.pdf
Normal file
Binary file not shown.
|
@ -124,6 +124,11 @@ if config("LOCAL_MODEL", default=""):
|
|||
|
||||
|
||||
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
|
||||
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
|
||||
config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4-vision"),
|
||||
config("OPENAI_API_VERSION", default=""),
|
||||
)
|
||||
|
||||
|
||||
SETTINGS_APP = {
|
||||
|
|
|
@ -378,6 +378,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
|||
("PDF text parser", "normal"),
|
||||
("Mathpix", "mathpix"),
|
||||
("Advanced ocr", "ocr"),
|
||||
("Multimodal parser", "multimodal"),
|
||||
],
|
||||
"component": "dropdown",
|
||||
},
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
import asyncio
|
||||
import html
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
import tiktoken
|
||||
from ktem.components import llms
|
||||
from ktem.reasoning.base import BaseReasoning
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
from kotaemon.base import (
|
||||
BaseComponent,
|
||||
|
@ -18,9 +21,15 @@ from kotaemon.base import (
|
|||
from kotaemon.indices.qa.citation import CitationPipeline
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||
from kotaemon.loaders.utils.gpt4v import stream_gpt4v
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVIDENCE_MODE_TEXT = 0
|
||||
EVIDENCE_MODE_TABLE = 1
|
||||
EVIDENCE_MODE_CHATBOT = 2
|
||||
EVIDENCE_MODE_FIGURE = 3
|
||||
|
||||
|
||||
class PrepareEvidencePipeline(BaseComponent):
|
||||
"""Prepare the evidence text from the list of retrieved documents
|
||||
|
@ -46,7 +55,7 @@ class PrepareEvidencePipeline(BaseComponent):
|
|||
def run(self, docs: list[RetrievedDocument]) -> Document:
|
||||
evidence = ""
|
||||
table_found = 0
|
||||
evidence_mode = 0
|
||||
evidence_mode = EVIDENCE_MODE_TEXT
|
||||
|
||||
for _id, retrieved_item in enumerate(docs):
|
||||
retrieved_content = ""
|
||||
|
@ -55,7 +64,7 @@ class PrepareEvidencePipeline(BaseComponent):
|
|||
if page:
|
||||
source += f" (Page {page})"
|
||||
if retrieved_item.metadata.get("type", "") == "table":
|
||||
evidence_mode = 1 # table
|
||||
evidence_mode = EVIDENCE_MODE_TABLE
|
||||
if table_found < 5:
|
||||
retrieved_content = retrieved_item.metadata.get("table_origin", "")
|
||||
if retrieved_content not in evidence:
|
||||
|
@ -66,13 +75,23 @@ class PrepareEvidencePipeline(BaseComponent):
|
|||
+ "\n<br>"
|
||||
)
|
||||
elif retrieved_item.metadata.get("type", "") == "chatbot":
|
||||
evidence_mode = 2 # chatbot
|
||||
evidence_mode = EVIDENCE_MODE_CHATBOT
|
||||
retrieved_content = retrieved_item.metadata["window"]
|
||||
evidence += (
|
||||
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
|
||||
+ retrieved_content
|
||||
+ "\n<br>"
|
||||
)
|
||||
elif retrieved_item.metadata.get("type", "") == "image":
|
||||
evidence_mode = EVIDENCE_MODE_FIGURE
|
||||
retrieved_content = retrieved_item.metadata.get("image_origin", "")
|
||||
retrieved_caption = html.escape(retrieved_item.get_content())
|
||||
evidence += (
|
||||
f"<br><b>Figure from {source}</b>\n"
|
||||
+ f"<img width='85%' src='{retrieved_content}' "
|
||||
+ f"alt='{retrieved_caption}'/>"
|
||||
+ "\n<br>"
|
||||
)
|
||||
else:
|
||||
if "window" in retrieved_item.metadata:
|
||||
retrieved_content = retrieved_item.metadata["window"]
|
||||
|
@ -90,6 +109,7 @@ class PrepareEvidencePipeline(BaseComponent):
|
|||
print(retrieved_item.metadata)
|
||||
print("Score", retrieved_item.metadata.get("relevance_score", None))
|
||||
|
||||
if evidence_mode != EVIDENCE_MODE_FIGURE:
|
||||
# trim context by trim_len
|
||||
print("len (original)", len(evidence))
|
||||
if evidence:
|
||||
|
@ -134,6 +154,16 @@ DEFAULT_QA_CHATBOT_PROMPT = (
|
|||
"Answer:"
|
||||
)
|
||||
|
||||
DEFAULT_QA_FIGURE_PROMPT = (
|
||||
"Use the given context: texts, tables, and figures below to answer the question. "
|
||||
"If you don't know the answer, just say that you don't know. "
|
||||
"Give answer in {lang}.\n\n"
|
||||
"Context: \n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Answer: "
|
||||
)
|
||||
|
||||
|
||||
class AnswerWithContextPipeline(BaseComponent):
|
||||
"""Answer the question based on the evidence
|
||||
|
@ -151,6 +181,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
"""
|
||||
|
||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
|
||||
vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT
|
||||
citation_pipeline: CitationPipeline = Node(
|
||||
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
|
||||
)
|
||||
|
@ -158,6 +189,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
||||
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
||||
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
|
||||
qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
|
||||
|
||||
enable_citation: bool = False
|
||||
system_prompt: str = ""
|
||||
|
@ -188,13 +220,25 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
(determined by retrieval pipeline)
|
||||
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
||||
"""
|
||||
if evidence_mode == 0:
|
||||
if evidence_mode == EVIDENCE_MODE_TEXT:
|
||||
prompt_template = PromptTemplate(self.qa_template)
|
||||
elif evidence_mode == 1:
|
||||
elif evidence_mode == EVIDENCE_MODE_TABLE:
|
||||
prompt_template = PromptTemplate(self.qa_table_template)
|
||||
elif evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
prompt_template = PromptTemplate(self.qa_figure_template)
|
||||
else:
|
||||
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
||||
|
||||
images = []
|
||||
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
# isolate image from evidence
|
||||
evidence, images = self.extract_evidence_images(evidence)
|
||||
prompt = prompt_template.populate(
|
||||
context=evidence,
|
||||
question=question,
|
||||
lang=self.lang,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_template.populate(
|
||||
context=evidence,
|
||||
question=question,
|
||||
|
@ -208,12 +252,18 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
)
|
||||
print("Citation task created")
|
||||
|
||||
output = ""
|
||||
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
|
||||
output += text
|
||||
self.report_output({"output": text})
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
messages = []
|
||||
if self.system_prompt:
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
|
||||
output = ""
|
||||
try:
|
||||
# try streaming first
|
||||
print("Trying LLM streaming")
|
||||
|
@ -237,6 +287,13 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
|
||||
return answer
|
||||
|
||||
def extract_evidence_images(self, evidence: str):
|
||||
"""Util function to extract and isolate images from context/evidence"""
|
||||
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
|
||||
matches = re.findall(image_pattern, evidence)
|
||||
context = re.sub(image_pattern, "", evidence)
|
||||
return context, matches
|
||||
|
||||
|
||||
class FullQAPipeline(BaseReasoning):
|
||||
"""Question answering pipeline. Handle from question to answer"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user