Refactor the index component and update the MVP insurance accordingly (#90)
Refactor the `kotaemon/pipelines` module to `kotaemon/indices`. Create the VectorIndex. Note: currently I place `qa` to be inside `kotaemon/indices` since at the moment we only have `qa` in RAG. At the same time, I think `qa` can be an independent module in `kotaemon/qa`. Since this can be changed later, I still go at the 1st option for now to observe if we can change it later.
This commit is contained in:
committed by
GitHub
parent
8e3a1d193f
commit
e34b1e4c6d
@@ -0,0 +1,3 @@
|
||||
from .vectorindex import VectorIndexing, VectorRetrieval
|
||||
|
||||
__all__ = ["VectorIndexing", "VectorRetrieval"]
|
||||
|
@@ -5,7 +5,7 @@ from typing import Any, Type
|
||||
|
||||
from llama_index.node_parser.interface import NodeParser
|
||||
|
||||
from ..base import BaseComponent, Document
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
|
||||
|
||||
class DocTransformer(BaseComponent):
|
||||
@@ -26,7 +26,7 @@ class DocTransformer(BaseComponent):
|
||||
...
|
||||
|
||||
|
||||
class LlamaIndexMixin:
|
||||
class LlamaIndexDocTransformerMixin:
|
||||
"""Allow automatically wrapping a Llama-index component into kotaemon component
|
||||
|
||||
Example:
|
||||
@@ -70,3 +70,23 @@ class LlamaIndexMixin:
|
||||
"""
|
||||
docs = self._obj(documents, **kwargs) # type: ignore
|
||||
return [Document.from_dict(doc.to_dict()) for doc in docs]
|
||||
|
||||
|
||||
class BaseIndexing(BaseComponent):
|
||||
"""Define the base interface for indexing pipeline"""
|
||||
|
||||
def to_retrieval_pipeline(self, **kwargs):
|
||||
"""Convert the indexing pipeline to a retrieval pipeline"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_qa_pipeline(self, **kwargs):
|
||||
"""Convert the indexing pipeline to a QA pipeline"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseRetrieval(BaseComponent):
|
||||
"""Define the base interface for retrieval pipeline"""
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs) -> list[RetrievedDocument]:
|
||||
...
|
||||
|
@@ -1,18 +1,18 @@
|
||||
from ..base import DocTransformer, LlamaIndexMixin
|
||||
from ..base import DocTransformer, LlamaIndexDocTransformerMixin
|
||||
|
||||
|
||||
class BaseDocParser(DocTransformer):
|
||||
...
|
||||
|
||||
|
||||
class TitleExtractor(LlamaIndexMixin, BaseDocParser):
|
||||
class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
|
||||
def _get_li_class(self):
|
||||
from llama_index.extractors import TitleExtractor
|
||||
|
||||
return TitleExtractor
|
||||
|
||||
|
||||
class SummaryExtractor(LlamaIndexMixin, BaseDocParser):
|
||||
class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
|
||||
def _get_li_class(self):
|
||||
from llama_index.extractors import SummaryExtractor
|
||||
|
||||
|
3
knowledgehub/indices/ingests/__init__.py
Normal file
3
knowledgehub/indices/ingests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .files import DocumentIngestor
|
||||
|
||||
__all__ = ["DocumentIngestor"]
|
75
knowledgehub/indices/ingests/files.py
Normal file
75
knowledgehub/indices/ingests/files.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from pathlib import Path
|
||||
|
||||
from llama_index.readers.base import BaseReader
|
||||
from theflow import Param
|
||||
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
from kotaemon.indices.extractors import BaseDocParser
|
||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||
from kotaemon.loaders import (
|
||||
AutoReader,
|
||||
DirectoryReader,
|
||||
MathpixPDFReader,
|
||||
OCRReader,
|
||||
PandasExcelReader,
|
||||
)
|
||||
|
||||
|
||||
class DocumentIngestor(BaseComponent):
|
||||
"""Ingest common office document types into Document for indexing
|
||||
|
||||
Document types:
|
||||
- pdf
|
||||
- xlsx
|
||||
- docx
|
||||
"""
|
||||
|
||||
pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
|
||||
doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: [])
|
||||
text_splitter: BaseSplitter = TokenSplitter.withx(
|
||||
chunk_size=1024,
|
||||
chunk_overlap=256,
|
||||
)
|
||||
|
||||
def _get_reader(self, input_files: list[str | Path]):
|
||||
"""Get appropriate readers for the input files based on file extension"""
|
||||
file_extractor: dict[str, AutoReader | BaseReader] = {
|
||||
".xlsx": PandasExcelReader(),
|
||||
}
|
||||
|
||||
if self.pdf_mode == "normal":
|
||||
file_extractor[".pdf"] = AutoReader("UnstructuredReader")
|
||||
elif self.pdf_mode == "ocr":
|
||||
file_extractor[".pdf"] = OCRReader()
|
||||
else:
|
||||
file_extractor[".pdf"] = MathpixPDFReader()
|
||||
|
||||
main_reader = DirectoryReader(
|
||||
input_files=input_files,
|
||||
file_extractor=file_extractor,
|
||||
)
|
||||
|
||||
return main_reader
|
||||
|
||||
def run(self, file_paths: list[str | Path] | str | Path) -> list[Document]:
|
||||
"""Ingest the file paths into Document
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths or a single file path
|
||||
|
||||
Returns:
|
||||
list of parsed Documents
|
||||
"""
|
||||
if not isinstance(file_paths, list):
|
||||
file_paths = [file_paths]
|
||||
|
||||
documents = self._get_reader(input_files=file_paths)()
|
||||
nodes = self.text_splitter(documents)
|
||||
self.log_progress(".num_docs", num_docs=len(nodes))
|
||||
|
||||
# document parsers call
|
||||
if self.doc_parsers:
|
||||
for parser in self.doc_parsers:
|
||||
nodes = parser(nodes)
|
||||
|
||||
return nodes
|
7
knowledgehub/indices/qa/__init__.py
Normal file
7
knowledgehub/indices/qa/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .citation import CitationPipeline
|
||||
from .text_based import CitationQAPipeline
|
||||
|
||||
__all__ = [
|
||||
"CitationPipeline",
|
||||
"CitationQAPipeline",
|
||||
]
|
106
knowledgehub/indices/qa/citation.py
Normal file
106
knowledgehub/indices/qa/citation.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import Iterator, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base.schema import HumanMessage, SystemMessage
|
||||
from kotaemon.llms import BaseLLM
|
||||
|
||||
|
||||
class FactWithEvidence(BaseModel):
|
||||
"""Class representing a single statement.
|
||||
|
||||
Each fact has a body and a list of sources.
|
||||
If there are multiple facts make sure to break them apart
|
||||
such that each one only uses a set of sources that are relevant to it.
|
||||
"""
|
||||
|
||||
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||
substring_quote: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Each source should be a direct quote from the context, "
|
||||
"as a substring of the original content"
|
||||
),
|
||||
)
|
||||
|
||||
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
|
||||
import regex
|
||||
|
||||
minor = quote
|
||||
major = context
|
||||
|
||||
errs_ = 0
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
while s is None and errs_ <= errs:
|
||||
errs_ += 1
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
|
||||
if s is not None:
|
||||
yield from s.spans()
|
||||
|
||||
def get_spans(self, context: str) -> Iterator[str]:
|
||||
for quote in self.substring_quote:
|
||||
yield from self._get_span(quote, context)
|
||||
|
||||
|
||||
class QuestionAnswer(BaseModel):
|
||||
"""A question and its answer as a list of facts each one should have a source.
|
||||
each sentence contains a body and a list of sources."""
|
||||
|
||||
question: str = Field(..., description="Question that was asked")
|
||||
answer: List[FactWithEvidence] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Body of the answer, each fact should be "
|
||||
"its separate object with a body and a list of sources"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CitationPipeline(BaseComponent):
|
||||
"""Citation pipeline to extract cited evidences from source
|
||||
(based on input question)"""
|
||||
|
||||
llm: BaseLLM
|
||||
|
||||
def run(
|
||||
self,
|
||||
context: str,
|
||||
question: str,
|
||||
) -> QuestionAnswer:
|
||||
schema = QuestionAnswer.schema()
|
||||
function = {
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
"parameters": schema,
|
||||
}
|
||||
llm_kwargs = {
|
||||
"functions": [function],
|
||||
"function_call": {"name": function["name"]},
|
||||
}
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content=(
|
||||
"You are a world class algorithm to answer "
|
||||
"questions with correct and exact citations."
|
||||
)
|
||||
),
|
||||
HumanMessage(content="Answer question using the following context"),
|
||||
HumanMessage(content=context),
|
||||
HumanMessage(content=f"Question: {question}"),
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Tips: Make sure to cite your sources, "
|
||||
"and use the exact words from the context."
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
llm_output = self.llm(messages, **llm_kwargs)
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
|
||||
return output
|
62
knowledgehub/indices/qa/text_based.py
Normal file
62
knowledgehub/indices/qa/text_based.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
|
||||
|
||||
from .citation import CitationPipeline
|
||||
|
||||
|
||||
class CitationQAPipeline(BaseComponent):
|
||||
"""Answering question from a text corpus with citation"""
|
||||
|
||||
qa_prompt_template: PromptTemplate = PromptTemplate(
|
||||
'Answer the following question: "{question}". '
|
||||
"The context is: \n{context}\nAnswer: "
|
||||
)
|
||||
llm: BaseLLM = AzureChatOpenAI.withx(
|
||||
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version="2023-07-01-preview",
|
||||
deployment_name="dummy-q2-16k",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
|
||||
def _format_doc_text(self, text: str) -> str:
|
||||
"""Format the text of each document"""
|
||||
return text.replace("\n", " ")
|
||||
|
||||
def _format_retrieved_context(self, documents: list[RetrievedDocument]) -> str:
|
||||
"""Format the texts between all documents"""
|
||||
matched_texts: list[str] = [
|
||||
self._format_doc_text(doc.text) for doc in documents
|
||||
]
|
||||
return "\n\n".join(matched_texts)
|
||||
|
||||
def run(
|
||||
self,
|
||||
question: str,
|
||||
documents: list[RetrievedDocument],
|
||||
use_citation: bool = False,
|
||||
**kwargs
|
||||
) -> Document:
|
||||
# retrieve relevant documents as context
|
||||
context = self._format_retrieved_context(documents)
|
||||
self.log_progress(".context", context=context)
|
||||
|
||||
# generate the answer
|
||||
prompt = self.qa_prompt_template.populate(
|
||||
context=context,
|
||||
question=question,
|
||||
)
|
||||
self.log_progress(".prompt", prompt=prompt)
|
||||
answer_text = self.llm(prompt).text
|
||||
if use_citation:
|
||||
# run citation pipeline
|
||||
citation_pipeline = CitationPipeline(llm=self.llm)
|
||||
citation = citation_pipeline(context=context, question=question)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
answer = Document(text=answer_text, metadata={"citation": citation})
|
||||
return answer
|
@@ -1,4 +1,4 @@
|
||||
from ..base import DocTransformer, LlamaIndexMixin
|
||||
from ..base import DocTransformer, LlamaIndexDocTransformerMixin
|
||||
|
||||
|
||||
class BaseSplitter(DocTransformer):
|
||||
@@ -7,14 +7,14 @@ class BaseSplitter(DocTransformer):
|
||||
...
|
||||
|
||||
|
||||
class TokenSplitter(LlamaIndexMixin, BaseSplitter):
|
||||
class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||
def _get_li_class(self):
|
||||
from llama_index.text_splitter import TokenTextSplitter
|
||||
|
||||
return TokenTextSplitter
|
||||
|
||||
|
||||
class SentenceWindowSplitter(LlamaIndexMixin, BaseSplitter):
|
||||
class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||
def _get_li_class(self):
|
||||
from llama_index.node_parser import SentenceWindowNodeParser
|
||||
|
||||
|
185
knowledgehub/indices/vectorindex.py
Normal file
185
knowledgehub/indices/vectorindex.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, cast
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
from kotaemon.embeddings import BaseEmbeddings
|
||||
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
|
||||
|
||||
from .base import BaseIndexing, BaseRetrieval
|
||||
from .rankings import BaseReranking
|
||||
|
||||
VECTOR_STORE_FNAME = "vectorstore"
|
||||
DOC_STORE_FNAME = "docstore"
|
||||
|
||||
|
||||
class VectorIndexing(BaseIndexing):
|
||||
"""Ingest the document, run through the embedding, and store the embedding in a
|
||||
vector store.
|
||||
|
||||
This pipeline supports the following set of inputs:
|
||||
- List of documents
|
||||
- List of texts
|
||||
"""
|
||||
|
||||
vector_store: BaseVectorStore
|
||||
doc_store: Optional[BaseDocumentStore] = None
|
||||
embedding: BaseEmbeddings
|
||||
|
||||
def to_retrieval_pipeline(self, *args, **kwargs):
|
||||
"""Convert the indexing pipeline to a retrieval pipeline"""
|
||||
return VectorRetrieval(
|
||||
vector_store=self.vector_store,
|
||||
doc_store=self.doc_store,
|
||||
embedding=self.embedding,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def to_qa_pipeline(self, *args, **kwargs):
|
||||
from .qa import CitationQAPipeline
|
||||
|
||||
return TextVectorQA(
|
||||
retrieving_pipeline=self.to_retrieval_pipeline(**kwargs),
|
||||
qa_pipeline=CitationQAPipeline(**kwargs),
|
||||
)
|
||||
|
||||
def run(self, text: str | list[str] | Document | list[Document]) -> None:
|
||||
input_: list[Document] = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
for item in cast(list, text):
|
||||
if isinstance(item, str):
|
||||
input_.append(Document(text=item, id_=str(uuid.uuid4())))
|
||||
elif isinstance(item, Document):
|
||||
input_.append(item)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(item)}, should be str or Document"
|
||||
)
|
||||
|
||||
embeddings = self.embedding(input_)
|
||||
self.vector_store.add(
|
||||
embeddings=embeddings,
|
||||
ids=[t.id_ for t in input_],
|
||||
)
|
||||
if self.doc_store:
|
||||
self.doc_store.add(input_)
|
||||
|
||||
def save(
|
||||
self,
|
||||
path: str | Path,
|
||||
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||
docstore_fname: str = DOC_STORE_FNAME,
|
||||
):
|
||||
"""Save the whole state of the indexing pipeline vector store and all
|
||||
necessary information to disk
|
||||
|
||||
Args:
|
||||
path (str): path to save the state
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.vector_store.save(path / vectorstore_fname)
|
||||
if self.doc_store:
|
||||
self.doc_store.save(path / docstore_fname)
|
||||
|
||||
def load(
|
||||
self,
|
||||
path: str | Path,
|
||||
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||
docstore_fname: str = DOC_STORE_FNAME,
|
||||
):
|
||||
"""Load all information from disk to an object"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.vector_store.load(path / vectorstore_fname)
|
||||
if self.doc_store:
|
||||
self.doc_store.load(path / docstore_fname)
|
||||
|
||||
|
||||
class VectorRetrieval(BaseRetrieval):
|
||||
"""Retrieve list of documents from vector store"""
|
||||
|
||||
vector_store: BaseVectorStore
|
||||
doc_store: Optional[BaseDocumentStore] = None
|
||||
embedding: BaseEmbeddings
|
||||
rerankers: Sequence[BaseReranking] = []
|
||||
top_k: int = 1
|
||||
|
||||
def run(
|
||||
self, text: str | Document, top_k: Optional[int] = None, **kwargs
|
||||
) -> list[RetrievedDocument]:
|
||||
"""Retrieve a list of documents from vector store
|
||||
|
||||
Args:
|
||||
text: the text to retrieve similar documents
|
||||
top_k: number of top similar documents to return
|
||||
|
||||
Returns:
|
||||
list[RetrievedDocument]: list of retrieved documents
|
||||
"""
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
|
||||
if self.doc_store is None:
|
||||
raise ValueError(
|
||||
"doc_store is not provided. Please provide a doc_store to "
|
||||
"retrieve the documents"
|
||||
)
|
||||
|
||||
emb: list[float] = self.embedding(text)[0].embedding
|
||||
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
|
||||
docs = self.doc_store.get(ids)
|
||||
result = [
|
||||
RetrievedDocument(**doc.to_dict(), score=score)
|
||||
for doc, score in zip(docs, scores)
|
||||
]
|
||||
# use additional reranker to re-order the document list
|
||||
if self.rerankers:
|
||||
for reranker in self.rerankers:
|
||||
result = reranker(documents=result, query=text)
|
||||
|
||||
return result
|
||||
|
||||
def save(
|
||||
self,
|
||||
path: str | Path,
|
||||
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||
docstore_fname: str = DOC_STORE_FNAME,
|
||||
):
|
||||
"""Save the whole state of the indexing pipeline vector store and all
|
||||
necessary information to disk
|
||||
|
||||
Args:
|
||||
path (str): path to save the state
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.vector_store.save(path / vectorstore_fname)
|
||||
if self.doc_store:
|
||||
self.doc_store.save(path / docstore_fname)
|
||||
|
||||
def load(
|
||||
self,
|
||||
path: str | Path,
|
||||
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||
docstore_fname: str = DOC_STORE_FNAME,
|
||||
):
|
||||
"""Load all information from disk to an object"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.vector_store.load(path / vectorstore_fname)
|
||||
if self.doc_store:
|
||||
self.doc_store.load(path / docstore_fname)
|
||||
|
||||
|
||||
class TextVectorQA(BaseComponent):
|
||||
retrieving_pipeline: BaseRetrieval
|
||||
qa_pipeline: BaseComponent
|
||||
|
||||
def run(self, question, **kwargs):
|
||||
retrieved_documents = self.retrieving_pipeline(question, **kwargs)
|
||||
return self.qa_pipeline(question, retrieved_documents, **kwargs)
|
Reference in New Issue
Block a user