diff --git a/libs/ktem/ktem/indexing/base.py b/libs/ktem/ktem/indexing/base.py index 0057b90..787a154 100644 --- a/libs/ktem/ktem/indexing/base.py +++ b/libs/ktem/ktem/indexing/base.py @@ -1,7 +1,13 @@ from kotaemon.base import BaseComponent -class BaseIndex(BaseComponent): +class BaseRetriever(BaseComponent): + pass + + +class BaseIndexing(BaseComponent): + """The pipeline to index information into the data store""" + def get_user_settings(self) -> dict: """Get the user settings for indexing @@ -12,5 +18,8 @@ class BaseIndex(BaseComponent): return {} @classmethod - def get_pipeline(cls, setting: dict) -> "BaseIndex": + def get_pipeline(cls, settings: dict) -> "BaseIndexing": + raise NotImplementedError + + def get_retrievers(self, settings: dict, **kwargs) -> list[BaseRetriever]: raise NotImplementedError diff --git a/libs/ktem/ktem/indexing/file.py b/libs/ktem/ktem/indexing/file.py index daf0feb..fc5ef1a 100644 --- a/libs/ktem/ktem/indexing/file.py +++ b/libs/ktem/ktem/indexing/file.py @@ -1,16 +1,35 @@ from __future__ import annotations import shutil +import warnings +from collections import defaultdict from hashlib import sha256 from pathlib import Path +from typing import Optional -from ktem.components import embeddings, filestorage_path, get_docstore, get_vectorstore +from ktem.components import ( + embeddings, + filestorage_path, + get_docstore, + get_vectorstore, + llms, +) from ktem.db.models import Index, Source, SourceTargetRelation, engine -from ktem.indexing.base import BaseIndex +from ktem.indexing.base import BaseIndexing, BaseRetriever from ktem.indexing.exceptions import FileExistsError -from kotaemon.indices import VectorIndexing +from kotaemon.base import RetrievedDocument +from kotaemon.indices import VectorIndexing, VectorRetrieval from kotaemon.indices.ingests import DocumentIngestor +from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking +from llama_index.vector_stores import ( + FilterCondition, + FilterOperator, + MetadataFilter, + MetadataFilters, +) +from llama_index.vector_stores.types import VectorStoreQueryMode from sqlmodel import Session, select +from theflow.settings import settings USER_SETTINGS = { "index_parser": { @@ -61,7 +80,109 @@ USER_SETTINGS = { } -class IndexDocumentPipeline(BaseIndex): +class DocumentRetrievalPipeline(BaseRetriever): + """Retrieve relevant document + + Args: + vector_retrieval: the retrieval pipeline that return the relevant documents + given a text query + reranker: the reranking pipeline that re-rank and filter the retrieved + documents + get_extra_table: if True, for each retrieved document, the pipeline will look + for surrounding tables (e.g. within the page) + """ + + vector_retrieval: VectorRetrieval = VectorRetrieval.withx( + doc_store=get_docstore(), + vector_store=get_vectorstore(), + embedding=embeddings.get_default(), + ) + reranker: BaseReranking = CohereReranking.withx( + cohere_api_key=getattr(settings, "COHERE_API_KEY", "") + ) >> LLMReranking.withx(llm=llms.get_lowest_cost()) + get_extra_table: bool = False + + def run( + self, + text: str, + top_k: int = 5, + mmr: bool = False, + doc_ids: Optional[list[str]] = None, + ) -> list[RetrievedDocument]: + """Retrieve document excerpts similar to the text + + Args: + text: the text to retrieve similar documents + top_k: number of documents to retrieve + mmr: whether to use mmr to re-rank the documents + doc_ids: list of document ids to constraint the retrieval + """ + kwargs = {} + if doc_ids: + with Session(engine) as session: + stmt = select(Index).where( + Index.relation_type == SourceTargetRelation.VECTOR, + Index.source_id.in_(doc_ids), # type: ignore + ) + results = session.exec(stmt) + vs_ids = [r.target_id for r in results.all()] + + kwargs["filters"] = MetadataFilters( + filters=[ + MetadataFilter( + key="doc_id", + value=vs_id, + operator=FilterOperator.EQ, + ) + for vs_id in vs_ids + ], + condition=FilterCondition.OR, + ) + + if mmr: + # TODO: double check that llama-index MMR works correctly + kwargs["mode"] = VectorStoreQueryMode.MMR + kwargs["mmr_threshold"] = 0.5 + + # rerank + docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs) + if self.get_from_path("reranker"): + docs = self.reranker(docs, query=text) + + if not self.get_extra_table: + return docs + + # retrieve extra nodes relate to table + table_pages = defaultdict(list) + retrieved_id = set([doc.doc_id for doc in docs]) + for doc in docs: + if "page_label" not in doc.metadata: + continue + if "file_name" not in doc.metadata: + warnings.warn( + "file_name not in metadata while page_label is in metadata: " + f"{doc.metadata}" + ) + table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"]) + + queries = [ + {"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]} + for fn, pls in table_pages.items() + ] + if queries: + extra_docs = self.vector_retrieval( + text="", + top_k=50, + where={"$or": queries}, + ) + for doc in extra_docs: + if doc.doc_id not in retrieved_id: + docs.append(doc) + + return docs + + +class IndexDocumentPipeline(BaseIndexing): """Store the documents and index the content into vector store and doc store Args: @@ -175,8 +296,29 @@ class IndexDocumentPipeline(BaseIndex): return USER_SETTINGS @classmethod - def get_pipeline(cls, setting) -> "IndexDocumentPipeline": + def get_pipeline(cls, settings) -> "IndexDocumentPipeline": """Get the pipeline based on the setting""" obj = cls() - obj.file_ingestor.pdf_mode = setting["index.index_parser"] + obj.file_ingestor.pdf_mode = settings["index.index_parser"] return obj + + def get_retrievers(self, settings, **kwargs) -> list[BaseRetriever]: + """Get retriever objects associated with the index + + Args: + settings: the settings of the app + kwargs: other arguments + """ + retriever = DocumentRetrievalPipeline( + get_extra_table=settings["index.prioritize_table"] + ) + if not settings["index.use_reranking"]: + retriever.reranker = None # type: ignore + + kwargs = { + ".top_k": int(settings["index.num_retrieval"]), + ".mmr": settings["index.mmr"], + ".doc_ids": kwargs.get("files", None), + } + retriever.set_run(kwargs, temp=True) + return [retriever] diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index fd69b6b..ea97a35 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -36,6 +36,7 @@ class ChatPage(BasePage): ).then( fn=chat_fn, inputs=[ + self.chat_control.conversation_id, self.chat_panel.chatbot, self.data_source.files, self._app.settings_state, @@ -64,6 +65,7 @@ class ChatPage(BasePage): ).then( fn=chat_fn, inputs=[ + self.chat_control.conversation_id, self.chat_panel.chatbot, self.data_source.files, self._app.settings_state, diff --git a/libs/ktem/ktem/pages/chat/events.py b/libs/ktem/ktem/pages/chat/events.py index fd9d595..2126e15 100644 --- a/libs/ktem/ktem/pages/chat/events.py +++ b/libs/ktem/ktem/pages/chat/events.py @@ -7,8 +7,7 @@ from typing import Optional import gradio as gr from ktem.components import llms, reasonings from ktem.db.models import Conversation, Source, engine -from ktem.indexing.base import BaseIndex -from ktem.reasoning.simple import DocumentRetrievalPipeline +from ktem.indexing.base import BaseIndexing from sqlmodel import Session, select from theflow.settings import settings as app_settings from theflow.utils.modules import import_dotted_string @@ -26,9 +25,15 @@ def create_pipeline(settings: dict, files: Optional[list] = None): the pipeline objects """ + # get retrievers + indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False) + retrievers = indexing_cls.get_pipeline(settings).get_retrievers( + settings, files=files + ) + reasoning_mode = settings["reasoning.use"] reasoning_cls = reasonings[reasoning_mode] - pipeline = reasoning_cls.get_pipeline(settings, files=files) + pipeline = reasoning_cls.get_pipeline(settings, retrievers, files=files) if settings["reasoning.use"] in ["rewoo", "react"]: from kotaemon.agents import ReactAgent, RewooAgent @@ -49,47 +54,38 @@ def create_pipeline(settings: dict, files: Optional[list] = None): from kotaemon.agents import LLMTool tools.append(LLMTool(llm=llm)) - elif tool == "docsearch": - from kotaemon.agents import ComponentTool + # elif tool == "docsearch": + # pass - filenames = "" - if files: - with Session(engine) as session: - statement = select(Source).where( - Source.id.in_(files) # type: ignore - ) - results = session.exec(statement).all() - filenames = ( - "The file names are: " - + " ".join([result.name for result in results]) - + ". " - ) + # filenames = "" + # if files: + # with Session(engine) as session: + # statement = select(Source).where( + # Source.id.in_(files) # type: ignore + # ) + # results = session.exec(statement).all() + # filenames = ( + # "The file names are: " + # + " ".join([result.name for result in results]) + # + ". " + # ) - retrieval_pipeline = DocumentRetrievalPipeline() - retrieval_pipeline.set_run( - { - ".top_k": int(settings["retrieval_number"]), - ".mmr": settings["retrieval_mmr"], - ".doc_ids": files, - }, - temp=True, - ) - tool = ComponentTool( - name="docsearch", - description=( - "A vector store that searches for similar and " - "related content " - f"in a document. {filenames}" - "The result is a huge chunk of text related " - "to your search but can also " - "contain irrelevant info." - ), - component=retrieval_pipeline, - postprocessor=lambda docs: "\n\n".join( - [doc.text.replace("\n", " ") for doc in docs] - ), - ) - tools.append(tool) + # tool = ComponentTool( + # name="docsearch", + # description=( + # "A vector store that searches for similar and " + # "related content " + # f"in a document. {filenames}" + # "The result is a huge chunk of text related " + # "to your search but can also " + # "contain irrelevant info." + # ), + # component=retrieval_pipeline, + # postprocessor=lambda docs: "\n\n".join( + # [doc.text.replace("\n", " ") for doc in docs] + # ), + # ) + # tools.append(tool) elif tool == "google": from kotaemon.agents import GoogleSearchTool @@ -117,7 +113,7 @@ def create_pipeline(settings: dict, files: Optional[list] = None): return pipeline -async def chat_fn(chat_history, files, settings): +async def chat_fn(conversation_id, chat_history, files, settings): """Chat function""" chat_input = chat_history[-1][0] chat_history = chat_history[:-1] @@ -128,7 +124,7 @@ async def chat_fn(chat_history, files, settings): pipeline = create_pipeline(settings, files) pipeline.set_output_queue(queue) - asyncio.create_task(pipeline(chat_input, chat_history)) + asyncio.create_task(pipeline(chat_input, conversation_id, chat_history)) text, refs = "", "" while True: @@ -207,7 +203,7 @@ def index_fn(files, reindex: bool, selected_files, settings): gr.Info(f"Start indexing {len(files)} files...") # get the pipeline - indexing_cls: BaseIndex = import_dotted_string(app_settings.KH_INDEX, safe=False) + indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False) indexing_pipeline = indexing_cls.get_pipeline(settings) output_nodes, file_ids = indexing_pipeline(files, reindex=reindex) diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 06cac96..c41f41f 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -1,13 +1,11 @@ import asyncio import logging -import warnings from collections import defaultdict from functools import partial -from typing import Optional import tiktoken -from ktem.components import embeddings, get_docstore, get_vectorstore, llms -from ktem.db.models import Index, SourceTargetRelation, engine +from ktem.components import llms +from ktem.indexing.base import BaseRetriever from kotaemon.base import ( BaseComponent, Document, @@ -16,126 +14,13 @@ from kotaemon.base import ( RetrievedDocument, SystemMessage, ) -from kotaemon.indices import VectorRetrieval from kotaemon.indices.qa.citation import CitationPipeline -from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking from kotaemon.indices.splitters import TokenSplitter from kotaemon.llms import ChatLLM, PromptTemplate -from llama_index.vector_stores import ( - FilterCondition, - FilterOperator, - MetadataFilter, - MetadataFilters, -) -from llama_index.vector_stores.types import VectorStoreQueryMode -from sqlmodel import Session, select -from theflow.settings import settings logger = logging.getLogger(__name__) -class DocumentRetrievalPipeline(BaseComponent): - """Retrieve relevant document - - Args: - vector_retrieval: the retrieval pipeline that return the relevant documents - given a text query - reranker: the reranking pipeline that re-rank and filter the retrieved - documents - get_extra_table: if True, for each retrieved document, the pipeline will look - for surrounding tables (e.g. within the page) - """ - - vector_retrieval: VectorRetrieval = VectorRetrieval.withx( - doc_store=get_docstore(), - vector_store=get_vectorstore(), - embedding=embeddings.get_default(), - ) - reranker: BaseReranking = CohereReranking.withx( - cohere_api_key=getattr(settings, "COHERE_API_KEY", "") - ) >> LLMReranking.withx(llm=llms.get_lowest_cost()) - get_extra_table: bool = False - - def run( - self, - text: str, - top_k: int = 5, - mmr: bool = False, - doc_ids: Optional[list[str]] = None, - ) -> list[RetrievedDocument]: - """Retrieve document excerpts similar to the text - - Args: - text: the text to retrieve similar documents - top_k: number of documents to retrieve - mmr: whether to use mmr to re-rank the documents - doc_ids: list of document ids to constraint the retrieval - """ - kwargs = {} - if doc_ids: - with Session(engine) as session: - stmt = select(Index).where( - Index.relation_type == SourceTargetRelation.VECTOR, - Index.source_id.in_(doc_ids), # type: ignore - ) - results = session.exec(stmt) - vs_ids = [r.target_id for r in results.all()] - - kwargs["filters"] = MetadataFilters( - filters=[ - MetadataFilter( - key="doc_id", - value=vs_id, - operator=FilterOperator.EQ, - ) - for vs_id in vs_ids - ], - condition=FilterCondition.OR, - ) - - if mmr: - # TODO: double check that llama-index MMR works correctly - kwargs["mode"] = VectorStoreQueryMode.MMR - kwargs["mmr_threshold"] = 0.5 - - # rerank - docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs) - if self.get_from_path("reranker"): - docs = self.reranker(docs, query=text) - - if not self.get_extra_table: - return docs - - # retrieve extra nodes relate to table - table_pages = defaultdict(list) - retrieved_id = set([doc.doc_id for doc in docs]) - for doc in docs: - if "page_label" not in doc.metadata: - continue - if "file_name" not in doc.metadata: - warnings.warn( - "file_name not in metadata while page_label is in metadata: " - f"{doc.metadata}" - ) - table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"]) - - queries = [ - {"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]} - for fn, pls in table_pages.items() - ] - if queries: - extra_docs = self.vector_retrieval( - text="", - top_k=50, - where={"$or": queries}, - ) - for doc in extra_docs: - if doc.doc_id not in retrieved_id: - docs.append(doc) - - return docs - - class PrepareEvidencePipeline(BaseComponent): """Prepare the evidence text from the list of retrieved documents @@ -338,22 +223,22 @@ class FullQAPipeline(BaseComponent): allow_extra = True params_publish = True - retrieval_pipeline: DocumentRetrievalPipeline = DocumentRetrievalPipeline.withx() + retrievers: list[BaseRetriever] evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx() answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx() async def run( # type: ignore - self, question: str, history: list, **kwargs # type: ignore + self, message: str, cid: str, history: list, **kwargs # type: ignore ) -> Document: # type: ignore - docs = self.retrieval_pipeline(text=question) + docs = [] + for retriever in self.retrievers: + docs.extend(retriever(text=message)) evidence_mode, evidence = self.evidence_pipeline(docs).content answer = await self.answering_pipeline( - question=question, evidence=evidence, evidence_mode=evidence_mode + question=message, evidence=evidence, evidence_mode=evidence_mode ) # prepare citation - from collections import defaultdict - spans = defaultdict(list) for fact_with_evidence in answer.metadata["citation"].answer: for quote in fact_with_evidence.substring_quote: @@ -369,6 +254,7 @@ class FullQAPipeline(BaseComponent): break id2docs = {doc.doc_id: doc for doc in docs} + lack_evidence = True for id, ss in spans.items(): if not ss: continue @@ -391,31 +277,24 @@ class FullQAPipeline(BaseComponent): ) } ) + lack_evidence = False + + if lack_evidence: + self.report_output({"evidence": "No evidence found"}) self.report_output(None) return answer @classmethod - def get_pipeline(cls, settings, **kwargs): + def get_pipeline(cls, settings, retrievers, **kwargs): """Get the reasoning pipeline - Need a base pipeline implementation. Currently the drawback is that we want to - treat the retrievers as tools. Hence, the reasoning pipelie should just take - the already initiated tools (retrievers), and do not need to set such logic - here. + Args: + settings: the settings for the pipeline + retrievers: the retrievers to use """ - pipeline = FullQAPipeline(get_extra_table=settings["index.prioritize_table"]) - if not settings["index.use_reranking"]: - pipeline.retrieval_pipeline.reranker = None # type: ignore - + pipeline = FullQAPipeline(retrievers=retrievers) pipeline.answering_pipeline.llm = llms.get_highest_accuracy() - kwargs = { - ".retrieval_pipeline.top_k": int(settings["index.num_retrieval"]), - ".retrieval_pipeline.mmr": settings["index.mmr"], - ".retrieval_pipeline.doc_ids": kwargs.get("files", None), - } - pipeline.set_run(kwargs, temp=True) - return pipeline @classmethod