feat: support TEI embedding service, configurable reranking model (#287)
* feat: add support for TEI embedding service, allow reranking model to be configurable. Signed-off-by: Kennywu <jdlow@live.cn> * fix: add cohere default reranking model * fix: comfort pre-commit --------- Signed-off-by: Kennywu <jdlow@live.cn> Co-authored-by: wujiaye <wujiaye@bluemoon.com.cn> Co-authored-by: Tadashi <tadashi@cinnamon.is>
This commit is contained in:
parent
2e3c17b256
commit
53530e296f
|
@ -92,6 +92,7 @@ KH_VECTORSTORE = {
|
||||||
}
|
}
|
||||||
KH_LLMS = {}
|
KH_LLMS = {}
|
||||||
KH_EMBEDDINGS = {}
|
KH_EMBEDDINGS = {}
|
||||||
|
KH_RERANKINGS = {}
|
||||||
|
|
||||||
# populate options from config
|
# populate options from config
|
||||||
if config("AZURE_OPENAI_API_KEY", default="") and config(
|
if config("AZURE_OPENAI_API_KEY", default="") and config(
|
||||||
|
@ -212,7 +213,7 @@ KH_LLMS["cohere"] = {
|
||||||
"spec": {
|
"spec": {
|
||||||
"__type__": "kotaemon.llms.chats.LCCohereChat",
|
"__type__": "kotaemon.llms.chats.LCCohereChat",
|
||||||
"model_name": "command-r-plus-08-2024",
|
"model_name": "command-r-plus-08-2024",
|
||||||
"api_key": "your-key",
|
"api_key": config("COHERE_API_KEY", default="your-key"),
|
||||||
},
|
},
|
||||||
"default": False,
|
"default": False,
|
||||||
}
|
}
|
||||||
|
@ -222,7 +223,7 @@ KH_EMBEDDINGS["cohere"] = {
|
||||||
"spec": {
|
"spec": {
|
||||||
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
|
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
|
||||||
"model": "embed-multilingual-v3.0",
|
"model": "embed-multilingual-v3.0",
|
||||||
"cohere_api_key": "your-key",
|
"cohere_api_key": config("COHERE_API_KEY", default="your-key"),
|
||||||
"user_agent": "default",
|
"user_agent": "default",
|
||||||
},
|
},
|
||||||
"default": False,
|
"default": False,
|
||||||
|
@ -235,6 +236,16 @@ KH_EMBEDDINGS["cohere"] = {
|
||||||
# "default": False,
|
# "default": False,
|
||||||
# }
|
# }
|
||||||
|
|
||||||
|
# default reranking models
|
||||||
|
KH_RERANKINGS["cohere"] = {
|
||||||
|
"spec": {
|
||||||
|
"__type__": "kotaemon.rerankings.CohereReranking",
|
||||||
|
"model_name": "rerank-multilingual-v2.0",
|
||||||
|
"cohere_api_key": config("COHERE_API_KEY", default="your-key"),
|
||||||
|
},
|
||||||
|
"default": True,
|
||||||
|
}
|
||||||
|
|
||||||
KH_REASONINGS = [
|
KH_REASONINGS = [
|
||||||
"ktem.reasoning.simple.FullQAPipeline",
|
"ktem.reasoning.simple.FullQAPipeline",
|
||||||
"ktem.reasoning.simple.FullDecomposeQAPipeline",
|
"ktem.reasoning.simple.FullDecomposeQAPipeline",
|
||||||
|
|
|
@ -8,10 +8,12 @@ from .langchain_based import (
|
||||||
LCOpenAIEmbeddings,
|
LCOpenAIEmbeddings,
|
||||||
)
|
)
|
||||||
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
|
from .tei_endpoint_embed import TeiEndpointEmbeddings
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseEmbeddings",
|
"BaseEmbeddings",
|
||||||
"EndpointEmbeddings",
|
"EndpointEmbeddings",
|
||||||
|
"TeiEndpointEmbeddings",
|
||||||
"LCOpenAIEmbeddings",
|
"LCOpenAIEmbeddings",
|
||||||
"LCAzureOpenAIEmbeddings",
|
"LCAzureOpenAIEmbeddings",
|
||||||
"LCCohereEmbeddings",
|
"LCCohereEmbeddings",
|
||||||
|
|
105
libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py
Normal file
105
libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from kotaemon.base import Document, DocumentWithEmbedding, Param
|
||||||
|
|
||||||
|
from .base import BaseEmbeddings
|
||||||
|
|
||||||
|
session = requests.session()
|
||||||
|
|
||||||
|
|
||||||
|
class TeiEndpointEmbeddings(BaseEmbeddings):
|
||||||
|
"""An Embeddings component that uses an
|
||||||
|
TEI (Text-Embedding-Inference) API compatible endpoint.
|
||||||
|
|
||||||
|
Ref: https://github.com/huggingface/text-embeddings-inference
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
endpoint_url (str): The url of an TEI
|
||||||
|
(Text-Embedding-Inference) API compatible endpoint.
|
||||||
|
normalize (bool): Whether to normalize embeddings to unit length.
|
||||||
|
truncate (bool): Whether to truncate embeddings
|
||||||
|
to a fixed/default length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint_url: str = Param(None, help="TEI embedding service api base URL")
|
||||||
|
normalize: bool = Param(
|
||||||
|
True,
|
||||||
|
help="Normalize embeddings to unit length",
|
||||||
|
)
|
||||||
|
truncate: bool = Param(
|
||||||
|
True,
|
||||||
|
help="Truncate embeddings to a fixed/default length",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def client_(self, inputs: list[str]):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
url=self.endpoint_url,
|
||||||
|
json={
|
||||||
|
"inputs": inputs,
|
||||||
|
"normalize": self.normalize,
|
||||||
|
"truncate": self.truncate,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
embeddings = await resp.json()
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||||
|
) -> list[DocumentWithEmbedding]:
|
||||||
|
if not isinstance(text, list):
|
||||||
|
text = [text]
|
||||||
|
text = self.prepare_input(text)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
batch_size = 6
|
||||||
|
num_batch = max(len(text) // batch_size, 1)
|
||||||
|
for i in range(num_batch):
|
||||||
|
if i == num_batch - 1:
|
||||||
|
mini_batch = text[batch_size * i :]
|
||||||
|
else:
|
||||||
|
mini_batch = text[batch_size * i : batch_size * (i + 1)]
|
||||||
|
mini_batch = [x.content for x in mini_batch]
|
||||||
|
embeddings = await self.client_(mini_batch) # type: ignore
|
||||||
|
outputs.extend(
|
||||||
|
[
|
||||||
|
DocumentWithEmbedding(content=doc, embedding=embedding)
|
||||||
|
for doc, embedding in zip(mini_batch, embeddings)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||||
|
) -> list[DocumentWithEmbedding]:
|
||||||
|
if not isinstance(text, list):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
text = self.prepare_input(text)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
batch_size = 6
|
||||||
|
num_batch = max(len(text) // batch_size, 1)
|
||||||
|
for i in range(num_batch):
|
||||||
|
if i == num_batch - 1:
|
||||||
|
mini_batch = text[batch_size * i :]
|
||||||
|
else:
|
||||||
|
mini_batch = text[batch_size * i : batch_size * (i + 1)]
|
||||||
|
mini_batch = [x.content for x in mini_batch]
|
||||||
|
embeddings = session.post(
|
||||||
|
url=self.endpoint_url,
|
||||||
|
json={
|
||||||
|
"inputs": mini_batch,
|
||||||
|
"normalize": self.normalize,
|
||||||
|
"truncate": self.truncate,
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
outputs.extend(
|
||||||
|
[
|
||||||
|
DocumentWithEmbedding(content=doc, embedding=embedding)
|
||||||
|
for doc, embedding in zip(mini_batch, embeddings)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return outputs
|
|
@ -39,7 +39,7 @@ class CohereReranking(BaseReranking):
|
||||||
print("Cannot get Cohere API key from `ktem`", e)
|
print("Cannot get Cohere API key from `ktem`", e)
|
||||||
|
|
||||||
if not self.cohere_api_key:
|
if not self.cohere_api_key:
|
||||||
print("Cohere API key not found. Skipping reranking.")
|
print("Cohere API key not found. Skipping rerankings.")
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
cohere_client = cohere.Client(self.cohere_api_key)
|
cohere_client = cohere.Client(self.cohere_api_key)
|
||||||
|
@ -52,10 +52,9 @@ class CohereReranking(BaseReranking):
|
||||||
response = cohere_client.rerank(
|
response = cohere_client.rerank(
|
||||||
model=self.model_name, query=query, documents=_docs
|
model=self.model_name, query=query, documents=_docs
|
||||||
)
|
)
|
||||||
# print("Cohere score", [r.relevance_score for r in response.results])
|
|
||||||
for r in response.results:
|
for r in response.results:
|
||||||
doc = documents[r.index]
|
doc = documents[r.index]
|
||||||
doc.metadata["cohere_reranking_score"] = r.relevance_score
|
doc.metadata["reranking_score"] = r.relevance_score
|
||||||
compressed_docs.append(doc)
|
compressed_docs.append(doc)
|
||||||
|
|
||||||
return compressed_docs
|
return compressed_docs
|
||||||
|
|
|
@ -241,7 +241,7 @@ class VectorRetrieval(BaseRetrieval):
|
||||||
# if reranker is LLMReranking, limit the document with top_k items only
|
# if reranker is LLMReranking, limit the document with top_k items only
|
||||||
if isinstance(reranker, LLMReranking):
|
if isinstance(reranker, LLMReranking):
|
||||||
result = self._filter_docs(result, top_k=top_k)
|
result = self._filter_docs(result, top_k=top_k)
|
||||||
result = reranker(documents=result, query=text)
|
result = reranker.run(documents=result, query=text)
|
||||||
|
|
||||||
result = self._filter_docs(result, top_k=top_k)
|
result = self._filter_docs(result, top_k=top_k)
|
||||||
print(f"Got raw {len(result)} retrieved documents")
|
print(f"Got raw {len(result)} retrieved documents")
|
||||||
|
|
5
libs/kotaemon/kotaemon/rerankings/__init__.py
Normal file
5
libs/kotaemon/kotaemon/rerankings/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
from .base import BaseReranking
|
||||||
|
from .cohere import CohereReranking
|
||||||
|
from .tei_fast_rerank import TeiFastReranking
|
||||||
|
|
||||||
|
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"]
|
13
libs/kotaemon/kotaemon/rerankings/base.py
Normal file
13
libs/kotaemon/kotaemon/rerankings/base.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent, Document
|
||||||
|
|
||||||
|
|
||||||
|
class BaseReranking(BaseComponent):
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||||
|
"""Main method to transform list of documents
|
||||||
|
(re-ranking, filtering, etc)"""
|
||||||
|
...
|
56
libs/kotaemon/kotaemon/rerankings/cohere.py
Normal file
56
libs/kotaemon/kotaemon/rerankings/cohere.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from decouple import config
|
||||||
|
|
||||||
|
from kotaemon.base import Document, Param
|
||||||
|
|
||||||
|
from .base import BaseReranking
|
||||||
|
|
||||||
|
|
||||||
|
class CohereReranking(BaseReranking):
|
||||||
|
"""Cohere Reranking model"""
|
||||||
|
|
||||||
|
model_name: str = Param(
|
||||||
|
"rerank-multilingual-v2.0",
|
||||||
|
help=(
|
||||||
|
"ID of the model to use. You can go to [Supported Models]"
|
||||||
|
"(https://docs.cohere.com/docs/rerank-2) to see the supported models"
|
||||||
|
),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
cohere_api_key: str = Param(
|
||||||
|
config("COHERE_API_KEY", ""),
|
||||||
|
help="Cohere API key",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||||
|
"""Use Cohere Reranker model to re-order documents
|
||||||
|
with their relevance score"""
|
||||||
|
try:
|
||||||
|
import cohere
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.cohere_api_key:
|
||||||
|
print("Cohere API key not found. Skipping rerankings.")
|
||||||
|
return documents
|
||||||
|
|
||||||
|
cohere_client = cohere.Client(self.cohere_api_key)
|
||||||
|
compressed_docs: list[Document] = []
|
||||||
|
|
||||||
|
if not documents: # to avoid empty api call
|
||||||
|
return compressed_docs
|
||||||
|
|
||||||
|
_docs = [d.content for d in documents]
|
||||||
|
response = cohere_client.rerank(
|
||||||
|
model=self.model_name, query=query, documents=_docs
|
||||||
|
)
|
||||||
|
for r in response.results:
|
||||||
|
doc = documents[r.index]
|
||||||
|
doc.metadata["reranking_score"] = r.relevance_score
|
||||||
|
compressed_docs.append(doc)
|
||||||
|
|
||||||
|
return compressed_docs
|
77
libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py
Normal file
77
libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from kotaemon.base import Document, Param
|
||||||
|
|
||||||
|
from .base import BaseReranking
|
||||||
|
|
||||||
|
session = requests.session()
|
||||||
|
|
||||||
|
|
||||||
|
class TeiFastReranking(BaseReranking):
|
||||||
|
"""Text Embeddings Inference (TEI) Reranking model
|
||||||
|
(https://huggingface.co/docs/text-embeddings-inference/en/index)
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint_url: str = Param(
|
||||||
|
None, help="TEI Reranking service api base URL", required=True
|
||||||
|
)
|
||||||
|
model_name: Optional[str] = Param(
|
||||||
|
None,
|
||||||
|
help=(
|
||||||
|
"ID of the model to use. You can go to [Supported Models]"
|
||||||
|
"(https://github.com/huggingface"
|
||||||
|
"/text-embeddings-inference?tab=readme-ov-file"
|
||||||
|
"#supported-models) to see the supported models"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
is_truncated: Optional[bool] = Param(True, help="Whether to truncate the inputs")
|
||||||
|
|
||||||
|
def client(self, query, texts):
|
||||||
|
response = session.post(
|
||||||
|
url=self.endpoint_url,
|
||||||
|
json={
|
||||||
|
"query": query,
|
||||||
|
"texts": texts,
|
||||||
|
"is_truncated": self.is_truncated, # default is True
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
return response
|
||||||
|
|
||||||
|
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||||
|
"""Use the deployed TEI rerankings service to re-order documents
|
||||||
|
with their relevance score"""
|
||||||
|
if not self.endpoint_url:
|
||||||
|
print("TEI API reranking URL not found. Skipping rerankings.")
|
||||||
|
return documents
|
||||||
|
|
||||||
|
compressed_docs: list[Document] = []
|
||||||
|
|
||||||
|
if not documents: # to avoid empty api call
|
||||||
|
return compressed_docs
|
||||||
|
|
||||||
|
if isinstance(documents[0], str):
|
||||||
|
documents = self.prepare_input(documents)
|
||||||
|
|
||||||
|
batch_size = 6
|
||||||
|
num_batch = max(len(documents) // batch_size, 1)
|
||||||
|
for i in range(num_batch):
|
||||||
|
if i == num_batch - 1:
|
||||||
|
mini_batch = documents[batch_size * i :]
|
||||||
|
else:
|
||||||
|
mini_batch = documents[batch_size * i : batch_size * (i + 1)]
|
||||||
|
|
||||||
|
_docs = [d.content for d in mini_batch]
|
||||||
|
rerank_resp = self.client(query, _docs)
|
||||||
|
for r in rerank_resp:
|
||||||
|
doc = mini_batch[r["index"]]
|
||||||
|
doc.metadata["reranking_score"] = r["score"]
|
||||||
|
compressed_docs.append(doc)
|
||||||
|
|
||||||
|
compressed_docs = sorted(
|
||||||
|
compressed_docs, key=lambda x: x.metadata["reranking_score"], reverse=True
|
||||||
|
)
|
||||||
|
return compressed_docs
|
|
@ -59,6 +59,7 @@ class EmbeddingManager:
|
||||||
LCCohereEmbeddings,
|
LCCohereEmbeddings,
|
||||||
LCHuggingFaceEmbeddings,
|
LCHuggingFaceEmbeddings,
|
||||||
OpenAIEmbeddings,
|
OpenAIEmbeddings,
|
||||||
|
TeiEndpointEmbeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._vendors = [
|
self._vendors = [
|
||||||
|
@ -67,6 +68,7 @@ class EmbeddingManager:
|
||||||
FastEmbedEmbeddings,
|
FastEmbedEmbeddings,
|
||||||
LCCohereEmbeddings,
|
LCCohereEmbeddings,
|
||||||
LCHuggingFaceEmbeddings,
|
LCHuggingFaceEmbeddings,
|
||||||
|
TeiEndpointEmbeddings,
|
||||||
]
|
]
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> BaseEmbeddings:
|
def __getitem__(self, key: str) -> BaseEmbeddings:
|
||||||
|
|
|
@ -16,6 +16,7 @@ import tiktoken
|
||||||
from ktem.db.models import engine
|
from ktem.db.models import engine
|
||||||
from ktem.embeddings.manager import embedding_models_manager
|
from ktem.embeddings.manager import embedding_models_manager
|
||||||
from ktem.llms.manager import llms
|
from ktem.llms.manager import llms
|
||||||
|
from ktem.rerankings.manager import reranking_models_manager
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from llama_index.core.readers.file.base import default_file_metadata_func
|
from llama_index.core.readers.file.base import default_file_metadata_func
|
||||||
from llama_index.core.vector_stores import (
|
from llama_index.core.vector_stores import (
|
||||||
|
@ -39,12 +40,7 @@ from kotaemon.indices.ingests.files import (
|
||||||
azure_reader,
|
azure_reader,
|
||||||
unstructured,
|
unstructured,
|
||||||
)
|
)
|
||||||
from kotaemon.indices.rankings import (
|
from kotaemon.indices.rankings import BaseReranking, LLMReranking, LLMTrulensScoring
|
||||||
BaseReranking,
|
|
||||||
CohereReranking,
|
|
||||||
LLMReranking,
|
|
||||||
LLMTrulensScoring,
|
|
||||||
)
|
|
||||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||||
|
|
||||||
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||||
|
@ -285,7 +281,13 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
],
|
],
|
||||||
retrieval_mode=user_settings["retrieval_mode"],
|
retrieval_mode=user_settings["retrieval_mode"],
|
||||||
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
|
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
|
||||||
rerankers=[CohereReranking(use_key_from_ktem=True)],
|
rerankers=[
|
||||||
|
reranking_models_manager[
|
||||||
|
index_settings.get(
|
||||||
|
"reranking", reranking_models_manager.get_default_name()
|
||||||
|
)
|
||||||
|
]
|
||||||
|
],
|
||||||
)
|
)
|
||||||
if not user_settings["use_reranking"]:
|
if not user_settings["use_reranking"]:
|
||||||
retriever.rerankers = [] # type: ignore
|
retriever.rerankers = [] # type: ignore
|
||||||
|
@ -715,7 +717,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||||
for idx, file_path in enumerate(file_paths):
|
for idx, file_path in enumerate(file_paths):
|
||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
yield Document(
|
yield Document(
|
||||||
content=f"Indexing [{idx+1}/{n_files}]: {file_path.name}",
|
content=f"Indexing [{idx + 1}/{n_files}]: {file_path.name}",
|
||||||
channel="debug",
|
channel="debug",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from ktem.db.models import User, engine
|
||||||
from ktem.embeddings.ui import EmbeddingManagement
|
from ktem.embeddings.ui import EmbeddingManagement
|
||||||
from ktem.index.ui import IndexManagement
|
from ktem.index.ui import IndexManagement
|
||||||
from ktem.llms.ui import LLMManagement
|
from ktem.llms.ui import LLMManagement
|
||||||
|
from ktem.rerankings.ui import RerankingManagement
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from .user import UserManagement
|
from .user import UserManagement
|
||||||
|
@ -24,6 +25,9 @@ class ResourcesTab(BasePage):
|
||||||
with gr.Tab("Embeddings") as self.emb_management_tab:
|
with gr.Tab("Embeddings") as self.emb_management_tab:
|
||||||
self.emb_management = EmbeddingManagement(self._app)
|
self.emb_management = EmbeddingManagement(self._app)
|
||||||
|
|
||||||
|
with gr.Tab("Rerankings") as self.rerank_management_tab:
|
||||||
|
self.rerank_management = RerankingManagement(self._app)
|
||||||
|
|
||||||
if self._app.f_user_management:
|
if self._app.f_user_management:
|
||||||
with gr.Tab("Users", visible=False) as self.user_management_tab:
|
with gr.Tab("Users", visible=False) as self.user_management_tab:
|
||||||
self.user_management = UserManagement(self._app)
|
self.user_management = UserManagement(self._app)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import requests
|
||||||
from ktem.app import BasePage
|
from ktem.app import BasePage
|
||||||
from ktem.embeddings.manager import embedding_models_manager as embeddings
|
from ktem.embeddings.manager import embedding_models_manager as embeddings
|
||||||
from ktem.llms.manager import llms
|
from ktem.llms.manager import llms
|
||||||
|
from ktem.rerankings.manager import reranking_models_manager as rerankers
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
||||||
|
@ -186,6 +187,15 @@ class SetupPage(BasePage):
|
||||||
},
|
},
|
||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
|
rerankers.update(
|
||||||
|
name="cohere",
|
||||||
|
spec={
|
||||||
|
"__type__": "kotaemon.rerankings.CohereReranking",
|
||||||
|
"model_name": "rerank-multilingual-v2.0",
|
||||||
|
"cohere_api_key": cohere_api_key,
|
||||||
|
},
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
elif radio_model_value == "openai":
|
elif radio_model_value == "openai":
|
||||||
if openai_api_key:
|
if openai_api_key:
|
||||||
llms.update(
|
llms.update(
|
||||||
|
|
|
@ -100,7 +100,7 @@ class DocSearchTool(BaseTool):
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
|
print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
|
||||||
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
|
print("Score", retrieved_item.metadata.get("reranking_score", None))
|
||||||
|
|
||||||
# trim context by trim_len
|
# trim context by trim_len
|
||||||
if evidence:
|
if evidence:
|
||||||
|
|
|
@ -138,7 +138,7 @@ class DocSearchTool(BaseTool):
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Retrieved #{}: {}".format(_id, retrieved_content))
|
print("Retrieved #{}: {}".format(_id, retrieved_content))
|
||||||
print("Score", retrieved_item.metadata.get("cohere_reranking_score", None))
|
print("Score", retrieved_item.metadata.get("reranking_score", None))
|
||||||
|
|
||||||
# trim context by trim_len
|
# trim context by trim_len
|
||||||
if evidence:
|
if evidence:
|
||||||
|
|
0
libs/ktem/ktem/rerankings/__init__.py
Normal file
0
libs/ktem/ktem/rerankings/__init__.py
Normal file
36
libs/ktem/ktem/rerankings/db.py
Normal file
36
libs/ktem/ktem/rerankings/db.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from ktem.db.engine import engine
|
||||||
|
from sqlalchemy import JSON, Boolean, Column, String
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
from theflow.utils.modules import import_dotted_string
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRerankingTable(Base):
|
||||||
|
"""Base table to store rerankings model"""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
name = Column(String, primary_key=True, unique=True)
|
||||||
|
spec = Column(JSON, default={})
|
||||||
|
default = Column(Boolean, default=False)
|
||||||
|
|
||||||
|
|
||||||
|
__base_reranking: Type[BaseRerankingTable] = (
|
||||||
|
import_dotted_string(flowsettings.KH_TABLE_RERANKING, safe=False)
|
||||||
|
if hasattr(flowsettings, "KH_TABLE_RERANKING")
|
||||||
|
else BaseRerankingTable
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RerankingTable(__base_reranking): # type: ignore
|
||||||
|
__tablename__ = "reranking"
|
||||||
|
|
||||||
|
|
||||||
|
if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
|
||||||
|
RerankingTable.metadata.create_all(engine)
|
194
libs/ktem/ktem/rerankings/manager.py
Normal file
194
libs/ktem/ktem/rerankings/manager.py
Normal file
|
@ -0,0 +1,194 @@
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
from theflow.utils.modules import deserialize
|
||||||
|
|
||||||
|
from kotaemon.rerankings.base import BaseReranking
|
||||||
|
|
||||||
|
from .db import RerankingTable, engine
|
||||||
|
|
||||||
|
|
||||||
|
class RerankingManager:
|
||||||
|
"""Represent a pool of rerankings models"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._models: dict[str, BaseReranking] = {}
|
||||||
|
self._info: dict[str, dict] = {}
|
||||||
|
self._default: str = ""
|
||||||
|
self._vendors: list[Type] = []
|
||||||
|
|
||||||
|
# populate the pool if empty
|
||||||
|
if hasattr(flowsettings, "KH_RERANKINGS"):
|
||||||
|
with Session(engine) as sess:
|
||||||
|
count = sess.query(RerankingTable).count()
|
||||||
|
if not count:
|
||||||
|
for name, model in flowsettings.KH_RERANKINGS.items():
|
||||||
|
self.add(
|
||||||
|
name=name,
|
||||||
|
spec=model["spec"],
|
||||||
|
default=model.get("default", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.load()
|
||||||
|
self.load_vendors()
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
"""Load the model pool from database"""
|
||||||
|
self._models, self._info, self._default = {}, {}, ""
|
||||||
|
with Session(engine) as sess:
|
||||||
|
stmt = select(RerankingTable)
|
||||||
|
items = sess.execute(stmt)
|
||||||
|
|
||||||
|
for (item,) in items:
|
||||||
|
self._models[item.name] = deserialize(item.spec, safe=False)
|
||||||
|
self._info[item.name] = {
|
||||||
|
"name": item.name,
|
||||||
|
"spec": item.spec,
|
||||||
|
"default": item.default,
|
||||||
|
}
|
||||||
|
if item.default:
|
||||||
|
self._default = item.name
|
||||||
|
|
||||||
|
def load_vendors(self):
|
||||||
|
from kotaemon.rerankings import CohereReranking, TeiFastReranking
|
||||||
|
|
||||||
|
self._vendors = [TeiFastReranking, CohereReranking]
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> BaseReranking:
|
||||||
|
"""Get model by name"""
|
||||||
|
return self._models[key]
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
"""Check if model exists"""
|
||||||
|
return key in self._models
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, key: str, default: Optional[BaseReranking] = None
|
||||||
|
) -> Optional[BaseReranking]:
|
||||||
|
"""Get model by name with default value"""
|
||||||
|
return self._models.get(key, default)
|
||||||
|
|
||||||
|
def settings(self) -> dict:
|
||||||
|
"""Present model pools option for gradio"""
|
||||||
|
return {
|
||||||
|
"label": "Reranking",
|
||||||
|
"choices": list(self._models.keys()),
|
||||||
|
"value": self.get_default_name(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def options(self) -> dict:
|
||||||
|
"""Present a dict of models"""
|
||||||
|
return self._models
|
||||||
|
|
||||||
|
def get_random_name(self) -> str:
|
||||||
|
"""Get the name of random model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: random model name in the pool
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
|
||||||
|
if not self._models:
|
||||||
|
raise ValueError("No models is pool")
|
||||||
|
|
||||||
|
return random.choice(list(self._models.keys()))
|
||||||
|
|
||||||
|
def get_default_name(self) -> str:
|
||||||
|
"""Get the name of default model
|
||||||
|
|
||||||
|
In case there is no default model, choose random model from pool. In
|
||||||
|
case there are multiple default models, choose random from them.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: model name
|
||||||
|
"""
|
||||||
|
if not self._models:
|
||||||
|
raise ValueError("No models in pool")
|
||||||
|
|
||||||
|
if not self._default:
|
||||||
|
return self.get_random_name()
|
||||||
|
|
||||||
|
return self._default
|
||||||
|
|
||||||
|
def get_random(self) -> BaseReranking:
|
||||||
|
"""Get random model"""
|
||||||
|
return self._models[self.get_random_name()]
|
||||||
|
|
||||||
|
def get_default(self) -> BaseReranking:
|
||||||
|
"""Get default model
|
||||||
|
|
||||||
|
In case there is no default model, choose random model from pool. In
|
||||||
|
case there are multiple default models, choose random from them.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseReranking: model
|
||||||
|
"""
|
||||||
|
return self._models[self.get_default_name()]
|
||||||
|
|
||||||
|
def info(self) -> dict:
|
||||||
|
"""List all models"""
|
||||||
|
return self._info
|
||||||
|
|
||||||
|
def add(self, name: str, spec: dict, default: bool):
|
||||||
|
if not name:
|
||||||
|
raise ValueError("Name must not be empty")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with Session(engine) as sess:
|
||||||
|
if default:
|
||||||
|
# turn all models to non-default
|
||||||
|
sess.query(RerankingTable).update({"default": False})
|
||||||
|
sess.commit()
|
||||||
|
|
||||||
|
item = RerankingTable(name=name, spec=spec, default=default)
|
||||||
|
sess.add(item)
|
||||||
|
sess.commit()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to add model {name}: {e}")
|
||||||
|
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def delete(self, name: str):
|
||||||
|
"""Delete a model from the pool"""
|
||||||
|
try:
|
||||||
|
with Session(engine) as sess:
|
||||||
|
item = sess.query(RerankingTable).filter_by(name=name).first()
|
||||||
|
sess.delete(item)
|
||||||
|
sess.commit()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to delete model {name}: {e}")
|
||||||
|
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def update(self, name: str, spec: dict, default: bool):
|
||||||
|
"""Update a model in the pool"""
|
||||||
|
if not name:
|
||||||
|
raise ValueError("Name must not be empty")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with Session(engine) as sess:
|
||||||
|
|
||||||
|
if default:
|
||||||
|
# turn all models to non-default
|
||||||
|
sess.query(RerankingTable).update({"default": False})
|
||||||
|
sess.commit()
|
||||||
|
|
||||||
|
item = sess.query(RerankingTable).filter_by(name=name).first()
|
||||||
|
if not item:
|
||||||
|
raise ValueError(f"Model {name} not found")
|
||||||
|
item.spec = spec
|
||||||
|
item.default = default
|
||||||
|
sess.commit()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to update model {name}: {e}")
|
||||||
|
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def vendors(self) -> dict:
|
||||||
|
"""Return list of vendors"""
|
||||||
|
return {vendor.__qualname__: vendor for vendor in self._vendors}
|
||||||
|
|
||||||
|
|
||||||
|
reranking_models_manager = RerankingManager()
|
390
libs/ktem/ktem/rerankings/ui.py
Normal file
390
libs/ktem/ktem/rerankings/ui.py
Normal file
|
@ -0,0 +1,390 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import pandas as pd
|
||||||
|
import yaml
|
||||||
|
from ktem.app import BasePage
|
||||||
|
from ktem.utils.file import YAMLNoDateSafeLoader
|
||||||
|
from theflow.utils.modules import deserialize
|
||||||
|
|
||||||
|
from .manager import reranking_models_manager
|
||||||
|
|
||||||
|
|
||||||
|
def format_description(cls):
|
||||||
|
params = cls.describe()["params"]
|
||||||
|
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
|
||||||
|
for key, value in params.items():
|
||||||
|
if isinstance(value["auto_callback"], str):
|
||||||
|
continue
|
||||||
|
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
|
||||||
|
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
|
||||||
|
|
||||||
|
|
||||||
|
class RerankingManagement(BasePage):
|
||||||
|
def __init__(self, app):
|
||||||
|
self._app = app
|
||||||
|
self.spec_desc_default = (
|
||||||
|
"# Spec description\n\nSelect a model to view the spec description."
|
||||||
|
)
|
||||||
|
self.on_building_ui()
|
||||||
|
|
||||||
|
def on_building_ui(self):
|
||||||
|
with gr.Tab(label="View"):
|
||||||
|
self.rerank_list = gr.DataFrame(
|
||||||
|
headers=["name", "vendor", "default"],
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(visible=False) as self._selected_panel:
|
||||||
|
self.selected_rerank_name = gr.Textbox(value="", visible=False)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
self.edit_default = gr.Checkbox(
|
||||||
|
label="Set default",
|
||||||
|
info=(
|
||||||
|
"Set this Reranking model as default. This default "
|
||||||
|
"Reranking will be used by other components by default "
|
||||||
|
"if no Reranking is specified for such components."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.edit_spec = gr.Textbox(
|
||||||
|
label="Specification",
|
||||||
|
info="Specification of the Embedding model in YAML format",
|
||||||
|
lines=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Accordion(
|
||||||
|
label="Test connection", visible=False, open=False
|
||||||
|
) as self._check_connection_panel:
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=4):
|
||||||
|
self.connection_logs = gr.HTML(
|
||||||
|
"Logs",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
self.btn_test_connection = gr.Button("Test")
|
||||||
|
|
||||||
|
with gr.Row(visible=False) as self._selected_panel_btn:
|
||||||
|
with gr.Column():
|
||||||
|
self.btn_edit_save = gr.Button(
|
||||||
|
"Save", min_width=10, variant="primary"
|
||||||
|
)
|
||||||
|
with gr.Column():
|
||||||
|
self.btn_delete = gr.Button(
|
||||||
|
"Delete", min_width=10, variant="stop"
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
self.btn_delete_yes = gr.Button(
|
||||||
|
"Confirm Delete",
|
||||||
|
variant="stop",
|
||||||
|
visible=False,
|
||||||
|
min_width=10,
|
||||||
|
)
|
||||||
|
self.btn_delete_no = gr.Button(
|
||||||
|
"Cancel", visible=False, min_width=10
|
||||||
|
)
|
||||||
|
with gr.Column():
|
||||||
|
self.btn_close = gr.Button("Close", min_width=10)
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
self.edit_spec_desc = gr.Markdown("# Spec description")
|
||||||
|
|
||||||
|
with gr.Tab(label="Add"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
self.name = gr.Textbox(
|
||||||
|
label="Name",
|
||||||
|
info=(
|
||||||
|
"Must be unique and non-empty. "
|
||||||
|
"The name will be used to identify the reranking model."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.rerank_choices = gr.Dropdown(
|
||||||
|
label="Vendors",
|
||||||
|
info=(
|
||||||
|
"Choose the vendor of the Reranking model. Each vendor "
|
||||||
|
"has different specification."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.spec = gr.Textbox(
|
||||||
|
label="Specification",
|
||||||
|
info="Specification of the Embedding model in YAML format.",
|
||||||
|
)
|
||||||
|
self.default = gr.Checkbox(
|
||||||
|
label="Set default",
|
||||||
|
info=(
|
||||||
|
"Set this Reranking model as default. This default "
|
||||||
|
"Reranking will be used by other components by default "
|
||||||
|
"if no Reranking is specified for such components."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.btn_new = gr.Button("Add", variant="primary")
|
||||||
|
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
self.spec_desc = gr.Markdown(self.spec_desc_default)
|
||||||
|
|
||||||
|
def _on_app_created(self):
|
||||||
|
"""Called when the app is created"""
|
||||||
|
self._app.app.load(
|
||||||
|
self.list_rerankings,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[self.rerank_list],
|
||||||
|
)
|
||||||
|
self._app.app.load(
|
||||||
|
lambda: gr.update(choices=list(reranking_models_manager.vendors().keys())),
|
||||||
|
outputs=[self.rerank_choices],
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_rerank_vendor_change(self, vendor):
|
||||||
|
vendor = reranking_models_manager.vendors()[vendor]
|
||||||
|
|
||||||
|
required: dict = {}
|
||||||
|
desc = vendor.describe()
|
||||||
|
for key, value in desc["params"].items():
|
||||||
|
if value.get("required", False):
|
||||||
|
required[key] = value.get("default", None)
|
||||||
|
|
||||||
|
return yaml.dump(required), format_description(vendor)
|
||||||
|
|
||||||
|
def on_register_events(self):
|
||||||
|
self.rerank_choices.select(
|
||||||
|
self.on_rerank_vendor_change,
|
||||||
|
inputs=[self.rerank_choices],
|
||||||
|
outputs=[self.spec, self.spec_desc],
|
||||||
|
)
|
||||||
|
self.btn_new.click(
|
||||||
|
self.create_rerank,
|
||||||
|
inputs=[self.name, self.rerank_choices, self.spec, self.default],
|
||||||
|
outputs=None,
|
||||||
|
).success(self.list_rerankings, inputs=[], outputs=[self.rerank_list]).success(
|
||||||
|
lambda: ("", None, "", False, self.spec_desc_default),
|
||||||
|
outputs=[
|
||||||
|
self.name,
|
||||||
|
self.rerank_choices,
|
||||||
|
self.spec,
|
||||||
|
self.default,
|
||||||
|
self.spec_desc,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.rerank_list.select(
|
||||||
|
self.select_rerank,
|
||||||
|
inputs=self.rerank_list,
|
||||||
|
outputs=[self.selected_rerank_name],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
self.selected_rerank_name.change(
|
||||||
|
self.on_selected_rerank_change,
|
||||||
|
inputs=[self.selected_rerank_name],
|
||||||
|
outputs=[
|
||||||
|
self._selected_panel,
|
||||||
|
self._selected_panel_btn,
|
||||||
|
# delete section
|
||||||
|
self.btn_delete,
|
||||||
|
self.btn_delete_yes,
|
||||||
|
self.btn_delete_no,
|
||||||
|
# edit section
|
||||||
|
self.edit_spec,
|
||||||
|
self.edit_spec_desc,
|
||||||
|
self.edit_default,
|
||||||
|
self._check_connection_panel,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
).success(lambda: gr.update(value=""), outputs=[self.connection_logs])
|
||||||
|
|
||||||
|
self.btn_delete.click(
|
||||||
|
self.on_btn_delete_click,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
self.btn_delete_yes.click(
|
||||||
|
self.delete_rerank,
|
||||||
|
inputs=[self.selected_rerank_name],
|
||||||
|
outputs=[self.selected_rerank_name],
|
||||||
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
self.list_rerankings,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[self.rerank_list],
|
||||||
|
)
|
||||||
|
self.btn_delete_no.click(
|
||||||
|
lambda: (
|
||||||
|
gr.update(visible=True),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(visible=False),
|
||||||
|
),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
self.btn_edit_save.click(
|
||||||
|
self.save_rerank,
|
||||||
|
inputs=[
|
||||||
|
self.selected_rerank_name,
|
||||||
|
self.edit_default,
|
||||||
|
self.edit_spec,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
self.list_rerankings,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[self.rerank_list],
|
||||||
|
)
|
||||||
|
self.btn_close.click(lambda: "", outputs=[self.selected_rerank_name])
|
||||||
|
|
||||||
|
self.btn_test_connection.click(
|
||||||
|
self.check_connection,
|
||||||
|
inputs=[self.selected_rerank_name, self.edit_spec],
|
||||||
|
outputs=[self.connection_logs],
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_rerank(self, name, choices, spec, default):
|
||||||
|
try:
|
||||||
|
spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
|
||||||
|
spec["__type__"] = (
|
||||||
|
reranking_models_manager.vendors()[choices].__module__
|
||||||
|
+ "."
|
||||||
|
+ reranking_models_manager.vendors()[choices].__qualname__
|
||||||
|
)
|
||||||
|
|
||||||
|
reranking_models_manager.add(name, spec=spec, default=default)
|
||||||
|
gr.Info(f'Create Reranking model "{name}" successfully')
|
||||||
|
except Exception as e:
|
||||||
|
raise gr.Error(f"Failed to create Reranking model {name}: {e}")
|
||||||
|
|
||||||
|
def list_rerankings(self):
|
||||||
|
"""List the Reranking models"""
|
||||||
|
items = []
|
||||||
|
for item in reranking_models_manager.info().values():
|
||||||
|
record = {}
|
||||||
|
record["name"] = item["name"]
|
||||||
|
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
|
||||||
|
record["default"] = item["default"]
|
||||||
|
items.append(record)
|
||||||
|
|
||||||
|
if items:
|
||||||
|
rerank_list = pd.DataFrame.from_records(items)
|
||||||
|
else:
|
||||||
|
rerank_list = pd.DataFrame.from_records(
|
||||||
|
[{"name": "-", "vendor": "-", "default": "-"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
return rerank_list
|
||||||
|
|
||||||
|
def select_rerank(self, rerank_list, ev: gr.SelectData):
|
||||||
|
if ev.value == "-" and ev.index[0] == 0:
|
||||||
|
gr.Info("No reranking model is loaded. Please add first")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if not ev.selected:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return rerank_list["name"][ev.index[0]]
|
||||||
|
|
||||||
|
def on_selected_rerank_change(self, selected_rerank_name):
|
||||||
|
if selected_rerank_name == "":
|
||||||
|
_check_connection_panel = gr.update(visible=False)
|
||||||
|
_selected_panel = gr.update(visible=False)
|
||||||
|
_selected_panel_btn = gr.update(visible=False)
|
||||||
|
btn_delete = gr.update(visible=True)
|
||||||
|
btn_delete_yes = gr.update(visible=False)
|
||||||
|
btn_delete_no = gr.update(visible=False)
|
||||||
|
edit_spec = gr.update(value="")
|
||||||
|
edit_spec_desc = gr.update(value="")
|
||||||
|
edit_default = gr.update(value=False)
|
||||||
|
else:
|
||||||
|
_check_connection_panel = gr.update(visible=True)
|
||||||
|
_selected_panel = gr.update(visible=True)
|
||||||
|
_selected_panel_btn = gr.update(visible=True)
|
||||||
|
btn_delete = gr.update(visible=True)
|
||||||
|
btn_delete_yes = gr.update(visible=False)
|
||||||
|
btn_delete_no = gr.update(visible=False)
|
||||||
|
|
||||||
|
info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
|
||||||
|
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
|
||||||
|
vendor = reranking_models_manager.vendors()[vendor_str]
|
||||||
|
|
||||||
|
edit_spec = yaml.dump(info["spec"])
|
||||||
|
edit_spec_desc = format_description(vendor)
|
||||||
|
edit_default = info["default"]
|
||||||
|
|
||||||
|
return (
|
||||||
|
_selected_panel,
|
||||||
|
_selected_panel_btn,
|
||||||
|
btn_delete,
|
||||||
|
btn_delete_yes,
|
||||||
|
btn_delete_no,
|
||||||
|
edit_spec,
|
||||||
|
edit_spec_desc,
|
||||||
|
edit_default,
|
||||||
|
_check_connection_panel,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_btn_delete_click(self):
|
||||||
|
btn_delete = gr.update(visible=False)
|
||||||
|
btn_delete_yes = gr.update(visible=True)
|
||||||
|
btn_delete_no = gr.update(visible=True)
|
||||||
|
|
||||||
|
return btn_delete, btn_delete_yes, btn_delete_no
|
||||||
|
|
||||||
|
def check_connection(self, selected_rerank_name, selected_spec):
|
||||||
|
log_content: str = ""
|
||||||
|
try:
|
||||||
|
log_content += f"- Testing model: {selected_rerank_name}<br>"
|
||||||
|
yield log_content
|
||||||
|
|
||||||
|
# Parse content & init model
|
||||||
|
info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
|
||||||
|
|
||||||
|
# Parse content & create dummy response
|
||||||
|
spec = yaml.load(selected_spec, Loader=YAMLNoDateSafeLoader)
|
||||||
|
info["spec"].update(spec)
|
||||||
|
|
||||||
|
rerank = deserialize(info["spec"], safe=False)
|
||||||
|
|
||||||
|
if rerank is None:
|
||||||
|
raise Exception(f"Can not found model: {selected_rerank_name}")
|
||||||
|
|
||||||
|
log_content += "- Sending a message ([`Hello`], `Hi`)<br>"
|
||||||
|
yield log_content
|
||||||
|
_ = rerank(["Hello"], "Hi")
|
||||||
|
|
||||||
|
log_content += (
|
||||||
|
"<mark style='background: green; color: white'>- Connection success. "
|
||||||
|
"</mark><br>"
|
||||||
|
)
|
||||||
|
yield log_content
|
||||||
|
|
||||||
|
gr.Info(f"Embedding {selected_rerank_name} connect successfully")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
log_content += (
|
||||||
|
f"<mark style='color: yellow; background: red'>- Connection failed. "
|
||||||
|
f"Got error:\n {str(e)}</mark>"
|
||||||
|
)
|
||||||
|
yield log_content
|
||||||
|
|
||||||
|
return log_content
|
||||||
|
|
||||||
|
def save_rerank(self, selected_rerank_name, default, spec):
|
||||||
|
try:
|
||||||
|
spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
|
||||||
|
spec["__type__"] = reranking_models_manager.info()[selected_rerank_name][
|
||||||
|
"spec"
|
||||||
|
]["__type__"]
|
||||||
|
reranking_models_manager.update(
|
||||||
|
selected_rerank_name, spec=spec, default=default
|
||||||
|
)
|
||||||
|
gr.Info(f'Save Reranking model "{selected_rerank_name}" successfully')
|
||||||
|
except Exception as e:
|
||||||
|
gr.Error(f'Failed to save Embedding model "{selected_rerank_name}": {e}')
|
||||||
|
|
||||||
|
def delete_rerank(self, selected_rerank_name):
|
||||||
|
try:
|
||||||
|
reranking_models_manager.delete(selected_rerank_name)
|
||||||
|
except Exception as e:
|
||||||
|
gr.Error(f'Failed to delete Reranking model "{selected_rerank_name}": {e}')
|
||||||
|
return selected_rerank_name
|
||||||
|
|
||||||
|
return ""
|
|
@ -154,9 +154,9 @@ class Render:
|
||||||
if doc.metadata.get("llm_trulens_score") is not None
|
if doc.metadata.get("llm_trulens_score") is not None
|
||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
cohere_reranking_score = (
|
reranking_score = (
|
||||||
round(doc.metadata["cohere_reranking_score"], 2)
|
round(doc.metadata["reranking_score"], 2)
|
||||||
if doc.metadata.get("cohere_reranking_score") is not None
|
if doc.metadata.get("reranking_score") is not None
|
||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
item_type_prefix = doc.metadata.get("type", "")
|
item_type_prefix = doc.metadata.get("type", "")
|
||||||
|
@ -166,8 +166,8 @@ class Render:
|
||||||
|
|
||||||
if llm_reranking_score > 0:
|
if llm_reranking_score > 0:
|
||||||
relevant_score = llm_reranking_score
|
relevant_score = llm_reranking_score
|
||||||
elif cohere_reranking_score > 0:
|
elif reranking_score > 0:
|
||||||
relevant_score = cohere_reranking_score
|
relevant_score = reranking_score
|
||||||
else:
|
else:
|
||||||
relevant_score = 0.0
|
relevant_score = 0.0
|
||||||
|
|
||||||
|
@ -179,7 +179,7 @@ class Render:
|
||||||
"<b>  LLM relevant score:</b>"
|
"<b>  LLM relevant score:</b>"
|
||||||
f" {llm_reranking_score}<br>"
|
f" {llm_reranking_score}<br>"
|
||||||
"<b>  Reranking score:</b>"
|
"<b>  Reranking score:</b>"
|
||||||
f" {cohere_reranking_score}<br>",
|
f" {reranking_score}<br>",
|
||||||
)
|
)
|
||||||
|
|
||||||
text = doc.text if not override_text else override_text
|
text = doc.text if not override_text else override_text
|
||||||
|
|
Loading…
Reference in New Issue
Block a user