Add Reranker implementation and integration in Retrieving pipeline (#77)

* Add base Reranker
* Add LLM Reranker
* Add Cohere Reranker
* Add integration of Rerankers in Retrieving pipeline
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2023-11-15 16:03:51 +07:00 committed by GitHub
parent b52f312d8e
commit 9945afdf6f
6 changed files with 207 additions and 8 deletions

4
.gitignore vendored
View File

@ -446,6 +446,9 @@ $RECYCLE.BIN/
# Windows shortcuts
*.lnk
# PDF files
*.pdf
.theflow/
# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm
@ -459,3 +462,4 @@ logs/
S.gpg-agent*
.vscode/settings.json
examples/example1/assets
storage/*

View File

@ -1,7 +1,8 @@
import os
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Sequence, Union
from llama_index.node_parser.extractors import MetadataExtractor
from llama_index.readers.base import BaseReader
from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
@ -18,6 +19,7 @@ from kotaemon.loaders import (
from kotaemon.parsers.splitter import SimpleNodeParser
from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
from kotaemon.pipelines.reranking import BaseRerankingPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.storages import (
BaseDocumentStore,
@ -43,12 +45,14 @@ class ReaderIndexingPipeline(BaseComponent):
chunk_overlap: int = 256
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
metadata_extractor: Optional[MetadataExtractor] = None
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
chunk_size=16,
)
def get_reader(self, input_files: List[Union[str, Path]]):
@ -79,7 +83,9 @@ class ReaderIndexingPipeline(BaseComponent):
@Node.auto(depends_on=["chunk_size", "chunk_overlap"])
def text_splitter(self) -> SimpleNodeParser:
return SimpleNodeParser(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
metadata_extractor=self.metadata_extractor,
)
def run(
@ -111,12 +117,15 @@ class ReaderIndexingPipeline(BaseComponent):
else:
self.indexing_vector_pipeline.load(file_storage_path)
def to_retrieving_pipeline(self, top_k=3):
def to_retrieving_pipeline(
self, top_k=3, rerankers: Sequence[BaseRerankingPipeline] = []
):
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
vector_store=self.vector_store,
doc_store=self.doc_store,
embedding=self.embedding,
top_k=top_k,
rerankers=rerankers,
)
return retrieving_pipeline

View File

@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import List
from typing import List, Sequence
from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
@ -11,6 +11,7 @@ from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.reranking import BaseRerankingPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.pipelines.tools import ComponentTool
from kotaemon.storages import (
@ -39,7 +40,7 @@ class QuestionAnsweringPipeline(BaseComponent):
)
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35",
@ -49,11 +50,12 @@ class QuestionAnsweringPipeline(BaseComponent):
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
rerankers: Sequence[BaseRerankingPipeline] = []
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
)
@ -72,6 +74,7 @@ class QuestionAnsweringPipeline(BaseComponent):
doc_store=self.doc_store,
embedding=self.embedding,
top_k=self.retrieval_top_k,
rerankers=self.rerankers,
)
# load persistent from selected path
collection_name = file_names_to_collection_name(self.file_name_list)

View File

@ -0,0 +1,114 @@
import os
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Union
from langchain.output_parsers.boolean import BooleanOutputParser
from ..base import BaseComponent
from ..base.schema import Document
from ..llms import PromptTemplate
from ..llms.chats.base import ChatLLM
from ..llms.completions.base import LLM
BaseLLM = Union[ChatLLM, LLM]
class BaseRerankingPipeline(BaseComponent):
@abstractmethod
def run(self, documents: List[Document], query: str) -> List[Document]:
"""Main method to transform list of documents
(re-ranking, filtering, etc)"""
...
class CohereReranking(BaseRerankingPipeline):
model_name: str = "rerank-multilingual-v2.0"
cohere_api_key: Optional[str] = None
top_k: int = 1
def run(self, documents: List[Document], query: str) -> List[Document]:
"""Use Cohere Reranker model to re-order documents
with their relevance score"""
try:
import cohere
except ImportError:
raise ImportError(
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
)
cohere_api_key = (
self.cohere_api_key if self.cohere_api_key else os.environ["COHERE_API_KEY"]
)
cohere_client = cohere.Client(cohere_api_key)
# output documents
compressed_docs = []
if len(documents) > 0: # to avoid empty api call
_docs = [d.content for d in documents]
results = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
)
for r in results:
doc = documents[r.index]
doc.metadata["relevance_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs
RERANK_PROMPT_TEMPLATE = """Given the following question and context,
return YES if the context is relevant to the question and NO if it isn't.
> Question: {question}
> Context:
>>>
{context}
>>>
> Relevant (YES / NO):"""
class LLMReranking(BaseRerankingPipeline):
llm: BaseLLM
prompt_template: PromptTemplate = PromptTemplate(template=RERANK_PROMPT_TEMPLATE)
top_k: int = 3
concurrent: bool = True
def run(
self,
documents: List[Document],
query: str,
) -> List[Document]:
"""Filter down documents based on their relevance to the query."""
filtered_docs = []
output_parser = BooleanOutputParser()
if self.concurrent:
with ThreadPoolExecutor() as executor:
futures = []
for doc in documents:
_prompt = self.prompt_template.populate(
question=query, context=doc.get_content()
)
futures.append(executor.submit(lambda: self.llm(_prompt).text))
results = [future.result() for future in futures]
else:
results = []
for doc in documents:
_prompt = self.prompt_template.populate(
question=query, context=doc.get_content()
)
results.append(self.llm(_prompt).text)
# use Boolean parser to extract relevancy output from LLM
results = [output_parser.parse(result) for result in results]
for include_doc, doc in zip(results, documents):
if include_doc:
filtered_docs.append(doc)
# prevent returning empty result
if len(filtered_docs) == 0:
filtered_docs = documents[: self.top_k]
return filtered_docs

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional
from typing import Optional, Sequence
from theflow import Node, Param
@ -9,6 +9,7 @@ from ..base import BaseComponent
from ..base.schema import Document, RetrievedDocument
from ..embeddings import BaseEmbeddings
from ..storages import BaseDocumentStore, BaseVectorStore
from .reranking import BaseRerankingPipeline
VECTOR_STORE_FNAME = "vectorstore"
DOC_STORE_FNAME = "docstore"
@ -20,6 +21,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
vector_store: Param[BaseVectorStore] = Param()
doc_store: Param[BaseDocumentStore] = Param()
embedding: Node[BaseEmbeddings] = Node()
rerankers: Sequence[BaseRerankingPipeline] = []
top_k: int = 1
# TODO: refer to llama_index's storage as well
@ -51,6 +53,11 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
RetrievedDocument(**doc.to_dict(), score=score)
for doc, score in zip(docs, scores)
]
# use additional reranker to re-order the document list
if self.rerankers:
for reranker in self.rerankers:
result = reranker(documents=result, query=text)
return result
def save(

62
tests/test_reranking.py Normal file
View File

@ -0,0 +1,62 @@
from unittest.mock import patch
import pytest
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.base import Document
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.reranking import LLMReranking
_openai_chat_completion_responses = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
for text in [
"YES",
"NO",
"YES",
]
]
@pytest.fixture
def llm():
return AzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
temperature=0,
)
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses,
)
def test_reranking(openai_completion, llm):
documents = [Document(text=f"test {idx}") for idx in range(3)]
query = "test query"
reranker = LLMReranking(llm=llm, concurrent=False)
rerank_docs = reranker(documents, query=query)
assert len(rerank_docs) == 2