Add Reranker implementation and integration in Retrieving pipeline (#77)

* Add base Reranker
* Add LLM Reranker
* Add Cohere Reranker
* Add integration of Rerankers in Retrieving pipeline
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2023-11-15 16:03:51 +07:00
committed by GitHub
parent b52f312d8e
commit 9945afdf6f
6 changed files with 207 additions and 8 deletions

View File

@@ -1,7 +1,8 @@
import os
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Sequence, Union
from llama_index.node_parser.extractors import MetadataExtractor
from llama_index.readers.base import BaseReader
from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
@@ -18,6 +19,7 @@ from kotaemon.loaders import (
from kotaemon.parsers.splitter import SimpleNodeParser
from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
from kotaemon.pipelines.reranking import BaseRerankingPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.storages import (
BaseDocumentStore,
@@ -43,12 +45,14 @@ class ReaderIndexingPipeline(BaseComponent):
chunk_overlap: int = 256
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
metadata_extractor: Optional[MetadataExtractor] = None
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
chunk_size=16,
)
def get_reader(self, input_files: List[Union[str, Path]]):
@@ -79,7 +83,9 @@ class ReaderIndexingPipeline(BaseComponent):
@Node.auto(depends_on=["chunk_size", "chunk_overlap"])
def text_splitter(self) -> SimpleNodeParser:
return SimpleNodeParser(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
metadata_extractor=self.metadata_extractor,
)
def run(
@@ -111,12 +117,15 @@ class ReaderIndexingPipeline(BaseComponent):
else:
self.indexing_vector_pipeline.load(file_storage_path)
def to_retrieving_pipeline(self, top_k=3):
def to_retrieving_pipeline(
self, top_k=3, rerankers: Sequence[BaseRerankingPipeline] = []
):
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
vector_store=self.vector_store,
doc_store=self.doc_store,
embedding=self.embedding,
top_k=top_k,
rerankers=rerankers,
)
return retrieving_pipeline