Allow users to select reasoning pipeline. Fix small issues with user UI, cohere name (#50)
* Fix user page * Allow changing LLM in reasoning pipeline * Fix CohereEmbedding name
This commit is contained in:
parent
e29bec6275
commit
a8725710af
|
@ -3,7 +3,7 @@ from .endpoint_based import EndpointEmbeddings
|
|||
from .fastembed import FastEmbedEmbeddings
|
||||
from .langchain_based import (
|
||||
LCAzureOpenAIEmbeddings,
|
||||
LCCohereEmbdeddings,
|
||||
LCCohereEmbeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
LCOpenAIEmbeddings,
|
||||
)
|
||||
|
@ -14,7 +14,7 @@ __all__ = [
|
|||
"EndpointEmbeddings",
|
||||
"LCOpenAIEmbeddings",
|
||||
"LCAzureOpenAIEmbeddings",
|
||||
"LCCohereEmbdeddings",
|
||||
"LCCohereEmbeddings",
|
||||
"LCHuggingFaceEmbeddings",
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAIEmbeddings",
|
||||
|
|
|
@ -159,7 +159,7 @@ class LCAzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
|||
return AzureOpenAIEmbeddings
|
||||
|
||||
|
||||
class LCCohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -9,7 +9,7 @@ from kotaemon.embeddings import (
|
|||
AzureOpenAIEmbeddings,
|
||||
FastEmbedEmbeddings,
|
||||
LCAzureOpenAIEmbeddings,
|
||||
LCCohereEmbdeddings,
|
||||
LCCohereEmbeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
OpenAIEmbeddings,
|
||||
)
|
||||
|
@ -148,7 +148,7 @@ def test_lchuggingface_embeddings(
|
|||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
||||
)
|
||||
def test_lccohere_embeddings(langchain_cohere_embedding_call):
|
||||
model = LCCohereEmbdeddings(
|
||||
model = LCCohereEmbeddings(
|
||||
model="embed-english-light-v2.0", cohere_api_key="my-api-key"
|
||||
)
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Type
|
||||
from typing import Optional, Type, overload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
@ -71,6 +71,14 @@ class LLMManager:
|
|||
"""Check if model exists"""
|
||||
return key in self._models
|
||||
|
||||
@overload
|
||||
def get(self, key: str, default: None) -> Optional[ChatLLM]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: str, default: ChatLLM) -> ChatLLM:
|
||||
...
|
||||
|
||||
def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
|
||||
"""Get model by name with default value"""
|
||||
return self._models.get(key, default)
|
||||
|
@ -138,6 +146,7 @@ class LLMManager:
|
|||
|
||||
def add(self, name: str, spec: dict, default: bool):
|
||||
"""Add a new model to the pool"""
|
||||
name = name.strip()
|
||||
if not name:
|
||||
raise ValueError("Name must not be empty")
|
||||
|
||||
|
|
|
@ -142,7 +142,7 @@ class UserManagement(BasePage):
|
|||
)
|
||||
self.admin_edit = gr.Checkbox(label="Admin")
|
||||
|
||||
with gr.Row() as self._selected_panel_btn:
|
||||
with gr.Row(visible=False) as self._selected_panel_btn:
|
||||
with gr.Column():
|
||||
self.btn_edit_save = gr.Button("Save")
|
||||
with gr.Column():
|
||||
|
@ -338,7 +338,7 @@ class UserManagement(BasePage):
|
|||
if not ev.selected:
|
||||
return -1
|
||||
|
||||
return user_list["id"][ev.index[0]]
|
||||
return int(user_list["id"][ev.index[0]])
|
||||
|
||||
def on_selected_user_change(self, selected_user_id):
|
||||
if selected_user_id == -1:
|
||||
|
|
|
@ -680,12 +680,15 @@ class FullQAPipeline(BaseReasoning):
|
|||
retrievers: the retrievers to use
|
||||
"""
|
||||
prefix = f"reasoning.options.{cls.get_info()['id']}"
|
||||
pipeline = FullQAPipeline(retrievers=retrievers)
|
||||
pipeline = cls(retrievers=retrievers)
|
||||
|
||||
llm_name = settings.get(f"{prefix}.llm", None)
|
||||
llm = llms.get(llm_name, llms.get_default())
|
||||
|
||||
# answering pipeline configuration
|
||||
answer_pipeline = pipeline.answering_pipeline
|
||||
answer_pipeline.llm = llms.get_default()
|
||||
answer_pipeline.citation_pipeline.llm = llms.get_default()
|
||||
answer_pipeline.llm = llm
|
||||
answer_pipeline.citation_pipeline.llm = llm
|
||||
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
|
||||
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
|
||||
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
|
||||
|
@ -694,14 +697,14 @@ class FullQAPipeline(BaseReasoning):
|
|||
settings["reasoning.lang"], "English"
|
||||
)
|
||||
|
||||
pipeline.add_query_context.llm = llms.get_default()
|
||||
pipeline.add_query_context.llm = llm
|
||||
pipeline.add_query_context.n_last_interactions = settings[
|
||||
f"{prefix}.n_last_interactions"
|
||||
]
|
||||
|
||||
pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
|
||||
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
||||
pipeline.rewrite_pipeline.llm = llms.get_default()
|
||||
pipeline.rewrite_pipeline.llm = llm
|
||||
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
||||
settings["reasoning.lang"], "English"
|
||||
)
|
||||
|
@ -709,7 +712,26 @@ class FullQAPipeline(BaseReasoning):
|
|||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
from ktem.llms.manager import llms
|
||||
|
||||
llm = ""
|
||||
choices = [("(default)", "")]
|
||||
try:
|
||||
choices += [(_, _) for _ in llms.options().keys()]
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get LLM options: {e}")
|
||||
|
||||
return {
|
||||
"llm": {
|
||||
"name": "Language model",
|
||||
"value": llm,
|
||||
"component": "dropdown",
|
||||
"choices": choices,
|
||||
"info": (
|
||||
"The language model to use for generating the answer. If None, "
|
||||
"the application default language model will be used."
|
||||
),
|
||||
},
|
||||
"highlight_citation": {
|
||||
"name": "Highlight Citation",
|
||||
"value": False,
|
||||
|
|
Loading…
Reference in New Issue
Block a user