Improve kotaemon based on insights from projects (#147)
- Include static files in the package. - More reliable information panel. Faster & not breaking randomly. - Add directory upload. - Enable zip file to upload. - Allow setting endpoint for the OCR reader using environment variable.
This commit is contained in:
committed by
GitHub
parent
e1cf970a3d
commit
033e7e05cc
@@ -67,7 +67,7 @@ class DocumentIngestor(BaseComponent):
|
||||
|
||||
main_reader = DirectoryReader(
|
||||
input_files=input_files,
|
||||
file_extractor=file_extractors, # type: ignore
|
||||
file_extractor=file_extractors,
|
||||
)
|
||||
|
||||
return main_reader
|
||||
@@ -85,7 +85,9 @@ class DocumentIngestor(BaseComponent):
|
||||
file_paths = [file_paths]
|
||||
|
||||
documents = self._get_reader(input_files=file_paths)()
|
||||
print(f"Read {len(file_paths)} files into {len(documents)} documents.")
|
||||
nodes = self.text_splitter(documents)
|
||||
print(f"Transform {len(documents)} documents into {len(nodes)} nodes.")
|
||||
self.log_progress(".num_docs", num_docs=len(nodes))
|
||||
|
||||
# document parsers call
|
||||
|
@@ -59,12 +59,15 @@ class VectorIndexing(BaseIndexing):
|
||||
f"Invalid input type {type(item)}, should be str or Document"
|
||||
)
|
||||
|
||||
print(f"Getting embeddings for {len(input_)} nodes")
|
||||
embeddings = self.embedding(input_)
|
||||
print("Adding embeddings to vector store")
|
||||
self.vector_store.add(
|
||||
embeddings=embeddings,
|
||||
ids=[t.doc_id for t in input_],
|
||||
)
|
||||
if self.doc_store:
|
||||
print("Adding documents to doc store")
|
||||
self.doc_store.add(input_)
|
||||
|
||||
|
||||
|
@@ -1,18 +1,34 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from llama_index.readers.base import BaseReader
|
||||
from tenacity import after_log, retry, stop_after_attempt, wait_fixed, wait_random
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .utils.pdf_ocr import parse_ocr_output, read_pdf_unstructured
|
||||
from .utils.table import strip_special_chars_markdown
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_OCR_ENDPOINT = "http://127.0.0.1:8000/v2/ai/infer/"
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(5) + wait_random(0, 2),
|
||||
after=after_log(logger, logging.DEBUG),
|
||||
)
|
||||
def tenacious_api_post(url, **kwargs):
|
||||
resp = requests.post(url=url, **kwargs)
|
||||
resp.raise_for_status()
|
||||
return resp
|
||||
|
||||
|
||||
class OCRReader(BaseReader):
|
||||
"""Read PDF using OCR, with high focus on table extraction
|
||||
|
||||
@@ -24,17 +40,20 @@ class OCRReader(BaseReader):
|
||||
```
|
||||
|
||||
Args:
|
||||
endpoint: URL to FullOCR endpoint. Defaults to
|
||||
endpoint: URL to FullOCR endpoint. If not provided, will look for
|
||||
environment variable `OCR_READER_ENDPOINT` or use the default
|
||||
`kotaemon.loaders.ocr_loader.DEFAULT_OCR_ENDPOINT`
|
||||
(http://127.0.0.1:8000/v2/ai/infer/)
|
||||
use_ocr: whether to use OCR to read text (e.g: from images, tables) in the PDF
|
||||
If False, only the table and text within table cells will be extracted.
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint: str = DEFAULT_OCR_ENDPOINT, use_ocr=True):
|
||||
def __init__(self, endpoint: Optional[str] = None, use_ocr=True):
|
||||
"""Init the OCR reader with OCR endpoint (FullOCR pipeline)"""
|
||||
super().__init__()
|
||||
self.ocr_endpoint = endpoint
|
||||
self.ocr_endpoint = endpoint or os.getenv(
|
||||
"OCR_READER_ENDPOINT", DEFAULT_OCR_ENDPOINT
|
||||
)
|
||||
self.use_ocr = use_ocr
|
||||
|
||||
def load_data(
|
||||
@@ -62,7 +81,7 @@ class OCRReader(BaseReader):
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
|
||||
resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
|
||||
ocr_results = resp.json()["result"]
|
||||
|
||||
debug_path = kwargs.pop("debug_path", None)
|
||||
|
@@ -26,6 +26,7 @@ dependencies = [
|
||||
"click",
|
||||
"pandas",
|
||||
"trogon",
|
||||
"tenacity",
|
||||
]
|
||||
readme = "README.md"
|
||||
license = { text = "MIT License" }
|
||||
|
4
libs/ktem/MANIFEST.in
Normal file
4
libs/ktem/MANIFEST.in
Normal file
@@ -0,0 +1,4 @@
|
||||
include ktem/assets/css/*.css
|
||||
include ktem/assets/img/*.svg
|
||||
include ktem/assets/js/*.js
|
||||
include ktem/assets/md/*.md
|
@@ -44,3 +44,16 @@ footer {
|
||||
mark {
|
||||
background-color: #1496bb;
|
||||
}
|
||||
|
||||
|
||||
/* clpse */
|
||||
.clpse {
|
||||
background-color: var(--background-fill-secondary);
|
||||
font-weight: bold;
|
||||
cursor: pointer;
|
||||
padding: 3px;
|
||||
width: 100%;
|
||||
border: none;
|
||||
text-align: left;
|
||||
outline: none;
|
||||
}
|
||||
|
@@ -4,3 +4,16 @@ main_parent.childNodes[0].classList.add("header-bar");
|
||||
main_parent.style = "padding: 0; margin: 0";
|
||||
main_parent.parentNode.style = "gap: 0";
|
||||
main_parent.parentNode.parentNode.style = "padding: 0";
|
||||
|
||||
|
||||
// clpse
|
||||
globalThis.clpseFn = (id) => {
|
||||
var obj = document.getElementById('clpse-btn-' + id);
|
||||
obj.classList.toggle("clpse-active");
|
||||
var content = obj.nextElementSibling;
|
||||
if (content.style.display === "none") {
|
||||
content.style.display = "block";
|
||||
} else {
|
||||
content.style.display = "none";
|
||||
}
|
||||
}
|
||||
|
@@ -16,7 +16,6 @@ from ktem.components import (
|
||||
)
|
||||
from ktem.db.models import Index, Source, SourceTargetRelation, engine
|
||||
from ktem.indexing.base import BaseIndexing, BaseRetriever
|
||||
from ktem.indexing.exceptions import FileExistsError
|
||||
from llama_index.vector_stores import (
|
||||
FilterCondition,
|
||||
FilterOperator,
|
||||
@@ -241,7 +240,7 @@ class IndexDocumentPipeline(BaseIndexing):
|
||||
to_index.append(abs_path)
|
||||
|
||||
if errors:
|
||||
raise FileExistsError(
|
||||
print(
|
||||
"Files already exist. Please rename/remove them or enable reindex.\n"
|
||||
f"{errors}"
|
||||
)
|
||||
@@ -258,14 +257,18 @@ class IndexDocumentPipeline(BaseIndexing):
|
||||
|
||||
# extract the files
|
||||
nodes = self.file_ingestor(to_index)
|
||||
print("Extracted", len(to_index), "files into", len(nodes), "nodes")
|
||||
for node in nodes:
|
||||
file_path = str(node.metadata["file_path"])
|
||||
node.source = file_to_source[file_path].id
|
||||
|
||||
# index the files
|
||||
print("Indexing the files into vector store")
|
||||
self.indexing_vector_pipeline(nodes)
|
||||
print("Finishing indexing the files into vector store")
|
||||
|
||||
# persist to the index
|
||||
print("Persisting the vector and the document into index")
|
||||
file_ids = []
|
||||
with Session(engine) as session:
|
||||
for source in file_to_source.values():
|
||||
@@ -291,6 +294,8 @@ class IndexDocumentPipeline(BaseIndexing):
|
||||
session.add(index)
|
||||
session.commit()
|
||||
|
||||
print("Finishing persisting the vector and the document into index")
|
||||
print(f"{len(nodes)} nodes are indexed")
|
||||
return nodes, file_ids
|
||||
|
||||
def get_user_settings(self) -> dict:
|
||||
|
@@ -4,9 +4,16 @@ from ktem.app import BasePage
|
||||
from .chat_panel import ChatPanel
|
||||
from .control import ConversationControl
|
||||
from .data_source import DataSource
|
||||
from .events import chat_fn, index_fn, is_liked, load_files, update_data_source
|
||||
from .events import (
|
||||
chat_fn,
|
||||
index_files_from_dir,
|
||||
index_fn,
|
||||
is_liked,
|
||||
load_files,
|
||||
update_data_source,
|
||||
)
|
||||
from .report import ReportIssue
|
||||
from .upload import FileUpload
|
||||
from .upload import DirectoryUpload, FileUpload
|
||||
|
||||
|
||||
class ChatPage(BasePage):
|
||||
@@ -20,12 +27,13 @@ class ChatPage(BasePage):
|
||||
self.chat_control = ConversationControl(self._app)
|
||||
self.data_source = DataSource(self._app)
|
||||
self.file_upload = FileUpload(self._app)
|
||||
self.dir_upload = DirectoryUpload(self._app)
|
||||
self.report_issue = ReportIssue(self._app)
|
||||
with gr.Column(scale=6):
|
||||
self.chat_panel = ChatPanel(self._app)
|
||||
with gr.Column(scale=3):
|
||||
with gr.Accordion(label="Information panel", open=True):
|
||||
self.info_panel = gr.Markdown(elem_id="chat-info-panel")
|
||||
self.info_panel = gr.HTML(elem_id="chat-info-panel")
|
||||
|
||||
def on_register_events(self):
|
||||
self.chat_panel.submit_btn.click(
|
||||
@@ -141,6 +149,17 @@ class ChatPage(BasePage):
|
||||
outputs=[self.file_upload.file_output, self.data_source.files],
|
||||
)
|
||||
|
||||
self.dir_upload.upload_button.click(
|
||||
fn=index_files_from_dir,
|
||||
inputs=[
|
||||
self.dir_upload.path,
|
||||
self.dir_upload.reindex,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
],
|
||||
outputs=[self.dir_upload.file_output, self.data_source.files],
|
||||
)
|
||||
|
||||
self._app.app.load(
|
||||
lambda: gr.update(choices=load_files()),
|
||||
inputs=None,
|
||||
|
@@ -2,7 +2,7 @@ import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
import gradio as gr
|
||||
from ktem.components import llms, reasonings
|
||||
@@ -127,14 +127,18 @@ async def chat_fn(conversation_id, chat_history, files, settings):
|
||||
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
|
||||
text, refs = "", ""
|
||||
|
||||
len_ref = -1 # for logging purpose
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = queue.get_nowait()
|
||||
except Exception:
|
||||
await asyncio.sleep(0)
|
||||
yield "", chat_history + [(chat_input, text or "Thinking ...")], refs
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
queue.task_done()
|
||||
print("Chat completed")
|
||||
break
|
||||
|
||||
if "output" in response:
|
||||
@@ -142,7 +146,11 @@ async def chat_fn(conversation_id, chat_history, files, settings):
|
||||
if "evidence" in response:
|
||||
refs += response["evidence"]
|
||||
|
||||
yield "", chat_history + [(chat_input, text)], refs
|
||||
if len(refs) > len_ref:
|
||||
print(f"Len refs: {len(refs)}")
|
||||
len_ref = len(refs)
|
||||
|
||||
yield "", chat_history + [(chat_input, text)], refs
|
||||
|
||||
|
||||
def is_liked(convo_id, liked: gr.LikeData):
|
||||
@@ -203,7 +211,9 @@ def index_fn(files, reindex: bool, selected_files, settings):
|
||||
gr.Info(f"Start indexing {len(files)} files...")
|
||||
|
||||
# get the pipeline
|
||||
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||
indexing_cls: Type[BaseIndexing] = import_dotted_string(
|
||||
app_settings.KH_INDEX, safe=False
|
||||
)
|
||||
indexing_pipeline = indexing_cls.get_pipeline(settings)
|
||||
|
||||
output_nodes, file_ids = indexing_pipeline(files, reindex=reindex)
|
||||
@@ -225,5 +235,71 @@ def index_fn(files, reindex: bool, selected_files, settings):
|
||||
|
||||
return (
|
||||
gr.update(value=file_path, visible=True),
|
||||
gr.update(value=output, choices=file_list),
|
||||
gr.update(value=output, choices=file_list), # unnecessary
|
||||
)
|
||||
|
||||
|
||||
def index_files_from_dir(folder_path, reindex, selected_files, settings):
|
||||
"""This should be constructable by users
|
||||
|
||||
It means that the users can build their own index.
|
||||
Build your own index:
|
||||
- Input:
|
||||
- Type: based on the type, then there are ranges of. Use can select multiple
|
||||
panels:
|
||||
- Panels
|
||||
- Data sources
|
||||
- Include patterns
|
||||
- Exclude patterns
|
||||
- Indexing functions. Can be a list of indexing functions. Each declared
|
||||
function is:
|
||||
- Condition (the source that will go through this indexing function)
|
||||
- Function (the pipeline that run this)
|
||||
- Output: artifacts that can be used to -> this is the artifacts that we wish
|
||||
- Build the UI
|
||||
- Upload page: fixed standard, based on the type
|
||||
- Read page: fixed standard, based on the type
|
||||
- Delete page: fixed standard, based on the type
|
||||
- Build the index function
|
||||
- Build the chat function
|
||||
|
||||
Step:
|
||||
1. Decide on the artifacts
|
||||
2. Implement the transformation from artifacts to UI
|
||||
"""
|
||||
if not folder_path:
|
||||
return
|
||||
|
||||
import fnmatch
|
||||
from pathlib import Path
|
||||
|
||||
include_patterns: list[str] = []
|
||||
exclude_patterns: list[str] = ["*.png", "*.gif", "*/.*"]
|
||||
if include_patterns and exclude_patterns:
|
||||
raise ValueError("Cannot have both include and exclude patterns")
|
||||
|
||||
# clean up the include patterns
|
||||
for idx in range(len(include_patterns)):
|
||||
if include_patterns[idx].startswith("*"):
|
||||
include_patterns[idx] = str(Path.cwd() / "**" / include_patterns[idx])
|
||||
else:
|
||||
include_patterns[idx] = str(Path.cwd() / include_patterns[idx].strip("/"))
|
||||
|
||||
# clean up the exclude patterns
|
||||
for idx in range(len(exclude_patterns)):
|
||||
if exclude_patterns[idx].startswith("*"):
|
||||
exclude_patterns[idx] = str(Path.cwd() / "**" / exclude_patterns[idx])
|
||||
else:
|
||||
exclude_patterns[idx] = str(Path.cwd() / exclude_patterns[idx].strip("/"))
|
||||
|
||||
# get the files
|
||||
files: list[str] = [str(p) for p in Path(folder_path).glob("**/*.*")]
|
||||
if include_patterns:
|
||||
for p in include_patterns:
|
||||
files = fnmatch.filter(names=files, pat=p)
|
||||
|
||||
if exclude_patterns:
|
||||
for p in exclude_patterns:
|
||||
files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)]
|
||||
|
||||
return index_fn(files, reindex, selected_files, settings)
|
||||
|
@@ -32,22 +32,46 @@ class FileUpload(BasePage):
|
||||
)
|
||||
with gr.Accordion("Advanced indexing options", open=False):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.reindex = gr.Checkbox(
|
||||
value=False, label="Force reindex file", container=False
|
||||
)
|
||||
with gr.Column():
|
||||
self.parser = gr.Dropdown(
|
||||
choices=[
|
||||
("PDF text parser", "normal"),
|
||||
("lib-table", "table"),
|
||||
("lib-table + OCR", "ocr"),
|
||||
("MathPix", "mathpix"),
|
||||
],
|
||||
value="normal",
|
||||
label="Use advance PDF parser (table+layout preserving)",
|
||||
container=True,
|
||||
)
|
||||
self.reindex = gr.Checkbox(
|
||||
value=False, label="Force reindex file", container=False
|
||||
)
|
||||
|
||||
self.upload_button = gr.Button("Upload and Index")
|
||||
self.file_output = gr.File(
|
||||
visible=False, label="Output files (debug purpose)"
|
||||
)
|
||||
|
||||
|
||||
class DirectoryUpload(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self._supported_file_types = [
|
||||
"image",
|
||||
".pdf",
|
||||
".txt",
|
||||
".csv",
|
||||
".xlsx",
|
||||
".doc",
|
||||
".docx",
|
||||
".pptx",
|
||||
".html",
|
||||
".zip",
|
||||
]
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(label="Directory upload", open=False):
|
||||
gr.Markdown(
|
||||
f"Supported file types: {', '.join(self._supported_file_types)}",
|
||||
)
|
||||
self.path = gr.Textbox(
|
||||
placeholder="Directory path...", lines=1, max_lines=1, container=False
|
||||
)
|
||||
with gr.Accordion("Advanced indexing options", open=False):
|
||||
with gr.Row():
|
||||
self.reindex = gr.Checkbox(
|
||||
value=False, label="Force reindex file", container=False
|
||||
)
|
||||
|
||||
self.upload_button = gr.Button("Upload and Index")
|
||||
self.file_output = gr.File(
|
||||
|
@@ -106,7 +106,7 @@ DEFAULT_QA_TEXT_PROMPT = (
|
||||
"Use the following pieces of context to answer the question at the end. "
|
||||
"If you don't know the answer, just say that you don't know, don't try to "
|
||||
"make up an answer. Keep the answer as concise as possible. Give answer in "
|
||||
"{lang}. {system}\n\n"
|
||||
"{lang}.\n\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Helpful Answer:"
|
||||
@@ -116,7 +116,7 @@ DEFAULT_QA_TABLE_PROMPT = (
|
||||
"List all rows (row number) from the table context that related to the question, "
|
||||
"then provide detail answer with clear explanation and citations. "
|
||||
"If you don't know the answer, just say that you don't know, "
|
||||
"don't try to make up an answer. Give answer in {lang}. {system}\n\n"
|
||||
"don't try to make up an answer. Give answer in {lang}.\n\n"
|
||||
"Context:\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
@@ -127,7 +127,7 @@ DEFAULT_QA_CHATBOT_PROMPT = (
|
||||
"Pick the most suitable chatbot scenarios to answer the question at the end, "
|
||||
"output the provided answer text. If you don't know the answer, "
|
||||
"just say that you don't know. Keep the answer as concise as possible. "
|
||||
"Give answer in {lang}. {system}\n\n"
|
||||
"Give answer in {lang}.\n\n"
|
||||
"Context:\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
@@ -198,13 +198,12 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
context=evidence,
|
||||
question=question,
|
||||
lang=self.lang,
|
||||
system=self.system_prompt,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant"),
|
||||
HumanMessage(content=prompt),
|
||||
]
|
||||
messages = []
|
||||
if self.system_prompt:
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
output = ""
|
||||
for text in self.llm(messages):
|
||||
output += text.text
|
||||
@@ -316,11 +315,19 @@ class FullQAPipeline(BaseComponent):
|
||||
settings: the settings for the pipeline
|
||||
retrievers: the retrievers to use
|
||||
"""
|
||||
_id = cls.get_info()["id"]
|
||||
|
||||
pipeline = FullQAPipeline(retrievers=retrievers)
|
||||
pipeline.answering_pipeline.llm = llms.get_highest_accuracy()
|
||||
pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
||||
settings["reasoning.lang"], "English"
|
||||
)
|
||||
pipeline.answering_pipeline.system_prompt = settings[
|
||||
f"reasoning.options.{_id}.system_prompt"
|
||||
]
|
||||
pipeline.answering_pipeline.qa_template = settings[
|
||||
f"reasoning.options.{_id}.qa_prompt"
|
||||
]
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
@@ -345,10 +352,6 @@ class FullQAPipeline(BaseComponent):
|
||||
"value": True,
|
||||
"component": "checkbox",
|
||||
},
|
||||
"system_prompt": {
|
||||
"name": "System Prompt",
|
||||
"value": "This is a question answering system",
|
||||
},
|
||||
"citation_llm": {
|
||||
"name": "LLM for citation",
|
||||
"value": citation_llm,
|
||||
@@ -361,6 +364,14 @@ class FullQAPipeline(BaseComponent):
|
||||
"component": "dropdown",
|
||||
"choices": main_llm_choices,
|
||||
},
|
||||
"system_prompt": {
|
||||
"name": "System Prompt",
|
||||
"value": "This is a question answering system",
|
||||
},
|
||||
"qa_prompt": {
|
||||
"name": "QA Prompt (contains {context}, {question}, {lang})",
|
||||
"value": DEFAULT_QA_TEXT_PROMPT,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
@@ -3,7 +3,7 @@ requires = ["setuptools >= 61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = false
|
||||
include-package-data = true
|
||||
packages.find.include = ["ktem*"]
|
||||
packages.find.exclude = ["tests*", "env*"]
|
||||
|
||||
|
Reference in New Issue
Block a user