feat: support TEI embedding service, configurable reranking model (#287)
* feat: add support for TEI embedding service, allow reranking model to be configurable. Signed-off-by: Kennywu <jdlow@live.cn> * fix: add cohere default reranking model * fix: comfort pre-commit --------- Signed-off-by: Kennywu <jdlow@live.cn> Co-authored-by: wujiaye <wujiaye@bluemoon.com.cn> Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
@@ -8,10 +8,12 @@ from .langchain_based import (
|
||||
LCOpenAIEmbeddings,
|
||||
)
|
||||
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from .tei_endpoint_embed import TeiEndpointEmbeddings
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddings",
|
||||
"EndpointEmbeddings",
|
||||
"TeiEndpointEmbeddings",
|
||||
"LCOpenAIEmbeddings",
|
||||
"LCAzureOpenAIEmbeddings",
|
||||
"LCCohereEmbeddings",
|
||||
|
105
libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py
Normal file
105
libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from kotaemon.base import Document, DocumentWithEmbedding, Param
|
||||
|
||||
from .base import BaseEmbeddings
|
||||
|
||||
session = requests.session()
|
||||
|
||||
|
||||
class TeiEndpointEmbeddings(BaseEmbeddings):
|
||||
"""An Embeddings component that uses an
|
||||
TEI (Text-Embedding-Inference) API compatible endpoint.
|
||||
|
||||
Ref: https://github.com/huggingface/text-embeddings-inference
|
||||
|
||||
Attributes:
|
||||
endpoint_url (str): The url of an TEI
|
||||
(Text-Embedding-Inference) API compatible endpoint.
|
||||
normalize (bool): Whether to normalize embeddings to unit length.
|
||||
truncate (bool): Whether to truncate embeddings
|
||||
to a fixed/default length.
|
||||
"""
|
||||
|
||||
endpoint_url: str = Param(None, help="TEI embedding service api base URL")
|
||||
normalize: bool = Param(
|
||||
True,
|
||||
help="Normalize embeddings to unit length",
|
||||
)
|
||||
truncate: bool = Param(
|
||||
True,
|
||||
help="Truncate embeddings to a fixed/default length",
|
||||
)
|
||||
|
||||
async def client_(self, inputs: list[str]):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url=self.endpoint_url,
|
||||
json={
|
||||
"inputs": inputs,
|
||||
"normalize": self.normalize,
|
||||
"truncate": self.truncate,
|
||||
},
|
||||
) as resp:
|
||||
embeddings = await resp.json()
|
||||
return embeddings
|
||||
|
||||
async def ainvoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
text = self.prepare_input(text)
|
||||
|
||||
outputs = []
|
||||
batch_size = 6
|
||||
num_batch = max(len(text) // batch_size, 1)
|
||||
for i in range(num_batch):
|
||||
if i == num_batch - 1:
|
||||
mini_batch = text[batch_size * i :]
|
||||
else:
|
||||
mini_batch = text[batch_size * i : batch_size * (i + 1)]
|
||||
mini_batch = [x.content for x in mini_batch]
|
||||
embeddings = await self.client_(mini_batch) # type: ignore
|
||||
outputs.extend(
|
||||
[
|
||||
DocumentWithEmbedding(content=doc, embedding=embedding)
|
||||
for doc, embedding in zip(mini_batch, embeddings)
|
||||
]
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def invoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = self.prepare_input(text)
|
||||
|
||||
outputs = []
|
||||
batch_size = 6
|
||||
num_batch = max(len(text) // batch_size, 1)
|
||||
for i in range(num_batch):
|
||||
if i == num_batch - 1:
|
||||
mini_batch = text[batch_size * i :]
|
||||
else:
|
||||
mini_batch = text[batch_size * i : batch_size * (i + 1)]
|
||||
mini_batch = [x.content for x in mini_batch]
|
||||
embeddings = session.post(
|
||||
url=self.endpoint_url,
|
||||
json={
|
||||
"inputs": mini_batch,
|
||||
"normalize": self.normalize,
|
||||
"truncate": self.truncate,
|
||||
},
|
||||
).json()
|
||||
outputs.extend(
|
||||
[
|
||||
DocumentWithEmbedding(content=doc, embedding=embedding)
|
||||
for doc, embedding in zip(mini_batch, embeddings)
|
||||
]
|
||||
)
|
||||
return outputs
|
@@ -39,7 +39,7 @@ class CohereReranking(BaseReranking):
|
||||
print("Cannot get Cohere API key from `ktem`", e)
|
||||
|
||||
if not self.cohere_api_key:
|
||||
print("Cohere API key not found. Skipping reranking.")
|
||||
print("Cohere API key not found. Skipping rerankings.")
|
||||
return documents
|
||||
|
||||
cohere_client = cohere.Client(self.cohere_api_key)
|
||||
@@ -52,10 +52,9 @@ class CohereReranking(BaseReranking):
|
||||
response = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs
|
||||
)
|
||||
# print("Cohere score", [r.relevance_score for r in response.results])
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["cohere_reranking_score"] = r.relevance_score
|
||||
doc.metadata["reranking_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
||||
|
@@ -241,7 +241,7 @@ class VectorRetrieval(BaseRetrieval):
|
||||
# if reranker is LLMReranking, limit the document with top_k items only
|
||||
if isinstance(reranker, LLMReranking):
|
||||
result = self._filter_docs(result, top_k=top_k)
|
||||
result = reranker(documents=result, query=text)
|
||||
result = reranker.run(documents=result, query=text)
|
||||
|
||||
result = self._filter_docs(result, top_k=top_k)
|
||||
print(f"Got raw {len(result)} retrieved documents")
|
||||
|
5
libs/kotaemon/kotaemon/rerankings/__init__.py
Normal file
5
libs/kotaemon/kotaemon/rerankings/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .base import BaseReranking
|
||||
from .cohere import CohereReranking
|
||||
from .tei_fast_rerank import TeiFastReranking
|
||||
|
||||
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"]
|
13
libs/kotaemon/kotaemon/rerankings/base.py
Normal file
13
libs/kotaemon/kotaemon/rerankings/base.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
|
||||
|
||||
class BaseReranking(BaseComponent):
|
||||
@abstractmethod
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Main method to transform list of documents
|
||||
(re-ranking, filtering, etc)"""
|
||||
...
|
56
libs/kotaemon/kotaemon/rerankings/cohere.py
Normal file
56
libs/kotaemon/kotaemon/rerankings/cohere.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decouple import config
|
||||
|
||||
from kotaemon.base import Document, Param
|
||||
|
||||
from .base import BaseReranking
|
||||
|
||||
|
||||
class CohereReranking(BaseReranking):
|
||||
"""Cohere Reranking model"""
|
||||
|
||||
model_name: str = Param(
|
||||
"rerank-multilingual-v2.0",
|
||||
help=(
|
||||
"ID of the model to use. You can go to [Supported Models]"
|
||||
"(https://docs.cohere.com/docs/rerank-2) to see the supported models"
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
cohere_api_key: str = Param(
|
||||
config("COHERE_API_KEY", ""),
|
||||
help="Cohere API key",
|
||||
required=True,
|
||||
)
|
||||
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Use Cohere Reranker model to re-order documents
|
||||
with their relevance score"""
|
||||
try:
|
||||
import cohere
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||
)
|
||||
|
||||
if not self.cohere_api_key:
|
||||
print("Cohere API key not found. Skipping rerankings.")
|
||||
return documents
|
||||
|
||||
cohere_client = cohere.Client(self.cohere_api_key)
|
||||
compressed_docs: list[Document] = []
|
||||
|
||||
if not documents: # to avoid empty api call
|
||||
return compressed_docs
|
||||
|
||||
_docs = [d.content for d in documents]
|
||||
response = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs
|
||||
)
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["reranking_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
77
libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py
Normal file
77
libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from kotaemon.base import Document, Param
|
||||
|
||||
from .base import BaseReranking
|
||||
|
||||
session = requests.session()
|
||||
|
||||
|
||||
class TeiFastReranking(BaseReranking):
|
||||
"""Text Embeddings Inference (TEI) Reranking model
|
||||
(https://huggingface.co/docs/text-embeddings-inference/en/index)
|
||||
"""
|
||||
|
||||
endpoint_url: str = Param(
|
||||
None, help="TEI Reranking service api base URL", required=True
|
||||
)
|
||||
model_name: Optional[str] = Param(
|
||||
None,
|
||||
help=(
|
||||
"ID of the model to use. You can go to [Supported Models]"
|
||||
"(https://github.com/huggingface"
|
||||
"/text-embeddings-inference?tab=readme-ov-file"
|
||||
"#supported-models) to see the supported models"
|
||||
),
|
||||
)
|
||||
is_truncated: Optional[bool] = Param(True, help="Whether to truncate the inputs")
|
||||
|
||||
def client(self, query, texts):
|
||||
response = session.post(
|
||||
url=self.endpoint_url,
|
||||
json={
|
||||
"query": query,
|
||||
"texts": texts,
|
||||
"is_truncated": self.is_truncated, # default is True
|
||||
},
|
||||
).json()
|
||||
return response
|
||||
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Use the deployed TEI rerankings service to re-order documents
|
||||
with their relevance score"""
|
||||
if not self.endpoint_url:
|
||||
print("TEI API reranking URL not found. Skipping rerankings.")
|
||||
return documents
|
||||
|
||||
compressed_docs: list[Document] = []
|
||||
|
||||
if not documents: # to avoid empty api call
|
||||
return compressed_docs
|
||||
|
||||
if isinstance(documents[0], str):
|
||||
documents = self.prepare_input(documents)
|
||||
|
||||
batch_size = 6
|
||||
num_batch = max(len(documents) // batch_size, 1)
|
||||
for i in range(num_batch):
|
||||
if i == num_batch - 1:
|
||||
mini_batch = documents[batch_size * i :]
|
||||
else:
|
||||
mini_batch = documents[batch_size * i : batch_size * (i + 1)]
|
||||
|
||||
_docs = [d.content for d in mini_batch]
|
||||
rerank_resp = self.client(query, _docs)
|
||||
for r in rerank_resp:
|
||||
doc = mini_batch[r["index"]]
|
||||
doc.metadata["reranking_score"] = r["score"]
|
||||
compressed_docs.append(doc)
|
||||
|
||||
compressed_docs = sorted(
|
||||
compressed_docs, key=lambda x: x.metadata["reranking_score"], reverse=True
|
||||
)
|
||||
return compressed_docs
|
Reference in New Issue
Block a user