improve llms selection of simple reasoning pipeline and fix non persistent settings bug

- improve llms selection of simple reasoning pipeline
- enable llms selection for reranking
- fix non-persistent settings bug
This commit is contained in:
ian_Cin 2024-03-28 16:39:09 +07:00 committed by GitHub
commit e8d3c70276
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 75 additions and 13 deletions

View File

@ -17,8 +17,7 @@ if machine == "x86_64":
BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}" BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}"
EXTENSION = ".exe" if os.name == "nt" else "" EXTENSION = ".exe" if os.name == "nt" else ""
BINARY_URL = ( BINARY_URL = (
"some-endpoint.com" "some-endpoint.com" f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
) )
BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}" BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}"

View File

@ -194,7 +194,6 @@ class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
def __init__( def __init__(
self, self,
azure_endpoint: str | None = None, azure_endpoint: str | None = None,

View File

@ -1,4 +1,5 @@
"""Common components, some kind of config""" """Common components, some kind of config"""
import logging import logging
from functools import cache from functools import cache
from pathlib import Path from pathlib import Path
@ -71,7 +72,7 @@ class ModelPool:
} }
def options(self) -> dict: def options(self) -> dict:
"""Present a list of models""" """Present a dict of models"""
return self._models return self._models
def get_random_name(self) -> str: def get_random_name(self) -> str:

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import logging
import shutil import shutil
import warnings import warnings
from collections import defaultdict from collections import defaultdict
@ -8,7 +9,7 @@ from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from ktem.components import embeddings, filestorage_path, llms from ktem.components import embeddings, filestorage_path
from ktem.db.models import engine from ktem.db.models import engine
from llama_index.vector_stores import ( from llama_index.vector_stores import (
FilterCondition, FilterCondition,
@ -25,10 +26,12 @@ from theflow.utils.modules import import_dotted_string
from kotaemon.base import RetrievedDocument from kotaemon.base import RetrievedDocument
from kotaemon.indices import VectorIndexing, VectorRetrieval from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests import DocumentIngestor from kotaemon.indices.ingests import DocumentIngestor
from kotaemon.indices.rankings import BaseReranking, LLMReranking from kotaemon.indices.rankings import BaseReranking
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
logger = logging.getLogger(__name__)
@lru_cache @lru_cache
def dev_settings(): def dev_settings():
@ -67,7 +70,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
vector_retrieval: VectorRetrieval = VectorRetrieval.withx( vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
embedding=embeddings.get_default(), embedding=embeddings.get_default(),
) )
reranker: BaseReranking = LLMReranking.withx(llm=llms.get_lowest_cost()) reranker: BaseReranking
get_extra_table: bool = False get_extra_table: bool = False
def run( def run(
@ -153,7 +156,23 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
@classmethod @classmethod
def get_user_settings(cls) -> dict: def get_user_settings(cls) -> dict:
from ktem.components import llms
try:
reranking_llm = llms.get_lowest_cost_name()
reranking_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)
reranking_llm = None
reranking_llm_choices = []
return { return {
"reranking_llm": {
"name": "LLM for reranking",
"value": reranking_llm,
"component": "dropdown",
"choices": reranking_llm_choices,
},
"separate_embedding": { "separate_embedding": {
"name": "Use separate embedding", "name": "Use separate embedding",
"value": False, "value": False,
@ -185,7 +204,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
}, },
"use_reranking": { "use_reranking": {
"name": "Use reranking", "name": "Use reranking",
"value": True, "value": False,
"choices": [True, False], "choices": [True, False],
"component": "checkbox", "component": "checkbox",
}, },
@ -199,7 +218,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
settings: the settings of the app settings: the settings of the app
kwargs: other arguments kwargs: other arguments
""" """
retriever = cls(get_extra_table=user_settings["prioritize_table"]) retriever = cls(
get_extra_table=user_settings["prioritize_table"],
reranker=user_settings["reranking_llm"],
)
if not user_settings["use_reranking"]: if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore retriever.reranker = None # type: ignore

View File

@ -87,6 +87,28 @@ class SettingsPage(BasePage):
self.reasoning_tab() self.reasoning_tab()
def on_subscribe_public_events(self): def on_subscribe_public_events(self):
"""
Subscribes to public events related to user management.
This function is responsible for subscribing to the "onSignIn" event, which is
triggered when a user signs in. It registers two event handlers for this event.
The first event handler, "load_setting", is responsible for loading the user's
settings when they sign in. It takes the user ID as input and returns the
settings state and a list of component outputs. The progress indicator for this
event is set to "hidden".
The second event handler, "get_name", is responsible for retrieving the
username of the current user. It takes the user ID as input and returns the
username if it exists, otherwise it returns "___". The progress indicator for
this event is also set to "hidden".
Parameters:
self (object): The instance of the class.
Returns:
None
"""
if self._app.f_user_management: if self._app.f_user_management:
self._app.subscribe_event( self._app.subscribe_event(
name="onSignIn", name="onSignIn",
@ -290,3 +312,12 @@ class SettingsPage(BasePage):
def component_names(self): def component_names(self):
"""Get the setting components""" """Get the setting components"""
return self._settings_keys return self._settings_keys
def _on_app_created(self):
if not self._app.f_user_management:
self._app.app.load(
self.load_setting,
inputs=self._user_id,
outputs=[self._settings_state] + self.components(),
show_progress="hidden",
)

View File

@ -159,6 +159,7 @@ class AnswerWithContextPipeline(BaseComponent):
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
enable_citation: bool = False
system_prompt: str = "" system_prompt: str = ""
lang: str = "English" # support English and Japanese lang: str = "English" # support English and Japanese
@ -200,7 +201,8 @@ class AnswerWithContextPipeline(BaseComponent):
lang=self.lang, lang=self.lang,
) )
if evidence: citation_task = None
if evidence and self.enable_citation:
citation_task = asyncio.create_task( citation_task = asyncio.create_task(
self.citation_pipeline.ainvoke(context=evidence, question=question) self.citation_pipeline.ainvoke(context=evidence, question=question)
) )
@ -226,7 +228,7 @@ class AnswerWithContextPipeline(BaseComponent):
# retrieve the citation # retrieve the citation
print("Waiting for citation task") print("Waiting for citation task")
if evidence: if citation_task is not None:
citation = await citation_task citation = await citation_task
else: else:
citation = None citation = None
@ -353,7 +355,15 @@ class FullQAPipeline(BaseReasoning):
_id = cls.get_info()["id"] _id = cls.get_info()["id"]
pipeline = FullQAPipeline(retrievers=retrievers) pipeline = FullQAPipeline(retrievers=retrievers)
pipeline.answering_pipeline.llm = llms.get_highest_accuracy() pipeline.answering_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.main_llm"]
]
pipeline.answering_pipeline.citation_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.citation_llm"]
]
pipeline.answering_pipeline.enable_citation = settings[
f"reasoning.options.{_id}.highlight_citation"
]
pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English" settings["reasoning.lang"], "English"
) )
@ -384,7 +394,7 @@ class FullQAPipeline(BaseReasoning):
return { return {
"highlight_citation": { "highlight_citation": {
"name": "Highlight Citation", "name": "Highlight Citation",
"value": True, "value": False,
"component": "checkbox", "component": "checkbox",
}, },
"citation_llm": { "citation_llm": {