improve llm selection for simple reasoning pipeline

This commit is contained in:
ian 2024-03-28 16:35:13 +07:00
parent b2089245f2
commit f9cc40ca25
5 changed files with 44 additions and 13 deletions

View File

@ -17,8 +17,7 @@ if machine == "x86_64":
BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}" BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}"
EXTENSION = ".exe" if os.name == "nt" else "" EXTENSION = ".exe" if os.name == "nt" else ""
BINARY_URL = ( BINARY_URL = (
"some-endpoint.com" "some-endpoint.com" f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
) )
BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}" BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}"

View File

@ -194,7 +194,6 @@ class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
def __init__( def __init__(
self, self,
azure_endpoint: str | None = None, azure_endpoint: str | None = None,

View File

@ -1,4 +1,5 @@
"""Common components, some kind of config""" """Common components, some kind of config"""
import logging import logging
from functools import cache from functools import cache
from pathlib import Path from pathlib import Path
@ -71,7 +72,7 @@ class ModelPool:
} }
def options(self) -> dict: def options(self) -> dict:
"""Present a list of models""" """Present a dict of models"""
return self._models return self._models
def get_random_name(self) -> str: def get_random_name(self) -> str:

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import logging
import shutil import shutil
import warnings import warnings
from collections import defaultdict from collections import defaultdict
@ -8,7 +9,7 @@ from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from ktem.components import embeddings, filestorage_path, llms from ktem.components import embeddings, filestorage_path
from ktem.db.models import engine from ktem.db.models import engine
from llama_index.vector_stores import ( from llama_index.vector_stores import (
FilterCondition, FilterCondition,
@ -25,10 +26,12 @@ from theflow.utils.modules import import_dotted_string
from kotaemon.base import RetrievedDocument from kotaemon.base import RetrievedDocument
from kotaemon.indices import VectorIndexing, VectorRetrieval from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests import DocumentIngestor from kotaemon.indices.ingests import DocumentIngestor
from kotaemon.indices.rankings import BaseReranking, LLMReranking from kotaemon.indices.rankings import BaseReranking
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
logger = logging.getLogger(__name__)
@lru_cache @lru_cache
def dev_settings(): def dev_settings():
@ -67,7 +70,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
vector_retrieval: VectorRetrieval = VectorRetrieval.withx( vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
embedding=embeddings.get_default(), embedding=embeddings.get_default(),
) )
reranker: BaseReranking = LLMReranking.withx(llm=llms.get_lowest_cost()) reranker: BaseReranking
get_extra_table: bool = False get_extra_table: bool = False
def run( def run(
@ -153,7 +156,23 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
@classmethod @classmethod
def get_user_settings(cls) -> dict: def get_user_settings(cls) -> dict:
from ktem.components import llms
try:
reranking_llm = llms.get_lowest_cost_name()
reranking_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)
reranking_llm = None
reranking_llm_choices = []
return { return {
"reranking_llm": {
"name": "LLM for reranking",
"value": reranking_llm,
"component": "dropdown",
"choices": reranking_llm_choices,
},
"separate_embedding": { "separate_embedding": {
"name": "Use separate embedding", "name": "Use separate embedding",
"value": False, "value": False,
@ -185,7 +204,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
}, },
"use_reranking": { "use_reranking": {
"name": "Use reranking", "name": "Use reranking",
"value": True, "value": False,
"choices": [True, False], "choices": [True, False],
"component": "checkbox", "component": "checkbox",
}, },
@ -199,7 +218,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
settings: the settings of the app settings: the settings of the app
kwargs: other arguments kwargs: other arguments
""" """
retriever = cls(get_extra_table=user_settings["prioritize_table"]) retriever = cls(
get_extra_table=user_settings["prioritize_table"],
reranker=user_settings["reranking_llm"],
)
if not user_settings["use_reranking"]: if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore retriever.reranker = None # type: ignore

View File

@ -159,6 +159,7 @@ class AnswerWithContextPipeline(BaseComponent):
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
enable_citation: bool = False
system_prompt: str = "" system_prompt: str = ""
lang: str = "English" # support English and Japanese lang: str = "English" # support English and Japanese
@ -200,7 +201,8 @@ class AnswerWithContextPipeline(BaseComponent):
lang=self.lang, lang=self.lang,
) )
if evidence: citation_task = None
if evidence and self.enable_citation:
citation_task = asyncio.create_task( citation_task = asyncio.create_task(
self.citation_pipeline.ainvoke(context=evidence, question=question) self.citation_pipeline.ainvoke(context=evidence, question=question)
) )
@ -226,7 +228,7 @@ class AnswerWithContextPipeline(BaseComponent):
# retrieve the citation # retrieve the citation
print("Waiting for citation task") print("Waiting for citation task")
if evidence: if citation_task is not None:
citation = await citation_task citation = await citation_task
else: else:
citation = None citation = None
@ -353,7 +355,15 @@ class FullQAPipeline(BaseReasoning):
_id = cls.get_info()["id"] _id = cls.get_info()["id"]
pipeline = FullQAPipeline(retrievers=retrievers) pipeline = FullQAPipeline(retrievers=retrievers)
pipeline.answering_pipeline.llm = llms.get_highest_accuracy() pipeline.answering_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.main_llm"]
]
pipeline.answering_pipeline.citation_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.citation_llm"]
]
pipeline.answering_pipeline.enable_citation = settings[
f"reasoning.options.{_id}.highlight_citation"
]
pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English" settings["reasoning.lang"], "English"
) )
@ -384,7 +394,7 @@ class FullQAPipeline(BaseReasoning):
return { return {
"highlight_citation": { "highlight_citation": {
"name": "Highlight Citation", "name": "Highlight Citation",
"value": True, "value": False,
"component": "checkbox", "component": "checkbox",
}, },
"citation_llm": { "citation_llm": {