feat: add first setup screen for LLM & Embedding models (#314) (bump:minor)
* fix: utf-8 txt reader * fix: revise vectorstore import and make it optional * feat: add cohere chat model with tool call support * fix: simplify citation pipeline * fix: improve citation logic * fix: improve decompose func call * fix: revise question rewrite prompt * fix: revise chat box default placeholder * fix: add key from ktem to cohere rerank * fix: conv name suggestion * fix: ignore default key cohere rerank * fix: improve test connection UI * fix: reorder requirements * feat: add first setup screen * fix: update requirements * fix: vectorstore tests * fix: update cohere version * fix: relax langchain core version * fix: add demo mode * fix: update flowsettings * fix: typo * fix: fix bool env passing
This commit is contained in:
parent
0bdb9a32f2
commit
88d577b0cc
|
@ -24,9 +24,13 @@ if not KH_APP_VERSION:
|
|||
except Exception:
|
||||
KH_APP_VERSION = "local"
|
||||
|
||||
KH_ENABLE_FIRST_SETUP = True
|
||||
KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool)
|
||||
|
||||
# App can be ran from anywhere and it's not trivial to decide where to store app data.
|
||||
# So let's use the same directory as the flowsetting.py file.
|
||||
KH_APP_DATA_DIR = this_dir / "ktem_app_data"
|
||||
KH_APP_DATA_EXISTS = KH_APP_DATA_DIR.exists()
|
||||
KH_APP_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# User data directory
|
||||
|
@ -59,7 +63,9 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface")
|
|||
KH_DOC_DIR = this_dir / "docs"
|
||||
|
||||
KH_MODE = "dev"
|
||||
KH_FEATURE_USER_MANAGEMENT = True
|
||||
KH_FEATURE_USER_MANAGEMENT = config(
|
||||
"KH_FEATURE_USER_MANAGEMENT", default=True, cast=bool
|
||||
)
|
||||
KH_USER_CAN_SEE_PUBLIC = None
|
||||
KH_FEATURE_USER_MANAGEMENT_ADMIN = str(
|
||||
config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin")
|
||||
|
@ -202,6 +208,14 @@ KH_LLMS["groq"] = {
|
|||
},
|
||||
"default": False,
|
||||
}
|
||||
KH_LLMS["cohere"] = {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.llms.chats.LCCohereChat",
|
||||
"model_name": "command-r-plus-08-2024",
|
||||
"api_key": "your-key",
|
||||
},
|
||||
"default": False,
|
||||
}
|
||||
|
||||
# additional embeddings configurations
|
||||
KH_EMBEDDINGS["cohere"] = {
|
||||
|
|
|
@ -183,7 +183,7 @@ class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
|||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_community.embeddings import CohereEmbeddings
|
||||
from langchain_cohere import CohereEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import CohereEmbeddings
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterator, List
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -7,53 +7,14 @@ from kotaemon.base.schema import HumanMessage, SystemMessage
|
|||
from kotaemon.llms import BaseLLM
|
||||
|
||||
|
||||
class FactWithEvidence(BaseModel):
|
||||
"""Class representing a single statement.
|
||||
class CiteEvidence(BaseModel):
|
||||
"""List of evidences (maximum 5) to support the answer."""
|
||||
|
||||
Each fact has a body and a list of sources.
|
||||
If there are multiple facts make sure to break them apart
|
||||
such that each one only uses a set of sources that are relevant to it.
|
||||
"""
|
||||
|
||||
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||
substring_quote: List[str] = Field(
|
||||
evidences: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Each source should be a direct quote from the context, "
|
||||
"as a substring of the original content"
|
||||
),
|
||||
)
|
||||
|
||||
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
|
||||
import regex
|
||||
|
||||
minor = quote
|
||||
major = context
|
||||
|
||||
errs_ = 0
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
while s is None and errs_ <= errs:
|
||||
errs_ += 1
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
|
||||
if s is not None:
|
||||
yield from s.spans()
|
||||
|
||||
def get_spans(self, context: str) -> Iterator[str]:
|
||||
for quote in self.substring_quote:
|
||||
yield from self._get_span(quote, context)
|
||||
|
||||
|
||||
class QuestionAnswer(BaseModel):
|
||||
"""A question and its answer as a list of facts each one should have a source.
|
||||
each sentence contains a body and a list of sources."""
|
||||
|
||||
question: str = Field(..., description="Question that was asked")
|
||||
answer: List[FactWithEvidence] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Body of the answer, each fact should be "
|
||||
"its separate object with a body and a list of sources"
|
||||
"as a substring of the original content (max 15 words)."
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -68,7 +29,7 @@ class CitationPipeline(BaseComponent):
|
|||
return self.invoke(context, question)
|
||||
|
||||
def prepare_llm(self, context: str, question: str):
|
||||
schema = QuestionAnswer.schema()
|
||||
schema = CiteEvidence.schema()
|
||||
function = {
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
|
@ -76,7 +37,8 @@ class CitationPipeline(BaseComponent):
|
|||
}
|
||||
llm_kwargs = {
|
||||
"tools": [{"type": "function", "function": function}],
|
||||
"tool_choice": "auto",
|
||||
"tool_choice": "required",
|
||||
"tools_pydantic": [CiteEvidence],
|
||||
}
|
||||
messages = [
|
||||
SystemMessage(
|
||||
|
@ -85,7 +47,12 @@ class CitationPipeline(BaseComponent):
|
|||
"questions with correct and exact citations."
|
||||
)
|
||||
),
|
||||
HumanMessage(content="Answer question using the following context"),
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Answer question using the following context. "
|
||||
"Use the provided function CiteEvidence() to cite your sources."
|
||||
)
|
||||
),
|
||||
HumanMessage(content=context),
|
||||
HumanMessage(content=f"Question: {question}"),
|
||||
HumanMessage(
|
||||
|
@ -103,14 +70,24 @@ class CitationPipeline(BaseComponent):
|
|||
print("CitationPipeline: invoking LLM")
|
||||
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
||||
print("CitationPipeline: finish invoking LLM")
|
||||
if not llm_output.messages or not llm_output.additional_kwargs.get(
|
||||
"tool_calls"
|
||||
):
|
||||
if not llm_output.additional_kwargs.get("tool_calls"):
|
||||
return None
|
||||
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
|
||||
first_func = llm_output.additional_kwargs["tool_calls"][0]
|
||||
|
||||
if "function" in first_func:
|
||||
# openai and cohere format
|
||||
function_output = first_func["function"]["arguments"]
|
||||
else:
|
||||
# anthropic format
|
||||
function_output = first_func["args"]
|
||||
|
||||
print("CitationPipeline:", function_output)
|
||||
|
||||
if isinstance(function_output, str):
|
||||
output = CiteEvidence.parse_raw(function_output)
|
||||
else:
|
||||
output = CiteEvidence.parse_obj(function_output)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
@ -118,18 +95,4 @@ class CitationPipeline(BaseComponent):
|
|||
return output
|
||||
|
||||
async def ainvoke(self, context: str, question: str):
|
||||
messages, llm_kwargs = self.prepare_llm(context, question)
|
||||
|
||||
try:
|
||||
print("CitationPipeline: async invoking LLM")
|
||||
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
|
||||
print("CitationPipeline: finish async invoking LLM")
|
||||
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
return output
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -10,6 +10,7 @@ from .base import BaseReranking
|
|||
class CohereReranking(BaseReranking):
|
||||
model_name: str = "rerank-multilingual-v2.0"
|
||||
cohere_api_key: str = config("COHERE_API_KEY", "")
|
||||
use_key_from_ktem: bool = False
|
||||
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Use Cohere Reranker model to re-order documents
|
||||
|
@ -18,9 +19,25 @@ class CohereReranking(BaseReranking):
|
|||
import cohere
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||
"Please install Cohere `pip install cohere` to use Cohere Reranking"
|
||||
)
|
||||
|
||||
# try to get COHERE_API_KEY from embeddings
|
||||
if not self.cohere_api_key and self.use_key_from_ktem:
|
||||
try:
|
||||
from ktem.embeddings.manager import (
|
||||
embedding_models_manager as embeddings,
|
||||
)
|
||||
|
||||
cohere_model = embeddings.get("cohere")
|
||||
ktem_cohere_api_key = cohere_model._kwargs.get( # type: ignore
|
||||
"cohere_api_key"
|
||||
)
|
||||
if ktem_cohere_api_key != "your-key":
|
||||
self.cohere_api_key = ktem_cohere_api_key
|
||||
except Exception as e:
|
||||
print("Cannot get Cohere API key from `ktem`", e)
|
||||
|
||||
if not self.cohere_api_key:
|
||||
print("Cohere API key not found. Skipping reranking.")
|
||||
return documents
|
||||
|
@ -35,7 +52,7 @@ class CohereReranking(BaseReranking):
|
|||
response = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs
|
||||
)
|
||||
print("Cohere score", [r.relevance_score for r in response.results])
|
||||
# print("Cohere score", [r.relevance_score for r in response.results])
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["cohere_reranking_score"] = r.relevance_score
|
||||
|
|
|
@ -10,6 +10,7 @@ from .chats import (
|
|||
LCAnthropicChat,
|
||||
LCAzureChatOpenAI,
|
||||
LCChatOpenAI,
|
||||
LCCohereChat,
|
||||
LCGeminiChat,
|
||||
LlamaCppChat,
|
||||
)
|
||||
|
@ -31,6 +32,7 @@ __all__ = [
|
|||
"ChatOpenAI",
|
||||
"LCAnthropicChat",
|
||||
"LCGeminiChat",
|
||||
"LCCohereChat",
|
||||
"LCAzureChatOpenAI",
|
||||
"LCChatOpenAI",
|
||||
"LlamaCppChat",
|
||||
|
|
|
@ -5,6 +5,7 @@ from .langchain_based import (
|
|||
LCAzureChatOpenAI,
|
||||
LCChatMixin,
|
||||
LCChatOpenAI,
|
||||
LCCohereChat,
|
||||
LCGeminiChat,
|
||||
)
|
||||
from .llamacpp import LlamaCppChat
|
||||
|
@ -18,6 +19,7 @@ __all__ = [
|
|||
"ChatOpenAI",
|
||||
"LCAnthropicChat",
|
||||
"LCGeminiChat",
|
||||
"LCCohereChat",
|
||||
"LCChatOpenAI",
|
||||
"LCAzureChatOpenAI",
|
||||
"LCChatMixin",
|
||||
|
|
|
@ -18,6 +18,9 @@ class LCChatMixin:
|
|||
"Please return the relevant Langchain class in in _get_lc_class"
|
||||
)
|
||||
|
||||
def _get_tool_call_kwargs(self):
|
||||
return {}
|
||||
|
||||
def __init__(self, stream: bool = False, **params):
|
||||
self._lc_class = self._get_lc_class()
|
||||
self._obj = self._lc_class(**params)
|
||||
|
@ -56,9 +59,7 @@ class LCChatMixin:
|
|||
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
|
||||
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
|
||||
)
|
||||
pass
|
||||
|
||||
return LLMInterface(
|
||||
text=all_text[0] if len(all_text) > 0 else "",
|
||||
|
@ -83,8 +84,30 @@ class LCChatMixin:
|
|||
LLMInterface: generated response
|
||||
"""
|
||||
input_ = self.prepare_message(messages)
|
||||
|
||||
if "tools_pydantic" in kwargs:
|
||||
tools = kwargs.pop(
|
||||
"tools_pydantic",
|
||||
)
|
||||
lc_tool_call = self._obj.bind_tools(tools)
|
||||
pred = lc_tool_call.invoke(
|
||||
input_,
|
||||
**self._get_tool_call_kwargs(),
|
||||
)
|
||||
if pred.tool_calls:
|
||||
tool_calls = pred.tool_calls
|
||||
else:
|
||||
tool_calls = pred.additional_kwargs.get("tool_calls", [])
|
||||
|
||||
output = LLMInterface(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls},
|
||||
)
|
||||
else:
|
||||
pred = self._obj.generate(messages=[input_], **kwargs)
|
||||
return self.prepare_response(pred)
|
||||
output = self.prepare_response(pred)
|
||||
|
||||
return output
|
||||
|
||||
async def ainvoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
|
@ -235,6 +258,9 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
|
|||
required=True,
|
||||
)
|
||||
|
||||
def _get_tool_call_kwargs(self):
|
||||
return {"tool_choice": {"type": "any"}}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
|
@ -291,3 +317,35 @@ class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
|
|||
raise ImportError("Please install langchain-google-genai")
|
||||
|
||||
return ChatGoogleGenerativeAI
|
||||
|
||||
|
||||
class LCCohereChat(LCChatMixin, ChatLLM): # type: ignore
|
||||
api_key: str = Param(
|
||||
help="API key (https://dashboard.cohere.com/api-keys)", required=True
|
||||
)
|
||||
model_name: str = Param(
|
||||
help=("Model name to use (https://dashboard.cohere.com/playground/chat)"),
|
||||
required=True,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model_name: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
cohere_api_key=api_key,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_cohere import ChatCohere
|
||||
except ImportError:
|
||||
raise ImportError("Please install langchain-cohere")
|
||||
|
||||
return ChatCohere
|
||||
|
|
|
@ -292,6 +292,9 @@ class ChatOpenAI(BaseChatOpenAI):
|
|||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
if "tools_pydantic" in kwargs:
|
||||
kwargs.pop("tools_pydantic")
|
||||
|
||||
params_ = {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
|
@ -360,6 +363,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
if "tools_pydantic" in kwargs:
|
||||
kwargs.pop("tools_pydantic")
|
||||
|
||||
params_ = {
|
||||
"model": self.azure_deployment,
|
||||
"temperature": self.temperature,
|
||||
|
|
|
@ -15,7 +15,7 @@ class TxtReader(BaseReader):
|
|||
def load_data(
|
||||
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
||||
) -> list[Document]:
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
metadata = extra_info or {}
|
||||
|
|
|
@ -73,17 +73,25 @@ class BaseVectorStore(ABC):
|
|||
|
||||
|
||||
class LlamaIndexVectorStore(BaseVectorStore):
|
||||
_li_class: type[LIVectorStore | BasePydanticVectorStore]
|
||||
"""Mixin for LlamaIndex based vectorstores"""
|
||||
|
||||
_li_class: type[LIVectorStore | BasePydanticVectorStore] | None
|
||||
|
||||
def _get_li_class(self):
|
||||
raise NotImplementedError(
|
||||
"Please return the relevant LlamaIndex class in in _get_li_class"
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if self._li_class is None:
|
||||
raise AttributeError(
|
||||
"Require `_li_class` to set a VectorStore class from LlamarIndex"
|
||||
)
|
||||
# get li_class from the method if not set
|
||||
if not self._li_class:
|
||||
LIClass = self._get_li_class()
|
||||
else:
|
||||
LIClass = self._li_class
|
||||
|
||||
from dataclasses import fields
|
||||
|
||||
self._client = self._li_class(*args, **kwargs)
|
||||
self._client = LIClass(*args, **kwargs)
|
||||
|
||||
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
|
||||
for key in ["query_embedding", "similarity_top_k", "node_ids"]:
|
||||
|
@ -97,6 +105,9 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
|||
return setattr(self._client, name, value)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name == "_li_class":
|
||||
return super().__getattribute__(name)
|
||||
|
||||
return getattr(self._client, name)
|
||||
|
||||
def add(
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import os
|
||||
from typing import Any, Optional, Type, cast
|
||||
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore as LIMilvusVectorStore
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from kotaemon.base import DocumentWithEmbedding
|
||||
|
||||
|
@ -9,7 +7,20 @@ from .base import LlamaIndexVectorStore
|
|||
|
||||
|
||||
class MilvusVectorStore(LlamaIndexVectorStore):
|
||||
_li_class: Type[LIMilvusVectorStore] = LIMilvusVectorStore
|
||||
_li_class = None
|
||||
|
||||
def _get_li_class(self):
|
||||
try:
|
||||
from llama_index.vector_stores.milvus import (
|
||||
MilvusVectorStore as LIMilvusVectorStore,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install missing package: "
|
||||
"'pip install llama-index-vector-stores-milvus'"
|
||||
)
|
||||
|
||||
return LIMilvusVectorStore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -46,6 +57,10 @@ class MilvusVectorStore(LlamaIndexVectorStore):
|
|||
dim=dim,
|
||||
**self._kwargs,
|
||||
)
|
||||
from llama_index.vector_stores.milvus import (
|
||||
MilvusVectorStore as LIMilvusVectorStore,
|
||||
)
|
||||
|
||||
self._client = cast(LIMilvusVectorStore, self._client)
|
||||
self._inited = True
|
||||
|
||||
|
|
|
@ -1,12 +1,23 @@
|
|||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore as LIQdrantVectorStore
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from .base import LlamaIndexVectorStore
|
||||
|
||||
|
||||
class QdrantVectorStore(LlamaIndexVectorStore):
|
||||
_li_class: Type[LIQdrantVectorStore] = LIQdrantVectorStore
|
||||
_li_class = None
|
||||
|
||||
def _get_li_class(self):
|
||||
try:
|
||||
from llama_index.vector_stores.qdrant import (
|
||||
QdrantVectorStore as LIQdrantVectorStore,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install missing package: "
|
||||
"'pip install llama-index-vector-stores-qdrant'"
|
||||
)
|
||||
|
||||
return LIQdrantVectorStore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -29,6 +40,10 @@ class QdrantVectorStore(LlamaIndexVectorStore):
|
|||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
from llama_index.vector_stores.qdrant import (
|
||||
QdrantVectorStore as LIQdrantVectorStore,
|
||||
)
|
||||
|
||||
self._client = cast(LIQdrantVectorStore, self._client)
|
||||
|
||||
def delete(self, ids: List[str], **kwargs):
|
||||
|
|
|
@ -30,16 +30,15 @@ dependencies = [
|
|||
"fastapi<=0.112.1",
|
||||
"gradio>=4.31.0,<4.40",
|
||||
"html2text==2024.2.26",
|
||||
"langchain>=0.1.16,<0.2.0",
|
||||
"langchain-anthropic",
|
||||
"langchain-community>=0.0.34,<0.1.0",
|
||||
"langchain>=0.1.16,<0.2.16",
|
||||
"langchain-community>=0.0.34,<=0.2.11",
|
||||
"langchain-openai>=0.1.4,<0.2.0",
|
||||
"langchain-anthropic",
|
||||
"langchain-cohere>=0.2.4,<0.3.0",
|
||||
"llama-hub>=0.0.79,<0.1.0",
|
||||
"llama-index>=0.10.40,<0.11.0",
|
||||
"llama-index-vector-stores-chroma>=0.1.9",
|
||||
"llama-index-vector-stores-lancedb",
|
||||
"llama-index-vector-stores-milvus",
|
||||
"llama-index-vector-stores-qdrant",
|
||||
"openai>=1.23.6,<2",
|
||||
"openpyxl>=3.1.2,<3.2",
|
||||
"opentelemetry-exporter-otlp-proto-grpc>=1.25.0", # https://github.com/chroma-core/chroma/issues/2571
|
||||
|
@ -75,6 +74,9 @@ adv = [
|
|||
"llama-cpp-python<0.2.8",
|
||||
"sentence-transformers",
|
||||
"wikipedia>=1.4.0,<1.5",
|
||||
"llama-index>=0.10.40,<0.11.0",
|
||||
"llama-index-vector-stores-milvus",
|
||||
"llama-index-vector-stores-qdrant",
|
||||
]
|
||||
dev = [
|
||||
"black",
|
||||
|
|
|
@ -135,7 +135,7 @@ def test_lchuggingface_embeddings(
|
|||
|
||||
@skip_when_cohere_not_installed
|
||||
@patch(
|
||||
"langchain.embeddings.cohere.CohereEmbeddings.embed_documents",
|
||||
"langchain_cohere.CohereEmbeddings.embed_documents",
|
||||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
||||
)
|
||||
def test_lccohere_embeddings(langchain_cohere_embedding_call):
|
||||
|
|
|
@ -354,7 +354,7 @@ class EmbeddingManagement(BasePage):
|
|||
_ = emb("Hi")
|
||||
|
||||
log_content += (
|
||||
"<mark style='background: yellow; color: red'>- Connection success. "
|
||||
"<mark style='background: green; color: white'>- Connection success. "
|
||||
"</mark><br>"
|
||||
)
|
||||
yield log_content
|
||||
|
|
|
@ -285,7 +285,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
|||
],
|
||||
retrieval_mode=user_settings["retrieval_mode"],
|
||||
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
|
||||
rerankers=[CohereReranking()],
|
||||
rerankers=[CohereReranking(use_key_from_ktem=True)],
|
||||
)
|
||||
if not user_settings["use_reranking"]:
|
||||
retriever.rerankers = [] # type: ignore
|
||||
|
|
|
@ -828,7 +828,6 @@ class FileIndexPage(BasePage):
|
|||
]
|
||||
)
|
||||
|
||||
print(f"{len(results)=}, {len(file_list)=}")
|
||||
return results, file_list
|
||||
|
||||
def interact_file_list(self, list_files, ev: gr.SelectData):
|
||||
|
|
|
@ -58,6 +58,7 @@ class LLMManager:
|
|||
AzureChatOpenAI,
|
||||
ChatOpenAI,
|
||||
LCAnthropicChat,
|
||||
LCCohereChat,
|
||||
LCGeminiChat,
|
||||
LlamaCppChat,
|
||||
)
|
||||
|
@ -67,6 +68,7 @@ class LLMManager:
|
|||
AzureChatOpenAI,
|
||||
LCAnthropicChat,
|
||||
LCGeminiChat,
|
||||
LCCohereChat,
|
||||
LlamaCppChat,
|
||||
]
|
||||
|
||||
|
|
|
@ -353,7 +353,7 @@ class LLMManagement(BasePage):
|
|||
respond = llm("Hi")
|
||||
|
||||
log_content += (
|
||||
f"<mark style='background: yellow; color: red'>- Connection success. "
|
||||
f"<mark style='background: green; color: white'>- Connection success. "
|
||||
f"Got response:\n {respond}</mark><br>"
|
||||
)
|
||||
yield log_content
|
||||
|
|
|
@ -1,9 +1,27 @@
|
|||
import gradio as gr
|
||||
from decouple import config
|
||||
from ktem.app import BaseApp
|
||||
from ktem.pages.chat import ChatPage
|
||||
from ktem.pages.help import HelpPage
|
||||
from ktem.pages.resources import ResourcesTab
|
||||
from ktem.pages.settings import SettingsPage
|
||||
from ktem.pages.setup import SetupPage
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
||||
KH_ENABLE_FIRST_SETUP = getattr(flowsettings, "KH_ENABLE_FIRST_SETUP", False)
|
||||
KH_APP_DATA_EXISTS = getattr(flowsettings, "KH_APP_DATA_EXISTS", True)
|
||||
|
||||
# override first setup setting
|
||||
if config("KH_FIRST_SETUP", default=False, cast=bool):
|
||||
KH_APP_DATA_EXISTS = False
|
||||
|
||||
|
||||
def toggle_first_setup_visibility():
|
||||
global KH_APP_DATA_EXISTS
|
||||
is_first_setup = KH_DEMO_MODE or not KH_APP_DATA_EXISTS
|
||||
KH_APP_DATA_EXISTS = True
|
||||
return gr.update(visible=is_first_setup), gr.update(visible=not is_first_setup)
|
||||
|
||||
|
||||
class App(BaseApp):
|
||||
|
@ -99,13 +117,17 @@ class App(BaseApp):
|
|||
) as self._tabs["help-tab"]:
|
||||
self.help_page = HelpPage(self)
|
||||
|
||||
if KH_ENABLE_FIRST_SETUP:
|
||||
with gr.Column(visible=False) as self.setup_page_wrapper:
|
||||
self.setup_page = SetupPage(self)
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
if self.f_user_management:
|
||||
from ktem.db.engine import engine
|
||||
from ktem.db.models import User
|
||||
from sqlmodel import Session, select
|
||||
|
||||
def signed_in_out(user_id):
|
||||
def toggle_login_visibility(user_id):
|
||||
if not user_id:
|
||||
return list(
|
||||
(
|
||||
|
@ -146,7 +168,7 @@ class App(BaseApp):
|
|||
self.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": signed_in_out,
|
||||
"fn": toggle_login_visibility,
|
||||
"inputs": [self.user_id],
|
||||
"outputs": list(self._tabs.values()) + [self.tabs],
|
||||
"show_progress": "hidden",
|
||||
|
@ -156,9 +178,30 @@ class App(BaseApp):
|
|||
self.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
"fn": signed_in_out,
|
||||
"fn": toggle_login_visibility,
|
||||
"inputs": [self.user_id],
|
||||
"outputs": list(self._tabs.values()) + [self.tabs],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
if KH_ENABLE_FIRST_SETUP:
|
||||
self.subscribe_event(
|
||||
name="onFirstSetupComplete",
|
||||
definition={
|
||||
"fn": toggle_first_setup_visibility,
|
||||
"inputs": [],
|
||||
"outputs": [self.setup_page_wrapper, self.tabs],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
def _on_app_created(self):
|
||||
"""Called when the app is created"""
|
||||
|
||||
if KH_ENABLE_FIRST_SETUP:
|
||||
self.app.load(
|
||||
toggle_first_setup_visibility,
|
||||
inputs=[],
|
||||
outputs=[self.setup_page_wrapper, self.tabs],
|
||||
)
|
||||
|
|
|
@ -883,7 +883,8 @@ class ChatPage(BasePage):
|
|||
|
||||
# check if this is a newly created conversation
|
||||
if len(chat_history) == 1:
|
||||
suggested_name = suggest_pipeline(chat_history).text[:40]
|
||||
suggested_name = suggest_pipeline(chat_history).text
|
||||
suggested_name = suggested_name.replace('"', "").replace("'", "")[:40]
|
||||
new_name = gr.update(value=suggested_name)
|
||||
renamed = True
|
||||
|
||||
|
|
|
@ -11,8 +11,8 @@ class ChatPanel(BasePage):
|
|||
self.chatbot = gr.Chatbot(
|
||||
label=self._app.app_name,
|
||||
placeholder=(
|
||||
"This is the beginning of a new conversation.\nMake sure to have added"
|
||||
" a LLM by following the instructions in the Help tab."
|
||||
"This is the beginning of a new conversation.\nIf you are new, "
|
||||
"visit the Help tab for quick instructions."
|
||||
),
|
||||
show_label=False,
|
||||
elem_id="main-chat-bot",
|
||||
|
|
347
libs/ktem/ktem/pages/setup.py
Normal file
347
libs/ktem/ktem/pages/setup.py
Normal file
|
@ -0,0 +1,347 @@
|
|||
import json
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
from ktem.app import BasePage
|
||||
from ktem.embeddings.manager import embedding_models_manager as embeddings
|
||||
from ktem.llms.manager import llms
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434/api"
|
||||
|
||||
|
||||
DEMO_MESSAGE = (
|
||||
"This is a public space. Please use the "
|
||||
'"Duplicate Space" function on the top right '
|
||||
"corner to setup your own space."
|
||||
)
|
||||
|
||||
|
||||
def pull_model(name: str, stream: bool = True):
|
||||
payload = {"name": name}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
response = requests.post(
|
||||
DEFAULT_OLLAMA_URL + "/pull", json=payload, headers=headers, stream=stream
|
||||
)
|
||||
|
||||
# Check if the request was successful
|
||||
response.raise_for_status()
|
||||
|
||||
if stream:
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
data = json.loads(line.decode("utf-8"))
|
||||
yield data
|
||||
if data.get("status") == "success":
|
||||
break
|
||||
else:
|
||||
data = response.json()
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class SetupPage(BasePage):
|
||||
|
||||
public_events = ["onFirstSetupComplete"]
|
||||
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
gr.Markdown(f"# Welcome to {self._app.app_name} first setup!")
|
||||
self.radio_model = gr.Radio(
|
||||
[
|
||||
("Cohere API (*free registration* available) - recommended", "cohere"),
|
||||
("OpenAI API (for more advance models)", "openai"),
|
||||
("Local LLM (for completely *private RAG*)", "ollama"),
|
||||
],
|
||||
label="Select your model provider",
|
||||
value="cohere",
|
||||
info=(
|
||||
"Note: You can change this later. "
|
||||
"If you are not sure, go with the first option "
|
||||
"which fits most normal users."
|
||||
),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Column(visible=False) as self.openai_option:
|
||||
gr.Markdown(
|
||||
(
|
||||
"#### OpenAI API Key\n\n"
|
||||
"(create at https://platform.openai.com/api-keys)"
|
||||
)
|
||||
)
|
||||
self.openai_api_key = gr.Textbox(
|
||||
show_label=False, placeholder="OpenAI API Key"
|
||||
)
|
||||
|
||||
with gr.Column(visible=True) as self.cohere_option:
|
||||
gr.Markdown(
|
||||
(
|
||||
"#### Cohere API Key\n\n"
|
||||
"(register your free API key "
|
||||
"at https://dashboard.cohere.com/api-keys)"
|
||||
)
|
||||
)
|
||||
self.cohere_api_key = gr.Textbox(
|
||||
show_label=False, placeholder="Cohere API Key"
|
||||
)
|
||||
|
||||
with gr.Column(visible=False) as self.ollama_option:
|
||||
gr.Markdown(
|
||||
(
|
||||
"#### Setup Ollama\n\n"
|
||||
"Download and install Ollama from "
|
||||
"https://ollama.com/"
|
||||
)
|
||||
)
|
||||
|
||||
self.setup_log = gr.HTML(
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.btn_finish = gr.Button("Proceed", variant="primary")
|
||||
self.btn_skip = gr.Button(
|
||||
"I am an advance user. Skip this.", variant="stop"
|
||||
)
|
||||
|
||||
def on_register_events(self):
|
||||
onFirstSetupComplete = gr.on(
|
||||
triggers=[
|
||||
self.btn_finish.click,
|
||||
self.cohere_api_key.submit,
|
||||
self.openai_api_key.submit,
|
||||
],
|
||||
fn=self.update_model,
|
||||
inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model],
|
||||
outputs=[self.setup_log],
|
||||
show_progress="hidden",
|
||||
)
|
||||
if not KH_DEMO_MODE:
|
||||
onSkipSetup = gr.on(
|
||||
triggers=[self.btn_skip.click],
|
||||
fn=lambda: None,
|
||||
inputs=[],
|
||||
show_progress="hidden",
|
||||
outputs=[self.radio_model],
|
||||
)
|
||||
|
||||
for event in self._app.get_event("onFirstSetupComplete"):
|
||||
onSkipSetup = onSkipSetup.success(**event)
|
||||
|
||||
onFirstSetupComplete = onFirstSetupComplete.success(
|
||||
fn=self.update_default_settings,
|
||||
inputs=[self.radio_model, self._app.settings_state],
|
||||
outputs=self._app.settings_state,
|
||||
)
|
||||
for event in self._app.get_event("onFirstSetupComplete"):
|
||||
onFirstSetupComplete = onFirstSetupComplete.success(**event)
|
||||
|
||||
self.radio_model.change(
|
||||
fn=self.switch_options_view,
|
||||
inputs=[self.radio_model],
|
||||
show_progress="hidden",
|
||||
outputs=[self.cohere_option, self.openai_option, self.ollama_option],
|
||||
)
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
cohere_api_key,
|
||||
openai_api_key,
|
||||
radio_model_value,
|
||||
):
|
||||
# skip if KH_DEMO_MODE
|
||||
if KH_DEMO_MODE:
|
||||
raise gr.Error(DEMO_MESSAGE)
|
||||
|
||||
log_content = ""
|
||||
if not radio_model_value:
|
||||
gr.Info("Skip setup models.")
|
||||
yield gr.value(visible=False)
|
||||
return
|
||||
|
||||
if radio_model_value == "cohere":
|
||||
if cohere_api_key:
|
||||
llms.update(
|
||||
name="cohere",
|
||||
spec={
|
||||
"__type__": "kotaemon.llms.chats.LCCohereChat",
|
||||
"model_name": "command-r-plus-08-2024",
|
||||
"api_key": cohere_api_key,
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
embeddings.update(
|
||||
name="cohere",
|
||||
spec={
|
||||
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
|
||||
"model": "embed-multilingual-v2.0",
|
||||
"cohere_api_key": cohere_api_key,
|
||||
"user_agent": "default",
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
elif radio_model_value == "openai":
|
||||
if openai_api_key:
|
||||
llms.update(
|
||||
name="openai",
|
||||
spec={
|
||||
"__type__": "kotaemon.llms.ChatOpenAI",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"model": "gpt-4o",
|
||||
"api_key": openai_api_key,
|
||||
"timeout": 20,
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
embeddings.update(
|
||||
name="openai",
|
||||
spec={
|
||||
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"model": "text-embedding-3-large",
|
||||
"api_key": openai_api_key,
|
||||
"timeout": 10,
|
||||
"context_length": 8191,
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
elif radio_model_value == "ollama":
|
||||
llms.update(
|
||||
name="ollama",
|
||||
spec={
|
||||
"__type__": "kotaemon.llms.ChatOpenAI",
|
||||
"base_url": "http://localhost:11434/v1/",
|
||||
"model": "llama3.1:8b",
|
||||
"api_key": "ollama",
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
embeddings.update(
|
||||
name="ollama",
|
||||
spec={
|
||||
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||
"base_url": "http://localhost:11434/v1/",
|
||||
"model": "nomic-embed-text",
|
||||
"api_key": "ollama",
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
|
||||
# download required models through ollama
|
||||
llm_model_name = llms.get("ollama").model # type: ignore
|
||||
emb_model_name = embeddings.get("ollama").model # type: ignore
|
||||
|
||||
try:
|
||||
for model_name in [emb_model_name, llm_model_name]:
|
||||
log_content += f"- Downloading model `{model_name}` from Ollama<br>"
|
||||
yield log_content
|
||||
|
||||
pre_download_log = log_content
|
||||
|
||||
for response in pull_model(model_name):
|
||||
complete = response.get("completed", 0)
|
||||
total = response.get("total", 0)
|
||||
if complete > 0 and total > 0:
|
||||
ratio = int(complete / total * 100)
|
||||
log_content = (
|
||||
pre_download_log
|
||||
+ f"- {response.get('status')}: {ratio}%<br>"
|
||||
)
|
||||
else:
|
||||
if "pulling" not in response.get("status", ""):
|
||||
log_content += f"- {response.get('status')}<br>"
|
||||
|
||||
yield log_content
|
||||
except Exception as e:
|
||||
log_content += (
|
||||
"Make sure you have download and installed Ollama correctly."
|
||||
f"Got error: {str(e)}"
|
||||
)
|
||||
yield log_content
|
||||
raise gr.Error("Failed to download model from Ollama.")
|
||||
|
||||
# test models connection
|
||||
llm_output = emb_output = None
|
||||
|
||||
# LLM model
|
||||
log_content += f"- Testing LLM model: {radio_model_value}<br>"
|
||||
yield log_content
|
||||
|
||||
llm = llms.get(radio_model_value) # type: ignore
|
||||
log_content += "- Sending a message `Hi`<br>"
|
||||
yield log_content
|
||||
try:
|
||||
llm_output = llm("Hi")
|
||||
except Exception as e:
|
||||
log_content += (
|
||||
f"<mark style='color: yellow; background: red'>- Connection failed. "
|
||||
f"Got error:\n {str(e)}</mark>"
|
||||
)
|
||||
|
||||
if llm_output:
|
||||
log_content += (
|
||||
"<mark style='background: green; color: white'>- Connection success. "
|
||||
"</mark><br>"
|
||||
)
|
||||
yield log_content
|
||||
|
||||
if llm_output:
|
||||
# embedding model
|
||||
log_content += f"- Testing Embedding model: {radio_model_value}<br>"
|
||||
yield log_content
|
||||
|
||||
emb = embeddings.get(radio_model_value)
|
||||
assert emb, f"Embedding model {radio_model_value} not found."
|
||||
|
||||
log_content += "- Sending a message `Hi`<br>"
|
||||
yield log_content
|
||||
try:
|
||||
emb_output = emb("Hi")
|
||||
except Exception as e:
|
||||
log_content += (
|
||||
f"<mark style='color: yellow; background: red'>"
|
||||
"- Connection failed. "
|
||||
f"Got error:\n {str(e)}</mark>"
|
||||
)
|
||||
|
||||
if emb_output:
|
||||
log_content += (
|
||||
"<mark style='background: green; color: white'>"
|
||||
"- Connection success. "
|
||||
"</mark><br>"
|
||||
)
|
||||
yield log_content
|
||||
|
||||
if llm_output and emb_output:
|
||||
gr.Info("Setup models completed successfully!")
|
||||
else:
|
||||
raise gr.Error(
|
||||
"Setup models failed. Please verify your connection and API key."
|
||||
)
|
||||
|
||||
def update_default_settings(self, radio_model_value, default_settings):
|
||||
# revise default settings
|
||||
# reranking llm
|
||||
default_settings["index.options.1.reranking_llm"] = radio_model_value
|
||||
if radio_model_value == "ollama":
|
||||
default_settings["index.options.1.use_llm_reranking"] = False
|
||||
|
||||
return default_settings
|
||||
|
||||
def switch_options_view(self, radio_model_value):
|
||||
components_visible = [gr.update(visible=False) for _ in range(3)]
|
||||
|
||||
values = ["cohere", "openai", "ollama", None]
|
||||
assert radio_model_value in values, f"Invalid value {radio_model_value}"
|
||||
|
||||
if radio_model_value is not None:
|
||||
idx = values.index(radio_model_value)
|
||||
components_visible[idx] = gr.update(visible=True)
|
||||
|
||||
return components_visible
|
|
@ -52,6 +52,7 @@ class DecomposeQuestionPipeline(RewriteQuestionPipeline):
|
|||
llm_kwargs = {
|
||||
"tools": [{"type": "function", "function": function}],
|
||||
"tool_choice": "auto",
|
||||
"tools_pydantic": [SubQuery],
|
||||
}
|
||||
|
||||
messages = [
|
||||
|
|
|
@ -7,6 +7,7 @@ DEFAULT_REWRITE_PROMPT = (
|
|||
"Given the following question, rephrase and expand it "
|
||||
"to help you do better answering. Maintain all information "
|
||||
"in the original question. Keep the question as concise as possible. "
|
||||
"Only output the rephrased question without additional information. "
|
||||
"Give answer in {lang}\n"
|
||||
"Original question: {question}\n"
|
||||
"Rephrased question: "
|
||||
|
|
|
@ -39,10 +39,13 @@ EVIDENCE_MODE_TABLE = 1
|
|||
EVIDENCE_MODE_CHATBOT = 2
|
||||
EVIDENCE_MODE_FIGURE = 3
|
||||
MAX_IMAGES = 10
|
||||
CITATION_TIMEOUT = 5.0
|
||||
|
||||
|
||||
def find_text(search_span, context):
|
||||
sentence_list = search_span.split("\n")
|
||||
context = context.replace("\n", " ")
|
||||
|
||||
matches = []
|
||||
# don't search for small text
|
||||
if len(search_span) > 5:
|
||||
|
@ -50,7 +53,7 @@ def find_text(search_span, context):
|
|||
match = SequenceMatcher(
|
||||
None, sentence, context, autojunk=False
|
||||
).find_longest_match()
|
||||
if match.size > len(sentence) * 0.35:
|
||||
if match.size > max(len(sentence) * 0.35, 5):
|
||||
matches.append((match.b, match.b + match.size))
|
||||
|
||||
return matches
|
||||
|
@ -200,15 +203,6 @@ DEFAULT_QA_FIGURE_PROMPT = (
|
|||
"Answer: "
|
||||
) # noqa
|
||||
|
||||
DEFAULT_REWRITE_PROMPT = (
|
||||
"Given the following question, rephrase and expand it "
|
||||
"to help you do better answering. Maintain all information "
|
||||
"in the original question. Keep the question as concise as possible. "
|
||||
"Give answer in {lang}\n"
|
||||
"Original question: {question}\n"
|
||||
"Rephrased question: "
|
||||
) # noqa
|
||||
|
||||
CONTEXT_RELEVANT_WARNING_SCORE = 0.7
|
||||
|
||||
|
||||
|
@ -391,7 +385,8 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
qa_score = None
|
||||
|
||||
if citation_thread:
|
||||
citation_thread.join()
|
||||
citation_thread.join(timeout=CITATION_TIMEOUT)
|
||||
|
||||
answer = Document(
|
||||
text=output,
|
||||
metadata={"citation": citation, "qa_score": qa_score},
|
||||
|
@ -525,9 +520,9 @@ class FullQAPipeline(BaseReasoning):
|
|||
spans = defaultdict(list)
|
||||
has_llm_score = any("llm_trulens_score" in doc.metadata for doc in docs)
|
||||
|
||||
if answer.metadata["citation"] and answer.metadata["citation"].answer:
|
||||
for fact_with_evidence in answer.metadata["citation"].answer:
|
||||
for quote in fact_with_evidence.substring_quote:
|
||||
if answer.metadata["citation"]:
|
||||
evidences = answer.metadata["citation"].evidences
|
||||
for quote in evidences:
|
||||
matched_excerpts = []
|
||||
for doc in docs:
|
||||
matches = find_text(quote, doc.text)
|
||||
|
@ -542,7 +537,7 @@ class FullQAPipeline(BaseReasoning):
|
|||
)
|
||||
matched_excerpts.append(doc.text[start:end])
|
||||
|
||||
print("Matched citation:", quote, matched_excerpts),
|
||||
# print("Matched citation:", quote, matched_excerpts),
|
||||
|
||||
id2docs = {doc.doc_id: doc for doc in docs}
|
||||
not_detected = set(id2docs.keys()) - set(spans.keys())
|
||||
|
|
|
@ -75,7 +75,6 @@ class Render:
|
|||
if not highlight_text:
|
||||
try:
|
||||
lang = detect(text.replace("\n", " "))["lang"]
|
||||
print("lang", lang)
|
||||
if lang not in ["ja", "cn"]:
|
||||
highlight_words = [
|
||||
t[:-1] if t.endswith("-") else t for t in text.split("\n")
|
||||
|
@ -83,10 +82,13 @@ class Render:
|
|||
highlight_text = highlight_words[0]
|
||||
phrase = "true"
|
||||
else:
|
||||
highlight_text = text.replace("\n", "")
|
||||
phrase = "false"
|
||||
|
||||
print("highlight_text", highlight_text, phrase)
|
||||
highlight_text = (
|
||||
text.replace("\n", "").replace('"', "").replace("'", "")
|
||||
)
|
||||
|
||||
# print("highlight_text", highlight_text, phrase, lang)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
highlight_text = text
|
||||
|
@ -162,8 +164,15 @@ class Render:
|
|||
if item_type_prefix:
|
||||
item_type_prefix += " from "
|
||||
|
||||
if llm_reranking_score > 0:
|
||||
relevant_score = llm_reranking_score
|
||||
elif cohere_reranking_score > 0:
|
||||
relevant_score = cohere_reranking_score
|
||||
else:
|
||||
relevant_score = 0.0
|
||||
|
||||
rendered_score = Render.collapsible(
|
||||
header=f"<b> Relevance score</b>: {llm_reranking_score}",
|
||||
header=f"<b> Relevance score</b>: {relevant_score:.1f}",
|
||||
content="<b>  Vectorstore score:</b>"
|
||||
f" {vectorstore_score}"
|
||||
f"{text_search_str}"
|
||||
|
|
Loading…
Reference in New Issue
Block a user