Allow users to select reasoning pipeline. Fix small issues with user UI, cohere name (#50)

* Fix user page

* Allow changing LLM in reasoning pipeline

* Fix CohereEmbedding name
This commit is contained in:
Duc Nguyen (john) 2024-04-25 17:18:12 +07:00 committed by GitHub
parent e29bec6275
commit a8725710af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 44 additions and 13 deletions

View File

@ -3,7 +3,7 @@ from .endpoint_based import EndpointEmbeddings
from .fastembed import FastEmbedEmbeddings
from .langchain_based import (
LCAzureOpenAIEmbeddings,
LCCohereEmbdeddings,
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
LCOpenAIEmbeddings,
)
@ -14,7 +14,7 @@ __all__ = [
"EndpointEmbeddings",
"LCOpenAIEmbeddings",
"LCAzureOpenAIEmbeddings",
"LCCohereEmbdeddings",
"LCCohereEmbeddings",
"LCHuggingFaceEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",

View File

@ -159,7 +159,7 @@ class LCAzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
return AzureOpenAIEmbeddings
class LCCohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
def __init__(

View File

@ -9,7 +9,7 @@ from kotaemon.embeddings import (
AzureOpenAIEmbeddings,
FastEmbedEmbeddings,
LCAzureOpenAIEmbeddings,
LCCohereEmbdeddings,
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
OpenAIEmbeddings,
)
@ -148,7 +148,7 @@ def test_lchuggingface_embeddings(
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
)
def test_lccohere_embeddings(langchain_cohere_embedding_call):
model = LCCohereEmbdeddings(
model = LCCohereEmbeddings(
model="embed-english-light-v2.0", cohere_api_key="my-api-key"
)

View File

@ -1,4 +1,4 @@
from typing import Optional, Type
from typing import Optional, Type, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -71,6 +71,14 @@ class LLMManager:
"""Check if model exists"""
return key in self._models
@overload
def get(self, key: str, default: None) -> Optional[ChatLLM]:
...
@overload
def get(self, key: str, default: ChatLLM) -> ChatLLM:
...
def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
"""Get model by name with default value"""
return self._models.get(key, default)
@ -138,6 +146,7 @@ class LLMManager:
def add(self, name: str, spec: dict, default: bool):
"""Add a new model to the pool"""
name = name.strip()
if not name:
raise ValueError("Name must not be empty")

View File

@ -142,7 +142,7 @@ class UserManagement(BasePage):
)
self.admin_edit = gr.Checkbox(label="Admin")
with gr.Row() as self._selected_panel_btn:
with gr.Row(visible=False) as self._selected_panel_btn:
with gr.Column():
self.btn_edit_save = gr.Button("Save")
with gr.Column():
@ -338,7 +338,7 @@ class UserManagement(BasePage):
if not ev.selected:
return -1
return user_list["id"][ev.index[0]]
return int(user_list["id"][ev.index[0]])
def on_selected_user_change(self, selected_user_id):
if selected_user_id == -1:

View File

@ -680,12 +680,15 @@ class FullQAPipeline(BaseReasoning):
retrievers: the retrievers to use
"""
prefix = f"reasoning.options.{cls.get_info()['id']}"
pipeline = FullQAPipeline(retrievers=retrievers)
pipeline = cls(retrievers=retrievers)
llm_name = settings.get(f"{prefix}.llm", None)
llm = llms.get(llm_name, llms.get_default())
# answering pipeline configuration
answer_pipeline = pipeline.answering_pipeline
answer_pipeline.llm = llms.get_default()
answer_pipeline.citation_pipeline.llm = llms.get_default()
answer_pipeline.llm = llm
answer_pipeline.citation_pipeline.llm = llm
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
@ -694,14 +697,14 @@ class FullQAPipeline(BaseReasoning):
settings["reasoning.lang"], "English"
)
pipeline.add_query_context.llm = llms.get_default()
pipeline.add_query_context.llm = llm
pipeline.add_query_context.n_last_interactions = settings[
f"{prefix}.n_last_interactions"
]
pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
pipeline.rewrite_pipeline.llm = llms.get_default()
pipeline.rewrite_pipeline.llm = llm
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
@ -709,7 +712,26 @@ class FullQAPipeline(BaseReasoning):
@classmethod
def get_user_settings(cls) -> dict:
from ktem.llms.manager import llms
llm = ""
choices = [("(default)", "")]
try:
choices += [(_, _) for _ in llms.options().keys()]
except Exception as e:
logger.exception(f"Failed to get LLM options: {e}")
return {
"llm": {
"name": "Language model",
"value": llm,
"component": "dropdown",
"choices": choices,
"info": (
"The language model to use for generating the answer. If None, "
"the application default language model will be used."
),
},
"highlight_citation": {
"name": "Highlight Citation",
"value": False,