diff --git a/flowsettings.py b/flowsettings.py
index bb2cb7a..e9f9217 100644
--- a/flowsettings.py
+++ b/flowsettings.py
@@ -92,6 +92,7 @@ KH_VECTORSTORE = {
}
KH_LLMS = {}
KH_EMBEDDINGS = {}
+KH_RERANKINGS = {}
# populate options from config
if config("AZURE_OPENAI_API_KEY", default="") and config(
@@ -212,7 +213,7 @@ KH_LLMS["cohere"] = {
"spec": {
"__type__": "kotaemon.llms.chats.LCCohereChat",
"model_name": "command-r-plus-08-2024",
- "api_key": "your-key",
+ "api_key": config("COHERE_API_KEY", default="your-key"),
},
"default": False,
}
@@ -222,7 +223,7 @@ KH_EMBEDDINGS["cohere"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
"model": "embed-multilingual-v3.0",
- "cohere_api_key": "your-key",
+ "cohere_api_key": config("COHERE_API_KEY", default="your-key"),
"user_agent": "default",
},
"default": False,
@@ -235,6 +236,16 @@ KH_EMBEDDINGS["cohere"] = {
# "default": False,
# }
+# default reranking models
+KH_RERANKINGS["cohere"] = {
+ "spec": {
+ "__type__": "kotaemon.rerankings.CohereReranking",
+ "model_name": "rerank-multilingual-v2.0",
+ "cohere_api_key": config("COHERE_API_KEY", default="your-key"),
+ },
+ "default": True,
+}
+
KH_REASONINGS = [
"ktem.reasoning.simple.FullQAPipeline",
"ktem.reasoning.simple.FullDecomposeQAPipeline",
diff --git a/libs/kotaemon/kotaemon/embeddings/__init__.py b/libs/kotaemon/kotaemon/embeddings/__init__.py
index af01ecc..92b3d1f 100644
--- a/libs/kotaemon/kotaemon/embeddings/__init__.py
+++ b/libs/kotaemon/kotaemon/embeddings/__init__.py
@@ -8,10 +8,12 @@ from .langchain_based import (
LCOpenAIEmbeddings,
)
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
+from .tei_endpoint_embed import TeiEndpointEmbeddings
__all__ = [
"BaseEmbeddings",
"EndpointEmbeddings",
+ "TeiEndpointEmbeddings",
"LCOpenAIEmbeddings",
"LCAzureOpenAIEmbeddings",
"LCCohereEmbeddings",
diff --git a/libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py b/libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py
new file mode 100644
index 0000000..6a436e8
--- /dev/null
+++ b/libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py
@@ -0,0 +1,105 @@
+import aiohttp
+import requests
+
+from kotaemon.base import Document, DocumentWithEmbedding, Param
+
+from .base import BaseEmbeddings
+
+session = requests.session()
+
+
+class TeiEndpointEmbeddings(BaseEmbeddings):
+ """An Embeddings component that uses an
+ TEI (Text-Embedding-Inference) API compatible endpoint.
+
+ Ref: https://github.com/huggingface/text-embeddings-inference
+
+ Attributes:
+ endpoint_url (str): The url of an TEI
+ (Text-Embedding-Inference) API compatible endpoint.
+ normalize (bool): Whether to normalize embeddings to unit length.
+ truncate (bool): Whether to truncate embeddings
+ to a fixed/default length.
+ """
+
+ endpoint_url: str = Param(None, help="TEI embedding service api base URL")
+ normalize: bool = Param(
+ True,
+ help="Normalize embeddings to unit length",
+ )
+ truncate: bool = Param(
+ True,
+ help="Truncate embeddings to a fixed/default length",
+ )
+
+ async def client_(self, inputs: list[str]):
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ url=self.endpoint_url,
+ json={
+ "inputs": inputs,
+ "normalize": self.normalize,
+ "truncate": self.truncate,
+ },
+ ) as resp:
+ embeddings = await resp.json()
+ return embeddings
+
+ async def ainvoke(
+ self, text: str | list[str] | Document | list[Document], *args, **kwargs
+ ) -> list[DocumentWithEmbedding]:
+ if not isinstance(text, list):
+ text = [text]
+ text = self.prepare_input(text)
+
+ outputs = []
+ batch_size = 6
+ num_batch = max(len(text) // batch_size, 1)
+ for i in range(num_batch):
+ if i == num_batch - 1:
+ mini_batch = text[batch_size * i :]
+ else:
+ mini_batch = text[batch_size * i : batch_size * (i + 1)]
+ mini_batch = [x.content for x in mini_batch]
+ embeddings = await self.client_(mini_batch) # type: ignore
+ outputs.extend(
+ [
+ DocumentWithEmbedding(content=doc, embedding=embedding)
+ for doc, embedding in zip(mini_batch, embeddings)
+ ]
+ )
+
+ return outputs
+
+ def invoke(
+ self, text: str | list[str] | Document | list[Document], *args, **kwargs
+ ) -> list[DocumentWithEmbedding]:
+ if not isinstance(text, list):
+ text = [text]
+
+ text = self.prepare_input(text)
+
+ outputs = []
+ batch_size = 6
+ num_batch = max(len(text) // batch_size, 1)
+ for i in range(num_batch):
+ if i == num_batch - 1:
+ mini_batch = text[batch_size * i :]
+ else:
+ mini_batch = text[batch_size * i : batch_size * (i + 1)]
+ mini_batch = [x.content for x in mini_batch]
+ embeddings = session.post(
+ url=self.endpoint_url,
+ json={
+ "inputs": mini_batch,
+ "normalize": self.normalize,
+ "truncate": self.truncate,
+ },
+ ).json()
+ outputs.extend(
+ [
+ DocumentWithEmbedding(content=doc, embedding=embedding)
+ for doc, embedding in zip(mini_batch, embeddings)
+ ]
+ )
+ return outputs
diff --git a/libs/kotaemon/kotaemon/indices/rankings/cohere.py b/libs/kotaemon/kotaemon/indices/rankings/cohere.py
index b4ce97e..9515d12 100644
--- a/libs/kotaemon/kotaemon/indices/rankings/cohere.py
+++ b/libs/kotaemon/kotaemon/indices/rankings/cohere.py
@@ -39,7 +39,7 @@ class CohereReranking(BaseReranking):
print("Cannot get Cohere API key from `ktem`", e)
if not self.cohere_api_key:
- print("Cohere API key not found. Skipping reranking.")
+ print("Cohere API key not found. Skipping rerankings.")
return documents
cohere_client = cohere.Client(self.cohere_api_key)
@@ -52,10 +52,9 @@ class CohereReranking(BaseReranking):
response = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs
)
- # print("Cohere score", [r.relevance_score for r in response.results])
for r in response.results:
doc = documents[r.index]
- doc.metadata["cohere_reranking_score"] = r.relevance_score
+ doc.metadata["reranking_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs
diff --git a/libs/kotaemon/kotaemon/indices/vectorindex.py b/libs/kotaemon/kotaemon/indices/vectorindex.py
index e2984c7..1906091 100644
--- a/libs/kotaemon/kotaemon/indices/vectorindex.py
+++ b/libs/kotaemon/kotaemon/indices/vectorindex.py
@@ -241,7 +241,7 @@ class VectorRetrieval(BaseRetrieval):
# if reranker is LLMReranking, limit the document with top_k items only
if isinstance(reranker, LLMReranking):
result = self._filter_docs(result, top_k=top_k)
- result = reranker(documents=result, query=text)
+ result = reranker.run(documents=result, query=text)
result = self._filter_docs(result, top_k=top_k)
print(f"Got raw {len(result)} retrieved documents")
diff --git a/libs/kotaemon/kotaemon/rerankings/__init__.py b/libs/kotaemon/kotaemon/rerankings/__init__.py
new file mode 100644
index 0000000..621b9a2
--- /dev/null
+++ b/libs/kotaemon/kotaemon/rerankings/__init__.py
@@ -0,0 +1,5 @@
+from .base import BaseReranking
+from .cohere import CohereReranking
+from .tei_fast_rerank import TeiFastReranking
+
+__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"]
diff --git a/libs/kotaemon/kotaemon/rerankings/base.py b/libs/kotaemon/kotaemon/rerankings/base.py
new file mode 100644
index 0000000..c9c0b9b
--- /dev/null
+++ b/libs/kotaemon/kotaemon/rerankings/base.py
@@ -0,0 +1,13 @@
+from __future__ import annotations
+
+from abc import abstractmethod
+
+from kotaemon.base import BaseComponent, Document
+
+
+class BaseReranking(BaseComponent):
+ @abstractmethod
+ def run(self, documents: list[Document], query: str) -> list[Document]:
+ """Main method to transform list of documents
+ (re-ranking, filtering, etc)"""
+ ...
diff --git a/libs/kotaemon/kotaemon/rerankings/cohere.py b/libs/kotaemon/kotaemon/rerankings/cohere.py
new file mode 100644
index 0000000..dbc5e9a
--- /dev/null
+++ b/libs/kotaemon/kotaemon/rerankings/cohere.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+from decouple import config
+
+from kotaemon.base import Document, Param
+
+from .base import BaseReranking
+
+
+class CohereReranking(BaseReranking):
+ """Cohere Reranking model"""
+
+ model_name: str = Param(
+ "rerank-multilingual-v2.0",
+ help=(
+ "ID of the model to use. You can go to [Supported Models]"
+ "(https://docs.cohere.com/docs/rerank-2) to see the supported models"
+ ),
+ required=True,
+ )
+ cohere_api_key: str = Param(
+ config("COHERE_API_KEY", ""),
+ help="Cohere API key",
+ required=True,
+ )
+
+ def run(self, documents: list[Document], query: str) -> list[Document]:
+ """Use Cohere Reranker model to re-order documents
+ with their relevance score"""
+ try:
+ import cohere
+ except ImportError:
+ raise ImportError(
+ "Please install Cohere " "`pip install cohere` to use Cohere Reranking"
+ )
+
+ if not self.cohere_api_key:
+ print("Cohere API key not found. Skipping rerankings.")
+ return documents
+
+ cohere_client = cohere.Client(self.cohere_api_key)
+ compressed_docs: list[Document] = []
+
+ if not documents: # to avoid empty api call
+ return compressed_docs
+
+ _docs = [d.content for d in documents]
+ response = cohere_client.rerank(
+ model=self.model_name, query=query, documents=_docs
+ )
+ for r in response.results:
+ doc = documents[r.index]
+ doc.metadata["reranking_score"] = r.relevance_score
+ compressed_docs.append(doc)
+
+ return compressed_docs
diff --git a/libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py b/libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py
new file mode 100644
index 0000000..4ac4b8e
--- /dev/null
+++ b/libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py
@@ -0,0 +1,77 @@
+from __future__ import annotations
+
+from typing import Optional
+
+import requests
+
+from kotaemon.base import Document, Param
+
+from .base import BaseReranking
+
+session = requests.session()
+
+
+class TeiFastReranking(BaseReranking):
+ """Text Embeddings Inference (TEI) Reranking model
+ (https://huggingface.co/docs/text-embeddings-inference/en/index)
+ """
+
+ endpoint_url: str = Param(
+ None, help="TEI Reranking service api base URL", required=True
+ )
+ model_name: Optional[str] = Param(
+ None,
+ help=(
+ "ID of the model to use. You can go to [Supported Models]"
+ "(https://github.com/huggingface"
+ "/text-embeddings-inference?tab=readme-ov-file"
+ "#supported-models) to see the supported models"
+ ),
+ )
+ is_truncated: Optional[bool] = Param(True, help="Whether to truncate the inputs")
+
+ def client(self, query, texts):
+ response = session.post(
+ url=self.endpoint_url,
+ json={
+ "query": query,
+ "texts": texts,
+ "is_truncated": self.is_truncated, # default is True
+ },
+ ).json()
+ return response
+
+ def run(self, documents: list[Document], query: str) -> list[Document]:
+ """Use the deployed TEI rerankings service to re-order documents
+ with their relevance score"""
+ if not self.endpoint_url:
+ print("TEI API reranking URL not found. Skipping rerankings.")
+ return documents
+
+ compressed_docs: list[Document] = []
+
+ if not documents: # to avoid empty api call
+ return compressed_docs
+
+ if isinstance(documents[0], str):
+ documents = self.prepare_input(documents)
+
+ batch_size = 6
+ num_batch = max(len(documents) // batch_size, 1)
+ for i in range(num_batch):
+ if i == num_batch - 1:
+ mini_batch = documents[batch_size * i :]
+ else:
+ mini_batch = documents[batch_size * i : batch_size * (i + 1)]
+
+ _docs = [d.content for d in mini_batch]
+ rerank_resp = self.client(query, _docs)
+ for r in rerank_resp:
+ doc = mini_batch[r["index"]]
+ doc.metadata["reranking_score"] = r["score"]
+ compressed_docs.append(doc)
+
+ compressed_docs = sorted(
+ compressed_docs, key=lambda x: x.metadata["reranking_score"], reverse=True
+ )
+ return compressed_docs
diff --git a/libs/ktem/ktem/embeddings/manager.py b/libs/ktem/ktem/embeddings/manager.py
index 88cdacb..c33d151 100644
--- a/libs/ktem/ktem/embeddings/manager.py
+++ b/libs/ktem/ktem/embeddings/manager.py
@@ -59,6 +59,7 @@ class EmbeddingManager:
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
OpenAIEmbeddings,
+ TeiEndpointEmbeddings,
)
self._vendors = [
@@ -67,6 +68,7 @@ class EmbeddingManager:
FastEmbedEmbeddings,
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
+ TeiEndpointEmbeddings,
]
def __getitem__(self, key: str) -> BaseEmbeddings:
diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py
index 598064a..cd94852 100644
--- a/libs/ktem/ktem/index/file/pipelines.py
+++ b/libs/ktem/ktem/index/file/pipelines.py
@@ -16,6 +16,7 @@ import tiktoken
from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager
from ktem.llms.manager import llms
+from ktem.rerankings.manager import reranking_models_manager
from llama_index.core.readers.base import BaseReader
from llama_index.core.readers.file.base import default_file_metadata_func
from llama_index.core.vector_stores import (
@@ -39,12 +40,7 @@ from kotaemon.indices.ingests.files import (
azure_reader,
unstructured,
)
-from kotaemon.indices.rankings import (
- BaseReranking,
- CohereReranking,
- LLMReranking,
- LLMTrulensScoring,
-)
+from kotaemon.indices.rankings import BaseReranking, LLMReranking, LLMTrulensScoring
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
@@ -285,7 +281,13 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
],
retrieval_mode=user_settings["retrieval_mode"],
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
- rerankers=[CohereReranking(use_key_from_ktem=True)],
+ rerankers=[
+ reranking_models_manager[
+ index_settings.get(
+ "reranking", reranking_models_manager.get_default_name()
+ )
+ ]
+ ],
)
if not user_settings["use_reranking"]:
retriever.rerankers = [] # type: ignore
@@ -715,7 +717,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
for idx, file_path in enumerate(file_paths):
file_path = Path(file_path)
yield Document(
- content=f"Indexing [{idx+1}/{n_files}]: {file_path.name}",
+ content=f"Indexing [{idx + 1}/{n_files}]: {file_path.name}",
channel="debug",
)
diff --git a/libs/ktem/ktem/pages/resources/__init__.py b/libs/ktem/ktem/pages/resources/__init__.py
index aa606c9..35bf54c 100644
--- a/libs/ktem/ktem/pages/resources/__init__.py
+++ b/libs/ktem/ktem/pages/resources/__init__.py
@@ -4,6 +4,7 @@ from ktem.db.models import User, engine
from ktem.embeddings.ui import EmbeddingManagement
from ktem.index.ui import IndexManagement
from ktem.llms.ui import LLMManagement
+from ktem.rerankings.ui import RerankingManagement
from sqlmodel import Session, select
from .user import UserManagement
@@ -24,6 +25,9 @@ class ResourcesTab(BasePage):
with gr.Tab("Embeddings") as self.emb_management_tab:
self.emb_management = EmbeddingManagement(self._app)
+ with gr.Tab("Rerankings") as self.rerank_management_tab:
+ self.rerank_management = RerankingManagement(self._app)
+
if self._app.f_user_management:
with gr.Tab("Users", visible=False) as self.user_management_tab:
self.user_management = UserManagement(self._app)
diff --git a/libs/ktem/ktem/pages/setup.py b/libs/ktem/ktem/pages/setup.py
index 5199ec4..f7e70a1 100644
--- a/libs/ktem/ktem/pages/setup.py
+++ b/libs/ktem/ktem/pages/setup.py
@@ -5,6 +5,7 @@ import requests
from ktem.app import BasePage
from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms
+from ktem.rerankings.manager import reranking_models_manager as rerankers
from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
@@ -186,6 +187,15 @@ class SetupPage(BasePage):
},
default=True,
)
+ rerankers.update(
+ name="cohere",
+ spec={
+ "__type__": "kotaemon.rerankings.CohereReranking",
+ "model_name": "rerank-multilingual-v2.0",
+ "cohere_api_key": cohere_api_key,
+ },
+ default=True,
+ )
elif radio_model_value == "openai":
if openai_api_key:
llms.update(
diff --git a/libs/ktem/ktem/reasoning/react.py b/libs/ktem/ktem/reasoning/react.py
index afdd931..d73a568 100644
--- a/libs/ktem/ktem/reasoning/react.py
+++ b/libs/ktem/ktem/reasoning/react.py
@@ -100,7 +100,7 @@ class DocSearchTool(BaseTool):
)
print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
- print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
+ print("Score", retrieved_item.metadata.get("reranking_score", None))
# trim context by trim_len
if evidence:
diff --git a/libs/ktem/ktem/reasoning/rewoo.py b/libs/ktem/ktem/reasoning/rewoo.py
index e4d461f..f751342 100644
--- a/libs/ktem/ktem/reasoning/rewoo.py
+++ b/libs/ktem/ktem/reasoning/rewoo.py
@@ -138,7 +138,7 @@ class DocSearchTool(BaseTool):
)
print("Retrieved #{}: {}".format(_id, retrieved_content))
- print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
+ print("Score", retrieved_item.metadata.get("reranking_score", None))
# trim context by trim_len
if evidence:
diff --git a/libs/ktem/ktem/rerankings/__init__.py b/libs/ktem/ktem/rerankings/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/libs/ktem/ktem/rerankings/db.py b/libs/ktem/ktem/rerankings/db.py
new file mode 100644
index 0000000..85049df
--- /dev/null
+++ b/libs/ktem/ktem/rerankings/db.py
@@ -0,0 +1,36 @@
+from typing import Type
+
+from ktem.db.engine import engine
+from sqlalchemy import JSON, Boolean, Column, String
+from sqlalchemy.orm import DeclarativeBase
+from theflow.settings import settings as flowsettings
+from theflow.utils.modules import import_dotted_string
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+class BaseRerankingTable(Base):
+ """Base table to store rerankings model"""
+
+ __abstract__ = True
+
+ name = Column(String, primary_key=True, unique=True)
+ spec = Column(JSON, default={})
+ default = Column(Boolean, default=False)
+
+
+__base_reranking: Type[BaseRerankingTable] = (
+ import_dotted_string(flowsettings.KH_TABLE_RERANKING, safe=False)
+ if hasattr(flowsettings, "KH_TABLE_RERANKING")
+ else BaseRerankingTable
+)
+
+
+class RerankingTable(__base_reranking): # type: ignore
+ __tablename__ = "reranking"
+
+
+if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
+ RerankingTable.metadata.create_all(engine)
diff --git a/libs/ktem/ktem/rerankings/manager.py b/libs/ktem/ktem/rerankings/manager.py
new file mode 100644
index 0000000..a9facc1
--- /dev/null
+++ b/libs/ktem/ktem/rerankings/manager.py
@@ -0,0 +1,194 @@
+from typing import Optional, Type
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+from theflow.settings import settings as flowsettings
+from theflow.utils.modules import deserialize
+
+from kotaemon.rerankings.base import BaseReranking
+
+from .db import RerankingTable, engine
+
+
+class RerankingManager:
+ """Represent a pool of rerankings models"""
+
+ def __init__(self):
+ self._models: dict[str, BaseReranking] = {}
+ self._info: dict[str, dict] = {}
+ self._default: str = ""
+ self._vendors: list[Type] = []
+
+ # populate the pool if empty
+ if hasattr(flowsettings, "KH_RERANKINGS"):
+ with Session(engine) as sess:
+ count = sess.query(RerankingTable).count()
+ if not count:
+ for name, model in flowsettings.KH_RERANKINGS.items():
+ self.add(
+ name=name,
+ spec=model["spec"],
+ default=model.get("default", False),
+ )
+
+ self.load()
+ self.load_vendors()
+
+ def load(self):
+ """Load the model pool from database"""
+ self._models, self._info, self._default = {}, {}, ""
+ with Session(engine) as sess:
+ stmt = select(RerankingTable)
+ items = sess.execute(stmt)
+
+ for (item,) in items:
+ self._models[item.name] = deserialize(item.spec, safe=False)
+ self._info[item.name] = {
+ "name": item.name,
+ "spec": item.spec,
+ "default": item.default,
+ }
+ if item.default:
+ self._default = item.name
+
+ def load_vendors(self):
+ from kotaemon.rerankings import CohereReranking, TeiFastReranking
+
+ self._vendors = [TeiFastReranking, CohereReranking]
+
+ def __getitem__(self, key: str) -> BaseReranking:
+ """Get model by name"""
+ return self._models[key]
+
+ def __contains__(self, key: str) -> bool:
+ """Check if model exists"""
+ return key in self._models
+
+ def get(
+ self, key: str, default: Optional[BaseReranking] = None
+ ) -> Optional[BaseReranking]:
+ """Get model by name with default value"""
+ return self._models.get(key, default)
+
+ def settings(self) -> dict:
+ """Present model pools option for gradio"""
+ return {
+ "label": "Reranking",
+ "choices": list(self._models.keys()),
+ "value": self.get_default_name(),
+ }
+
+ def options(self) -> dict:
+ """Present a dict of models"""
+ return self._models
+
+ def get_random_name(self) -> str:
+ """Get the name of random model
+
+ Returns:
+ str: random model name in the pool
+ """
+ import random
+
+ if not self._models:
+ raise ValueError("No models is pool")
+
+ return random.choice(list(self._models.keys()))
+
+ def get_default_name(self) -> str:
+ """Get the name of default model
+
+ In case there is no default model, choose random model from pool. In
+ case there are multiple default models, choose random from them.
+
+ Returns:
+ str: model name
+ """
+ if not self._models:
+ raise ValueError("No models in pool")
+
+ if not self._default:
+ return self.get_random_name()
+
+ return self._default
+
+ def get_random(self) -> BaseReranking:
+ """Get random model"""
+ return self._models[self.get_random_name()]
+
+ def get_default(self) -> BaseReranking:
+ """Get default model
+
+ In case there is no default model, choose random model from pool. In
+ case there are multiple default models, choose random from them.
+
+ Returns:
+ BaseReranking: model
+ """
+ return self._models[self.get_default_name()]
+
+ def info(self) -> dict:
+ """List all models"""
+ return self._info
+
+ def add(self, name: str, spec: dict, default: bool):
+ if not name:
+ raise ValueError("Name must not be empty")
+
+ try:
+ with Session(engine) as sess:
+ if default:
+ # turn all models to non-default
+ sess.query(RerankingTable).update({"default": False})
+ sess.commit()
+
+ item = RerankingTable(name=name, spec=spec, default=default)
+ sess.add(item)
+ sess.commit()
+ except Exception as e:
+ raise ValueError(f"Failed to add model {name}: {e}")
+
+ self.load()
+
+ def delete(self, name: str):
+ """Delete a model from the pool"""
+ try:
+ with Session(engine) as sess:
+ item = sess.query(RerankingTable).filter_by(name=name).first()
+ sess.delete(item)
+ sess.commit()
+ except Exception as e:
+ raise ValueError(f"Failed to delete model {name}: {e}")
+
+ self.load()
+
+ def update(self, name: str, spec: dict, default: bool):
+ """Update a model in the pool"""
+ if not name:
+ raise ValueError("Name must not be empty")
+
+ try:
+ with Session(engine) as sess:
+
+ if default:
+ # turn all models to non-default
+ sess.query(RerankingTable).update({"default": False})
+ sess.commit()
+
+ item = sess.query(RerankingTable).filter_by(name=name).first()
+ if not item:
+ raise ValueError(f"Model {name} not found")
+ item.spec = spec
+ item.default = default
+ sess.commit()
+ except Exception as e:
+ raise ValueError(f"Failed to update model {name}: {e}")
+
+ self.load()
+
+ def vendors(self) -> dict:
+ """Return list of vendors"""
+ return {vendor.__qualname__: vendor for vendor in self._vendors}
+
+
+reranking_models_manager = RerankingManager()
diff --git a/libs/ktem/ktem/rerankings/ui.py b/libs/ktem/ktem/rerankings/ui.py
new file mode 100644
index 0000000..311a794
--- /dev/null
+++ b/libs/ktem/ktem/rerankings/ui.py
@@ -0,0 +1,390 @@
+from copy import deepcopy
+
+import gradio as gr
+import pandas as pd
+import yaml
+from ktem.app import BasePage
+from ktem.utils.file import YAMLNoDateSafeLoader
+from theflow.utils.modules import deserialize
+
+from .manager import reranking_models_manager
+
+
+def format_description(cls):
+ params = cls.describe()["params"]
+ params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
+ for key, value in params.items():
+ if isinstance(value["auto_callback"], str):
+ continue
+ params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
+ return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
+
+
+class RerankingManagement(BasePage):
+ def __init__(self, app):
+ self._app = app
+ self.spec_desc_default = (
+ "# Spec description\n\nSelect a model to view the spec description."
+ )
+ self.on_building_ui()
+
+ def on_building_ui(self):
+ with gr.Tab(label="View"):
+ self.rerank_list = gr.DataFrame(
+ headers=["name", "vendor", "default"],
+ interactive=False,
+ )
+
+ with gr.Column(visible=False) as self._selected_panel:
+ self.selected_rerank_name = gr.Textbox(value="", visible=False)
+ with gr.Row():
+ with gr.Column():
+ self.edit_default = gr.Checkbox(
+ label="Set default",
+ info=(
+ "Set this Reranking model as default. This default "
+ "Reranking will be used by other components by default "
+ "if no Reranking is specified for such components."
+ ),
+ )
+ self.edit_spec = gr.Textbox(
+ label="Specification",
+ info="Specification of the Embedding model in YAML format",
+ lines=10,
+ )
+
+ with gr.Accordion(
+ label="Test connection", visible=False, open=False
+ ) as self._check_connection_panel:
+ with gr.Row():
+ with gr.Column(scale=4):
+ self.connection_logs = gr.HTML(
+ "Logs",
+ )
+
+ with gr.Column(scale=1):
+ self.btn_test_connection = gr.Button("Test")
+
+ with gr.Row(visible=False) as self._selected_panel_btn:
+ with gr.Column():
+ self.btn_edit_save = gr.Button(
+ "Save", min_width=10, variant="primary"
+ )
+ with gr.Column():
+ self.btn_delete = gr.Button(
+ "Delete", min_width=10, variant="stop"
+ )
+ with gr.Row():
+ self.btn_delete_yes = gr.Button(
+ "Confirm Delete",
+ variant="stop",
+ visible=False,
+ min_width=10,
+ )
+ self.btn_delete_no = gr.Button(
+ "Cancel", visible=False, min_width=10
+ )
+ with gr.Column():
+ self.btn_close = gr.Button("Close", min_width=10)
+
+ with gr.Column():
+ self.edit_spec_desc = gr.Markdown("# Spec description")
+
+ with gr.Tab(label="Add"):
+ with gr.Row():
+ with gr.Column(scale=2):
+ self.name = gr.Textbox(
+ label="Name",
+ info=(
+ "Must be unique and non-empty. "
+ "The name will be used to identify the reranking model."
+ ),
+ )
+ self.rerank_choices = gr.Dropdown(
+ label="Vendors",
+ info=(
+ "Choose the vendor of the Reranking model. Each vendor "
+ "has different specification."
+ ),
+ )
+ self.spec = gr.Textbox(
+ label="Specification",
+ info="Specification of the Embedding model in YAML format.",
+ )
+ self.default = gr.Checkbox(
+ label="Set default",
+ info=(
+ "Set this Reranking model as default. This default "
+ "Reranking will be used by other components by default "
+ "if no Reranking is specified for such components."
+ ),
+ )
+ self.btn_new = gr.Button("Add", variant="primary")
+
+ with gr.Column(scale=3):
+ self.spec_desc = gr.Markdown(self.spec_desc_default)
+
+ def _on_app_created(self):
+ """Called when the app is created"""
+ self._app.app.load(
+ self.list_rerankings,
+ inputs=[],
+ outputs=[self.rerank_list],
+ )
+ self._app.app.load(
+ lambda: gr.update(choices=list(reranking_models_manager.vendors().keys())),
+ outputs=[self.rerank_choices],
+ )
+
+ def on_rerank_vendor_change(self, vendor):
+ vendor = reranking_models_manager.vendors()[vendor]
+
+ required: dict = {}
+ desc = vendor.describe()
+ for key, value in desc["params"].items():
+ if value.get("required", False):
+ required[key] = value.get("default", None)
+
+ return yaml.dump(required), format_description(vendor)
+
+ def on_register_events(self):
+ self.rerank_choices.select(
+ self.on_rerank_vendor_change,
+ inputs=[self.rerank_choices],
+ outputs=[self.spec, self.spec_desc],
+ )
+ self.btn_new.click(
+ self.create_rerank,
+ inputs=[self.name, self.rerank_choices, self.spec, self.default],
+ outputs=None,
+ ).success(self.list_rerankings, inputs=[], outputs=[self.rerank_list]).success(
+ lambda: ("", None, "", False, self.spec_desc_default),
+ outputs=[
+ self.name,
+ self.rerank_choices,
+ self.spec,
+ self.default,
+ self.spec_desc,
+ ],
+ )
+ self.rerank_list.select(
+ self.select_rerank,
+ inputs=self.rerank_list,
+ outputs=[self.selected_rerank_name],
+ show_progress="hidden",
+ )
+ self.selected_rerank_name.change(
+ self.on_selected_rerank_change,
+ inputs=[self.selected_rerank_name],
+ outputs=[
+ self._selected_panel,
+ self._selected_panel_btn,
+ # delete section
+ self.btn_delete,
+ self.btn_delete_yes,
+ self.btn_delete_no,
+ # edit section
+ self.edit_spec,
+ self.edit_spec_desc,
+ self.edit_default,
+ self._check_connection_panel,
+ ],
+ show_progress="hidden",
+ ).success(lambda: gr.update(value=""), outputs=[self.connection_logs])
+
+ self.btn_delete.click(
+ self.on_btn_delete_click,
+ inputs=[],
+ outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
+ show_progress="hidden",
+ )
+ self.btn_delete_yes.click(
+ self.delete_rerank,
+ inputs=[self.selected_rerank_name],
+ outputs=[self.selected_rerank_name],
+ show_progress="hidden",
+ ).then(
+ self.list_rerankings,
+ inputs=[],
+ outputs=[self.rerank_list],
+ )
+ self.btn_delete_no.click(
+ lambda: (
+ gr.update(visible=True),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ ),
+ inputs=[],
+ outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
+ show_progress="hidden",
+ )
+ self.btn_edit_save.click(
+ self.save_rerank,
+ inputs=[
+ self.selected_rerank_name,
+ self.edit_default,
+ self.edit_spec,
+ ],
+ show_progress="hidden",
+ ).then(
+ self.list_rerankings,
+ inputs=[],
+ outputs=[self.rerank_list],
+ )
+ self.btn_close.click(lambda: "", outputs=[self.selected_rerank_name])
+
+ self.btn_test_connection.click(
+ self.check_connection,
+ inputs=[self.selected_rerank_name, self.edit_spec],
+ outputs=[self.connection_logs],
+ )
+
+ def create_rerank(self, name, choices, spec, default):
+ try:
+ spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
+ spec["__type__"] = (
+ reranking_models_manager.vendors()[choices].__module__
+ + "."
+ + reranking_models_manager.vendors()[choices].__qualname__
+ )
+
+ reranking_models_manager.add(name, spec=spec, default=default)
+ gr.Info(f'Create Reranking model "{name}" successfully')
+ except Exception as e:
+ raise gr.Error(f"Failed to create Reranking model {name}: {e}")
+
+ def list_rerankings(self):
+ """List the Reranking models"""
+ items = []
+ for item in reranking_models_manager.info().values():
+ record = {}
+ record["name"] = item["name"]
+ record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
+ record["default"] = item["default"]
+ items.append(record)
+
+ if items:
+ rerank_list = pd.DataFrame.from_records(items)
+ else:
+ rerank_list = pd.DataFrame.from_records(
+ [{"name": "-", "vendor": "-", "default": "-"}]
+ )
+
+ return rerank_list
+
+ def select_rerank(self, rerank_list, ev: gr.SelectData):
+ if ev.value == "-" and ev.index[0] == 0:
+ gr.Info("No reranking model is loaded. Please add first")
+ return ""
+
+ if not ev.selected:
+ return ""
+
+ return rerank_list["name"][ev.index[0]]
+
+ def on_selected_rerank_change(self, selected_rerank_name):
+ if selected_rerank_name == "":
+ _check_connection_panel = gr.update(visible=False)
+ _selected_panel = gr.update(visible=False)
+ _selected_panel_btn = gr.update(visible=False)
+ btn_delete = gr.update(visible=True)
+ btn_delete_yes = gr.update(visible=False)
+ btn_delete_no = gr.update(visible=False)
+ edit_spec = gr.update(value="")
+ edit_spec_desc = gr.update(value="")
+ edit_default = gr.update(value=False)
+ else:
+ _check_connection_panel = gr.update(visible=True)
+ _selected_panel = gr.update(visible=True)
+ _selected_panel_btn = gr.update(visible=True)
+ btn_delete = gr.update(visible=True)
+ btn_delete_yes = gr.update(visible=False)
+ btn_delete_no = gr.update(visible=False)
+
+ info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
+ vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
+ vendor = reranking_models_manager.vendors()[vendor_str]
+
+ edit_spec = yaml.dump(info["spec"])
+ edit_spec_desc = format_description(vendor)
+ edit_default = info["default"]
+
+ return (
+ _selected_panel,
+ _selected_panel_btn,
+ btn_delete,
+ btn_delete_yes,
+ btn_delete_no,
+ edit_spec,
+ edit_spec_desc,
+ edit_default,
+ _check_connection_panel,
+ )
+
+ def on_btn_delete_click(self):
+ btn_delete = gr.update(visible=False)
+ btn_delete_yes = gr.update(visible=True)
+ btn_delete_no = gr.update(visible=True)
+
+ return btn_delete, btn_delete_yes, btn_delete_no
+
+ def check_connection(self, selected_rerank_name, selected_spec):
+ log_content: str = ""
+ try:
+ log_content += f"- Testing model: {selected_rerank_name}
"
+ yield log_content
+
+ # Parse content & init model
+ info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
+
+ # Parse content & create dummy response
+ spec = yaml.load(selected_spec, Loader=YAMLNoDateSafeLoader)
+ info["spec"].update(spec)
+
+ rerank = deserialize(info["spec"], safe=False)
+
+ if rerank is None:
+ raise Exception(f"Can not found model: {selected_rerank_name}")
+
+ log_content += "- Sending a message ([`Hello`], `Hi`)
"
+ yield log_content
+ _ = rerank(["Hello"], "Hi")
+
+ log_content += (
+ "- Connection success. "
+ "
"
+ )
+ yield log_content
+
+ gr.Info(f"Embedding {selected_rerank_name} connect successfully")
+ except Exception as e:
+ print(e)
+ log_content += (
+ f"- Connection failed. "
+ f"Got error:\n {str(e)}"
+ )
+ yield log_content
+
+ return log_content
+
+ def save_rerank(self, selected_rerank_name, default, spec):
+ try:
+ spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
+ spec["__type__"] = reranking_models_manager.info()[selected_rerank_name][
+ "spec"
+ ]["__type__"]
+ reranking_models_manager.update(
+ selected_rerank_name, spec=spec, default=default
+ )
+ gr.Info(f'Save Reranking model "{selected_rerank_name}" successfully')
+ except Exception as e:
+ gr.Error(f'Failed to save Embedding model "{selected_rerank_name}": {e}')
+
+ def delete_rerank(self, selected_rerank_name):
+ try:
+ reranking_models_manager.delete(selected_rerank_name)
+ except Exception as e:
+ gr.Error(f'Failed to delete Reranking model "{selected_rerank_name}": {e}')
+ return selected_rerank_name
+
+ return ""
diff --git a/libs/ktem/ktem/utils/render.py b/libs/ktem/ktem/utils/render.py
index 9176627..3e6c434 100644
--- a/libs/ktem/ktem/utils/render.py
+++ b/libs/ktem/ktem/utils/render.py
@@ -154,9 +154,9 @@ class Render:
if doc.metadata.get("llm_trulens_score") is not None
else 0.0
)
- cohere_reranking_score = (
- round(doc.metadata["cohere_reranking_score"], 2)
- if doc.metadata.get("cohere_reranking_score") is not None
+ reranking_score = (
+ round(doc.metadata["reranking_score"], 2)
+ if doc.metadata.get("reranking_score") is not None
else 0.0
)
item_type_prefix = doc.metadata.get("type", "")
@@ -166,8 +166,8 @@ class Render:
if llm_reranking_score > 0:
relevant_score = llm_reranking_score
- elif cohere_reranking_score > 0:
- relevant_score = cohere_reranking_score
+ elif reranking_score > 0:
+ relevant_score = reranking_score
else:
relevant_score = 0.0
@@ -179,7 +179,7 @@ class Render:
" LLM relevant score:"
f" {llm_reranking_score}
"
" Reranking score:"
- f" {cohere_reranking_score}
",
+ f" {reranking_score}
",
)
text = doc.text if not override_text else override_text