feat: support TEI embedding service, configurable reranking model (#287)
* feat: add support for TEI embedding service, allow reranking model to be configurable. Signed-off-by: Kennywu <jdlow@live.cn> * fix: add cohere default reranking model * fix: comfort pre-commit --------- Signed-off-by: Kennywu <jdlow@live.cn> Co-authored-by: wujiaye <wujiaye@bluemoon.com.cn> Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
parent
2e3c17b256
commit
53530e296f
|
@ -92,6 +92,7 @@ KH_VECTORSTORE = {
|
|||
}
|
||||
KH_LLMS = {}
|
||||
KH_EMBEDDINGS = {}
|
||||
KH_RERANKINGS = {}
|
||||
|
||||
# populate options from config
|
||||
if config("AZURE_OPENAI_API_KEY", default="") and config(
|
||||
|
@ -212,7 +213,7 @@ KH_LLMS["cohere"] = {
|
|||
"spec": {
|
||||
"__type__": "kotaemon.llms.chats.LCCohereChat",
|
||||
"model_name": "command-r-plus-08-2024",
|
||||
"api_key": "your-key",
|
||||
"api_key": config("COHERE_API_KEY", default="your-key"),
|
||||
},
|
||||
"default": False,
|
||||
}
|
||||
|
@ -222,7 +223,7 @@ KH_EMBEDDINGS["cohere"] = {
|
|||
"spec": {
|
||||
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
|
||||
"model": "embed-multilingual-v3.0",
|
||||
"cohere_api_key": "your-key",
|
||||
"cohere_api_key": config("COHERE_API_KEY", default="your-key"),
|
||||
"user_agent": "default",
|
||||
},
|
||||
"default": False,
|
||||
|
@ -235,6 +236,16 @@ KH_EMBEDDINGS["cohere"] = {
|
|||
# "default": False,
|
||||
# }
|
||||
|
||||
# default reranking models
|
||||
KH_RERANKINGS["cohere"] = {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.rerankings.CohereReranking",
|
||||
"model_name": "rerank-multilingual-v2.0",
|
||||
"cohere_api_key": config("COHERE_API_KEY", default="your-key"),
|
||||
},
|
||||
"default": True,
|
||||
}
|
||||
|
||||
KH_REASONINGS = [
|
||||
"ktem.reasoning.simple.FullQAPipeline",
|
||||
"ktem.reasoning.simple.FullDecomposeQAPipeline",
|
||||
|
|
|
@ -8,10 +8,12 @@ from .langchain_based import (
|
|||
LCOpenAIEmbeddings,
|
||||
)
|
||||
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from .tei_endpoint_embed import TeiEndpointEmbeddings
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddings",
|
||||
"EndpointEmbeddings",
|
||||
"TeiEndpointEmbeddings",
|
||||
"LCOpenAIEmbeddings",
|
||||
"LCAzureOpenAIEmbeddings",
|
||||
"LCCohereEmbeddings",
|
||||
|
|
105
libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py
Normal file
105
libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
import aiohttp
|
||||
import requests
|
||||
|
||||
from kotaemon.base import Document, DocumentWithEmbedding, Param
|
||||
|
||||
from .base import BaseEmbeddings
|
||||
|
||||
session = requests.session()
|
||||
|
||||
|
||||
class TeiEndpointEmbeddings(BaseEmbeddings):
|
||||
"""An Embeddings component that uses an
|
||||
TEI (Text-Embedding-Inference) API compatible endpoint.
|
||||
|
||||
Ref: https://github.com/huggingface/text-embeddings-inference
|
||||
|
||||
Attributes:
|
||||
endpoint_url (str): The url of an TEI
|
||||
(Text-Embedding-Inference) API compatible endpoint.
|
||||
normalize (bool): Whether to normalize embeddings to unit length.
|
||||
truncate (bool): Whether to truncate embeddings
|
||||
to a fixed/default length.
|
||||
"""
|
||||
|
||||
endpoint_url: str = Param(None, help="TEI embedding service api base URL")
|
||||
normalize: bool = Param(
|
||||
True,
|
||||
help="Normalize embeddings to unit length",
|
||||
)
|
||||
truncate: bool = Param(
|
||||
True,
|
||||
help="Truncate embeddings to a fixed/default length",
|
||||
)
|
||||
|
||||
async def client_(self, inputs: list[str]):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url=self.endpoint_url,
|
||||
json={
|
||||
"inputs": inputs,
|
||||
"normalize": self.normalize,
|
||||
"truncate": self.truncate,
|
||||
},
|
||||
) as resp:
|
||||
embeddings = await resp.json()
|
||||
return embeddings
|
||||
|
||||
async def ainvoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
text = self.prepare_input(text)
|
||||
|
||||
outputs = []
|
||||
batch_size = 6
|
||||
num_batch = max(len(text) // batch_size, 1)
|
||||
for i in range(num_batch):
|
||||
if i == num_batch - 1:
|
||||
mini_batch = text[batch_size * i :]
|
||||
else:
|
||||
mini_batch = text[batch_size * i : batch_size * (i + 1)]
|
||||
mini_batch = [x.content for x in mini_batch]
|
||||
embeddings = await self.client_(mini_batch) # type: ignore
|
||||
outputs.extend(
|
||||
[
|
||||
DocumentWithEmbedding(content=doc, embedding=embedding)
|
||||
for doc, embedding in zip(mini_batch, embeddings)
|
||||
]
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def invoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = self.prepare_input(text)
|
||||
|
||||
outputs = []
|
||||
batch_size = 6
|
||||
num_batch = max(len(text) // batch_size, 1)
|
||||
for i in range(num_batch):
|
||||
if i == num_batch - 1:
|
||||
mini_batch = text[batch_size * i :]
|
||||
else:
|
||||
mini_batch = text[batch_size * i : batch_size * (i + 1)]
|
||||
mini_batch = [x.content for x in mini_batch]
|
||||
embeddings = session.post(
|
||||
url=self.endpoint_url,
|
||||
json={
|
||||
"inputs": mini_batch,
|
||||
"normalize": self.normalize,
|
||||
"truncate": self.truncate,
|
||||
},
|
||||
).json()
|
||||
outputs.extend(
|
||||
[
|
||||
DocumentWithEmbedding(content=doc, embedding=embedding)
|
||||
for doc, embedding in zip(mini_batch, embeddings)
|
||||
]
|
||||
)
|
||||
return outputs
|
|
@ -39,7 +39,7 @@ class CohereReranking(BaseReranking):
|
|||
print("Cannot get Cohere API key from `ktem`", e)
|
||||
|
||||
if not self.cohere_api_key:
|
||||
print("Cohere API key not found. Skipping reranking.")
|
||||
print("Cohere API key not found. Skipping rerankings.")
|
||||
return documents
|
||||
|
||||
cohere_client = cohere.Client(self.cohere_api_key)
|
||||
|
@ -52,10 +52,9 @@ class CohereReranking(BaseReranking):
|
|||
response = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs
|
||||
)
|
||||
# print("Cohere score", [r.relevance_score for r in response.results])
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["cohere_reranking_score"] = r.relevance_score
|
||||
doc.metadata["reranking_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
||||
|
|
|
@ -241,7 +241,7 @@ class VectorRetrieval(BaseRetrieval):
|
|||
# if reranker is LLMReranking, limit the document with top_k items only
|
||||
if isinstance(reranker, LLMReranking):
|
||||
result = self._filter_docs(result, top_k=top_k)
|
||||
result = reranker(documents=result, query=text)
|
||||
result = reranker.run(documents=result, query=text)
|
||||
|
||||
result = self._filter_docs(result, top_k=top_k)
|
||||
print(f"Got raw {len(result)} retrieved documents")
|
||||
|
|
5
libs/kotaemon/kotaemon/rerankings/__init__.py
Normal file
5
libs/kotaemon/kotaemon/rerankings/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
from .base import BaseReranking
|
||||
from .cohere import CohereReranking
|
||||
from .tei_fast_rerank import TeiFastReranking
|
||||
|
||||
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"]
|
13
libs/kotaemon/kotaemon/rerankings/base.py
Normal file
13
libs/kotaemon/kotaemon/rerankings/base.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
|
||||
|
||||
class BaseReranking(BaseComponent):
|
||||
@abstractmethod
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Main method to transform list of documents
|
||||
(re-ranking, filtering, etc)"""
|
||||
...
|
56
libs/kotaemon/kotaemon/rerankings/cohere.py
Normal file
56
libs/kotaemon/kotaemon/rerankings/cohere.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from decouple import config
|
||||
|
||||
from kotaemon.base import Document, Param
|
||||
|
||||
from .base import BaseReranking
|
||||
|
||||
|
||||
class CohereReranking(BaseReranking):
|
||||
"""Cohere Reranking model"""
|
||||
|
||||
model_name: str = Param(
|
||||
"rerank-multilingual-v2.0",
|
||||
help=(
|
||||
"ID of the model to use. You can go to [Supported Models]"
|
||||
"(https://docs.cohere.com/docs/rerank-2) to see the supported models"
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
cohere_api_key: str = Param(
|
||||
config("COHERE_API_KEY", ""),
|
||||
help="Cohere API key",
|
||||
required=True,
|
||||
)
|
||||
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Use Cohere Reranker model to re-order documents
|
||||
with their relevance score"""
|
||||
try:
|
||||
import cohere
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||
)
|
||||
|
||||
if not self.cohere_api_key:
|
||||
print("Cohere API key not found. Skipping rerankings.")
|
||||
return documents
|
||||
|
||||
cohere_client = cohere.Client(self.cohere_api_key)
|
||||
compressed_docs: list[Document] = []
|
||||
|
||||
if not documents: # to avoid empty api call
|
||||
return compressed_docs
|
||||
|
||||
_docs = [d.content for d in documents]
|
||||
response = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs
|
||||
)
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["reranking_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
77
libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py
Normal file
77
libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from kotaemon.base import Document, Param
|
||||
|
||||
from .base import BaseReranking
|
||||
|
||||
session = requests.session()
|
||||
|
||||
|
||||
class TeiFastReranking(BaseReranking):
|
||||
"""Text Embeddings Inference (TEI) Reranking model
|
||||
(https://huggingface.co/docs/text-embeddings-inference/en/index)
|
||||
"""
|
||||
|
||||
endpoint_url: str = Param(
|
||||
None, help="TEI Reranking service api base URL", required=True
|
||||
)
|
||||
model_name: Optional[str] = Param(
|
||||
None,
|
||||
help=(
|
||||
"ID of the model to use. You can go to [Supported Models]"
|
||||
"(https://github.com/huggingface"
|
||||
"/text-embeddings-inference?tab=readme-ov-file"
|
||||
"#supported-models) to see the supported models"
|
||||
),
|
||||
)
|
||||
is_truncated: Optional[bool] = Param(True, help="Whether to truncate the inputs")
|
||||
|
||||
def client(self, query, texts):
|
||||
response = session.post(
|
||||
url=self.endpoint_url,
|
||||
json={
|
||||
"query": query,
|
||||
"texts": texts,
|
||||
"is_truncated": self.is_truncated, # default is True
|
||||
},
|
||||
).json()
|
||||
return response
|
||||
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Use the deployed TEI rerankings service to re-order documents
|
||||
with their relevance score"""
|
||||
if not self.endpoint_url:
|
||||
print("TEI API reranking URL not found. Skipping rerankings.")
|
||||
return documents
|
||||
|
||||
compressed_docs: list[Document] = []
|
||||
|
||||
if not documents: # to avoid empty api call
|
||||
return compressed_docs
|
||||
|
||||
if isinstance(documents[0], str):
|
||||
documents = self.prepare_input(documents)
|
||||
|
||||
batch_size = 6
|
||||
num_batch = max(len(documents) // batch_size, 1)
|
||||
for i in range(num_batch):
|
||||
if i == num_batch - 1:
|
||||
mini_batch = documents[batch_size * i :]
|
||||
else:
|
||||
mini_batch = documents[batch_size * i : batch_size * (i + 1)]
|
||||
|
||||
_docs = [d.content for d in mini_batch]
|
||||
rerank_resp = self.client(query, _docs)
|
||||
for r in rerank_resp:
|
||||
doc = mini_batch[r["index"]]
|
||||
doc.metadata["reranking_score"] = r["score"]
|
||||
compressed_docs.append(doc)
|
||||
|
||||
compressed_docs = sorted(
|
||||
compressed_docs, key=lambda x: x.metadata["reranking_score"], reverse=True
|
||||
)
|
||||
return compressed_docs
|
|
@ -59,6 +59,7 @@ class EmbeddingManager:
|
|||
LCCohereEmbeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
OpenAIEmbeddings,
|
||||
TeiEndpointEmbeddings,
|
||||
)
|
||||
|
||||
self._vendors = [
|
||||
|
@ -67,6 +68,7 @@ class EmbeddingManager:
|
|||
FastEmbedEmbeddings,
|
||||
LCCohereEmbeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
TeiEndpointEmbeddings,
|
||||
]
|
||||
|
||||
def __getitem__(self, key: str) -> BaseEmbeddings:
|
||||
|
|
|
@ -16,6 +16,7 @@ import tiktoken
|
|||
from ktem.db.models import engine
|
||||
from ktem.embeddings.manager import embedding_models_manager
|
||||
from ktem.llms.manager import llms
|
||||
from ktem.rerankings.manager import reranking_models_manager
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.readers.file.base import default_file_metadata_func
|
||||
from llama_index.core.vector_stores import (
|
||||
|
@ -39,12 +40,7 @@ from kotaemon.indices.ingests.files import (
|
|||
azure_reader,
|
||||
unstructured,
|
||||
)
|
||||
from kotaemon.indices.rankings import (
|
||||
BaseReranking,
|
||||
CohereReranking,
|
||||
LLMReranking,
|
||||
LLMTrulensScoring,
|
||||
)
|
||||
from kotaemon.indices.rankings import BaseReranking, LLMReranking, LLMTrulensScoring
|
||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||
|
||||
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||
|
@ -285,7 +281,13 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
|||
],
|
||||
retrieval_mode=user_settings["retrieval_mode"],
|
||||
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
|
||||
rerankers=[CohereReranking(use_key_from_ktem=True)],
|
||||
rerankers=[
|
||||
reranking_models_manager[
|
||||
index_settings.get(
|
||||
"reranking", reranking_models_manager.get_default_name()
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
if not user_settings["use_reranking"]:
|
||||
retriever.rerankers = [] # type: ignore
|
||||
|
|
|
@ -4,6 +4,7 @@ from ktem.db.models import User, engine
|
|||
from ktem.embeddings.ui import EmbeddingManagement
|
||||
from ktem.index.ui import IndexManagement
|
||||
from ktem.llms.ui import LLMManagement
|
||||
from ktem.rerankings.ui import RerankingManagement
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .user import UserManagement
|
||||
|
@ -24,6 +25,9 @@ class ResourcesTab(BasePage):
|
|||
with gr.Tab("Embeddings") as self.emb_management_tab:
|
||||
self.emb_management = EmbeddingManagement(self._app)
|
||||
|
||||
with gr.Tab("Rerankings") as self.rerank_management_tab:
|
||||
self.rerank_management = RerankingManagement(self._app)
|
||||
|
||||
if self._app.f_user_management:
|
||||
with gr.Tab("Users", visible=False) as self.user_management_tab:
|
||||
self.user_management = UserManagement(self._app)
|
||||
|
|
|
@ -5,6 +5,7 @@ import requests
|
|||
from ktem.app import BasePage
|
||||
from ktem.embeddings.manager import embedding_models_manager as embeddings
|
||||
from ktem.llms.manager import llms
|
||||
from ktem.rerankings.manager import reranking_models_manager as rerankers
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
||||
|
@ -186,6 +187,15 @@ class SetupPage(BasePage):
|
|||
},
|
||||
default=True,
|
||||
)
|
||||
rerankers.update(
|
||||
name="cohere",
|
||||
spec={
|
||||
"__type__": "kotaemon.rerankings.CohereReranking",
|
||||
"model_name": "rerank-multilingual-v2.0",
|
||||
"cohere_api_key": cohere_api_key,
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
elif radio_model_value == "openai":
|
||||
if openai_api_key:
|
||||
llms.update(
|
||||
|
|
|
@ -100,7 +100,7 @@ class DocSearchTool(BaseTool):
|
|||
)
|
||||
|
||||
print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
|
||||
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
|
||||
print("Score", retrieved_item.metadata.get("reranking_score", None))
|
||||
|
||||
# trim context by trim_len
|
||||
if evidence:
|
||||
|
|
|
@ -138,7 +138,7 @@ class DocSearchTool(BaseTool):
|
|||
)
|
||||
|
||||
print("Retrieved #{}: {}".format(_id, retrieved_content))
|
||||
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
|
||||
print("Score", retrieved_item.metadata.get("reranking_score", None))
|
||||
|
||||
# trim context by trim_len
|
||||
if evidence:
|
||||
|
|
0
libs/ktem/ktem/rerankings/__init__.py
Normal file
0
libs/ktem/ktem/rerankings/__init__.py
Normal file
36
libs/ktem/ktem/rerankings/db.py
Normal file
36
libs/ktem/ktem/rerankings/db.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from typing import Type
|
||||
|
||||
from ktem.db.engine import engine
|
||||
from sqlalchemy import JSON, Boolean, Column, String
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from theflow.settings import settings as flowsettings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class BaseRerankingTable(Base):
|
||||
"""Base table to store rerankings model"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
name = Column(String, primary_key=True, unique=True)
|
||||
spec = Column(JSON, default={})
|
||||
default = Column(Boolean, default=False)
|
||||
|
||||
|
||||
__base_reranking: Type[BaseRerankingTable] = (
|
||||
import_dotted_string(flowsettings.KH_TABLE_RERANKING, safe=False)
|
||||
if hasattr(flowsettings, "KH_TABLE_RERANKING")
|
||||
else BaseRerankingTable
|
||||
)
|
||||
|
||||
|
||||
class RerankingTable(__base_reranking): # type: ignore
|
||||
__tablename__ = "reranking"
|
||||
|
||||
|
||||
if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
|
||||
RerankingTable.metadata.create_all(engine)
|
194
libs/ktem/ktem/rerankings/manager.py
Normal file
194
libs/ktem/ktem/rerankings/manager.py
Normal file
|
@ -0,0 +1,194 @@
|
|||
from typing import Optional, Type
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from theflow.settings import settings as flowsettings
|
||||
from theflow.utils.modules import deserialize
|
||||
|
||||
from kotaemon.rerankings.base import BaseReranking
|
||||
|
||||
from .db import RerankingTable, engine
|
||||
|
||||
|
||||
class RerankingManager:
|
||||
"""Represent a pool of rerankings models"""
|
||||
|
||||
def __init__(self):
|
||||
self._models: dict[str, BaseReranking] = {}
|
||||
self._info: dict[str, dict] = {}
|
||||
self._default: str = ""
|
||||
self._vendors: list[Type] = []
|
||||
|
||||
# populate the pool if empty
|
||||
if hasattr(flowsettings, "KH_RERANKINGS"):
|
||||
with Session(engine) as sess:
|
||||
count = sess.query(RerankingTable).count()
|
||||
if not count:
|
||||
for name, model in flowsettings.KH_RERANKINGS.items():
|
||||
self.add(
|
||||
name=name,
|
||||
spec=model["spec"],
|
||||
default=model.get("default", False),
|
||||
)
|
||||
|
||||
self.load()
|
||||
self.load_vendors()
|
||||
|
||||
def load(self):
|
||||
"""Load the model pool from database"""
|
||||
self._models, self._info, self._default = {}, {}, ""
|
||||
with Session(engine) as sess:
|
||||
stmt = select(RerankingTable)
|
||||
items = sess.execute(stmt)
|
||||
|
||||
for (item,) in items:
|
||||
self._models[item.name] = deserialize(item.spec, safe=False)
|
||||
self._info[item.name] = {
|
||||
"name": item.name,
|
||||
"spec": item.spec,
|
||||
"default": item.default,
|
||||
}
|
||||
if item.default:
|
||||
self._default = item.name
|
||||
|
||||
def load_vendors(self):
|
||||
from kotaemon.rerankings import CohereReranking, TeiFastReranking
|
||||
|
||||
self._vendors = [TeiFastReranking, CohereReranking]
|
||||
|
||||
def __getitem__(self, key: str) -> BaseReranking:
|
||||
"""Get model by name"""
|
||||
return self._models[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if model exists"""
|
||||
return key in self._models
|
||||
|
||||
def get(
|
||||
self, key: str, default: Optional[BaseReranking] = None
|
||||
) -> Optional[BaseReranking]:
|
||||
"""Get model by name with default value"""
|
||||
return self._models.get(key, default)
|
||||
|
||||
def settings(self) -> dict:
|
||||
"""Present model pools option for gradio"""
|
||||
return {
|
||||
"label": "Reranking",
|
||||
"choices": list(self._models.keys()),
|
||||
"value": self.get_default_name(),
|
||||
}
|
||||
|
||||
def options(self) -> dict:
|
||||
"""Present a dict of models"""
|
||||
return self._models
|
||||
|
||||
def get_random_name(self) -> str:
|
||||
"""Get the name of random model
|
||||
|
||||
Returns:
|
||||
str: random model name in the pool
|
||||
"""
|
||||
import random
|
||||
|
||||
if not self._models:
|
||||
raise ValueError("No models is pool")
|
||||
|
||||
return random.choice(list(self._models.keys()))
|
||||
|
||||
def get_default_name(self) -> str:
|
||||
"""Get the name of default model
|
||||
|
||||
In case there is no default model, choose random model from pool. In
|
||||
case there are multiple default models, choose random from them.
|
||||
|
||||
Returns:
|
||||
str: model name
|
||||
"""
|
||||
if not self._models:
|
||||
raise ValueError("No models in pool")
|
||||
|
||||
if not self._default:
|
||||
return self.get_random_name()
|
||||
|
||||
return self._default
|
||||
|
||||
def get_random(self) -> BaseReranking:
|
||||
"""Get random model"""
|
||||
return self._models[self.get_random_name()]
|
||||
|
||||
def get_default(self) -> BaseReranking:
|
||||
"""Get default model
|
||||
|
||||
In case there is no default model, choose random model from pool. In
|
||||
case there are multiple default models, choose random from them.
|
||||
|
||||
Returns:
|
||||
BaseReranking: model
|
||||
"""
|
||||
return self._models[self.get_default_name()]
|
||||
|
||||
def info(self) -> dict:
|
||||
"""List all models"""
|
||||
return self._info
|
||||
|
||||
def add(self, name: str, spec: dict, default: bool):
|
||||
if not name:
|
||||
raise ValueError("Name must not be empty")
|
||||
|
||||
try:
|
||||
with Session(engine) as sess:
|
||||
if default:
|
||||
# turn all models to non-default
|
||||
sess.query(RerankingTable).update({"default": False})
|
||||
sess.commit()
|
||||
|
||||
item = RerankingTable(name=name, spec=spec, default=default)
|
||||
sess.add(item)
|
||||
sess.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to add model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def delete(self, name: str):
|
||||
"""Delete a model from the pool"""
|
||||
try:
|
||||
with Session(engine) as sess:
|
||||
item = sess.query(RerankingTable).filter_by(name=name).first()
|
||||
sess.delete(item)
|
||||
sess.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to delete model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def update(self, name: str, spec: dict, default: bool):
|
||||
"""Update a model in the pool"""
|
||||
if not name:
|
||||
raise ValueError("Name must not be empty")
|
||||
|
||||
try:
|
||||
with Session(engine) as sess:
|
||||
|
||||
if default:
|
||||
# turn all models to non-default
|
||||
sess.query(RerankingTable).update({"default": False})
|
||||
sess.commit()
|
||||
|
||||
item = sess.query(RerankingTable).filter_by(name=name).first()
|
||||
if not item:
|
||||
raise ValueError(f"Model {name} not found")
|
||||
item.spec = spec
|
||||
item.default = default
|
||||
sess.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def vendors(self) -> dict:
|
||||
"""Return list of vendors"""
|
||||
return {vendor.__qualname__: vendor for vendor in self._vendors}
|
||||
|
||||
|
||||
reranking_models_manager = RerankingManager()
|
390
libs/ktem/ktem/rerankings/ui.py
Normal file
390
libs/ktem/ktem/rerankings/ui.py
Normal file
|
@ -0,0 +1,390 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from ktem.app import BasePage
|
||||
from ktem.utils.file import YAMLNoDateSafeLoader
|
||||
from theflow.utils.modules import deserialize
|
||||
|
||||
from .manager import reranking_models_manager
|
||||
|
||||
|
||||
def format_description(cls):
|
||||
params = cls.describe()["params"]
|
||||
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
|
||||
for key, value in params.items():
|
||||
if isinstance(value["auto_callback"], str):
|
||||
continue
|
||||
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
|
||||
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
|
||||
|
||||
|
||||
class RerankingManagement(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.spec_desc_default = (
|
||||
"# Spec description\n\nSelect a model to view the spec description."
|
||||
)
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Tab(label="View"):
|
||||
self.rerank_list = gr.DataFrame(
|
||||
headers=["name", "vendor", "default"],
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Column(visible=False) as self._selected_panel:
|
||||
self.selected_rerank_name = gr.Textbox(value="", visible=False)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.edit_default = gr.Checkbox(
|
||||
label="Set default",
|
||||
info=(
|
||||
"Set this Reranking model as default. This default "
|
||||
"Reranking will be used by other components by default "
|
||||
"if no Reranking is specified for such components."
|
||||
),
|
||||
)
|
||||
self.edit_spec = gr.Textbox(
|
||||
label="Specification",
|
||||
info="Specification of the Embedding model in YAML format",
|
||||
lines=10,
|
||||
)
|
||||
|
||||
with gr.Accordion(
|
||||
label="Test connection", visible=False, open=False
|
||||
) as self._check_connection_panel:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
self.connection_logs = gr.HTML(
|
||||
"Logs",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
self.btn_test_connection = gr.Button("Test")
|
||||
|
||||
with gr.Row(visible=False) as self._selected_panel_btn:
|
||||
with gr.Column():
|
||||
self.btn_edit_save = gr.Button(
|
||||
"Save", min_width=10, variant="primary"
|
||||
)
|
||||
with gr.Column():
|
||||
self.btn_delete = gr.Button(
|
||||
"Delete", min_width=10, variant="stop"
|
||||
)
|
||||
with gr.Row():
|
||||
self.btn_delete_yes = gr.Button(
|
||||
"Confirm Delete",
|
||||
variant="stop",
|
||||
visible=False,
|
||||
min_width=10,
|
||||
)
|
||||
self.btn_delete_no = gr.Button(
|
||||
"Cancel", visible=False, min_width=10
|
||||
)
|
||||
with gr.Column():
|
||||
self.btn_close = gr.Button("Close", min_width=10)
|
||||
|
||||
with gr.Column():
|
||||
self.edit_spec_desc = gr.Markdown("# Spec description")
|
||||
|
||||
with gr.Tab(label="Add"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
self.name = gr.Textbox(
|
||||
label="Name",
|
||||
info=(
|
||||
"Must be unique and non-empty. "
|
||||
"The name will be used to identify the reranking model."
|
||||
),
|
||||
)
|
||||
self.rerank_choices = gr.Dropdown(
|
||||
label="Vendors",
|
||||
info=(
|
||||
"Choose the vendor of the Reranking model. Each vendor "
|
||||
"has different specification."
|
||||
),
|
||||
)
|
||||
self.spec = gr.Textbox(
|
||||
label="Specification",
|
||||
info="Specification of the Embedding model in YAML format.",
|
||||
)
|
||||
self.default = gr.Checkbox(
|
||||
label="Set default",
|
||||
info=(
|
||||
"Set this Reranking model as default. This default "
|
||||
"Reranking will be used by other components by default "
|
||||
"if no Reranking is specified for such components."
|
||||
),
|
||||
)
|
||||
self.btn_new = gr.Button("Add", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
self.spec_desc = gr.Markdown(self.spec_desc_default)
|
||||
|
||||
def _on_app_created(self):
|
||||
"""Called when the app is created"""
|
||||
self._app.app.load(
|
||||
self.list_rerankings,
|
||||
inputs=[],
|
||||
outputs=[self.rerank_list],
|
||||
)
|
||||
self._app.app.load(
|
||||
lambda: gr.update(choices=list(reranking_models_manager.vendors().keys())),
|
||||
outputs=[self.rerank_choices],
|
||||
)
|
||||
|
||||
def on_rerank_vendor_change(self, vendor):
|
||||
vendor = reranking_models_manager.vendors()[vendor]
|
||||
|
||||
required: dict = {}
|
||||
desc = vendor.describe()
|
||||
for key, value in desc["params"].items():
|
||||
if value.get("required", False):
|
||||
required[key] = value.get("default", None)
|
||||
|
||||
return yaml.dump(required), format_description(vendor)
|
||||
|
||||
def on_register_events(self):
|
||||
self.rerank_choices.select(
|
||||
self.on_rerank_vendor_change,
|
||||
inputs=[self.rerank_choices],
|
||||
outputs=[self.spec, self.spec_desc],
|
||||
)
|
||||
self.btn_new.click(
|
||||
self.create_rerank,
|
||||
inputs=[self.name, self.rerank_choices, self.spec, self.default],
|
||||
outputs=None,
|
||||
).success(self.list_rerankings, inputs=[], outputs=[self.rerank_list]).success(
|
||||
lambda: ("", None, "", False, self.spec_desc_default),
|
||||
outputs=[
|
||||
self.name,
|
||||
self.rerank_choices,
|
||||
self.spec,
|
||||
self.default,
|
||||
self.spec_desc,
|
||||
],
|
||||
)
|
||||
self.rerank_list.select(
|
||||
self.select_rerank,
|
||||
inputs=self.rerank_list,
|
||||
outputs=[self.selected_rerank_name],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.selected_rerank_name.change(
|
||||
self.on_selected_rerank_change,
|
||||
inputs=[self.selected_rerank_name],
|
||||
outputs=[
|
||||
self._selected_panel,
|
||||
self._selected_panel_btn,
|
||||
# delete section
|
||||
self.btn_delete,
|
||||
self.btn_delete_yes,
|
||||
self.btn_delete_no,
|
||||
# edit section
|
||||
self.edit_spec,
|
||||
self.edit_spec_desc,
|
||||
self.edit_default,
|
||||
self._check_connection_panel,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).success(lambda: gr.update(value=""), outputs=[self.connection_logs])
|
||||
|
||||
self.btn_delete.click(
|
||||
self.on_btn_delete_click,
|
||||
inputs=[],
|
||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_delete_yes.click(
|
||||
self.delete_rerank,
|
||||
inputs=[self.selected_rerank_name],
|
||||
outputs=[self.selected_rerank_name],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.list_rerankings,
|
||||
inputs=[],
|
||||
outputs=[self.rerank_list],
|
||||
)
|
||||
self.btn_delete_no.click(
|
||||
lambda: (
|
||||
gr.update(visible=True),
|
||||
gr.update(visible=False),
|
||||
gr.update(visible=False),
|
||||
),
|
||||
inputs=[],
|
||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_edit_save.click(
|
||||
self.save_rerank,
|
||||
inputs=[
|
||||
self.selected_rerank_name,
|
||||
self.edit_default,
|
||||
self.edit_spec,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.list_rerankings,
|
||||
inputs=[],
|
||||
outputs=[self.rerank_list],
|
||||
)
|
||||
self.btn_close.click(lambda: "", outputs=[self.selected_rerank_name])
|
||||
|
||||
self.btn_test_connection.click(
|
||||
self.check_connection,
|
||||
inputs=[self.selected_rerank_name, self.edit_spec],
|
||||
outputs=[self.connection_logs],
|
||||
)
|
||||
|
||||
def create_rerank(self, name, choices, spec, default):
|
||||
try:
|
||||
spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
|
||||
spec["__type__"] = (
|
||||
reranking_models_manager.vendors()[choices].__module__
|
||||
+ "."
|
||||
+ reranking_models_manager.vendors()[choices].__qualname__
|
||||
)
|
||||
|
||||
reranking_models_manager.add(name, spec=spec, default=default)
|
||||
gr.Info(f'Create Reranking model "{name}" successfully')
|
||||
except Exception as e:
|
||||
raise gr.Error(f"Failed to create Reranking model {name}: {e}")
|
||||
|
||||
def list_rerankings(self):
|
||||
"""List the Reranking models"""
|
||||
items = []
|
||||
for item in reranking_models_manager.info().values():
|
||||
record = {}
|
||||
record["name"] = item["name"]
|
||||
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
|
||||
record["default"] = item["default"]
|
||||
items.append(record)
|
||||
|
||||
if items:
|
||||
rerank_list = pd.DataFrame.from_records(items)
|
||||
else:
|
||||
rerank_list = pd.DataFrame.from_records(
|
||||
[{"name": "-", "vendor": "-", "default": "-"}]
|
||||
)
|
||||
|
||||
return rerank_list
|
||||
|
||||
def select_rerank(self, rerank_list, ev: gr.SelectData):
|
||||
if ev.value == "-" and ev.index[0] == 0:
|
||||
gr.Info("No reranking model is loaded. Please add first")
|
||||
return ""
|
||||
|
||||
if not ev.selected:
|
||||
return ""
|
||||
|
||||
return rerank_list["name"][ev.index[0]]
|
||||
|
||||
def on_selected_rerank_change(self, selected_rerank_name):
|
||||
if selected_rerank_name == "":
|
||||
_check_connection_panel = gr.update(visible=False)
|
||||
_selected_panel = gr.update(visible=False)
|
||||
_selected_panel_btn = gr.update(visible=False)
|
||||
btn_delete = gr.update(visible=True)
|
||||
btn_delete_yes = gr.update(visible=False)
|
||||
btn_delete_no = gr.update(visible=False)
|
||||
edit_spec = gr.update(value="")
|
||||
edit_spec_desc = gr.update(value="")
|
||||
edit_default = gr.update(value=False)
|
||||
else:
|
||||
_check_connection_panel = gr.update(visible=True)
|
||||
_selected_panel = gr.update(visible=True)
|
||||
_selected_panel_btn = gr.update(visible=True)
|
||||
btn_delete = gr.update(visible=True)
|
||||
btn_delete_yes = gr.update(visible=False)
|
||||
btn_delete_no = gr.update(visible=False)
|
||||
|
||||
info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
|
||||
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
|
||||
vendor = reranking_models_manager.vendors()[vendor_str]
|
||||
|
||||
edit_spec = yaml.dump(info["spec"])
|
||||
edit_spec_desc = format_description(vendor)
|
||||
edit_default = info["default"]
|
||||
|
||||
return (
|
||||
_selected_panel,
|
||||
_selected_panel_btn,
|
||||
btn_delete,
|
||||
btn_delete_yes,
|
||||
btn_delete_no,
|
||||
edit_spec,
|
||||
edit_spec_desc,
|
||||
edit_default,
|
||||
_check_connection_panel,
|
||||
)
|
||||
|
||||
def on_btn_delete_click(self):
|
||||
btn_delete = gr.update(visible=False)
|
||||
btn_delete_yes = gr.update(visible=True)
|
||||
btn_delete_no = gr.update(visible=True)
|
||||
|
||||
return btn_delete, btn_delete_yes, btn_delete_no
|
||||
|
||||
def check_connection(self, selected_rerank_name, selected_spec):
|
||||
log_content: str = ""
|
||||
try:
|
||||
log_content += f"- Testing model: {selected_rerank_name}<br>"
|
||||
yield log_content
|
||||
|
||||
# Parse content & init model
|
||||
info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
|
||||
|
||||
# Parse content & create dummy response
|
||||
spec = yaml.load(selected_spec, Loader=YAMLNoDateSafeLoader)
|
||||
info["spec"].update(spec)
|
||||
|
||||
rerank = deserialize(info["spec"], safe=False)
|
||||
|
||||
if rerank is None:
|
||||
raise Exception(f"Can not found model: {selected_rerank_name}")
|
||||
|
||||
log_content += "- Sending a message ([`Hello`], `Hi`)<br>"
|
||||
yield log_content
|
||||
_ = rerank(["Hello"], "Hi")
|
||||
|
||||
log_content += (
|
||||
"<mark style='background: green; color: white'>- Connection success. "
|
||||
"</mark><br>"
|
||||
)
|
||||
yield log_content
|
||||
|
||||
gr.Info(f"Embedding {selected_rerank_name} connect successfully")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log_content += (
|
||||
f"<mark style='color: yellow; background: red'>- Connection failed. "
|
||||
f"Got error:\n {str(e)}</mark>"
|
||||
)
|
||||
yield log_content
|
||||
|
||||
return log_content
|
||||
|
||||
def save_rerank(self, selected_rerank_name, default, spec):
|
||||
try:
|
||||
spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
|
||||
spec["__type__"] = reranking_models_manager.info()[selected_rerank_name][
|
||||
"spec"
|
||||
]["__type__"]
|
||||
reranking_models_manager.update(
|
||||
selected_rerank_name, spec=spec, default=default
|
||||
)
|
||||
gr.Info(f'Save Reranking model "{selected_rerank_name}" successfully')
|
||||
except Exception as e:
|
||||
gr.Error(f'Failed to save Embedding model "{selected_rerank_name}": {e}')
|
||||
|
||||
def delete_rerank(self, selected_rerank_name):
|
||||
try:
|
||||
reranking_models_manager.delete(selected_rerank_name)
|
||||
except Exception as e:
|
||||
gr.Error(f'Failed to delete Reranking model "{selected_rerank_name}": {e}')
|
||||
return selected_rerank_name
|
||||
|
||||
return ""
|
|
@ -154,9 +154,9 @@ class Render:
|
|||
if doc.metadata.get("llm_trulens_score") is not None
|
||||
else 0.0
|
||||
)
|
||||
cohere_reranking_score = (
|
||||
round(doc.metadata["cohere_reranking_score"], 2)
|
||||
if doc.metadata.get("cohere_reranking_score") is not None
|
||||
reranking_score = (
|
||||
round(doc.metadata["reranking_score"], 2)
|
||||
if doc.metadata.get("reranking_score") is not None
|
||||
else 0.0
|
||||
)
|
||||
item_type_prefix = doc.metadata.get("type", "")
|
||||
|
@ -166,8 +166,8 @@ class Render:
|
|||
|
||||
if llm_reranking_score > 0:
|
||||
relevant_score = llm_reranking_score
|
||||
elif cohere_reranking_score > 0:
|
||||
relevant_score = cohere_reranking_score
|
||||
elif reranking_score > 0:
|
||||
relevant_score = reranking_score
|
||||
else:
|
||||
relevant_score = 0.0
|
||||
|
||||
|
@ -179,7 +179,7 @@ class Render:
|
|||
"<b>  LLM relevant score:</b>"
|
||||
f" {llm_reranking_score}<br>"
|
||||
"<b>  Reranking score:</b>"
|
||||
f" {cohere_reranking_score}<br>",
|
||||
f" {reranking_score}<br>",
|
||||
)
|
||||
|
||||
text = doc.text if not override_text else override_text
|
||||
|
|
Loading…
Reference in New Issue
Block a user