Migrate the MVP into kotaemon (#108)

- Migrate the MVP into kotaemon.
- Preliminary include the pipeline within chatbot interface.
- Organize MVP as an application.

Todo:

- Add an info panel to view the planning of agents -> Fix streaming agents' output.

Resolve: #60
Resolve: #61 
Resolve: #62
This commit is contained in:
Duc Nguyen (john)
2024-01-10 15:28:09 +07:00
committed by GitHub
parent 230328c62f
commit 5a9d6f75be
31 changed files with 273 additions and 92 deletions

View File

@@ -11,6 +11,7 @@ from kotaemon.loaders import (
MathpixPDFReader,
OCRReader,
PandasExcelReader,
UnstructuredReader,
)
@@ -19,8 +20,16 @@ class DocumentIngestor(BaseComponent):
Document types:
- pdf
- xlsx
- docx
- xlsx, xls
- docx, doc
Args:
pdf_mode: mode for pdf extraction, one of "normal", "mathpix", "ocr"
- normal: parse pdf text
- mathpix: parse pdf text using mathpix
- ocr: parse pdf image using flax
doc_parsers: list of document parsers to parse the document
text_splitter: splitter to split the document into text nodes
"""
pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
@@ -34,6 +43,9 @@ class DocumentIngestor(BaseComponent):
"""Get appropriate readers for the input files based on file extension"""
file_extractor: dict[str, AutoReader | BaseReader] = {
".xlsx": PandasExcelReader(),
".docx": UnstructuredReader(),
".xls": UnstructuredReader(),
".doc": UnstructuredReader(),
}
if self.pdf_mode == "normal":

View File

@@ -64,11 +64,7 @@ class CitationPipeline(BaseComponent):
llm: BaseLLM
def run(
self,
context: str,
question: str,
) -> QuestionAnswer:
def run(self, context: str, question: str):
schema = QuestionAnswer.schema()
function = {
"name": schema["title"],

View File

@@ -1,6 +1,6 @@
import os
from kotaemon.base import BaseComponent, Document, RetrievedDocument
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
from .citation import CitationPipeline
@@ -21,6 +21,9 @@ class CitationQAPipeline(BaseComponent):
temperature=0,
request_timeout=60,
)
citation_pipeline: CitationPipeline = Node(
default_callback=lambda self: CitationPipeline(llm=self.llm)
)
def _format_doc_text(self, text: str) -> str:
"""Format the text of each document"""
@@ -52,9 +55,7 @@ class CitationQAPipeline(BaseComponent):
self.log_progress(".prompt", prompt=prompt)
answer_text = self.llm(prompt).text
if use_citation:
# run citation pipeline
citation_pipeline = CitationPipeline(llm=self.llm)
citation = citation_pipeline(context=context, question=question)
citation = self.citation_pipeline(context=context, question=question)
else:
citation = None

View File

@@ -23,17 +23,18 @@ class CohereReranking(BaseReranking):
)
cohere_client = cohere.Client(self.cohere_api_key)
compressed_docs: list[Document] = []
# output documents
compressed_docs = []
if len(documents) > 0: # to avoid empty api call
_docs = [d.content for d in documents]
results = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
)
for r in results:
doc = documents[r.index]
doc.metadata["relevance_score"] = r.relevance_score
compressed_docs.append(doc)
if not documents: # to avoid empty api call
return compressed_docs
_docs = [d.content for d in documents]
results = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
)
for r in results:
doc = documents[r.index]
doc.metadata["relevance_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs

View File

@@ -29,8 +29,19 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
def __init__(self, window_size: int = 3, **params):
super().__init__(window_size=window_size, **params)
def __init__(
self,
window_size: int = 3,
window_metadata_key: str = "window",
original_text_metadata_key: str = "original_text",
**params,
):
super().__init__(
window_size=window_size,
window_metadata_key=window_metadata_key,
original_text_metadata_key=original_text_metadata_key,
**params,
)
def _get_li_class(self):
from llama_index.node_parser import SentenceWindowNodeParser

View File

@@ -62,7 +62,7 @@ class VectorIndexing(BaseIndexing):
embeddings = self.embedding(input_)
self.vector_store.add(
embeddings=embeddings,
ids=[t.id_ for t in input_],
ids=[t.doc_id for t in input_],
)
if self.doc_store:
self.doc_store.add(input_)
@@ -99,7 +99,7 @@ class VectorRetrieval(BaseRetrieval):
)
emb: list[float] = self.embedding(text)[0].embedding
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs)
docs = self.doc_store.get(ids)
result = [
RetrievedDocument(**doc.to_dict(), score=score)