From 9945afdf6f428009b0ad0aa81ea17a553fa5f297 Mon Sep 17 00:00:00 2001 From: "Tuan Anh Nguyen Dang (Tadashi_Cin)" Date: Wed, 15 Nov 2023 16:03:51 +0700 Subject: [PATCH] Add Reranker implementation and integration in Retrieving pipeline (#77) * Add base Reranker * Add LLM Reranker * Add Cohere Reranker * Add integration of Rerankers in Retrieving pipeline --- .gitignore | 4 + knowledgehub/pipelines/ingest.py | 17 +++- knowledgehub/pipelines/qa.py | 9 ++- knowledgehub/pipelines/reranking.py | 114 +++++++++++++++++++++++++++ knowledgehub/pipelines/retrieving.py | 9 ++- tests/test_reranking.py | 62 +++++++++++++++ 6 files changed, 207 insertions(+), 8 deletions(-) create mode 100644 knowledgehub/pipelines/reranking.py create mode 100644 tests/test_reranking.py diff --git a/.gitignore b/.gitignore index 2122691..19ecc56 100644 --- a/.gitignore +++ b/.gitignore @@ -446,6 +446,9 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk +# PDF files +*.pdf + .theflow/ # End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm @@ -459,3 +462,4 @@ logs/ S.gpg-agent* .vscode/settings.json examples/example1/assets +storage/* diff --git a/knowledgehub/pipelines/ingest.py b/knowledgehub/pipelines/ingest.py index d3f19b8..c486c73 100644 --- a/knowledgehub/pipelines/ingest.py +++ b/knowledgehub/pipelines/ingest.py @@ -1,7 +1,8 @@ import os from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Sequence, Union +from llama_index.node_parser.extractors import MetadataExtractor from llama_index.readers.base import BaseReader from theflow import Node from theflow.utils.modules import ObjectInitDeclaration as _ @@ -18,6 +19,7 @@ from kotaemon.loaders import ( from kotaemon.parsers.splitter import SimpleNodeParser from kotaemon.pipelines.agents import BaseAgent from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline +from kotaemon.pipelines.reranking import BaseRerankingPipeline from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline from kotaemon.storages import ( BaseDocumentStore, @@ -43,12 +45,14 @@ class ReaderIndexingPipeline(BaseComponent): chunk_overlap: int = 256 vector_store: _[BaseVectorStore] = _(InMemoryVectorStore) doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore) + metadata_extractor: Optional[MetadataExtractor] = None embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( model="text-embedding-ada-002", deployment="dummy-q2-text-embedding", - openai_api_base="https://bleh-dummy-2.openai.azure.com/", + azure_endpoint="https://bleh-dummy-2.openai.azure.com/", openai_api_key=os.environ.get("OPENAI_API_KEY", ""), + chunk_size=16, ) def get_reader(self, input_files: List[Union[str, Path]]): @@ -79,7 +83,9 @@ class ReaderIndexingPipeline(BaseComponent): @Node.auto(depends_on=["chunk_size", "chunk_overlap"]) def text_splitter(self) -> SimpleNodeParser: return SimpleNodeParser( - chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + metadata_extractor=self.metadata_extractor, ) def run( @@ -111,12 +117,15 @@ class ReaderIndexingPipeline(BaseComponent): else: self.indexing_vector_pipeline.load(file_storage_path) - def to_retrieving_pipeline(self, top_k=3): + def to_retrieving_pipeline( + self, top_k=3, rerankers: Sequence[BaseRerankingPipeline] = [] + ): retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline( vector_store=self.vector_store, doc_store=self.doc_store, embedding=self.embedding, top_k=top_k, + rerankers=rerankers, ) return retrieving_pipeline diff --git a/knowledgehub/pipelines/qa.py b/knowledgehub/pipelines/qa.py index 0763535..93a7147 100644 --- a/knowledgehub/pipelines/qa.py +++ b/knowledgehub/pipelines/qa.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import List +from typing import List, Sequence from theflow import Node from theflow.utils.modules import ObjectInitDeclaration as _ @@ -11,6 +11,7 @@ from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.llms import PromptTemplate from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.pipelines.agents import BaseAgent +from kotaemon.pipelines.reranking import BaseRerankingPipeline from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline from kotaemon.pipelines.tools import ComponentTool from kotaemon.storages import ( @@ -39,7 +40,7 @@ class QuestionAnsweringPipeline(BaseComponent): ) llm: AzureChatOpenAI = AzureChatOpenAI.withx( - openai_api_base="https://bleh-dummy-2.openai.azure.com/", + azure_endpoint="https://bleh-dummy-2.openai.azure.com/", openai_api_key=os.environ.get("OPENAI_API_KEY", ""), openai_api_version="2023-03-15-preview", deployment_name="dummy-q2-gpt35", @@ -49,11 +50,12 @@ class QuestionAnsweringPipeline(BaseComponent): vector_store: _[BaseVectorStore] = _(InMemoryVectorStore) doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore) + rerankers: Sequence[BaseRerankingPipeline] = [] embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( model="text-embedding-ada-002", deployment="dummy-q2-text-embedding", - openai_api_base="https://bleh-dummy-2.openai.azure.com/", + azure_endpoint="https://bleh-dummy-2.openai.azure.com/", openai_api_key=os.environ.get("OPENAI_API_KEY", ""), ) @@ -72,6 +74,7 @@ class QuestionAnsweringPipeline(BaseComponent): doc_store=self.doc_store, embedding=self.embedding, top_k=self.retrieval_top_k, + rerankers=self.rerankers, ) # load persistent from selected path collection_name = file_names_to_collection_name(self.file_name_list) diff --git a/knowledgehub/pipelines/reranking.py b/knowledgehub/pipelines/reranking.py new file mode 100644 index 0000000..8c4c20f --- /dev/null +++ b/knowledgehub/pipelines/reranking.py @@ -0,0 +1,114 @@ +import os +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional, Union + +from langchain.output_parsers.boolean import BooleanOutputParser + +from ..base import BaseComponent +from ..base.schema import Document +from ..llms import PromptTemplate +from ..llms.chats.base import ChatLLM +from ..llms.completions.base import LLM + +BaseLLM = Union[ChatLLM, LLM] + + +class BaseRerankingPipeline(BaseComponent): + @abstractmethod + def run(self, documents: List[Document], query: str) -> List[Document]: + """Main method to transform list of documents + (re-ranking, filtering, etc)""" + ... + + +class CohereReranking(BaseRerankingPipeline): + model_name: str = "rerank-multilingual-v2.0" + cohere_api_key: Optional[str] = None + top_k: int = 1 + + def run(self, documents: List[Document], query: str) -> List[Document]: + """Use Cohere Reranker model to re-order documents + with their relevance score""" + try: + import cohere + except ImportError: + raise ImportError( + "Please install Cohere " "`pip install cohere` to use Cohere Reranking" + ) + + cohere_api_key = ( + self.cohere_api_key if self.cohere_api_key else os.environ["COHERE_API_KEY"] + ) + cohere_client = cohere.Client(cohere_api_key) + + # output documents + compressed_docs = [] + if len(documents) > 0: # to avoid empty api call + _docs = [d.content for d in documents] + results = cohere_client.rerank( + model=self.model_name, query=query, documents=_docs, top_n=self.top_k + ) + for r in results: + doc = documents[r.index] + doc.metadata["relevance_score"] = r.relevance_score + compressed_docs.append(doc) + + return compressed_docs + + +RERANK_PROMPT_TEMPLATE = """Given the following question and context, +return YES if the context is relevant to the question and NO if it isn't. + +> Question: {question} +> Context: +>>> +{context} +>>> +> Relevant (YES / NO):""" + + +class LLMReranking(BaseRerankingPipeline): + llm: BaseLLM + prompt_template: PromptTemplate = PromptTemplate(template=RERANK_PROMPT_TEMPLATE) + top_k: int = 3 + concurrent: bool = True + + def run( + self, + documents: List[Document], + query: str, + ) -> List[Document]: + """Filter down documents based on their relevance to the query.""" + filtered_docs = [] + output_parser = BooleanOutputParser() + + if self.concurrent: + with ThreadPoolExecutor() as executor: + futures = [] + for doc in documents: + _prompt = self.prompt_template.populate( + question=query, context=doc.get_content() + ) + futures.append(executor.submit(lambda: self.llm(_prompt).text)) + + results = [future.result() for future in futures] + else: + results = [] + for doc in documents: + _prompt = self.prompt_template.populate( + question=query, context=doc.get_content() + ) + results.append(self.llm(_prompt).text) + + # use Boolean parser to extract relevancy output from LLM + results = [output_parser.parse(result) for result in results] + for include_doc, doc in zip(results, documents): + if include_doc: + filtered_docs.append(doc) + + # prevent returning empty result + if len(filtered_docs) == 0: + filtered_docs = documents[: self.top_k] + + return filtered_docs diff --git a/knowledgehub/pipelines/retrieving.py b/knowledgehub/pipelines/retrieving.py index 7003391..67bf26b 100644 --- a/knowledgehub/pipelines/retrieving.py +++ b/knowledgehub/pipelines/retrieving.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Optional +from typing import Optional, Sequence from theflow import Node, Param @@ -9,6 +9,7 @@ from ..base import BaseComponent from ..base.schema import Document, RetrievedDocument from ..embeddings import BaseEmbeddings from ..storages import BaseDocumentStore, BaseVectorStore +from .reranking import BaseRerankingPipeline VECTOR_STORE_FNAME = "vectorstore" DOC_STORE_FNAME = "docstore" @@ -20,6 +21,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent): vector_store: Param[BaseVectorStore] = Param() doc_store: Param[BaseDocumentStore] = Param() embedding: Node[BaseEmbeddings] = Node() + rerankers: Sequence[BaseRerankingPipeline] = [] top_k: int = 1 # TODO: refer to llama_index's storage as well @@ -51,6 +53,11 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent): RetrievedDocument(**doc.to_dict(), score=score) for doc, score in zip(docs, scores) ] + # use additional reranker to re-order the document list + if self.rerankers: + for reranker in self.rerankers: + result = reranker(documents=result, query=text) + return result def save( diff --git a/tests/test_reranking.py b/tests/test_reranking.py new file mode 100644 index 0000000..3652b8c --- /dev/null +++ b/tests/test_reranking.py @@ -0,0 +1,62 @@ +from unittest.mock import patch + +import pytest +from openai.types.chat.chat_completion import ChatCompletion + +from kotaemon.base import Document +from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.pipelines.reranking import LLMReranking + +_openai_chat_completion_responses = [ + ChatCompletion.parse_obj( + { + "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", + "object": "chat.completion", + "created": 1692338378, + "model": "gpt-35-turbo", + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": text, + "function_call": None, + "tool_calls": None, + }, + } + ], + "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, + } + ) + for text in [ + "YES", + "NO", + "YES", + ] +] + + +@pytest.fixture +def llm(): + return AzureChatOpenAI( + azure_endpoint="https://dummy.openai.azure.com/", + openai_api_key="dummy", + openai_api_version="2023-03-15-preview", + temperature=0, + ) + + +@patch( + "openai.resources.chat.completions.Completions.create", + side_effect=_openai_chat_completion_responses, +) +def test_reranking(openai_completion, llm): + documents = [Document(text=f"test {idx}") for idx in range(3)] + query = "test query" + + reranker = LLMReranking(llm=llm, concurrent=False) + rerank_docs = reranker(documents, query=query) + + assert len(rerank_docs) == 2