feat: support TEI embedding service, configurable reranking model (#287)
* feat: add support for TEI embedding service, allow reranking model to be configurable. Signed-off-by: Kennywu <jdlow@live.cn> * fix: add cohere default reranking model * fix: comfort pre-commit --------- Signed-off-by: Kennywu <jdlow@live.cn> Co-authored-by: wujiaye <wujiaye@bluemoon.com.cn> Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
@@ -59,6 +59,7 @@ class EmbeddingManager:
|
||||
LCCohereEmbeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
OpenAIEmbeddings,
|
||||
TeiEndpointEmbeddings,
|
||||
)
|
||||
|
||||
self._vendors = [
|
||||
@@ -67,6 +68,7 @@ class EmbeddingManager:
|
||||
FastEmbedEmbeddings,
|
||||
LCCohereEmbeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
TeiEndpointEmbeddings,
|
||||
]
|
||||
|
||||
def __getitem__(self, key: str) -> BaseEmbeddings:
|
||||
|
@@ -16,6 +16,7 @@ import tiktoken
|
||||
from ktem.db.models import engine
|
||||
from ktem.embeddings.manager import embedding_models_manager
|
||||
from ktem.llms.manager import llms
|
||||
from ktem.rerankings.manager import reranking_models_manager
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.readers.file.base import default_file_metadata_func
|
||||
from llama_index.core.vector_stores import (
|
||||
@@ -39,12 +40,7 @@ from kotaemon.indices.ingests.files import (
|
||||
azure_reader,
|
||||
unstructured,
|
||||
)
|
||||
from kotaemon.indices.rankings import (
|
||||
BaseReranking,
|
||||
CohereReranking,
|
||||
LLMReranking,
|
||||
LLMTrulensScoring,
|
||||
)
|
||||
from kotaemon.indices.rankings import BaseReranking, LLMReranking, LLMTrulensScoring
|
||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||
|
||||
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||
@@ -285,7 +281,13 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||
],
|
||||
retrieval_mode=user_settings["retrieval_mode"],
|
||||
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
|
||||
rerankers=[CohereReranking(use_key_from_ktem=True)],
|
||||
rerankers=[
|
||||
reranking_models_manager[
|
||||
index_settings.get(
|
||||
"reranking", reranking_models_manager.get_default_name()
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
if not user_settings["use_reranking"]:
|
||||
retriever.rerankers = [] # type: ignore
|
||||
@@ -715,7 +717,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||
for idx, file_path in enumerate(file_paths):
|
||||
file_path = Path(file_path)
|
||||
yield Document(
|
||||
content=f"Indexing [{idx+1}/{n_files}]: {file_path.name}",
|
||||
content=f"Indexing [{idx + 1}/{n_files}]: {file_path.name}",
|
||||
channel="debug",
|
||||
)
|
||||
|
||||
|
@@ -4,6 +4,7 @@ from ktem.db.models import User, engine
|
||||
from ktem.embeddings.ui import EmbeddingManagement
|
||||
from ktem.index.ui import IndexManagement
|
||||
from ktem.llms.ui import LLMManagement
|
||||
from ktem.rerankings.ui import RerankingManagement
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .user import UserManagement
|
||||
@@ -24,6 +25,9 @@ class ResourcesTab(BasePage):
|
||||
with gr.Tab("Embeddings") as self.emb_management_tab:
|
||||
self.emb_management = EmbeddingManagement(self._app)
|
||||
|
||||
with gr.Tab("Rerankings") as self.rerank_management_tab:
|
||||
self.rerank_management = RerankingManagement(self._app)
|
||||
|
||||
if self._app.f_user_management:
|
||||
with gr.Tab("Users", visible=False) as self.user_management_tab:
|
||||
self.user_management = UserManagement(self._app)
|
||||
|
@@ -5,6 +5,7 @@ import requests
|
||||
from ktem.app import BasePage
|
||||
from ktem.embeddings.manager import embedding_models_manager as embeddings
|
||||
from ktem.llms.manager import llms
|
||||
from ktem.rerankings.manager import reranking_models_manager as rerankers
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
||||
@@ -186,6 +187,15 @@ class SetupPage(BasePage):
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
rerankers.update(
|
||||
name="cohere",
|
||||
spec={
|
||||
"__type__": "kotaemon.rerankings.CohereReranking",
|
||||
"model_name": "rerank-multilingual-v2.0",
|
||||
"cohere_api_key": cohere_api_key,
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
elif radio_model_value == "openai":
|
||||
if openai_api_key:
|
||||
llms.update(
|
||||
|
@@ -100,7 +100,7 @@ class DocSearchTool(BaseTool):
|
||||
)
|
||||
|
||||
print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
|
||||
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
|
||||
print("Score", retrieved_item.metadata.get("reranking_score", None))
|
||||
|
||||
# trim context by trim_len
|
||||
if evidence:
|
||||
|
@@ -138,7 +138,7 @@ class DocSearchTool(BaseTool):
|
||||
)
|
||||
|
||||
print("Retrieved #{}: {}".format(_id, retrieved_content))
|
||||
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
|
||||
print("Score", retrieved_item.metadata.get("reranking_score", None))
|
||||
|
||||
# trim context by trim_len
|
||||
if evidence:
|
||||
|
0
libs/ktem/ktem/rerankings/__init__.py
Normal file
0
libs/ktem/ktem/rerankings/__init__.py
Normal file
36
libs/ktem/ktem/rerankings/db.py
Normal file
36
libs/ktem/ktem/rerankings/db.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Type
|
||||
|
||||
from ktem.db.engine import engine
|
||||
from sqlalchemy import JSON, Boolean, Column, String
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from theflow.settings import settings as flowsettings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class BaseRerankingTable(Base):
|
||||
"""Base table to store rerankings model"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
name = Column(String, primary_key=True, unique=True)
|
||||
spec = Column(JSON, default={})
|
||||
default = Column(Boolean, default=False)
|
||||
|
||||
|
||||
__base_reranking: Type[BaseRerankingTable] = (
|
||||
import_dotted_string(flowsettings.KH_TABLE_RERANKING, safe=False)
|
||||
if hasattr(flowsettings, "KH_TABLE_RERANKING")
|
||||
else BaseRerankingTable
|
||||
)
|
||||
|
||||
|
||||
class RerankingTable(__base_reranking): # type: ignore
|
||||
__tablename__ = "reranking"
|
||||
|
||||
|
||||
if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
|
||||
RerankingTable.metadata.create_all(engine)
|
194
libs/ktem/ktem/rerankings/manager.py
Normal file
194
libs/ktem/ktem/rerankings/manager.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from theflow.settings import settings as flowsettings
|
||||
from theflow.utils.modules import deserialize
|
||||
|
||||
from kotaemon.rerankings.base import BaseReranking
|
||||
|
||||
from .db import RerankingTable, engine
|
||||
|
||||
|
||||
class RerankingManager:
|
||||
"""Represent a pool of rerankings models"""
|
||||
|
||||
def __init__(self):
|
||||
self._models: dict[str, BaseReranking] = {}
|
||||
self._info: dict[str, dict] = {}
|
||||
self._default: str = ""
|
||||
self._vendors: list[Type] = []
|
||||
|
||||
# populate the pool if empty
|
||||
if hasattr(flowsettings, "KH_RERANKINGS"):
|
||||
with Session(engine) as sess:
|
||||
count = sess.query(RerankingTable).count()
|
||||
if not count:
|
||||
for name, model in flowsettings.KH_RERANKINGS.items():
|
||||
self.add(
|
||||
name=name,
|
||||
spec=model["spec"],
|
||||
default=model.get("default", False),
|
||||
)
|
||||
|
||||
self.load()
|
||||
self.load_vendors()
|
||||
|
||||
def load(self):
|
||||
"""Load the model pool from database"""
|
||||
self._models, self._info, self._default = {}, {}, ""
|
||||
with Session(engine) as sess:
|
||||
stmt = select(RerankingTable)
|
||||
items = sess.execute(stmt)
|
||||
|
||||
for (item,) in items:
|
||||
self._models[item.name] = deserialize(item.spec, safe=False)
|
||||
self._info[item.name] = {
|
||||
"name": item.name,
|
||||
"spec": item.spec,
|
||||
"default": item.default,
|
||||
}
|
||||
if item.default:
|
||||
self._default = item.name
|
||||
|
||||
def load_vendors(self):
|
||||
from kotaemon.rerankings import CohereReranking, TeiFastReranking
|
||||
|
||||
self._vendors = [TeiFastReranking, CohereReranking]
|
||||
|
||||
def __getitem__(self, key: str) -> BaseReranking:
|
||||
"""Get model by name"""
|
||||
return self._models[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if model exists"""
|
||||
return key in self._models
|
||||
|
||||
def get(
|
||||
self, key: str, default: Optional[BaseReranking] = None
|
||||
) -> Optional[BaseReranking]:
|
||||
"""Get model by name with default value"""
|
||||
return self._models.get(key, default)
|
||||
|
||||
def settings(self) -> dict:
|
||||
"""Present model pools option for gradio"""
|
||||
return {
|
||||
"label": "Reranking",
|
||||
"choices": list(self._models.keys()),
|
||||
"value": self.get_default_name(),
|
||||
}
|
||||
|
||||
def options(self) -> dict:
|
||||
"""Present a dict of models"""
|
||||
return self._models
|
||||
|
||||
def get_random_name(self) -> str:
|
||||
"""Get the name of random model
|
||||
|
||||
Returns:
|
||||
str: random model name in the pool
|
||||
"""
|
||||
import random
|
||||
|
||||
if not self._models:
|
||||
raise ValueError("No models is pool")
|
||||
|
||||
return random.choice(list(self._models.keys()))
|
||||
|
||||
def get_default_name(self) -> str:
|
||||
"""Get the name of default model
|
||||
|
||||
In case there is no default model, choose random model from pool. In
|
||||
case there are multiple default models, choose random from them.
|
||||
|
||||
Returns:
|
||||
str: model name
|
||||
"""
|
||||
if not self._models:
|
||||
raise ValueError("No models in pool")
|
||||
|
||||
if not self._default:
|
||||
return self.get_random_name()
|
||||
|
||||
return self._default
|
||||
|
||||
def get_random(self) -> BaseReranking:
|
||||
"""Get random model"""
|
||||
return self._models[self.get_random_name()]
|
||||
|
||||
def get_default(self) -> BaseReranking:
|
||||
"""Get default model
|
||||
|
||||
In case there is no default model, choose random model from pool. In
|
||||
case there are multiple default models, choose random from them.
|
||||
|
||||
Returns:
|
||||
BaseReranking: model
|
||||
"""
|
||||
return self._models[self.get_default_name()]
|
||||
|
||||
def info(self) -> dict:
|
||||
"""List all models"""
|
||||
return self._info
|
||||
|
||||
def add(self, name: str, spec: dict, default: bool):
|
||||
if not name:
|
||||
raise ValueError("Name must not be empty")
|
||||
|
||||
try:
|
||||
with Session(engine) as sess:
|
||||
if default:
|
||||
# turn all models to non-default
|
||||
sess.query(RerankingTable).update({"default": False})
|
||||
sess.commit()
|
||||
|
||||
item = RerankingTable(name=name, spec=spec, default=default)
|
||||
sess.add(item)
|
||||
sess.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to add model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def delete(self, name: str):
|
||||
"""Delete a model from the pool"""
|
||||
try:
|
||||
with Session(engine) as sess:
|
||||
item = sess.query(RerankingTable).filter_by(name=name).first()
|
||||
sess.delete(item)
|
||||
sess.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to delete model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def update(self, name: str, spec: dict, default: bool):
|
||||
"""Update a model in the pool"""
|
||||
if not name:
|
||||
raise ValueError("Name must not be empty")
|
||||
|
||||
try:
|
||||
with Session(engine) as sess:
|
||||
|
||||
if default:
|
||||
# turn all models to non-default
|
||||
sess.query(RerankingTable).update({"default": False})
|
||||
sess.commit()
|
||||
|
||||
item = sess.query(RerankingTable).filter_by(name=name).first()
|
||||
if not item:
|
||||
raise ValueError(f"Model {name} not found")
|
||||
item.spec = spec
|
||||
item.default = default
|
||||
sess.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def vendors(self) -> dict:
|
||||
"""Return list of vendors"""
|
||||
return {vendor.__qualname__: vendor for vendor in self._vendors}
|
||||
|
||||
|
||||
reranking_models_manager = RerankingManager()
|
390
libs/ktem/ktem/rerankings/ui.py
Normal file
390
libs/ktem/ktem/rerankings/ui.py
Normal file
@@ -0,0 +1,390 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from ktem.app import BasePage
|
||||
from ktem.utils.file import YAMLNoDateSafeLoader
|
||||
from theflow.utils.modules import deserialize
|
||||
|
||||
from .manager import reranking_models_manager
|
||||
|
||||
|
||||
def format_description(cls):
|
||||
params = cls.describe()["params"]
|
||||
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
|
||||
for key, value in params.items():
|
||||
if isinstance(value["auto_callback"], str):
|
||||
continue
|
||||
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
|
||||
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
|
||||
|
||||
|
||||
class RerankingManagement(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.spec_desc_default = (
|
||||
"# Spec description\n\nSelect a model to view the spec description."
|
||||
)
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Tab(label="View"):
|
||||
self.rerank_list = gr.DataFrame(
|
||||
headers=["name", "vendor", "default"],
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Column(visible=False) as self._selected_panel:
|
||||
self.selected_rerank_name = gr.Textbox(value="", visible=False)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.edit_default = gr.Checkbox(
|
||||
label="Set default",
|
||||
info=(
|
||||
"Set this Reranking model as default. This default "
|
||||
"Reranking will be used by other components by default "
|
||||
"if no Reranking is specified for such components."
|
||||
),
|
||||
)
|
||||
self.edit_spec = gr.Textbox(
|
||||
label="Specification",
|
||||
info="Specification of the Embedding model in YAML format",
|
||||
lines=10,
|
||||
)
|
||||
|
||||
with gr.Accordion(
|
||||
label="Test connection", visible=False, open=False
|
||||
) as self._check_connection_panel:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
self.connection_logs = gr.HTML(
|
||||
"Logs",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
self.btn_test_connection = gr.Button("Test")
|
||||
|
||||
with gr.Row(visible=False) as self._selected_panel_btn:
|
||||
with gr.Column():
|
||||
self.btn_edit_save = gr.Button(
|
||||
"Save", min_width=10, variant="primary"
|
||||
)
|
||||
with gr.Column():
|
||||
self.btn_delete = gr.Button(
|
||||
"Delete", min_width=10, variant="stop"
|
||||
)
|
||||
with gr.Row():
|
||||
self.btn_delete_yes = gr.Button(
|
||||
"Confirm Delete",
|
||||
variant="stop",
|
||||
visible=False,
|
||||
min_width=10,
|
||||
)
|
||||
self.btn_delete_no = gr.Button(
|
||||
"Cancel", visible=False, min_width=10
|
||||
)
|
||||
with gr.Column():
|
||||
self.btn_close = gr.Button("Close", min_width=10)
|
||||
|
||||
with gr.Column():
|
||||
self.edit_spec_desc = gr.Markdown("# Spec description")
|
||||
|
||||
with gr.Tab(label="Add"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
self.name = gr.Textbox(
|
||||
label="Name",
|
||||
info=(
|
||||
"Must be unique and non-empty. "
|
||||
"The name will be used to identify the reranking model."
|
||||
),
|
||||
)
|
||||
self.rerank_choices = gr.Dropdown(
|
||||
label="Vendors",
|
||||
info=(
|
||||
"Choose the vendor of the Reranking model. Each vendor "
|
||||
"has different specification."
|
||||
),
|
||||
)
|
||||
self.spec = gr.Textbox(
|
||||
label="Specification",
|
||||
info="Specification of the Embedding model in YAML format.",
|
||||
)
|
||||
self.default = gr.Checkbox(
|
||||
label="Set default",
|
||||
info=(
|
||||
"Set this Reranking model as default. This default "
|
||||
"Reranking will be used by other components by default "
|
||||
"if no Reranking is specified for such components."
|
||||
),
|
||||
)
|
||||
self.btn_new = gr.Button("Add", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
self.spec_desc = gr.Markdown(self.spec_desc_default)
|
||||
|
||||
def _on_app_created(self):
|
||||
"""Called when the app is created"""
|
||||
self._app.app.load(
|
||||
self.list_rerankings,
|
||||
inputs=[],
|
||||
outputs=[self.rerank_list],
|
||||
)
|
||||
self._app.app.load(
|
||||
lambda: gr.update(choices=list(reranking_models_manager.vendors().keys())),
|
||||
outputs=[self.rerank_choices],
|
||||
)
|
||||
|
||||
def on_rerank_vendor_change(self, vendor):
|
||||
vendor = reranking_models_manager.vendors()[vendor]
|
||||
|
||||
required: dict = {}
|
||||
desc = vendor.describe()
|
||||
for key, value in desc["params"].items():
|
||||
if value.get("required", False):
|
||||
required[key] = value.get("default", None)
|
||||
|
||||
return yaml.dump(required), format_description(vendor)
|
||||
|
||||
def on_register_events(self):
|
||||
self.rerank_choices.select(
|
||||
self.on_rerank_vendor_change,
|
||||
inputs=[self.rerank_choices],
|
||||
outputs=[self.spec, self.spec_desc],
|
||||
)
|
||||
self.btn_new.click(
|
||||
self.create_rerank,
|
||||
inputs=[self.name, self.rerank_choices, self.spec, self.default],
|
||||
outputs=None,
|
||||
).success(self.list_rerankings, inputs=[], outputs=[self.rerank_list]).success(
|
||||
lambda: ("", None, "", False, self.spec_desc_default),
|
||||
outputs=[
|
||||
self.name,
|
||||
self.rerank_choices,
|
||||
self.spec,
|
||||
self.default,
|
||||
self.spec_desc,
|
||||
],
|
||||
)
|
||||
self.rerank_list.select(
|
||||
self.select_rerank,
|
||||
inputs=self.rerank_list,
|
||||
outputs=[self.selected_rerank_name],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.selected_rerank_name.change(
|
||||
self.on_selected_rerank_change,
|
||||
inputs=[self.selected_rerank_name],
|
||||
outputs=[
|
||||
self._selected_panel,
|
||||
self._selected_panel_btn,
|
||||
# delete section
|
||||
self.btn_delete,
|
||||
self.btn_delete_yes,
|
||||
self.btn_delete_no,
|
||||
# edit section
|
||||
self.edit_spec,
|
||||
self.edit_spec_desc,
|
||||
self.edit_default,
|
||||
self._check_connection_panel,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).success(lambda: gr.update(value=""), outputs=[self.connection_logs])
|
||||
|
||||
self.btn_delete.click(
|
||||
self.on_btn_delete_click,
|
||||
inputs=[],
|
||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_delete_yes.click(
|
||||
self.delete_rerank,
|
||||
inputs=[self.selected_rerank_name],
|
||||
outputs=[self.selected_rerank_name],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.list_rerankings,
|
||||
inputs=[],
|
||||
outputs=[self.rerank_list],
|
||||
)
|
||||
self.btn_delete_no.click(
|
||||
lambda: (
|
||||
gr.update(visible=True),
|
||||
gr.update(visible=False),
|
||||
gr.update(visible=False),
|
||||
),
|
||||
inputs=[],
|
||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_edit_save.click(
|
||||
self.save_rerank,
|
||||
inputs=[
|
||||
self.selected_rerank_name,
|
||||
self.edit_default,
|
||||
self.edit_spec,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.list_rerankings,
|
||||
inputs=[],
|
||||
outputs=[self.rerank_list],
|
||||
)
|
||||
self.btn_close.click(lambda: "", outputs=[self.selected_rerank_name])
|
||||
|
||||
self.btn_test_connection.click(
|
||||
self.check_connection,
|
||||
inputs=[self.selected_rerank_name, self.edit_spec],
|
||||
outputs=[self.connection_logs],
|
||||
)
|
||||
|
||||
def create_rerank(self, name, choices, spec, default):
|
||||
try:
|
||||
spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
|
||||
spec["__type__"] = (
|
||||
reranking_models_manager.vendors()[choices].__module__
|
||||
+ "."
|
||||
+ reranking_models_manager.vendors()[choices].__qualname__
|
||||
)
|
||||
|
||||
reranking_models_manager.add(name, spec=spec, default=default)
|
||||
gr.Info(f'Create Reranking model "{name}" successfully')
|
||||
except Exception as e:
|
||||
raise gr.Error(f"Failed to create Reranking model {name}: {e}")
|
||||
|
||||
def list_rerankings(self):
|
||||
"""List the Reranking models"""
|
||||
items = []
|
||||
for item in reranking_models_manager.info().values():
|
||||
record = {}
|
||||
record["name"] = item["name"]
|
||||
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
|
||||
record["default"] = item["default"]
|
||||
items.append(record)
|
||||
|
||||
if items:
|
||||
rerank_list = pd.DataFrame.from_records(items)
|
||||
else:
|
||||
rerank_list = pd.DataFrame.from_records(
|
||||
[{"name": "-", "vendor": "-", "default": "-"}]
|
||||
)
|
||||
|
||||
return rerank_list
|
||||
|
||||
def select_rerank(self, rerank_list, ev: gr.SelectData):
|
||||
if ev.value == "-" and ev.index[0] == 0:
|
||||
gr.Info("No reranking model is loaded. Please add first")
|
||||
return ""
|
||||
|
||||
if not ev.selected:
|
||||
return ""
|
||||
|
||||
return rerank_list["name"][ev.index[0]]
|
||||
|
||||
def on_selected_rerank_change(self, selected_rerank_name):
|
||||
if selected_rerank_name == "":
|
||||
_check_connection_panel = gr.update(visible=False)
|
||||
_selected_panel = gr.update(visible=False)
|
||||
_selected_panel_btn = gr.update(visible=False)
|
||||
btn_delete = gr.update(visible=True)
|
||||
btn_delete_yes = gr.update(visible=False)
|
||||
btn_delete_no = gr.update(visible=False)
|
||||
edit_spec = gr.update(value="")
|
||||
edit_spec_desc = gr.update(value="")
|
||||
edit_default = gr.update(value=False)
|
||||
else:
|
||||
_check_connection_panel = gr.update(visible=True)
|
||||
_selected_panel = gr.update(visible=True)
|
||||
_selected_panel_btn = gr.update(visible=True)
|
||||
btn_delete = gr.update(visible=True)
|
||||
btn_delete_yes = gr.update(visible=False)
|
||||
btn_delete_no = gr.update(visible=False)
|
||||
|
||||
info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
|
||||
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
|
||||
vendor = reranking_models_manager.vendors()[vendor_str]
|
||||
|
||||
edit_spec = yaml.dump(info["spec"])
|
||||
edit_spec_desc = format_description(vendor)
|
||||
edit_default = info["default"]
|
||||
|
||||
return (
|
||||
_selected_panel,
|
||||
_selected_panel_btn,
|
||||
btn_delete,
|
||||
btn_delete_yes,
|
||||
btn_delete_no,
|
||||
edit_spec,
|
||||
edit_spec_desc,
|
||||
edit_default,
|
||||
_check_connection_panel,
|
||||
)
|
||||
|
||||
def on_btn_delete_click(self):
|
||||
btn_delete = gr.update(visible=False)
|
||||
btn_delete_yes = gr.update(visible=True)
|
||||
btn_delete_no = gr.update(visible=True)
|
||||
|
||||
return btn_delete, btn_delete_yes, btn_delete_no
|
||||
|
||||
def check_connection(self, selected_rerank_name, selected_spec):
|
||||
log_content: str = ""
|
||||
try:
|
||||
log_content += f"- Testing model: {selected_rerank_name}<br>"
|
||||
yield log_content
|
||||
|
||||
# Parse content & init model
|
||||
info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
|
||||
|
||||
# Parse content & create dummy response
|
||||
spec = yaml.load(selected_spec, Loader=YAMLNoDateSafeLoader)
|
||||
info["spec"].update(spec)
|
||||
|
||||
rerank = deserialize(info["spec"], safe=False)
|
||||
|
||||
if rerank is None:
|
||||
raise Exception(f"Can not found model: {selected_rerank_name}")
|
||||
|
||||
log_content += "- Sending a message ([`Hello`], `Hi`)<br>"
|
||||
yield log_content
|
||||
_ = rerank(["Hello"], "Hi")
|
||||
|
||||
log_content += (
|
||||
"<mark style='background: green; color: white'>- Connection success. "
|
||||
"</mark><br>"
|
||||
)
|
||||
yield log_content
|
||||
|
||||
gr.Info(f"Embedding {selected_rerank_name} connect successfully")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log_content += (
|
||||
f"<mark style='color: yellow; background: red'>- Connection failed. "
|
||||
f"Got error:\n {str(e)}</mark>"
|
||||
)
|
||||
yield log_content
|
||||
|
||||
return log_content
|
||||
|
||||
def save_rerank(self, selected_rerank_name, default, spec):
|
||||
try:
|
||||
spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
|
||||
spec["__type__"] = reranking_models_manager.info()[selected_rerank_name][
|
||||
"spec"
|
||||
]["__type__"]
|
||||
reranking_models_manager.update(
|
||||
selected_rerank_name, spec=spec, default=default
|
||||
)
|
||||
gr.Info(f'Save Reranking model "{selected_rerank_name}" successfully')
|
||||
except Exception as e:
|
||||
gr.Error(f'Failed to save Embedding model "{selected_rerank_name}": {e}')
|
||||
|
||||
def delete_rerank(self, selected_rerank_name):
|
||||
try:
|
||||
reranking_models_manager.delete(selected_rerank_name)
|
||||
except Exception as e:
|
||||
gr.Error(f'Failed to delete Reranking model "{selected_rerank_name}": {e}')
|
||||
return selected_rerank_name
|
||||
|
||||
return ""
|
@@ -154,9 +154,9 @@ class Render:
|
||||
if doc.metadata.get("llm_trulens_score") is not None
|
||||
else 0.0
|
||||
)
|
||||
cohere_reranking_score = (
|
||||
round(doc.metadata["cohere_reranking_score"], 2)
|
||||
if doc.metadata.get("cohere_reranking_score") is not None
|
||||
reranking_score = (
|
||||
round(doc.metadata["reranking_score"], 2)
|
||||
if doc.metadata.get("reranking_score") is not None
|
||||
else 0.0
|
||||
)
|
||||
item_type_prefix = doc.metadata.get("type", "")
|
||||
@@ -166,8 +166,8 @@ class Render:
|
||||
|
||||
if llm_reranking_score > 0:
|
||||
relevant_score = llm_reranking_score
|
||||
elif cohere_reranking_score > 0:
|
||||
relevant_score = cohere_reranking_score
|
||||
elif reranking_score > 0:
|
||||
relevant_score = reranking_score
|
||||
else:
|
||||
relevant_score = 0.0
|
||||
|
||||
@@ -179,7 +179,7 @@ class Render:
|
||||
"<b>  LLM relevant score:</b>"
|
||||
f" {llm_reranking_score}<br>"
|
||||
"<b>  Reranking score:</b>"
|
||||
f" {cohere_reranking_score}<br>",
|
||||
f" {reranking_score}<br>",
|
||||
)
|
||||
|
||||
text = doc.text if not override_text else override_text
|
||||
|
Reference in New Issue
Block a user