improve llm selection for simple reasoning pipeline
This commit is contained in:
parent
b2089245f2
commit
f9cc40ca25
|
@ -17,8 +17,7 @@ if machine == "x86_64":
|
||||||
BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}"
|
BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}"
|
||||||
EXTENSION = ".exe" if os.name == "nt" else ""
|
EXTENSION = ".exe" if os.name == "nt" else ""
|
||||||
BINARY_URL = (
|
BINARY_URL = (
|
||||||
"some-endpoint.com"
|
"some-endpoint.com" f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
|
||||||
f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}"
|
BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}"
|
||||||
|
|
|
@ -194,7 +194,6 @@ class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
azure_endpoint: str | None = None,
|
azure_endpoint: str | None = None,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Common components, some kind of config"""
|
"""Common components, some kind of config"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -71,7 +72,7 @@ class ModelPool:
|
||||||
}
|
}
|
||||||
|
|
||||||
def options(self) -> dict:
|
def options(self) -> dict:
|
||||||
"""Present a list of models"""
|
"""Present a dict of models"""
|
||||||
return self._models
|
return self._models
|
||||||
|
|
||||||
def get_random_name(self) -> str:
|
def get_random_name(self) -> str:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
@ -8,7 +9,7 @@ from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ktem.components import embeddings, filestorage_path, llms
|
from ktem.components import embeddings, filestorage_path
|
||||||
from ktem.db.models import engine
|
from ktem.db.models import engine
|
||||||
from llama_index.vector_stores import (
|
from llama_index.vector_stores import (
|
||||||
FilterCondition,
|
FilterCondition,
|
||||||
|
@ -25,10 +26,12 @@ from theflow.utils.modules import import_dotted_string
|
||||||
from kotaemon.base import RetrievedDocument
|
from kotaemon.base import RetrievedDocument
|
||||||
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||||
from kotaemon.indices.ingests import DocumentIngestor
|
from kotaemon.indices.ingests import DocumentIngestor
|
||||||
from kotaemon.indices.rankings import BaseReranking, LLMReranking
|
from kotaemon.indices.rankings import BaseReranking
|
||||||
|
|
||||||
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def dev_settings():
|
def dev_settings():
|
||||||
|
@ -67,7 +70,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
|
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
|
||||||
embedding=embeddings.get_default(),
|
embedding=embeddings.get_default(),
|
||||||
)
|
)
|
||||||
reranker: BaseReranking = LLMReranking.withx(llm=llms.get_lowest_cost())
|
reranker: BaseReranking
|
||||||
get_extra_table: bool = False
|
get_extra_table: bool = False
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
|
@ -153,7 +156,23 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_user_settings(cls) -> dict:
|
def get_user_settings(cls) -> dict:
|
||||||
|
from ktem.components import llms
|
||||||
|
|
||||||
|
try:
|
||||||
|
reranking_llm = llms.get_lowest_cost_name()
|
||||||
|
reranking_llm_choices = list(llms.options().keys())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
reranking_llm = None
|
||||||
|
reranking_llm_choices = []
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"reranking_llm": {
|
||||||
|
"name": "LLM for reranking",
|
||||||
|
"value": reranking_llm,
|
||||||
|
"component": "dropdown",
|
||||||
|
"choices": reranking_llm_choices,
|
||||||
|
},
|
||||||
"separate_embedding": {
|
"separate_embedding": {
|
||||||
"name": "Use separate embedding",
|
"name": "Use separate embedding",
|
||||||
"value": False,
|
"value": False,
|
||||||
|
@ -185,7 +204,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
},
|
},
|
||||||
"use_reranking": {
|
"use_reranking": {
|
||||||
"name": "Use reranking",
|
"name": "Use reranking",
|
||||||
"value": True,
|
"value": False,
|
||||||
"choices": [True, False],
|
"choices": [True, False],
|
||||||
"component": "checkbox",
|
"component": "checkbox",
|
||||||
},
|
},
|
||||||
|
@ -199,7 +218,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
settings: the settings of the app
|
settings: the settings of the app
|
||||||
kwargs: other arguments
|
kwargs: other arguments
|
||||||
"""
|
"""
|
||||||
retriever = cls(get_extra_table=user_settings["prioritize_table"])
|
retriever = cls(
|
||||||
|
get_extra_table=user_settings["prioritize_table"],
|
||||||
|
reranker=user_settings["reranking_llm"],
|
||||||
|
)
|
||||||
if not user_settings["use_reranking"]:
|
if not user_settings["use_reranking"]:
|
||||||
retriever.reranker = None # type: ignore
|
retriever.reranker = None # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -159,6 +159,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
|
||||||
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
|
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
|
||||||
|
|
||||||
|
enable_citation: bool = False
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
lang: str = "English" # support English and Japanese
|
lang: str = "English" # support English and Japanese
|
||||||
|
|
||||||
|
@ -200,7 +201,8 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
lang=self.lang,
|
lang=self.lang,
|
||||||
)
|
)
|
||||||
|
|
||||||
if evidence:
|
citation_task = None
|
||||||
|
if evidence and self.enable_citation:
|
||||||
citation_task = asyncio.create_task(
|
citation_task = asyncio.create_task(
|
||||||
self.citation_pipeline.ainvoke(context=evidence, question=question)
|
self.citation_pipeline.ainvoke(context=evidence, question=question)
|
||||||
)
|
)
|
||||||
|
@ -226,7 +228,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
|
|
||||||
# retrieve the citation
|
# retrieve the citation
|
||||||
print("Waiting for citation task")
|
print("Waiting for citation task")
|
||||||
if evidence:
|
if citation_task is not None:
|
||||||
citation = await citation_task
|
citation = await citation_task
|
||||||
else:
|
else:
|
||||||
citation = None
|
citation = None
|
||||||
|
@ -353,7 +355,15 @@ class FullQAPipeline(BaseReasoning):
|
||||||
_id = cls.get_info()["id"]
|
_id = cls.get_info()["id"]
|
||||||
|
|
||||||
pipeline = FullQAPipeline(retrievers=retrievers)
|
pipeline = FullQAPipeline(retrievers=retrievers)
|
||||||
pipeline.answering_pipeline.llm = llms.get_highest_accuracy()
|
pipeline.answering_pipeline.llm = llms[
|
||||||
|
settings[f"reasoning.options.{_id}.main_llm"]
|
||||||
|
]
|
||||||
|
pipeline.answering_pipeline.citation_pipeline.llm = llms[
|
||||||
|
settings[f"reasoning.options.{_id}.citation_llm"]
|
||||||
|
]
|
||||||
|
pipeline.answering_pipeline.enable_citation = settings[
|
||||||
|
f"reasoning.options.{_id}.highlight_citation"
|
||||||
|
]
|
||||||
pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
||||||
settings["reasoning.lang"], "English"
|
settings["reasoning.lang"], "English"
|
||||||
)
|
)
|
||||||
|
@ -384,7 +394,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
return {
|
return {
|
||||||
"highlight_citation": {
|
"highlight_citation": {
|
||||||
"name": "Highlight Citation",
|
"name": "Highlight Citation",
|
||||||
"value": True,
|
"value": False,
|
||||||
"component": "checkbox",
|
"component": "checkbox",
|
||||||
},
|
},
|
||||||
"citation_llm": {
|
"citation_llm": {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user