Relate the retrievers to the indexer

This commit is contained in:
trducng 2024-01-27 16:39:40 +07:00
parent 9b586466ff
commit c6637ca56e
5 changed files with 220 additions and 192 deletions

View File

@ -1,7 +1,13 @@
from kotaemon.base import BaseComponent
class BaseIndex(BaseComponent):
class BaseRetriever(BaseComponent):
pass
class BaseIndexing(BaseComponent):
"""The pipeline to index information into the data store"""
def get_user_settings(self) -> dict:
"""Get the user settings for indexing
@ -12,5 +18,8 @@ class BaseIndex(BaseComponent):
return {}
@classmethod
def get_pipeline(cls, setting: dict) -> "BaseIndex":
def get_pipeline(cls, settings: dict) -> "BaseIndexing":
raise NotImplementedError
def get_retrievers(self, settings: dict, **kwargs) -> list[BaseRetriever]:
raise NotImplementedError

View File

@ -1,16 +1,35 @@
from __future__ import annotations
import shutil
import warnings
from collections import defaultdict
from hashlib import sha256
from pathlib import Path
from typing import Optional
from ktem.components import embeddings, filestorage_path, get_docstore, get_vectorstore
from ktem.components import (
embeddings,
filestorage_path,
get_docstore,
get_vectorstore,
llms,
)
from ktem.db.models import Index, Source, SourceTargetRelation, engine
from ktem.indexing.base import BaseIndex
from ktem.indexing.base import BaseIndexing, BaseRetriever
from ktem.indexing.exceptions import FileExistsError
from kotaemon.indices import VectorIndexing
from kotaemon.base import RetrievedDocument
from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests import DocumentIngestor
from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking
from llama_index.vector_stores import (
FilterCondition,
FilterOperator,
MetadataFilter,
MetadataFilters,
)
from llama_index.vector_stores.types import VectorStoreQueryMode
from sqlmodel import Session, select
from theflow.settings import settings
USER_SETTINGS = {
"index_parser": {
@ -61,7 +80,109 @@ USER_SETTINGS = {
}
class IndexDocumentPipeline(BaseIndex):
class DocumentRetrievalPipeline(BaseRetriever):
"""Retrieve relevant document
Args:
vector_retrieval: the retrieval pipeline that return the relevant documents
given a text query
reranker: the reranking pipeline that re-rank and filter the retrieved
documents
get_extra_table: if True, for each retrieved document, the pipeline will look
for surrounding tables (e.g. within the page)
"""
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
doc_store=get_docstore(),
vector_store=get_vectorstore(),
embedding=embeddings.get_default(),
)
reranker: BaseReranking = CohereReranking.withx(
cohere_api_key=getattr(settings, "COHERE_API_KEY", "")
) >> LLMReranking.withx(llm=llms.get_lowest_cost())
get_extra_table: bool = False
def run(
self,
text: str,
top_k: int = 5,
mmr: bool = False,
doc_ids: Optional[list[str]] = None,
) -> list[RetrievedDocument]:
"""Retrieve document excerpts similar to the text
Args:
text: the text to retrieve similar documents
top_k: number of documents to retrieve
mmr: whether to use mmr to re-rank the documents
doc_ids: list of document ids to constraint the retrieval
"""
kwargs = {}
if doc_ids:
with Session(engine) as session:
stmt = select(Index).where(
Index.relation_type == SourceTargetRelation.VECTOR,
Index.source_id.in_(doc_ids), # type: ignore
)
results = session.exec(stmt)
vs_ids = [r.target_id for r in results.all()]
kwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
key="doc_id",
value=vs_id,
operator=FilterOperator.EQ,
)
for vs_id in vs_ids
],
condition=FilterCondition.OR,
)
if mmr:
# TODO: double check that llama-index MMR works correctly
kwargs["mode"] = VectorStoreQueryMode.MMR
kwargs["mmr_threshold"] = 0.5
# rerank
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
if self.get_from_path("reranker"):
docs = self.reranker(docs, query=text)
if not self.get_extra_table:
return docs
# retrieve extra nodes relate to table
table_pages = defaultdict(list)
retrieved_id = set([doc.doc_id for doc in docs])
for doc in docs:
if "page_label" not in doc.metadata:
continue
if "file_name" not in doc.metadata:
warnings.warn(
"file_name not in metadata while page_label is in metadata: "
f"{doc.metadata}"
)
table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"])
queries = [
{"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]}
for fn, pls in table_pages.items()
]
if queries:
extra_docs = self.vector_retrieval(
text="",
top_k=50,
where={"$or": queries},
)
for doc in extra_docs:
if doc.doc_id not in retrieved_id:
docs.append(doc)
return docs
class IndexDocumentPipeline(BaseIndexing):
"""Store the documents and index the content into vector store and doc store
Args:
@ -175,8 +296,29 @@ class IndexDocumentPipeline(BaseIndex):
return USER_SETTINGS
@classmethod
def get_pipeline(cls, setting) -> "IndexDocumentPipeline":
def get_pipeline(cls, settings) -> "IndexDocumentPipeline":
"""Get the pipeline based on the setting"""
obj = cls()
obj.file_ingestor.pdf_mode = setting["index.index_parser"]
obj.file_ingestor.pdf_mode = settings["index.index_parser"]
return obj
def get_retrievers(self, settings, **kwargs) -> list[BaseRetriever]:
"""Get retriever objects associated with the index
Args:
settings: the settings of the app
kwargs: other arguments
"""
retriever = DocumentRetrievalPipeline(
get_extra_table=settings["index.prioritize_table"]
)
if not settings["index.use_reranking"]:
retriever.reranker = None # type: ignore
kwargs = {
".top_k": int(settings["index.num_retrieval"]),
".mmr": settings["index.mmr"],
".doc_ids": kwargs.get("files", None),
}
retriever.set_run(kwargs, temp=True)
return [retriever]

View File

@ -36,6 +36,7 @@ class ChatPage(BasePage):
).then(
fn=chat_fn,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self.data_source.files,
self._app.settings_state,
@ -64,6 +65,7 @@ class ChatPage(BasePage):
).then(
fn=chat_fn,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self.data_source.files,
self._app.settings_state,

View File

@ -7,8 +7,7 @@ from typing import Optional
import gradio as gr
from ktem.components import llms, reasonings
from ktem.db.models import Conversation, Source, engine
from ktem.indexing.base import BaseIndex
from ktem.reasoning.simple import DocumentRetrievalPipeline
from ktem.indexing.base import BaseIndexing
from sqlmodel import Session, select
from theflow.settings import settings as app_settings
from theflow.utils.modules import import_dotted_string
@ -26,9 +25,15 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
the pipeline objects
"""
# get retrievers
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
retrievers = indexing_cls.get_pipeline(settings).get_retrievers(
settings, files=files
)
reasoning_mode = settings["reasoning.use"]
reasoning_cls = reasonings[reasoning_mode]
pipeline = reasoning_cls.get_pipeline(settings, files=files)
pipeline = reasoning_cls.get_pipeline(settings, retrievers, files=files)
if settings["reasoning.use"] in ["rewoo", "react"]:
from kotaemon.agents import ReactAgent, RewooAgent
@ -49,47 +54,38 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
from kotaemon.agents import LLMTool
tools.append(LLMTool(llm=llm))
elif tool == "docsearch":
from kotaemon.agents import ComponentTool
# elif tool == "docsearch":
# pass
filenames = ""
if files:
with Session(engine) as session:
statement = select(Source).where(
Source.id.in_(files) # type: ignore
)
results = session.exec(statement).all()
filenames = (
"The file names are: "
+ " ".join([result.name for result in results])
+ ". "
)
# filenames = ""
# if files:
# with Session(engine) as session:
# statement = select(Source).where(
# Source.id.in_(files) # type: ignore
# )
# results = session.exec(statement).all()
# filenames = (
# "The file names are: "
# + " ".join([result.name for result in results])
# + ". "
# )
retrieval_pipeline = DocumentRetrievalPipeline()
retrieval_pipeline.set_run(
{
".top_k": int(settings["retrieval_number"]),
".mmr": settings["retrieval_mmr"],
".doc_ids": files,
},
temp=True,
)
tool = ComponentTool(
name="docsearch",
description=(
"A vector store that searches for similar and "
"related content "
f"in a document. {filenames}"
"The result is a huge chunk of text related "
"to your search but can also "
"contain irrelevant info."
),
component=retrieval_pipeline,
postprocessor=lambda docs: "\n\n".join(
[doc.text.replace("\n", " ") for doc in docs]
),
)
tools.append(tool)
# tool = ComponentTool(
# name="docsearch",
# description=(
# "A vector store that searches for similar and "
# "related content "
# f"in a document. {filenames}"
# "The result is a huge chunk of text related "
# "to your search but can also "
# "contain irrelevant info."
# ),
# component=retrieval_pipeline,
# postprocessor=lambda docs: "\n\n".join(
# [doc.text.replace("\n", " ") for doc in docs]
# ),
# )
# tools.append(tool)
elif tool == "google":
from kotaemon.agents import GoogleSearchTool
@ -117,7 +113,7 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
return pipeline
async def chat_fn(chat_history, files, settings):
async def chat_fn(conversation_id, chat_history, files, settings):
"""Chat function"""
chat_input = chat_history[-1][0]
chat_history = chat_history[:-1]
@ -128,7 +124,7 @@ async def chat_fn(chat_history, files, settings):
pipeline = create_pipeline(settings, files)
pipeline.set_output_queue(queue)
asyncio.create_task(pipeline(chat_input, chat_history))
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
text, refs = "", ""
while True:
@ -207,7 +203,7 @@ def index_fn(files, reindex: bool, selected_files, settings):
gr.Info(f"Start indexing {len(files)} files...")
# get the pipeline
indexing_cls: BaseIndex = import_dotted_string(app_settings.KH_INDEX, safe=False)
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
indexing_pipeline = indexing_cls.get_pipeline(settings)
output_nodes, file_ids = indexing_pipeline(files, reindex=reindex)

View File

@ -1,13 +1,11 @@
import asyncio
import logging
import warnings
from collections import defaultdict
from functools import partial
from typing import Optional
import tiktoken
from ktem.components import embeddings, get_docstore, get_vectorstore, llms
from ktem.db.models import Index, SourceTargetRelation, engine
from ktem.components import llms
from ktem.indexing.base import BaseRetriever
from kotaemon.base import (
BaseComponent,
Document,
@ -16,126 +14,13 @@ from kotaemon.base import (
RetrievedDocument,
SystemMessage,
)
from kotaemon.indices import VectorRetrieval
from kotaemon.indices.qa.citation import CitationPipeline
from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate
from llama_index.vector_stores import (
FilterCondition,
FilterOperator,
MetadataFilter,
MetadataFilters,
)
from llama_index.vector_stores.types import VectorStoreQueryMode
from sqlmodel import Session, select
from theflow.settings import settings
logger = logging.getLogger(__name__)
class DocumentRetrievalPipeline(BaseComponent):
"""Retrieve relevant document
Args:
vector_retrieval: the retrieval pipeline that return the relevant documents
given a text query
reranker: the reranking pipeline that re-rank and filter the retrieved
documents
get_extra_table: if True, for each retrieved document, the pipeline will look
for surrounding tables (e.g. within the page)
"""
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
doc_store=get_docstore(),
vector_store=get_vectorstore(),
embedding=embeddings.get_default(),
)
reranker: BaseReranking = CohereReranking.withx(
cohere_api_key=getattr(settings, "COHERE_API_KEY", "")
) >> LLMReranking.withx(llm=llms.get_lowest_cost())
get_extra_table: bool = False
def run(
self,
text: str,
top_k: int = 5,
mmr: bool = False,
doc_ids: Optional[list[str]] = None,
) -> list[RetrievedDocument]:
"""Retrieve document excerpts similar to the text
Args:
text: the text to retrieve similar documents
top_k: number of documents to retrieve
mmr: whether to use mmr to re-rank the documents
doc_ids: list of document ids to constraint the retrieval
"""
kwargs = {}
if doc_ids:
with Session(engine) as session:
stmt = select(Index).where(
Index.relation_type == SourceTargetRelation.VECTOR,
Index.source_id.in_(doc_ids), # type: ignore
)
results = session.exec(stmt)
vs_ids = [r.target_id for r in results.all()]
kwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
key="doc_id",
value=vs_id,
operator=FilterOperator.EQ,
)
for vs_id in vs_ids
],
condition=FilterCondition.OR,
)
if mmr:
# TODO: double check that llama-index MMR works correctly
kwargs["mode"] = VectorStoreQueryMode.MMR
kwargs["mmr_threshold"] = 0.5
# rerank
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
if self.get_from_path("reranker"):
docs = self.reranker(docs, query=text)
if not self.get_extra_table:
return docs
# retrieve extra nodes relate to table
table_pages = defaultdict(list)
retrieved_id = set([doc.doc_id for doc in docs])
for doc in docs:
if "page_label" not in doc.metadata:
continue
if "file_name" not in doc.metadata:
warnings.warn(
"file_name not in metadata while page_label is in metadata: "
f"{doc.metadata}"
)
table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"])
queries = [
{"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]}
for fn, pls in table_pages.items()
]
if queries:
extra_docs = self.vector_retrieval(
text="",
top_k=50,
where={"$or": queries},
)
for doc in extra_docs:
if doc.doc_id not in retrieved_id:
docs.append(doc)
return docs
class PrepareEvidencePipeline(BaseComponent):
"""Prepare the evidence text from the list of retrieved documents
@ -338,22 +223,22 @@ class FullQAPipeline(BaseComponent):
allow_extra = True
params_publish = True
retrieval_pipeline: DocumentRetrievalPipeline = DocumentRetrievalPipeline.withx()
retrievers: list[BaseRetriever]
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
async def run( # type: ignore
self, question: str, history: list, **kwargs # type: ignore
self, message: str, cid: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
docs = self.retrieval_pipeline(text=question)
docs = []
for retriever in self.retrievers:
docs.extend(retriever(text=message))
evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = await self.answering_pipeline(
question=question, evidence=evidence, evidence_mode=evidence_mode
question=message, evidence=evidence, evidence_mode=evidence_mode
)
# prepare citation
from collections import defaultdict
spans = defaultdict(list)
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
@ -369,6 +254,7 @@ class FullQAPipeline(BaseComponent):
break
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
for id, ss in spans.items():
if not ss:
continue
@ -391,31 +277,24 @@ class FullQAPipeline(BaseComponent):
)
}
)
lack_evidence = False
if lack_evidence:
self.report_output({"evidence": "No evidence found"})
self.report_output(None)
return answer
@classmethod
def get_pipeline(cls, settings, **kwargs):
def get_pipeline(cls, settings, retrievers, **kwargs):
"""Get the reasoning pipeline
Need a base pipeline implementation. Currently the drawback is that we want to
treat the retrievers as tools. Hence, the reasoning pipelie should just take
the already initiated tools (retrievers), and do not need to set such logic
here.
Args:
settings: the settings for the pipeline
retrievers: the retrievers to use
"""
pipeline = FullQAPipeline(get_extra_table=settings["index.prioritize_table"])
if not settings["index.use_reranking"]:
pipeline.retrieval_pipeline.reranker = None # type: ignore
pipeline = FullQAPipeline(retrievers=retrievers)
pipeline.answering_pipeline.llm = llms.get_highest_accuracy()
kwargs = {
".retrieval_pipeline.top_k": int(settings["index.num_retrieval"]),
".retrieval_pipeline.mmr": settings["index.mmr"],
".retrieval_pipeline.doc_ids": kwargs.get("files", None),
}
pipeline.set_run(kwargs, temp=True)
return pipeline
@classmethod