Add Reranker implementation and integration in Retrieving pipeline (#77)

* Add base Reranker
* Add LLM Reranker
* Add Cohere Reranker
* Add integration of Rerankers in Retrieving pipeline
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2023-11-15 16:03:51 +07:00
committed by GitHub
parent b52f312d8e
commit 9945afdf6f
6 changed files with 207 additions and 8 deletions

View File

@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import List
from typing import List, Sequence
from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
@@ -11,6 +11,7 @@ from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.reranking import BaseRerankingPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.pipelines.tools import ComponentTool
from kotaemon.storages import (
@@ -39,7 +40,7 @@ class QuestionAnsweringPipeline(BaseComponent):
)
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35",
@@ -49,11 +50,12 @@ class QuestionAnsweringPipeline(BaseComponent):
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
rerankers: Sequence[BaseRerankingPipeline] = []
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
)
@@ -72,6 +74,7 @@ class QuestionAnsweringPipeline(BaseComponent):
doc_store=self.doc_store,
embedding=self.embedding,
top_k=self.retrieval_top_k,
rerankers=self.rerankers,
)
# load persistent from selected path
collection_name = file_names_to_collection_name(self.file_name_list)