diff --git a/libs/kotaemon/kotaemon/indices/ingests/files.py b/libs/kotaemon/kotaemon/indices/ingests/files.py
index ed00e5c..75f944e 100644
--- a/libs/kotaemon/kotaemon/indices/ingests/files.py
+++ b/libs/kotaemon/kotaemon/indices/ingests/files.py
@@ -7,6 +7,7 @@ from kotaemon.base import BaseComponent, Document, Param
from kotaemon.indices.extractors import BaseDocParser
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from kotaemon.loaders import (
+ AdobeReader,
DirectoryReader,
MathpixPDFReader,
OCRReader,
@@ -41,7 +42,7 @@ class DocumentIngestor(BaseComponent):
The default file extractors are stored in `KH_DEFAULT_FILE_EXTRACTORS`
"""
- pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
+ pdf_mode: str = "normal" # "normal", "mathpix", "ocr", "multimodal"
doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: [])
text_splitter: BaseSplitter = TokenSplitter.withx(
chunk_size=1024,
@@ -61,6 +62,8 @@ class DocumentIngestor(BaseComponent):
pass # use default loader of llama-index which is pypdf
elif self.pdf_mode == "ocr":
file_extractors[".pdf"] = OCRReader()
+ elif self.pdf_mode == "multimodal":
+ file_extractors[".pdf"] = AdobeReader()
else:
file_extractors[".pdf"] = MathpixPDFReader()
diff --git a/libs/kotaemon/kotaemon/loaders/__init__.py b/libs/kotaemon/kotaemon/loaders/__init__.py
index d742b52..28cb5f3 100644
--- a/libs/kotaemon/kotaemon/loaders/__init__.py
+++ b/libs/kotaemon/kotaemon/loaders/__init__.py
@@ -1,3 +1,4 @@
+from .adobe_loader import AdobeReader
from .base import AutoReader, BaseReader
from .composite_loader import DirectoryReader
from .docx_loader import DocxReader
@@ -17,4 +18,5 @@ __all__ = [
"UnstructuredReader",
"DocxReader",
"HtmlReader",
+ "AdobeReader",
]
diff --git a/libs/kotaemon/kotaemon/loaders/adobe_loader.py b/libs/kotaemon/kotaemon/loaders/adobe_loader.py
new file mode 100644
index 0000000..dd8cbc9
--- /dev/null
+++ b/libs/kotaemon/kotaemon/loaders/adobe_loader.py
@@ -0,0 +1,187 @@
+import logging
+import os
+import re
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from decouple import config
+from llama_index.readers.base import BaseReader
+
+from kotaemon.base import Document
+
+from .utils.adobe import (
+ generate_figure_captions,
+ load_json,
+ parse_figure_paths,
+ parse_table_paths,
+ request_adobe_service,
+)
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_VLM_ENDPOINT = (
+ "{0}openai/deployments/{1}/chat/completions?api-version={2}".format(
+ config("AZURE_OPENAI_ENDPOINT", default=""),
+ "gpt-4-vision",
+ config("OPENAI_API_VERSION", default=""),
+ )
+)
+
+
+class AdobeReader(BaseReader):
+ """Read PDF using the Adobe's PDF Services.
+ Be able to extract text, table, and figure with high accuracy
+
+ Example:
+ ```python
+ >> from kotaemon.loaders import AdobeReader
+ >> reader = AdobeReader()
+ >> documents = reader.load_data("path/to/pdf")
+ ```
+ Args:
+ endpoint: URL to the Vision Language Model endpoint. If not provided,
+ will use the default `kotaemon.loaders.adobe_loader.DEFAULT_VLM_ENDPOINT`
+
+ max_figures_to_caption: an int decides how many figured will be captioned.
+ The rest will be ignored (are indexed without captions).
+ """
+
+ def __init__(
+ self,
+ vlm_endpoint: Optional[str] = None,
+ max_figures_to_caption: int = 100,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ """Init params"""
+ super().__init__(*args)
+ self.table_regex = r"/Table(\[\d+\])?$"
+ self.figure_regex = r"/Figure(\[\d+\])?$"
+ self.vlm_endpoint = vlm_endpoint or DEFAULT_VLM_ENDPOINT
+ self.max_figures_to_caption = max_figures_to_caption
+
+ def load_data(
+ self, file: Path, extra_info: Optional[Dict] = None, **kwargs
+ ) -> List[Document]:
+ """Load data by calling to the Adobe's API
+
+ Args:
+ file (Path): Path to the PDF file
+
+ Returns:
+ List[Document]: list of documents extracted from the PDF file,
+ includes 3 types: text, table, and image
+
+ """
+
+ filename = file.name
+ filepath = str(Path(file).resolve())
+ output_path = request_adobe_service(file_path=str(file), output_path="")
+ results_path = os.path.join(output_path, "structuredData.json")
+
+ if not os.path.exists(results_path):
+ logger.exception("Fail to parse the document.")
+ return []
+
+ data = load_json(results_path)
+
+ texts = defaultdict(list)
+ tables = []
+ figures = []
+
+ elements = data["elements"]
+ for item_id, item in enumerate(elements):
+ page_number = item.get("Page", -1) + 1
+ item_path = item["Path"]
+ item_text = item.get("Text", "")
+
+ file_paths = [
+ Path(output_path) / path for path in item.get("filePaths", [])
+ ]
+ prev_item = elements[item_id - 1]
+ title = prev_item.get("Text", "")
+
+ if re.search(self.table_regex, item_path):
+ table_content = parse_table_paths(file_paths)
+ if not table_content:
+ continue
+ table_caption = (
+ table_content.replace("|", "").replace("---", "")
+ + f"\n(Table in Page {page_number}. {title})"
+ )
+ tables.append((page_number, table_content, table_caption))
+
+ elif re.search(self.figure_regex, item_path):
+ figure_caption = (
+ item_text + f"\n(Figure in Page {page_number}. {title})"
+ )
+ figure_content = parse_figure_paths(file_paths)
+ if not figure_content:
+ continue
+ figures.append([page_number, figure_content, figure_caption])
+
+ else:
+ if item_text and "Table" not in item_path and "Figure" not in item_path:
+ texts[page_number].append(item_text)
+
+ # get figure caption using GPT-4V
+ figure_captions = generate_figure_captions(
+ self.vlm_endpoint,
+ [item[1] for item in figures],
+ self.max_figures_to_caption,
+ )
+ for item, caption in zip(figures, figure_captions):
+ # update figure caption
+ item[2] += " " + caption
+
+ # Wrap elements with Document
+ documents = []
+
+ # join plain text elements
+ for page_number, txts in texts.items():
+ documents.append(
+ Document(
+ text="\n".join(txts),
+ metadata={
+ "page_label": page_number,
+ "file_name": filename,
+ "file_path": filepath,
+ },
+ )
+ )
+
+ # table elements
+ for page_number, table_content, table_caption in tables:
+ documents.append(
+ Document(
+ text=table_caption,
+ metadata={
+ "table_origin": table_content,
+ "type": "table",
+ "page_label": page_number,
+ "file_name": filename,
+ "file_path": filepath,
+ },
+ metadata_template="",
+ metadata_seperator="",
+ )
+ )
+
+ # figure elements
+ for page_number, figure_content, figure_caption in figures:
+ documents.append(
+ Document(
+ text=figure_caption,
+ metadata={
+ "image_origin": figure_content,
+ "type": "image",
+ "page_label": page_number,
+ "file_name": filename,
+ "file_path": filepath,
+ },
+ metadata_template="",
+ metadata_seperator="",
+ )
+ )
+ return documents
diff --git a/libs/kotaemon/kotaemon/loaders/utils/adobe.py b/libs/kotaemon/kotaemon/loaders/utils/adobe.py
new file mode 100644
index 0000000..a780c45
--- /dev/null
+++ b/libs/kotaemon/kotaemon/loaders/utils/adobe.py
@@ -0,0 +1,248 @@
+# need pip install pdfservices-sdk==2.3.0
+
+import base64
+import json
+import logging
+import os
+import tempfile
+import zipfile
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import List, Union
+
+import pandas as pd
+from decouple import config
+
+from kotaemon.loaders.utils.gpt4v import generate_gpt4v
+
+logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
+
+
+def request_adobe_service(file_path: str, output_path: str = "") -> str:
+ """Main function to call the adobe service, and unzip the results.
+ Args:
+ file_path (str): path to the pdf file
+ output_path (str): path to store the results
+
+ Returns:
+ output_path (str): path to the results
+
+ """
+ try:
+ from adobe.pdfservices.operation.auth.credentials import Credentials
+ from adobe.pdfservices.operation.exception.exceptions import (
+ SdkException,
+ ServiceApiException,
+ ServiceUsageException,
+ )
+ from adobe.pdfservices.operation.execution_context import ExecutionContext
+ from adobe.pdfservices.operation.io.file_ref import FileRef
+ from adobe.pdfservices.operation.pdfops.extract_pdf_operation import (
+ ExtractPDFOperation,
+ )
+ from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_element_type import ( # noqa: E501
+ ExtractElementType,
+ )
+ from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_pdf_options import ( # noqa: E501
+ ExtractPDFOptions,
+ )
+ from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_renditions_element_type import ( # noqa: E501
+ ExtractRenditionsElementType,
+ )
+ except ImportError:
+ raise ImportError(
+ "pdfservices-sdk is not installed. "
+ "Please install it by running `pip install pdfservices-sdk"
+ "@git+https://github.com/niallcm/pdfservices-python-sdk.git"
+ "@bump-and-unfreeze-requirements`"
+ )
+
+ if not output_path:
+ output_path = tempfile.mkdtemp()
+
+ try:
+ # Initial setup, create credentials instance.
+ credentials = (
+ Credentials.service_principal_credentials_builder()
+ .with_client_id(config("PDF_SERVICES_CLIENT_ID", default=""))
+ .with_client_secret(config("PDF_SERVICES_CLIENT_SECRET", default=""))
+ .build()
+ )
+
+ # Create an ExecutionContext using credentials
+ # and create a new operation instance.
+ execution_context = ExecutionContext.create(credentials)
+ extract_pdf_operation = ExtractPDFOperation.create_new()
+
+ # Set operation input from a source file.
+ source = FileRef.create_from_local_file(file_path)
+ extract_pdf_operation.set_input(source)
+
+ # Build ExtractPDF options and set them into the operation
+ extract_pdf_options: ExtractPDFOptions = (
+ ExtractPDFOptions.builder()
+ .with_elements_to_extract(
+ [ExtractElementType.TEXT, ExtractElementType.TABLES]
+ )
+ .with_elements_to_extract_renditions(
+ [
+ ExtractRenditionsElementType.TABLES,
+ ExtractRenditionsElementType.FIGURES,
+ ]
+ )
+ .build()
+ )
+ extract_pdf_operation.set_options(extract_pdf_options)
+
+ # Execute the operation.
+ result: FileRef = extract_pdf_operation.execute(execution_context)
+
+ # Save the result to the specified location.
+ zip_file_path = os.path.join(
+ output_path, "ExtractTextTableWithFigureTableRendition.zip"
+ )
+ result.save_as(zip_file_path)
+ # Open the ZIP file
+ with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
+ # Extract all contents to the destination folder
+ zip_ref.extractall(output_path)
+ except (ServiceApiException, ServiceUsageException, SdkException):
+ logging.exception("Exception encountered while executing operation")
+
+ return output_path
+
+
+def make_markdown_table(table_as_list: List[str]) -> str:
+ """
+ Convert table from python list representation to markdown format.
+ The input list consists of rows of tables, the first row is the header.
+
+ Args:
+ table_as_list: list of table rows
+ Example: [["Name", "Age", "Height"],
+ ["Jake", 20, 5'10],
+ ["Mary", 21, 5'7]]
+ Returns:
+ markdown representation of the table
+ """
+ markdown = "\n" + str("| ")
+
+ for e in table_as_list[0]:
+ to_add = " " + str(e) + str(" |")
+ markdown += to_add
+ markdown += "\n"
+
+ markdown += "| "
+ for i in range(len(table_as_list[0])):
+ markdown += str("--- | ")
+ markdown += "\n"
+
+ for entry in table_as_list[1:]:
+ markdown += str("| ")
+ for e in entry:
+ to_add = str(e) + str(" | ")
+ markdown += to_add
+ markdown += "\n"
+
+ return markdown + "\n"
+
+
+def load_json(input_path: Union[str | Path]) -> dict:
+ """Load json file"""
+ with open(input_path, "r") as fi:
+ data = json.load(fi)
+
+ return data
+
+
+def load_excel(input_path: Union[str | Path]) -> str:
+ """Load excel file and convert to markdown"""
+
+ df = pd.read_excel(input_path).fillna("")
+ # Convert dataframe to a list of rows
+ row_list = [df.columns.values.tolist()] + df.values.tolist()
+
+ for item_id, item in enumerate(row_list[0]):
+ if "Unnamed" in item:
+ row_list[0][item_id] = ""
+
+ for row in row_list:
+ for item_id, item in enumerate(row):
+ row[item_id] = str(item).replace("_x000D_", " ").replace("\n", " ").strip()
+
+ markdown_str = make_markdown_table(row_list)
+ return markdown_str
+
+
+def encode_image_base64(image_path: Union[str | Path]) -> Union[bytes, str]:
+ """Convert image to base64"""
+
+ with open(image_path, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode("utf-8")
+
+
+def parse_table_paths(file_paths: List[Path]) -> str:
+ """Read the table stored in an excel file given the file path"""
+
+ content = ""
+ for path in file_paths:
+ if path.suffix == ".xlsx":
+ content = load_excel(path)
+ break
+ return content
+
+
+def parse_figure_paths(file_paths: List[Path]) -> Union[bytes, str]:
+ """Read and convert an image to base64 given the image path"""
+
+ content = ""
+ for path in file_paths:
+ if path.suffix == ".png":
+ base64_image = encode_image_base64(path)
+ content = f"data:image/png;base64,{base64_image}" # type: ignore
+ break
+ return content
+
+
+def generate_single_figure_caption(vlm_endpoint: str, figure: str) -> str:
+ """Summarize a single figure using GPT-4V"""
+ if figure:
+ output = generate_gpt4v(
+ endpoint=vlm_endpoint,
+ prompt="Provide a short 2 sentence summary of this image?",
+ images=figure,
+ )
+ if "sorry" in output.lower():
+ output = ""
+ else:
+ output = ""
+ return output
+
+
+def generate_figure_captions(
+ vlm_endpoint: str, figures: List, max_figures_to_process: int
+) -> List:
+ """Summarize several figures using GPT-4V.
+ Args:
+ vlm_endpoint (str): endpoint to the vision language model service
+ figures (List): list of base64 images
+ max_figures_to_process (int): the maximum number of figures will be summarized,
+ the rest are ignored.
+
+ Returns:
+ results (List[str]): list of all figure captions and empty strings for
+ ignored figures.
+ """
+ to_gen_figures = figures[:max_figures_to_process]
+ other_figures = figures[max_figures_to_process:]
+
+ with ThreadPoolExecutor() as executor:
+ futures = [
+ executor.submit(
+ lambda: generate_single_figure_caption(vlm_endpoint, figure)
+ )
+ for figure in to_gen_figures
+ ]
+
+ results = [future.result() for future in futures]
+ return results + [""] * len(other_figures)
diff --git a/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py b/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py
new file mode 100644
index 0000000..1e219d6
--- /dev/null
+++ b/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py
@@ -0,0 +1,96 @@
+import json
+from typing import Any, List
+
+import requests
+from decouple import config
+
+
+def generate_gpt4v(
+ endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
+) -> str:
+ # OpenAI API Key
+ api_key = config("AZURE_OPENAI_API_KEY", default="")
+ headers = {"Content-Type": "application/json", "api-key": api_key}
+
+ if isinstance(images, str):
+ images = [images]
+
+ payload = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": prompt},
+ ]
+ + [
+ {
+ "type": "image_url",
+ "image_url": {"url": image},
+ }
+ for image in images
+ ],
+ }
+ ],
+ "max_tokens": max_tokens,
+ }
+
+ try:
+ response = requests.post(endpoint, headers=headers, json=payload)
+ output = response.json()
+ output = output["choices"][0]["message"]["content"]
+ except Exception:
+ output = ""
+ return output
+
+
+def stream_gpt4v(
+ endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
+) -> Any:
+ # OpenAI API Key
+ api_key = config("AZURE_OPENAI_API_KEY", default="")
+ headers = {"Content-Type": "application/json", "api-key": api_key}
+
+ if isinstance(images, str):
+ images = [images]
+
+ payload = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": prompt},
+ ]
+ + [
+ {
+ "type": "image_url",
+ "image_url": {"url": image},
+ }
+ for image in images
+ ],
+ }
+ ],
+ "max_tokens": max_tokens,
+ "stream": True,
+ }
+ try:
+ response = requests.post(endpoint, headers=headers, json=payload, stream=True)
+ assert response.status_code == 200, str(response.content)
+ output = ""
+ for line in response.iter_lines():
+ if line:
+ if line.startswith(b"\xef\xbb\xbf"):
+ line = line[9:]
+ else:
+ line = line[6:]
+ try:
+ if line == "[DONE]":
+ break
+ line = json.loads(line.decode("utf-8"))
+ except Exception:
+ break
+ if len(line["choices"]):
+ output += line["choices"][0]["delta"].get("content", "")
+ yield line["choices"][0]["delta"].get("content", "")
+ except Exception:
+ output = ""
+ return output
diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml
index e1e3028..73c3e8a 100644
--- a/libs/kotaemon/pyproject.toml
+++ b/libs/kotaemon/pyproject.toml
@@ -60,6 +60,7 @@ adv = [
"cohere",
"elasticsearch",
"llama-cpp-python",
+ "pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements",
]
dev = [
"ipython",
@@ -69,6 +70,7 @@ dev = [
"flake8",
"sphinx",
"coverage",
+ "python-decouple"
]
all = ["kotaemon[adv,dev]"]
diff --git a/libs/kotaemon/tests/_test_multimodal_reader.py b/libs/kotaemon/tests/_test_multimodal_reader.py
new file mode 100644
index 0000000..b07786f
--- /dev/null
+++ b/libs/kotaemon/tests/_test_multimodal_reader.py
@@ -0,0 +1,21 @@
+# TODO: This test is broken and should be rewritten
+from pathlib import Path
+
+from kotaemon.loaders import AdobeReader
+
+# from dotenv import load_dotenv
+
+
+input_file = Path(__file__).parent / "resources" / "multimodal.pdf"
+
+# load_dotenv()
+
+
+def test_adobe_reader():
+ reader = AdobeReader()
+ documents = reader.load_data(input_file)
+ table_docs = [doc for doc in documents if doc.metadata.get("type", "") == "table"]
+ assert len(table_docs) == 2
+
+ figure_docs = [doc for doc in documents if doc.metadata.get("type", "") == "image"]
+ assert len(figure_docs) == 2
diff --git a/libs/kotaemon/tests/resources/multimodal.pdf b/libs/kotaemon/tests/resources/multimodal.pdf
new file mode 100644
index 0000000..29c2bdc
Binary files /dev/null and b/libs/kotaemon/tests/resources/multimodal.pdf differ
diff --git a/libs/ktem/flowsettings.py b/libs/ktem/flowsettings.py
index a3589fe..33ba88f 100644
--- a/libs/ktem/flowsettings.py
+++ b/libs/ktem/flowsettings.py
@@ -124,6 +124,11 @@ if config("LOCAL_MODEL", default=""):
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
+KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
+ config("AZURE_OPENAI_ENDPOINT", default=""),
+ config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4-vision"),
+ config("OPENAI_API_VERSION", default=""),
+)
SETTINGS_APP = {
diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py
index 1d813f5..68b3a4d 100644
--- a/libs/ktem/ktem/index/file/pipelines.py
+++ b/libs/ktem/ktem/index/file/pipelines.py
@@ -378,6 +378,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
("PDF text parser", "normal"),
("Mathpix", "mathpix"),
("Advanced ocr", "ocr"),
+ ("Multimodal parser", "multimodal"),
],
"component": "dropdown",
},
diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py
index 47f7c92..1627522 100644
--- a/libs/ktem/ktem/reasoning/simple.py
+++ b/libs/ktem/ktem/reasoning/simple.py
@@ -1,11 +1,14 @@
import asyncio
+import html
import logging
+import re
from collections import defaultdict
from functools import partial
import tiktoken
from ktem.components import llms
from ktem.reasoning.base import BaseReasoning
+from theflow.settings import settings as flowsettings
from kotaemon.base import (
BaseComponent,
@@ -18,9 +21,15 @@ from kotaemon.base import (
from kotaemon.indices.qa.citation import CitationPipeline
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate
+from kotaemon.loaders.utils.gpt4v import stream_gpt4v
logger = logging.getLogger(__name__)
+EVIDENCE_MODE_TEXT = 0
+EVIDENCE_MODE_TABLE = 1
+EVIDENCE_MODE_CHATBOT = 2
+EVIDENCE_MODE_FIGURE = 3
+
class PrepareEvidencePipeline(BaseComponent):
"""Prepare the evidence text from the list of retrieved documents
@@ -46,7 +55,7 @@ class PrepareEvidencePipeline(BaseComponent):
def run(self, docs: list[RetrievedDocument]) -> Document:
evidence = ""
table_found = 0
- evidence_mode = 0
+ evidence_mode = EVIDENCE_MODE_TEXT
for _id, retrieved_item in enumerate(docs):
retrieved_content = ""
@@ -55,7 +64,7 @@ class PrepareEvidencePipeline(BaseComponent):
if page:
source += f" (Page {page})"
if retrieved_item.metadata.get("type", "") == "table":
- evidence_mode = 1 # table
+ evidence_mode = EVIDENCE_MODE_TABLE
if table_found < 5:
retrieved_content = retrieved_item.metadata.get("table_origin", "")
if retrieved_content not in evidence:
@@ -66,13 +75,23 @@ class PrepareEvidencePipeline(BaseComponent):
+ "\n
"
)
elif retrieved_item.metadata.get("type", "") == "chatbot":
- evidence_mode = 2 # chatbot
+ evidence_mode = EVIDENCE_MODE_CHATBOT
retrieved_content = retrieved_item.metadata["window"]
evidence += (
f"
Chatbot scenario from {filename} (Row {page})\n"
+ retrieved_content
+ "\n
"
)
+ elif retrieved_item.metadata.get("type", "") == "image":
+ evidence_mode = EVIDENCE_MODE_FIGURE
+ retrieved_content = retrieved_item.metadata.get("image_origin", "")
+ retrieved_caption = html.escape(retrieved_item.get_content())
+ evidence += (
+ f"
Figure from {source}\n"
+ + f"
"
+ + "\n
"
+ )
else:
if "window" in retrieved_item.metadata:
retrieved_content = retrieved_item.metadata["window"]
@@ -90,12 +109,13 @@ class PrepareEvidencePipeline(BaseComponent):
print(retrieved_item.metadata)
print("Score", retrieved_item.metadata.get("relevance_score", None))
- # trim context by trim_len
- print("len (original)", len(evidence))
- if evidence:
- texts = self.trim_func([Document(text=evidence)])
- evidence = texts[0].text
- print("len (trimmed)", len(evidence))
+ if evidence_mode != EVIDENCE_MODE_FIGURE:
+ # trim context by trim_len
+ print("len (original)", len(evidence))
+ if evidence:
+ texts = self.trim_func([Document(text=evidence)])
+ evidence = texts[0].text
+ print("len (trimmed)", len(evidence))
print(f"PrepareEvidence with input {docs}\nOutput: {evidence}\n")
@@ -134,6 +154,16 @@ DEFAULT_QA_CHATBOT_PROMPT = (
"Answer:"
)
+DEFAULT_QA_FIGURE_PROMPT = (
+ "Use the given context: texts, tables, and figures below to answer the question. "
+ "If you don't know the answer, just say that you don't know. "
+ "Give answer in {lang}.\n\n"
+ "Context: \n"
+ "{context}\n"
+ "Question: {question}\n"
+ "Answer: "
+)
+
class AnswerWithContextPipeline(BaseComponent):
"""Answer the question based on the evidence
@@ -151,6 +181,7 @@ class AnswerWithContextPipeline(BaseComponent):
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
+ vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT
citation_pipeline: CitationPipeline = Node(
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
)
@@ -158,6 +189,7 @@ class AnswerWithContextPipeline(BaseComponent):
qa_template: str = DEFAULT_QA_TEXT_PROMPT
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
+ qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
enable_citation: bool = False
system_prompt: str = ""
@@ -188,18 +220,30 @@ class AnswerWithContextPipeline(BaseComponent):
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
- if evidence_mode == 0:
+ if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
- elif evidence_mode == 1:
+ elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
+ elif evidence_mode == EVIDENCE_MODE_FIGURE:
+ prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
- prompt = prompt_template.populate(
- context=evidence,
- question=question,
- lang=self.lang,
- )
+ images = []
+ if evidence_mode == EVIDENCE_MODE_FIGURE:
+ # isolate image from evidence
+ evidence, images = self.extract_evidence_images(evidence)
+ prompt = prompt_template.populate(
+ context=evidence,
+ question=question,
+ lang=self.lang,
+ )
+ else:
+ prompt = prompt_template.populate(
+ context=evidence,
+ question=question,
+ lang=self.lang,
+ )
citation_task = None
if evidence and self.enable_citation:
@@ -208,23 +252,29 @@ class AnswerWithContextPipeline(BaseComponent):
)
print("Citation task created")
- messages = []
- if self.system_prompt:
- messages.append(SystemMessage(content=self.system_prompt))
- messages.append(HumanMessage(content=prompt))
-
output = ""
- try:
- # try streaming first
- print("Trying LLM streaming")
- for text in self.llm.stream(messages):
- output += text.text
- self.report_output({"output": text.text})
+ if evidence_mode == EVIDENCE_MODE_FIGURE:
+ for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
+ output += text
+ self.report_output({"output": text})
await asyncio.sleep(0)
- except NotImplementedError:
- print("Streaming is not supported, falling back to normal processing")
- output = self.llm(messages).text
- self.report_output({"output": output})
+ else:
+ messages = []
+ if self.system_prompt:
+ messages.append(SystemMessage(content=self.system_prompt))
+ messages.append(HumanMessage(content=prompt))
+
+ try:
+ # try streaming first
+ print("Trying LLM streaming")
+ for text in self.llm.stream(messages):
+ output += text.text
+ self.report_output({"output": text.text})
+ await asyncio.sleep(0)
+ except NotImplementedError:
+ print("Streaming is not supported, falling back to normal processing")
+ output = self.llm(messages).text
+ self.report_output({"output": output})
# retrieve the citation
print("Waiting for citation task")
@@ -237,6 +287,13 @@ class AnswerWithContextPipeline(BaseComponent):
return answer
+ def extract_evidence_images(self, evidence: str):
+ """Util function to extract and isolate images from context/evidence"""
+ image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
+ matches = re.findall(image_pattern, evidence)
+ context = re.sub(image_pattern, "", evidence)
+ return context, matches
+
class FullQAPipeline(BaseReasoning):
"""Question answering pipeline. Handle from question to answer"""