feat: support TEI embedding service, configurable reranking model (#287)

* feat: add support for TEI embedding service, allow reranking model to be configurable.

Signed-off-by: Kennywu <jdlow@live.cn>

* fix: add cohere default reranking model

* fix: comfort pre-commit

---------

Signed-off-by: Kennywu <jdlow@live.cn>
Co-authored-by: wujiaye <wujiaye@bluemoon.com.cn>
Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
KennyWu 2024-09-30 23:00:00 +08:00 committed by GitHub
parent 2e3c17b256
commit 53530e296f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 928 additions and 22 deletions

View File

@ -92,6 +92,7 @@ KH_VECTORSTORE = {
} }
KH_LLMS = {} KH_LLMS = {}
KH_EMBEDDINGS = {} KH_EMBEDDINGS = {}
KH_RERANKINGS = {}
# populate options from config # populate options from config
if config("AZURE_OPENAI_API_KEY", default="") and config( if config("AZURE_OPENAI_API_KEY", default="") and config(
@ -212,7 +213,7 @@ KH_LLMS["cohere"] = {
"spec": { "spec": {
"__type__": "kotaemon.llms.chats.LCCohereChat", "__type__": "kotaemon.llms.chats.LCCohereChat",
"model_name": "command-r-plus-08-2024", "model_name": "command-r-plus-08-2024",
"api_key": "your-key", "api_key": config("COHERE_API_KEY", default="your-key"),
}, },
"default": False, "default": False,
} }
@ -222,7 +223,7 @@ KH_EMBEDDINGS["cohere"] = {
"spec": { "spec": {
"__type__": "kotaemon.embeddings.LCCohereEmbeddings", "__type__": "kotaemon.embeddings.LCCohereEmbeddings",
"model": "embed-multilingual-v3.0", "model": "embed-multilingual-v3.0",
"cohere_api_key": "your-key", "cohere_api_key": config("COHERE_API_KEY", default="your-key"),
"user_agent": "default", "user_agent": "default",
}, },
"default": False, "default": False,
@ -235,6 +236,16 @@ KH_EMBEDDINGS["cohere"] = {
# "default": False, # "default": False,
# } # }
# default reranking models
KH_RERANKINGS["cohere"] = {
"spec": {
"__type__": "kotaemon.rerankings.CohereReranking",
"model_name": "rerank-multilingual-v2.0",
"cohere_api_key": config("COHERE_API_KEY", default="your-key"),
},
"default": True,
}
KH_REASONINGS = [ KH_REASONINGS = [
"ktem.reasoning.simple.FullQAPipeline", "ktem.reasoning.simple.FullQAPipeline",
"ktem.reasoning.simple.FullDecomposeQAPipeline", "ktem.reasoning.simple.FullDecomposeQAPipeline",

View File

@ -8,10 +8,12 @@ from .langchain_based import (
LCOpenAIEmbeddings, LCOpenAIEmbeddings,
) )
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from .tei_endpoint_embed import TeiEndpointEmbeddings
__all__ = [ __all__ = [
"BaseEmbeddings", "BaseEmbeddings",
"EndpointEmbeddings", "EndpointEmbeddings",
"TeiEndpointEmbeddings",
"LCOpenAIEmbeddings", "LCOpenAIEmbeddings",
"LCAzureOpenAIEmbeddings", "LCAzureOpenAIEmbeddings",
"LCCohereEmbeddings", "LCCohereEmbeddings",

View File

@ -0,0 +1,105 @@
import aiohttp
import requests
from kotaemon.base import Document, DocumentWithEmbedding, Param
from .base import BaseEmbeddings
session = requests.session()
class TeiEndpointEmbeddings(BaseEmbeddings):
"""An Embeddings component that uses an
TEI (Text-Embedding-Inference) API compatible endpoint.
Ref: https://github.com/huggingface/text-embeddings-inference
Attributes:
endpoint_url (str): The url of an TEI
(Text-Embedding-Inference) API compatible endpoint.
normalize (bool): Whether to normalize embeddings to unit length.
truncate (bool): Whether to truncate embeddings
to a fixed/default length.
"""
endpoint_url: str = Param(None, help="TEI embedding service api base URL")
normalize: bool = Param(
True,
help="Normalize embeddings to unit length",
)
truncate: bool = Param(
True,
help="Truncate embeddings to a fixed/default length",
)
async def client_(self, inputs: list[str]):
async with aiohttp.ClientSession() as session:
async with session.post(
url=self.endpoint_url,
json={
"inputs": inputs,
"normalize": self.normalize,
"truncate": self.truncate,
},
) as resp:
embeddings = await resp.json()
return embeddings
async def ainvoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
if not isinstance(text, list):
text = [text]
text = self.prepare_input(text)
outputs = []
batch_size = 6
num_batch = max(len(text) // batch_size, 1)
for i in range(num_batch):
if i == num_batch - 1:
mini_batch = text[batch_size * i :]
else:
mini_batch = text[batch_size * i : batch_size * (i + 1)]
mini_batch = [x.content for x in mini_batch]
embeddings = await self.client_(mini_batch) # type: ignore
outputs.extend(
[
DocumentWithEmbedding(content=doc, embedding=embedding)
for doc, embedding in zip(mini_batch, embeddings)
]
)
return outputs
def invoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
if not isinstance(text, list):
text = [text]
text = self.prepare_input(text)
outputs = []
batch_size = 6
num_batch = max(len(text) // batch_size, 1)
for i in range(num_batch):
if i == num_batch - 1:
mini_batch = text[batch_size * i :]
else:
mini_batch = text[batch_size * i : batch_size * (i + 1)]
mini_batch = [x.content for x in mini_batch]
embeddings = session.post(
url=self.endpoint_url,
json={
"inputs": mini_batch,
"normalize": self.normalize,
"truncate": self.truncate,
},
).json()
outputs.extend(
[
DocumentWithEmbedding(content=doc, embedding=embedding)
for doc, embedding in zip(mini_batch, embeddings)
]
)
return outputs

View File

@ -39,7 +39,7 @@ class CohereReranking(BaseReranking):
print("Cannot get Cohere API key from `ktem`", e) print("Cannot get Cohere API key from `ktem`", e)
if not self.cohere_api_key: if not self.cohere_api_key:
print("Cohere API key not found. Skipping reranking.") print("Cohere API key not found. Skipping rerankings.")
return documents return documents
cohere_client = cohere.Client(self.cohere_api_key) cohere_client = cohere.Client(self.cohere_api_key)
@ -52,10 +52,9 @@ class CohereReranking(BaseReranking):
response = cohere_client.rerank( response = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs model=self.model_name, query=query, documents=_docs
) )
# print("Cohere score", [r.relevance_score for r in response.results])
for r in response.results: for r in response.results:
doc = documents[r.index] doc = documents[r.index]
doc.metadata["cohere_reranking_score"] = r.relevance_score doc.metadata["reranking_score"] = r.relevance_score
compressed_docs.append(doc) compressed_docs.append(doc)
return compressed_docs return compressed_docs

View File

@ -241,7 +241,7 @@ class VectorRetrieval(BaseRetrieval):
# if reranker is LLMReranking, limit the document with top_k items only # if reranker is LLMReranking, limit the document with top_k items only
if isinstance(reranker, LLMReranking): if isinstance(reranker, LLMReranking):
result = self._filter_docs(result, top_k=top_k) result = self._filter_docs(result, top_k=top_k)
result = reranker(documents=result, query=text) result = reranker.run(documents=result, query=text)
result = self._filter_docs(result, top_k=top_k) result = self._filter_docs(result, top_k=top_k)
print(f"Got raw {len(result)} retrieved documents") print(f"Got raw {len(result)} retrieved documents")

View File

@ -0,0 +1,5 @@
from .base import BaseReranking
from .cohere import CohereReranking
from .tei_fast_rerank import TeiFastReranking
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"]

View File

@ -0,0 +1,13 @@
from __future__ import annotations
from abc import abstractmethod
from kotaemon.base import BaseComponent, Document
class BaseReranking(BaseComponent):
@abstractmethod
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Main method to transform list of documents
(re-ranking, filtering, etc)"""
...

View File

@ -0,0 +1,56 @@
from __future__ import annotations
from decouple import config
from kotaemon.base import Document, Param
from .base import BaseReranking
class CohereReranking(BaseReranking):
"""Cohere Reranking model"""
model_name: str = Param(
"rerank-multilingual-v2.0",
help=(
"ID of the model to use. You can go to [Supported Models]"
"(https://docs.cohere.com/docs/rerank-2) to see the supported models"
),
required=True,
)
cohere_api_key: str = Param(
config("COHERE_API_KEY", ""),
help="Cohere API key",
required=True,
)
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use Cohere Reranker model to re-order documents
with their relevance score"""
try:
import cohere
except ImportError:
raise ImportError(
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
)
if not self.cohere_api_key:
print("Cohere API key not found. Skipping rerankings.")
return documents
cohere_client = cohere.Client(self.cohere_api_key)
compressed_docs: list[Document] = []
if not documents: # to avoid empty api call
return compressed_docs
_docs = [d.content for d in documents]
response = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs
)
for r in response.results:
doc = documents[r.index]
doc.metadata["reranking_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs

View File

@ -0,0 +1,77 @@
from __future__ import annotations
from typing import Optional
import requests
from kotaemon.base import Document, Param
from .base import BaseReranking
session = requests.session()
class TeiFastReranking(BaseReranking):
"""Text Embeddings Inference (TEI) Reranking model
(https://huggingface.co/docs/text-embeddings-inference/en/index)
"""
endpoint_url: str = Param(
None, help="TEI Reranking service api base URL", required=True
)
model_name: Optional[str] = Param(
None,
help=(
"ID of the model to use. You can go to [Supported Models]"
"(https://github.com/huggingface"
"/text-embeddings-inference?tab=readme-ov-file"
"#supported-models) to see the supported models"
),
)
is_truncated: Optional[bool] = Param(True, help="Whether to truncate the inputs")
def client(self, query, texts):
response = session.post(
url=self.endpoint_url,
json={
"query": query,
"texts": texts,
"is_truncated": self.is_truncated, # default is True
},
).json()
return response
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use the deployed TEI rerankings service to re-order documents
with their relevance score"""
if not self.endpoint_url:
print("TEI API reranking URL not found. Skipping rerankings.")
return documents
compressed_docs: list[Document] = []
if not documents: # to avoid empty api call
return compressed_docs
if isinstance(documents[0], str):
documents = self.prepare_input(documents)
batch_size = 6
num_batch = max(len(documents) // batch_size, 1)
for i in range(num_batch):
if i == num_batch - 1:
mini_batch = documents[batch_size * i :]
else:
mini_batch = documents[batch_size * i : batch_size * (i + 1)]
_docs = [d.content for d in mini_batch]
rerank_resp = self.client(query, _docs)
for r in rerank_resp:
doc = mini_batch[r["index"]]
doc.metadata["reranking_score"] = r["score"]
compressed_docs.append(doc)
compressed_docs = sorted(
compressed_docs, key=lambda x: x.metadata["reranking_score"], reverse=True
)
return compressed_docs

View File

@ -59,6 +59,7 @@ class EmbeddingManager:
LCCohereEmbeddings, LCCohereEmbeddings,
LCHuggingFaceEmbeddings, LCHuggingFaceEmbeddings,
OpenAIEmbeddings, OpenAIEmbeddings,
TeiEndpointEmbeddings,
) )
self._vendors = [ self._vendors = [
@ -67,6 +68,7 @@ class EmbeddingManager:
FastEmbedEmbeddings, FastEmbedEmbeddings,
LCCohereEmbeddings, LCCohereEmbeddings,
LCHuggingFaceEmbeddings, LCHuggingFaceEmbeddings,
TeiEndpointEmbeddings,
] ]
def __getitem__(self, key: str) -> BaseEmbeddings: def __getitem__(self, key: str) -> BaseEmbeddings:

View File

@ -16,6 +16,7 @@ import tiktoken
from ktem.db.models import engine from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager from ktem.embeddings.manager import embedding_models_manager
from ktem.llms.manager import llms from ktem.llms.manager import llms
from ktem.rerankings.manager import reranking_models_manager
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
from llama_index.core.readers.file.base import default_file_metadata_func from llama_index.core.readers.file.base import default_file_metadata_func
from llama_index.core.vector_stores import ( from llama_index.core.vector_stores import (
@ -39,12 +40,7 @@ from kotaemon.indices.ingests.files import (
azure_reader, azure_reader,
unstructured, unstructured,
) )
from kotaemon.indices.rankings import ( from kotaemon.indices.rankings import BaseReranking, LLMReranking, LLMTrulensScoring
BaseReranking,
CohereReranking,
LLMReranking,
LLMTrulensScoring,
)
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
@ -285,7 +281,13 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
], ],
retrieval_mode=user_settings["retrieval_mode"], retrieval_mode=user_settings["retrieval_mode"],
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None), llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
rerankers=[CohereReranking(use_key_from_ktem=True)], rerankers=[
reranking_models_manager[
index_settings.get(
"reranking", reranking_models_manager.get_default_name()
)
]
],
) )
if not user_settings["use_reranking"]: if not user_settings["use_reranking"]:
retriever.rerankers = [] # type: ignore retriever.rerankers = [] # type: ignore
@ -715,7 +717,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
for idx, file_path in enumerate(file_paths): for idx, file_path in enumerate(file_paths):
file_path = Path(file_path) file_path = Path(file_path)
yield Document( yield Document(
content=f"Indexing [{idx+1}/{n_files}]: {file_path.name}", content=f"Indexing [{idx + 1}/{n_files}]: {file_path.name}",
channel="debug", channel="debug",
) )

View File

@ -4,6 +4,7 @@ from ktem.db.models import User, engine
from ktem.embeddings.ui import EmbeddingManagement from ktem.embeddings.ui import EmbeddingManagement
from ktem.index.ui import IndexManagement from ktem.index.ui import IndexManagement
from ktem.llms.ui import LLMManagement from ktem.llms.ui import LLMManagement
from ktem.rerankings.ui import RerankingManagement
from sqlmodel import Session, select from sqlmodel import Session, select
from .user import UserManagement from .user import UserManagement
@ -24,6 +25,9 @@ class ResourcesTab(BasePage):
with gr.Tab("Embeddings") as self.emb_management_tab: with gr.Tab("Embeddings") as self.emb_management_tab:
self.emb_management = EmbeddingManagement(self._app) self.emb_management = EmbeddingManagement(self._app)
with gr.Tab("Rerankings") as self.rerank_management_tab:
self.rerank_management = RerankingManagement(self._app)
if self._app.f_user_management: if self._app.f_user_management:
with gr.Tab("Users", visible=False) as self.user_management_tab: with gr.Tab("Users", visible=False) as self.user_management_tab:
self.user_management = UserManagement(self._app) self.user_management = UserManagement(self._app)

View File

@ -5,6 +5,7 @@ import requests
from ktem.app import BasePage from ktem.app import BasePage
from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms from ktem.llms.manager import llms
from ktem.rerankings.manager import reranking_models_manager as rerankers
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
@ -186,6 +187,15 @@ class SetupPage(BasePage):
}, },
default=True, default=True,
) )
rerankers.update(
name="cohere",
spec={
"__type__": "kotaemon.rerankings.CohereReranking",
"model_name": "rerank-multilingual-v2.0",
"cohere_api_key": cohere_api_key,
},
default=True,
)
elif radio_model_value == "openai": elif radio_model_value == "openai":
if openai_api_key: if openai_api_key:
llms.update( llms.update(

View File

@ -100,7 +100,7 @@ class DocSearchTool(BaseTool):
) )
print("Retrieved #{}: {}".format(_id, retrieved_content[:100])) print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None)) print("Score", retrieved_item.metadata.get("reranking_score", None))
# trim context by trim_len # trim context by trim_len
if evidence: if evidence:

View File

@ -138,7 +138,7 @@ class DocSearchTool(BaseTool):
) )
print("Retrieved #{}: {}".format(_id, retrieved_content)) print("Retrieved #{}: {}".format(_id, retrieved_content))
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None)) print("Score", retrieved_item.metadata.get("reranking_score", None))
# trim context by trim_len # trim context by trim_len
if evidence: if evidence:

View File

View File

@ -0,0 +1,36 @@
from typing import Type
from ktem.db.engine import engine
from sqlalchemy import JSON, Boolean, Column, String
from sqlalchemy.orm import DeclarativeBase
from theflow.settings import settings as flowsettings
from theflow.utils.modules import import_dotted_string
class Base(DeclarativeBase):
pass
class BaseRerankingTable(Base):
"""Base table to store rerankings model"""
__abstract__ = True
name = Column(String, primary_key=True, unique=True)
spec = Column(JSON, default={})
default = Column(Boolean, default=False)
__base_reranking: Type[BaseRerankingTable] = (
import_dotted_string(flowsettings.KH_TABLE_RERANKING, safe=False)
if hasattr(flowsettings, "KH_TABLE_RERANKING")
else BaseRerankingTable
)
class RerankingTable(__base_reranking): # type: ignore
__tablename__ = "reranking"
if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
RerankingTable.metadata.create_all(engine)

View File

@ -0,0 +1,194 @@
from typing import Optional, Type
from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize
from kotaemon.rerankings.base import BaseReranking
from .db import RerankingTable, engine
class RerankingManager:
"""Represent a pool of rerankings models"""
def __init__(self):
self._models: dict[str, BaseReranking] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []
# populate the pool if empty
if hasattr(flowsettings, "KH_RERANKINGS"):
with Session(engine) as sess:
count = sess.query(RerankingTable).count()
if not count:
for name, model in flowsettings.KH_RERANKINGS.items():
self.add(
name=name,
spec=model["spec"],
default=model.get("default", False),
)
self.load()
self.load_vendors()
def load(self):
"""Load the model pool from database"""
self._models, self._info, self._default = {}, {}, ""
with Session(engine) as sess:
stmt = select(RerankingTable)
items = sess.execute(stmt)
for (item,) in items:
self._models[item.name] = deserialize(item.spec, safe=False)
self._info[item.name] = {
"name": item.name,
"spec": item.spec,
"default": item.default,
}
if item.default:
self._default = item.name
def load_vendors(self):
from kotaemon.rerankings import CohereReranking, TeiFastReranking
self._vendors = [TeiFastReranking, CohereReranking]
def __getitem__(self, key: str) -> BaseReranking:
"""Get model by name"""
return self._models[key]
def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models
def get(
self, key: str, default: Optional[BaseReranking] = None
) -> Optional[BaseReranking]:
"""Get model by name with default value"""
return self._models.get(key, default)
def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
"label": "Reranking",
"choices": list(self._models.keys()),
"value": self.get_default_name(),
}
def options(self) -> dict:
"""Present a dict of models"""
return self._models
def get_random_name(self) -> str:
"""Get the name of random model
Returns:
str: random model name in the pool
"""
import random
if not self._models:
raise ValueError("No models is pool")
return random.choice(list(self._models.keys()))
def get_default_name(self) -> str:
"""Get the name of default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
str: model name
"""
if not self._models:
raise ValueError("No models in pool")
if not self._default:
return self.get_random_name()
return self._default
def get_random(self) -> BaseReranking:
"""Get random model"""
return self._models[self.get_random_name()]
def get_default(self) -> BaseReranking:
"""Get default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
BaseReranking: model
"""
return self._models[self.get_default_name()]
def info(self) -> dict:
"""List all models"""
return self._info
def add(self, name: str, spec: dict, default: bool):
if not name:
raise ValueError("Name must not be empty")
try:
with Session(engine) as sess:
if default:
# turn all models to non-default
sess.query(RerankingTable).update({"default": False})
sess.commit()
item = RerankingTable(name=name, spec=spec, default=default)
sess.add(item)
sess.commit()
except Exception as e:
raise ValueError(f"Failed to add model {name}: {e}")
self.load()
def delete(self, name: str):
"""Delete a model from the pool"""
try:
with Session(engine) as sess:
item = sess.query(RerankingTable).filter_by(name=name).first()
sess.delete(item)
sess.commit()
except Exception as e:
raise ValueError(f"Failed to delete model {name}: {e}")
self.load()
def update(self, name: str, spec: dict, default: bool):
"""Update a model in the pool"""
if not name:
raise ValueError("Name must not be empty")
try:
with Session(engine) as sess:
if default:
# turn all models to non-default
sess.query(RerankingTable).update({"default": False})
sess.commit()
item = sess.query(RerankingTable).filter_by(name=name).first()
if not item:
raise ValueError(f"Model {name} not found")
item.spec = spec
item.default = default
sess.commit()
except Exception as e:
raise ValueError(f"Failed to update model {name}: {e}")
self.load()
def vendors(self) -> dict:
"""Return list of vendors"""
return {vendor.__qualname__: vendor for vendor in self._vendors}
reranking_models_manager = RerankingManager()

View File

@ -0,0 +1,390 @@
from copy import deepcopy
import gradio as gr
import pandas as pd
import yaml
from ktem.app import BasePage
from ktem.utils.file import YAMLNoDateSafeLoader
from theflow.utils.modules import deserialize
from .manager import reranking_models_manager
def format_description(cls):
params = cls.describe()["params"]
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
for key, value in params.items():
if isinstance(value["auto_callback"], str):
continue
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
class RerankingManagement(BasePage):
def __init__(self, app):
self._app = app
self.spec_desc_default = (
"# Spec description\n\nSelect a model to view the spec description."
)
self.on_building_ui()
def on_building_ui(self):
with gr.Tab(label="View"):
self.rerank_list = gr.DataFrame(
headers=["name", "vendor", "default"],
interactive=False,
)
with gr.Column(visible=False) as self._selected_panel:
self.selected_rerank_name = gr.Textbox(value="", visible=False)
with gr.Row():
with gr.Column():
self.edit_default = gr.Checkbox(
label="Set default",
info=(
"Set this Reranking model as default. This default "
"Reranking will be used by other components by default "
"if no Reranking is specified for such components."
),
)
self.edit_spec = gr.Textbox(
label="Specification",
info="Specification of the Embedding model in YAML format",
lines=10,
)
with gr.Accordion(
label="Test connection", visible=False, open=False
) as self._check_connection_panel:
with gr.Row():
with gr.Column(scale=4):
self.connection_logs = gr.HTML(
"Logs",
)
with gr.Column(scale=1):
self.btn_test_connection = gr.Button("Test")
with gr.Row(visible=False) as self._selected_panel_btn:
with gr.Column():
self.btn_edit_save = gr.Button(
"Save", min_width=10, variant="primary"
)
with gr.Column():
self.btn_delete = gr.Button(
"Delete", min_width=10, variant="stop"
)
with gr.Row():
self.btn_delete_yes = gr.Button(
"Confirm Delete",
variant="stop",
visible=False,
min_width=10,
)
self.btn_delete_no = gr.Button(
"Cancel", visible=False, min_width=10
)
with gr.Column():
self.btn_close = gr.Button("Close", min_width=10)
with gr.Column():
self.edit_spec_desc = gr.Markdown("# Spec description")
with gr.Tab(label="Add"):
with gr.Row():
with gr.Column(scale=2):
self.name = gr.Textbox(
label="Name",
info=(
"Must be unique and non-empty. "
"The name will be used to identify the reranking model."
),
)
self.rerank_choices = gr.Dropdown(
label="Vendors",
info=(
"Choose the vendor of the Reranking model. Each vendor "
"has different specification."
),
)
self.spec = gr.Textbox(
label="Specification",
info="Specification of the Embedding model in YAML format.",
)
self.default = gr.Checkbox(
label="Set default",
info=(
"Set this Reranking model as default. This default "
"Reranking will be used by other components by default "
"if no Reranking is specified for such components."
),
)
self.btn_new = gr.Button("Add", variant="primary")
with gr.Column(scale=3):
self.spec_desc = gr.Markdown(self.spec_desc_default)
def _on_app_created(self):
"""Called when the app is created"""
self._app.app.load(
self.list_rerankings,
inputs=[],
outputs=[self.rerank_list],
)
self._app.app.load(
lambda: gr.update(choices=list(reranking_models_manager.vendors().keys())),
outputs=[self.rerank_choices],
)
def on_rerank_vendor_change(self, vendor):
vendor = reranking_models_manager.vendors()[vendor]
required: dict = {}
desc = vendor.describe()
for key, value in desc["params"].items():
if value.get("required", False):
required[key] = value.get("default", None)
return yaml.dump(required), format_description(vendor)
def on_register_events(self):
self.rerank_choices.select(
self.on_rerank_vendor_change,
inputs=[self.rerank_choices],
outputs=[self.spec, self.spec_desc],
)
self.btn_new.click(
self.create_rerank,
inputs=[self.name, self.rerank_choices, self.spec, self.default],
outputs=None,
).success(self.list_rerankings, inputs=[], outputs=[self.rerank_list]).success(
lambda: ("", None, "", False, self.spec_desc_default),
outputs=[
self.name,
self.rerank_choices,
self.spec,
self.default,
self.spec_desc,
],
)
self.rerank_list.select(
self.select_rerank,
inputs=self.rerank_list,
outputs=[self.selected_rerank_name],
show_progress="hidden",
)
self.selected_rerank_name.change(
self.on_selected_rerank_change,
inputs=[self.selected_rerank_name],
outputs=[
self._selected_panel,
self._selected_panel_btn,
# delete section
self.btn_delete,
self.btn_delete_yes,
self.btn_delete_no,
# edit section
self.edit_spec,
self.edit_spec_desc,
self.edit_default,
self._check_connection_panel,
],
show_progress="hidden",
).success(lambda: gr.update(value=""), outputs=[self.connection_logs])
self.btn_delete.click(
self.on_btn_delete_click,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
self.btn_delete_yes.click(
self.delete_rerank,
inputs=[self.selected_rerank_name],
outputs=[self.selected_rerank_name],
show_progress="hidden",
).then(
self.list_rerankings,
inputs=[],
outputs=[self.rerank_list],
)
self.btn_delete_no.click(
lambda: (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
),
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
self.btn_edit_save.click(
self.save_rerank,
inputs=[
self.selected_rerank_name,
self.edit_default,
self.edit_spec,
],
show_progress="hidden",
).then(
self.list_rerankings,
inputs=[],
outputs=[self.rerank_list],
)
self.btn_close.click(lambda: "", outputs=[self.selected_rerank_name])
self.btn_test_connection.click(
self.check_connection,
inputs=[self.selected_rerank_name, self.edit_spec],
outputs=[self.connection_logs],
)
def create_rerank(self, name, choices, spec, default):
try:
spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
spec["__type__"] = (
reranking_models_manager.vendors()[choices].__module__
+ "."
+ reranking_models_manager.vendors()[choices].__qualname__
)
reranking_models_manager.add(name, spec=spec, default=default)
gr.Info(f'Create Reranking model "{name}" successfully')
except Exception as e:
raise gr.Error(f"Failed to create Reranking model {name}: {e}")
def list_rerankings(self):
"""List the Reranking models"""
items = []
for item in reranking_models_manager.info().values():
record = {}
record["name"] = item["name"]
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
record["default"] = item["default"]
items.append(record)
if items:
rerank_list = pd.DataFrame.from_records(items)
else:
rerank_list = pd.DataFrame.from_records(
[{"name": "-", "vendor": "-", "default": "-"}]
)
return rerank_list
def select_rerank(self, rerank_list, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0:
gr.Info("No reranking model is loaded. Please add first")
return ""
if not ev.selected:
return ""
return rerank_list["name"][ev.index[0]]
def on_selected_rerank_change(self, selected_rerank_name):
if selected_rerank_name == "":
_check_connection_panel = gr.update(visible=False)
_selected_panel = gr.update(visible=False)
_selected_panel_btn = gr.update(visible=False)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
edit_spec = gr.update(value="")
edit_spec_desc = gr.update(value="")
edit_default = gr.update(value=False)
else:
_check_connection_panel = gr.update(visible=True)
_selected_panel = gr.update(visible=True)
_selected_panel_btn = gr.update(visible=True)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
vendor = reranking_models_manager.vendors()[vendor_str]
edit_spec = yaml.dump(info["spec"])
edit_spec_desc = format_description(vendor)
edit_default = info["default"]
return (
_selected_panel,
_selected_panel_btn,
btn_delete,
btn_delete_yes,
btn_delete_no,
edit_spec,
edit_spec_desc,
edit_default,
_check_connection_panel,
)
def on_btn_delete_click(self):
btn_delete = gr.update(visible=False)
btn_delete_yes = gr.update(visible=True)
btn_delete_no = gr.update(visible=True)
return btn_delete, btn_delete_yes, btn_delete_no
def check_connection(self, selected_rerank_name, selected_spec):
log_content: str = ""
try:
log_content += f"- Testing model: {selected_rerank_name}<br>"
yield log_content
# Parse content & init model
info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
# Parse content & create dummy response
spec = yaml.load(selected_spec, Loader=YAMLNoDateSafeLoader)
info["spec"].update(spec)
rerank = deserialize(info["spec"], safe=False)
if rerank is None:
raise Exception(f"Can not found model: {selected_rerank_name}")
log_content += "- Sending a message ([`Hello`], `Hi`)<br>"
yield log_content
_ = rerank(["Hello"], "Hi")
log_content += (
"<mark style='background: green; color: white'>- Connection success. "
"</mark><br>"
)
yield log_content
gr.Info(f"Embedding {selected_rerank_name} connect successfully")
except Exception as e:
print(e)
log_content += (
f"<mark style='color: yellow; background: red'>- Connection failed. "
f"Got error:\n {str(e)}</mark>"
)
yield log_content
return log_content
def save_rerank(self, selected_rerank_name, default, spec):
try:
spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
spec["__type__"] = reranking_models_manager.info()[selected_rerank_name][
"spec"
]["__type__"]
reranking_models_manager.update(
selected_rerank_name, spec=spec, default=default
)
gr.Info(f'Save Reranking model "{selected_rerank_name}" successfully')
except Exception as e:
gr.Error(f'Failed to save Embedding model "{selected_rerank_name}": {e}')
def delete_rerank(self, selected_rerank_name):
try:
reranking_models_manager.delete(selected_rerank_name)
except Exception as e:
gr.Error(f'Failed to delete Reranking model "{selected_rerank_name}": {e}')
return selected_rerank_name
return ""

View File

@ -154,9 +154,9 @@ class Render:
if doc.metadata.get("llm_trulens_score") is not None if doc.metadata.get("llm_trulens_score") is not None
else 0.0 else 0.0
) )
cohere_reranking_score = ( reranking_score = (
round(doc.metadata["cohere_reranking_score"], 2) round(doc.metadata["reranking_score"], 2)
if doc.metadata.get("cohere_reranking_score") is not None if doc.metadata.get("reranking_score") is not None
else 0.0 else 0.0
) )
item_type_prefix = doc.metadata.get("type", "") item_type_prefix = doc.metadata.get("type", "")
@ -166,8 +166,8 @@ class Render:
if llm_reranking_score > 0: if llm_reranking_score > 0:
relevant_score = llm_reranking_score relevant_score = llm_reranking_score
elif cohere_reranking_score > 0: elif reranking_score > 0:
relevant_score = cohere_reranking_score relevant_score = reranking_score
else: else:
relevant_score = 0.0 relevant_score = 0.0
@ -179,7 +179,7 @@ class Render:
"<b>&emsp;&emsp;LLM relevant score:</b>" "<b>&emsp;&emsp;LLM relevant score:</b>"
f" {llm_reranking_score}<br>" f" {llm_reranking_score}<br>"
"<b>&emsp;&emsp;Reranking score:</b>" "<b>&emsp;&emsp;Reranking score:</b>"
f" {cohere_reranking_score}<br>", f" {reranking_score}<br>",
) )
text = doc.text if not override_text else override_text text = doc.text if not override_text else override_text