Separate rerankers, splitters and extractors (#85)
This commit is contained in:
committed by
GitHub
parent
0dede9c82d
commit
2186c5558f
5
knowledgehub/indices/rankings/__init__.py
Normal file
5
knowledgehub/indices/rankings/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .base import BaseReranking
|
||||
from .cohere import CohereReranking
|
||||
from .llm import LLMReranking
|
||||
|
||||
__all__ = ["CohereReranking", "LLMReranking", "BaseReranking"]
|
13
knowledgehub/indices/rankings/base.py
Normal file
13
knowledgehub/indices/rankings/base.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from ...base import BaseComponent, Document
|
||||
|
||||
|
||||
class BaseReranking(BaseComponent):
|
||||
@abstractmethod
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Main method to transform list of documents
|
||||
(re-ranking, filtering, etc)"""
|
||||
...
|
38
knowledgehub/indices/rankings/cohere.py
Normal file
38
knowledgehub/indices/rankings/cohere.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from ...base import Document
|
||||
from .base import BaseReranking
|
||||
|
||||
|
||||
class CohereReranking(BaseReranking):
|
||||
model_name: str = "rerank-multilingual-v2.0"
|
||||
cohere_api_key: str = os.environ.get("COHERE_API_KEY", "")
|
||||
top_k: int = 1
|
||||
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Use Cohere Reranker model to re-order documents
|
||||
with their relevance score"""
|
||||
try:
|
||||
import cohere
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||
)
|
||||
|
||||
cohere_client = cohere.Client(self.cohere_api_key)
|
||||
|
||||
# output documents
|
||||
compressed_docs = []
|
||||
if len(documents) > 0: # to avoid empty api call
|
||||
_docs = [d.content for d in documents]
|
||||
results = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
|
||||
)
|
||||
for r in results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["relevance_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
70
knowledgehub/indices/rankings/llm.py
Normal file
70
knowledgehub/indices/rankings/llm.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Union
|
||||
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
|
||||
from ...base import Document
|
||||
from ...llms import PromptTemplate
|
||||
from ...llms.chats.base import ChatLLM
|
||||
from ...llms.completions.base import LLM
|
||||
from .base import BaseReranking
|
||||
|
||||
BaseLLM = Union[ChatLLM, LLM]
|
||||
|
||||
RERANK_PROMPT_TEMPLATE = """Given the following question and context,
|
||||
return YES if the context is relevant to the question and NO if it isn't.
|
||||
|
||||
> Question: {question}
|
||||
> Context:
|
||||
>>>
|
||||
{context}
|
||||
>>>
|
||||
> Relevant (YES / NO):"""
|
||||
|
||||
|
||||
class LLMReranking(BaseReranking):
|
||||
llm: BaseLLM
|
||||
prompt_template: PromptTemplate = PromptTemplate(template=RERANK_PROMPT_TEMPLATE)
|
||||
top_k: int = 3
|
||||
concurrent: bool = True
|
||||
|
||||
def run(
|
||||
self,
|
||||
documents: list[Document],
|
||||
query: str,
|
||||
) -> list[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
filtered_docs = []
|
||||
output_parser = BooleanOutputParser()
|
||||
|
||||
if self.concurrent:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for doc in documents:
|
||||
_prompt = self.prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
futures.append(executor.submit(lambda: self.llm(_prompt).text))
|
||||
|
||||
results = [future.result() for future in futures]
|
||||
else:
|
||||
results = []
|
||||
for doc in documents:
|
||||
_prompt = self.prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
results.append(self.llm(_prompt).text)
|
||||
|
||||
# use Boolean parser to extract relevancy output from LLM
|
||||
results = [output_parser.parse(result) for result in results]
|
||||
for include_doc, doc in zip(results, documents):
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
# prevent returning empty result
|
||||
if len(filtered_docs) == 0:
|
||||
filtered_docs = documents[: self.top_k]
|
||||
|
||||
return filtered_docs
|
Reference in New Issue
Block a user