Add Reranker implementation and integration in Retrieving pipeline (#77)
* Add base Reranker * Add LLM Reranker * Add Cohere Reranker * Add integration of Rerankers in Retrieving pipeline
This commit is contained in:
parent
b52f312d8e
commit
9945afdf6f
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -446,6 +446,9 @@ $RECYCLE.BIN/
|
||||||
# Windows shortcuts
|
# Windows shortcuts
|
||||||
*.lnk
|
*.lnk
|
||||||
|
|
||||||
|
# PDF files
|
||||||
|
*.pdf
|
||||||
|
|
||||||
.theflow/
|
.theflow/
|
||||||
|
|
||||||
# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm
|
# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm
|
||||||
|
@ -459,3 +462,4 @@ logs/
|
||||||
S.gpg-agent*
|
S.gpg-agent*
|
||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
examples/example1/assets
|
examples/example1/assets
|
||||||
|
storage/*
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from llama_index.node_parser.extractors import MetadataExtractor
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.readers.base import BaseReader
|
||||||
from theflow import Node
|
from theflow import Node
|
||||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||||
|
@ -18,6 +19,7 @@ from kotaemon.loaders import (
|
||||||
from kotaemon.parsers.splitter import SimpleNodeParser
|
from kotaemon.parsers.splitter import SimpleNodeParser
|
||||||
from kotaemon.pipelines.agents import BaseAgent
|
from kotaemon.pipelines.agents import BaseAgent
|
||||||
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
||||||
|
from kotaemon.pipelines.reranking import BaseRerankingPipeline
|
||||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||||
from kotaemon.storages import (
|
from kotaemon.storages import (
|
||||||
BaseDocumentStore,
|
BaseDocumentStore,
|
||||||
|
@ -43,12 +45,14 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||||
chunk_overlap: int = 256
|
chunk_overlap: int = 256
|
||||||
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
|
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
|
||||||
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
|
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
|
||||||
|
metadata_extractor: Optional[MetadataExtractor] = None
|
||||||
|
|
||||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
deployment="dummy-q2-text-embedding",
|
deployment="dummy-q2-text-embedding",
|
||||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||||
|
chunk_size=16,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_reader(self, input_files: List[Union[str, Path]]):
|
def get_reader(self, input_files: List[Union[str, Path]]):
|
||||||
|
@ -79,7 +83,9 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||||
@Node.auto(depends_on=["chunk_size", "chunk_overlap"])
|
@Node.auto(depends_on=["chunk_size", "chunk_overlap"])
|
||||||
def text_splitter(self) -> SimpleNodeParser:
|
def text_splitter(self) -> SimpleNodeParser:
|
||||||
return SimpleNodeParser(
|
return SimpleNodeParser(
|
||||||
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
chunk_size=self.chunk_size,
|
||||||
|
chunk_overlap=self.chunk_overlap,
|
||||||
|
metadata_extractor=self.metadata_extractor,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
|
@ -111,12 +117,15 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||||
else:
|
else:
|
||||||
self.indexing_vector_pipeline.load(file_storage_path)
|
self.indexing_vector_pipeline.load(file_storage_path)
|
||||||
|
|
||||||
def to_retrieving_pipeline(self, top_k=3):
|
def to_retrieving_pipeline(
|
||||||
|
self, top_k=3, rerankers: Sequence[BaseRerankingPipeline] = []
|
||||||
|
):
|
||||||
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
||||||
vector_store=self.vector_store,
|
vector_store=self.vector_store,
|
||||||
doc_store=self.doc_store,
|
doc_store=self.doc_store,
|
||||||
embedding=self.embedding,
|
embedding=self.embedding,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
rerankers=rerankers,
|
||||||
)
|
)
|
||||||
return retrieving_pipeline
|
return retrieving_pipeline
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List, Sequence
|
||||||
|
|
||||||
from theflow import Node
|
from theflow import Node
|
||||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||||
|
@ -11,6 +11,7 @@ from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||||
from kotaemon.llms import PromptTemplate
|
from kotaemon.llms import PromptTemplate
|
||||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
from kotaemon.pipelines.agents import BaseAgent
|
from kotaemon.pipelines.agents import BaseAgent
|
||||||
|
from kotaemon.pipelines.reranking import BaseRerankingPipeline
|
||||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||||
from kotaemon.pipelines.tools import ComponentTool
|
from kotaemon.pipelines.tools import ComponentTool
|
||||||
from kotaemon.storages import (
|
from kotaemon.storages import (
|
||||||
|
@ -39,7 +40,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||||
)
|
)
|
||||||
|
|
||||||
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
|
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
|
||||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
deployment_name="dummy-q2-gpt35",
|
deployment_name="dummy-q2-gpt35",
|
||||||
|
@ -49,11 +50,12 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||||
|
|
||||||
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
|
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
|
||||||
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
|
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
|
||||||
|
rerankers: Sequence[BaseRerankingPipeline] = []
|
||||||
|
|
||||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
deployment="dummy-q2-text-embedding",
|
deployment="dummy-q2-text-embedding",
|
||||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -72,6 +74,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||||
doc_store=self.doc_store,
|
doc_store=self.doc_store,
|
||||||
embedding=self.embedding,
|
embedding=self.embedding,
|
||||||
top_k=self.retrieval_top_k,
|
top_k=self.retrieval_top_k,
|
||||||
|
rerankers=self.rerankers,
|
||||||
)
|
)
|
||||||
# load persistent from selected path
|
# load persistent from selected path
|
||||||
collection_name = file_names_to_collection_name(self.file_name_list)
|
collection_name = file_names_to_collection_name(self.file_name_list)
|
||||||
|
|
114
knowledgehub/pipelines/reranking.py
Normal file
114
knowledgehub/pipelines/reranking.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
import os
|
||||||
|
from abc import abstractmethod
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||||
|
|
||||||
|
from ..base import BaseComponent
|
||||||
|
from ..base.schema import Document
|
||||||
|
from ..llms import PromptTemplate
|
||||||
|
from ..llms.chats.base import ChatLLM
|
||||||
|
from ..llms.completions.base import LLM
|
||||||
|
|
||||||
|
BaseLLM = Union[ChatLLM, LLM]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRerankingPipeline(BaseComponent):
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, documents: List[Document], query: str) -> List[Document]:
|
||||||
|
"""Main method to transform list of documents
|
||||||
|
(re-ranking, filtering, etc)"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class CohereReranking(BaseRerankingPipeline):
|
||||||
|
model_name: str = "rerank-multilingual-v2.0"
|
||||||
|
cohere_api_key: Optional[str] = None
|
||||||
|
top_k: int = 1
|
||||||
|
|
||||||
|
def run(self, documents: List[Document], query: str) -> List[Document]:
|
||||||
|
"""Use Cohere Reranker model to re-order documents
|
||||||
|
with their relevance score"""
|
||||||
|
try:
|
||||||
|
import cohere
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||||
|
)
|
||||||
|
|
||||||
|
cohere_api_key = (
|
||||||
|
self.cohere_api_key if self.cohere_api_key else os.environ["COHERE_API_KEY"]
|
||||||
|
)
|
||||||
|
cohere_client = cohere.Client(cohere_api_key)
|
||||||
|
|
||||||
|
# output documents
|
||||||
|
compressed_docs = []
|
||||||
|
if len(documents) > 0: # to avoid empty api call
|
||||||
|
_docs = [d.content for d in documents]
|
||||||
|
results = cohere_client.rerank(
|
||||||
|
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
|
||||||
|
)
|
||||||
|
for r in results:
|
||||||
|
doc = documents[r.index]
|
||||||
|
doc.metadata["relevance_score"] = r.relevance_score
|
||||||
|
compressed_docs.append(doc)
|
||||||
|
|
||||||
|
return compressed_docs
|
||||||
|
|
||||||
|
|
||||||
|
RERANK_PROMPT_TEMPLATE = """Given the following question and context,
|
||||||
|
return YES if the context is relevant to the question and NO if it isn't.
|
||||||
|
|
||||||
|
> Question: {question}
|
||||||
|
> Context:
|
||||||
|
>>>
|
||||||
|
{context}
|
||||||
|
>>>
|
||||||
|
> Relevant (YES / NO):"""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMReranking(BaseRerankingPipeline):
|
||||||
|
llm: BaseLLM
|
||||||
|
prompt_template: PromptTemplate = PromptTemplate(template=RERANK_PROMPT_TEMPLATE)
|
||||||
|
top_k: int = 3
|
||||||
|
concurrent: bool = True
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
documents: List[Document],
|
||||||
|
query: str,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Filter down documents based on their relevance to the query."""
|
||||||
|
filtered_docs = []
|
||||||
|
output_parser = BooleanOutputParser()
|
||||||
|
|
||||||
|
if self.concurrent:
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = []
|
||||||
|
for doc in documents:
|
||||||
|
_prompt = self.prompt_template.populate(
|
||||||
|
question=query, context=doc.get_content()
|
||||||
|
)
|
||||||
|
futures.append(executor.submit(lambda: self.llm(_prompt).text))
|
||||||
|
|
||||||
|
results = [future.result() for future in futures]
|
||||||
|
else:
|
||||||
|
results = []
|
||||||
|
for doc in documents:
|
||||||
|
_prompt = self.prompt_template.populate(
|
||||||
|
question=query, context=doc.get_content()
|
||||||
|
)
|
||||||
|
results.append(self.llm(_prompt).text)
|
||||||
|
|
||||||
|
# use Boolean parser to extract relevancy output from LLM
|
||||||
|
results = [output_parser.parse(result) for result in results]
|
||||||
|
for include_doc, doc in zip(results, documents):
|
||||||
|
if include_doc:
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
|
# prevent returning empty result
|
||||||
|
if len(filtered_docs) == 0:
|
||||||
|
filtered_docs = documents[: self.top_k]
|
||||||
|
|
||||||
|
return filtered_docs
|
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
from theflow import Node, Param
|
from theflow import Node, Param
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ from ..base import BaseComponent
|
||||||
from ..base.schema import Document, RetrievedDocument
|
from ..base.schema import Document, RetrievedDocument
|
||||||
from ..embeddings import BaseEmbeddings
|
from ..embeddings import BaseEmbeddings
|
||||||
from ..storages import BaseDocumentStore, BaseVectorStore
|
from ..storages import BaseDocumentStore, BaseVectorStore
|
||||||
|
from .reranking import BaseRerankingPipeline
|
||||||
|
|
||||||
VECTOR_STORE_FNAME = "vectorstore"
|
VECTOR_STORE_FNAME = "vectorstore"
|
||||||
DOC_STORE_FNAME = "docstore"
|
DOC_STORE_FNAME = "docstore"
|
||||||
|
@ -20,6 +21,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
|
||||||
vector_store: Param[BaseVectorStore] = Param()
|
vector_store: Param[BaseVectorStore] = Param()
|
||||||
doc_store: Param[BaseDocumentStore] = Param()
|
doc_store: Param[BaseDocumentStore] = Param()
|
||||||
embedding: Node[BaseEmbeddings] = Node()
|
embedding: Node[BaseEmbeddings] = Node()
|
||||||
|
rerankers: Sequence[BaseRerankingPipeline] = []
|
||||||
top_k: int = 1
|
top_k: int = 1
|
||||||
# TODO: refer to llama_index's storage as well
|
# TODO: refer to llama_index's storage as well
|
||||||
|
|
||||||
|
@ -51,6 +53,11 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
|
||||||
RetrievedDocument(**doc.to_dict(), score=score)
|
RetrievedDocument(**doc.to_dict(), score=score)
|
||||||
for doc, score in zip(docs, scores)
|
for doc, score in zip(docs, scores)
|
||||||
]
|
]
|
||||||
|
# use additional reranker to re-order the document list
|
||||||
|
if self.rerankers:
|
||||||
|
for reranker in self.rerankers:
|
||||||
|
result = reranker(documents=result, query=text)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
|
|
62
tests/test_reranking.py
Normal file
62
tests/test_reranking.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
|
from kotaemon.base import Document
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.pipelines.reranking import LLMReranking
|
||||||
|
|
||||||
|
_openai_chat_completion_responses = [
|
||||||
|
ChatCompletion.parse_obj(
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1692338378,
|
||||||
|
"model": "gpt-35-turbo",
|
||||||
|
"system_fingerprint": None,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text,
|
||||||
|
"function_call": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for text in [
|
||||||
|
"YES",
|
||||||
|
"NO",
|
||||||
|
"YES",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm():
|
||||||
|
return AzureChatOpenAI(
|
||||||
|
azure_endpoint="https://dummy.openai.azure.com/",
|
||||||
|
openai_api_key="dummy",
|
||||||
|
openai_api_version="2023-03-15-preview",
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.resources.chat.completions.Completions.create",
|
||||||
|
side_effect=_openai_chat_completion_responses,
|
||||||
|
)
|
||||||
|
def test_reranking(openai_completion, llm):
|
||||||
|
documents = [Document(text=f"test {idx}") for idx in range(3)]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
reranker = LLMReranking(llm=llm, concurrent=False)
|
||||||
|
rerank_docs = reranker(documents, query=query)
|
||||||
|
|
||||||
|
assert len(rerank_docs) == 2
|
Loading…
Reference in New Issue
Block a user