feat: support TEI embedding service, configurable reranking model (#287)

* feat: add support for TEI embedding service, allow reranking model to be configurable.

Signed-off-by: Kennywu <jdlow@live.cn>

* fix: add cohere default reranking model

* fix: comfort pre-commit

---------

Signed-off-by: Kennywu <jdlow@live.cn>
Co-authored-by: wujiaye <wujiaye@bluemoon.com.cn>
Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
KennyWu
2024-09-30 23:00:00 +08:00
committed by GitHub
parent 2e3c17b256
commit 53530e296f
20 changed files with 928 additions and 22 deletions

View File

@@ -8,10 +8,12 @@ from .langchain_based import (
LCOpenAIEmbeddings,
)
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from .tei_endpoint_embed import TeiEndpointEmbeddings
__all__ = [
"BaseEmbeddings",
"EndpointEmbeddings",
"TeiEndpointEmbeddings",
"LCOpenAIEmbeddings",
"LCAzureOpenAIEmbeddings",
"LCCohereEmbeddings",

View File

@@ -0,0 +1,105 @@
import aiohttp
import requests
from kotaemon.base import Document, DocumentWithEmbedding, Param
from .base import BaseEmbeddings
session = requests.session()
class TeiEndpointEmbeddings(BaseEmbeddings):
"""An Embeddings component that uses an
TEI (Text-Embedding-Inference) API compatible endpoint.
Ref: https://github.com/huggingface/text-embeddings-inference
Attributes:
endpoint_url (str): The url of an TEI
(Text-Embedding-Inference) API compatible endpoint.
normalize (bool): Whether to normalize embeddings to unit length.
truncate (bool): Whether to truncate embeddings
to a fixed/default length.
"""
endpoint_url: str = Param(None, help="TEI embedding service api base URL")
normalize: bool = Param(
True,
help="Normalize embeddings to unit length",
)
truncate: bool = Param(
True,
help="Truncate embeddings to a fixed/default length",
)
async def client_(self, inputs: list[str]):
async with aiohttp.ClientSession() as session:
async with session.post(
url=self.endpoint_url,
json={
"inputs": inputs,
"normalize": self.normalize,
"truncate": self.truncate,
},
) as resp:
embeddings = await resp.json()
return embeddings
async def ainvoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
if not isinstance(text, list):
text = [text]
text = self.prepare_input(text)
outputs = []
batch_size = 6
num_batch = max(len(text) // batch_size, 1)
for i in range(num_batch):
if i == num_batch - 1:
mini_batch = text[batch_size * i :]
else:
mini_batch = text[batch_size * i : batch_size * (i + 1)]
mini_batch = [x.content for x in mini_batch]
embeddings = await self.client_(mini_batch) # type: ignore
outputs.extend(
[
DocumentWithEmbedding(content=doc, embedding=embedding)
for doc, embedding in zip(mini_batch, embeddings)
]
)
return outputs
def invoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
if not isinstance(text, list):
text = [text]
text = self.prepare_input(text)
outputs = []
batch_size = 6
num_batch = max(len(text) // batch_size, 1)
for i in range(num_batch):
if i == num_batch - 1:
mini_batch = text[batch_size * i :]
else:
mini_batch = text[batch_size * i : batch_size * (i + 1)]
mini_batch = [x.content for x in mini_batch]
embeddings = session.post(
url=self.endpoint_url,
json={
"inputs": mini_batch,
"normalize": self.normalize,
"truncate": self.truncate,
},
).json()
outputs.extend(
[
DocumentWithEmbedding(content=doc, embedding=embedding)
for doc, embedding in zip(mini_batch, embeddings)
]
)
return outputs

View File

@@ -39,7 +39,7 @@ class CohereReranking(BaseReranking):
print("Cannot get Cohere API key from `ktem`", e)
if not self.cohere_api_key:
print("Cohere API key not found. Skipping reranking.")
print("Cohere API key not found. Skipping rerankings.")
return documents
cohere_client = cohere.Client(self.cohere_api_key)
@@ -52,10 +52,9 @@ class CohereReranking(BaseReranking):
response = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs
)
# print("Cohere score", [r.relevance_score for r in response.results])
for r in response.results:
doc = documents[r.index]
doc.metadata["cohere_reranking_score"] = r.relevance_score
doc.metadata["reranking_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs

View File

@@ -241,7 +241,7 @@ class VectorRetrieval(BaseRetrieval):
# if reranker is LLMReranking, limit the document with top_k items only
if isinstance(reranker, LLMReranking):
result = self._filter_docs(result, top_k=top_k)
result = reranker(documents=result, query=text)
result = reranker.run(documents=result, query=text)
result = self._filter_docs(result, top_k=top_k)
print(f"Got raw {len(result)} retrieved documents")

View File

@@ -0,0 +1,5 @@
from .base import BaseReranking
from .cohere import CohereReranking
from .tei_fast_rerank import TeiFastReranking
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"]

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from abc import abstractmethod
from kotaemon.base import BaseComponent, Document
class BaseReranking(BaseComponent):
@abstractmethod
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Main method to transform list of documents
(re-ranking, filtering, etc)"""
...

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
from decouple import config
from kotaemon.base import Document, Param
from .base import BaseReranking
class CohereReranking(BaseReranking):
"""Cohere Reranking model"""
model_name: str = Param(
"rerank-multilingual-v2.0",
help=(
"ID of the model to use. You can go to [Supported Models]"
"(https://docs.cohere.com/docs/rerank-2) to see the supported models"
),
required=True,
)
cohere_api_key: str = Param(
config("COHERE_API_KEY", ""),
help="Cohere API key",
required=True,
)
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use Cohere Reranker model to re-order documents
with their relevance score"""
try:
import cohere
except ImportError:
raise ImportError(
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
)
if not self.cohere_api_key:
print("Cohere API key not found. Skipping rerankings.")
return documents
cohere_client = cohere.Client(self.cohere_api_key)
compressed_docs: list[Document] = []
if not documents: # to avoid empty api call
return compressed_docs
_docs = [d.content for d in documents]
response = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs
)
for r in response.results:
doc = documents[r.index]
doc.metadata["reranking_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs

View File

@@ -0,0 +1,77 @@
from __future__ import annotations
from typing import Optional
import requests
from kotaemon.base import Document, Param
from .base import BaseReranking
session = requests.session()
class TeiFastReranking(BaseReranking):
"""Text Embeddings Inference (TEI) Reranking model
(https://huggingface.co/docs/text-embeddings-inference/en/index)
"""
endpoint_url: str = Param(
None, help="TEI Reranking service api base URL", required=True
)
model_name: Optional[str] = Param(
None,
help=(
"ID of the model to use. You can go to [Supported Models]"
"(https://github.com/huggingface"
"/text-embeddings-inference?tab=readme-ov-file"
"#supported-models) to see the supported models"
),
)
is_truncated: Optional[bool] = Param(True, help="Whether to truncate the inputs")
def client(self, query, texts):
response = session.post(
url=self.endpoint_url,
json={
"query": query,
"texts": texts,
"is_truncated": self.is_truncated, # default is True
},
).json()
return response
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use the deployed TEI rerankings service to re-order documents
with their relevance score"""
if not self.endpoint_url:
print("TEI API reranking URL not found. Skipping rerankings.")
return documents
compressed_docs: list[Document] = []
if not documents: # to avoid empty api call
return compressed_docs
if isinstance(documents[0], str):
documents = self.prepare_input(documents)
batch_size = 6
num_batch = max(len(documents) // batch_size, 1)
for i in range(num_batch):
if i == num_batch - 1:
mini_batch = documents[batch_size * i :]
else:
mini_batch = documents[batch_size * i : batch_size * (i + 1)]
_docs = [d.content for d in mini_batch]
rerank_resp = self.client(query, _docs)
for r in rerank_resp:
doc = mini_batch[r["index"]]
doc.metadata["reranking_score"] = r["score"]
compressed_docs.append(doc)
compressed_docs = sorted(
compressed_docs, key=lambda x: x.metadata["reranking_score"], reverse=True
)
return compressed_docs