Allow users to add LLM within the UI (#6)

* Rename AzureChatOpenAI to LCAzureChatOpenAI
* Provide vanilla ChatOpenAI and AzureChatOpenAI
* Remove the highest accuracy, lowest cost criteria

These criteria are unnecessary. The users, not pipeline creators, should choose
which LLM to use. Furthermore, it's cumbersome to input this information,
really degrades user experience.

* Remove the LLM selection in simple reasoning pipeline
* Provide a dedicated stream method to generate the output
* Return placeholder message to chat if the text is empty
This commit is contained in:
Duc Nguyen (john)
2024-04-06 11:53:17 +07:00
committed by GitHub
parent e187e23dd1
commit a203fc0f7c
35 changed files with 1339 additions and 169 deletions

View File

@@ -40,16 +40,15 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
):
if config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""):
KH_LLMS["azure"] = {
"def": {
"spec": {
"__type__": "kotaemon.llms.AzureChatOpenAI",
"temperature": 0,
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
"api_key": config("AZURE_OPENAI_API_KEY", default=""),
"api_version": config("OPENAI_API_VERSION", default="")
or "2024-02-15-preview",
"deployment_name": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""),
"request_timeout": 10,
"stream": False,
"azure_deployment": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""),
"timeout": 20,
},
"default": False,
"accuracy": 5,
@@ -57,7 +56,7 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
}
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
KH_EMBEDDINGS["azure"] = {
"def": {
"spec": {
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
@@ -164,5 +163,11 @@ KH_INDICES = [
"name": "File",
"config": {},
"index_type": "ktem.index.file.FileIndex",
}
},
{
"id": 2,
"name": "Sample",
"config": {},
"index_type": "ktem.index.file.FileIndex",
},
]

View File

@@ -3,6 +3,7 @@
import logging
from functools import cache
from pathlib import Path
from typing import Optional
from theflow.settings import settings
from theflow.utils.modules import deserialize
@@ -48,7 +49,7 @@ class ModelPool:
self._default: list[str] = []
for name, model in conf.items():
self._models[name] = deserialize(model["def"], safe=False)
self._models[name] = deserialize(model["spec"], safe=False)
if model.get("default", False):
self._default.append(name)
@@ -58,11 +59,27 @@ class ModelPool:
self._cost = list(sorted(conf, key=lambda x: conf[x].get("cost", float("inf"))))
def __getitem__(self, key: str) -> BaseComponent:
"""Get model by name"""
return self._models[key]
def __setitem__(self, key: str, value: BaseComponent):
"""Set model by name"""
self._models[key] = value
def __delitem__(self, key: str):
"""Delete model by name"""
del self._models[key]
def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models
def get(
self, key: str, default: Optional[BaseComponent] = None
) -> Optional[BaseComponent]:
"""Get model by name with default value"""
return self._models.get(key, default)
def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
@@ -169,4 +186,3 @@ llms = ModelPool("LLMs", settings.KH_LLMS)
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
reasonings: dict = {}
tools = ModelPool("Tools", {})
indices = ModelPool("Indices", {})

View File

@@ -157,10 +157,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
@classmethod
def get_user_settings(cls) -> dict:
from ktem.components import llms
from ktem.llms.manager import llms
try:
reranking_llm = llms.get_lowest_cost_name()
reranking_llm = llms.get_default_name()
reranking_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)

View File

36
libs/ktem/ktem/llms/db.py Normal file
View File

@@ -0,0 +1,36 @@
from typing import Type
from ktem.db.engine import engine
from sqlalchemy import JSON, Boolean, Column, String
from sqlalchemy.orm import DeclarativeBase
from theflow.settings import settings as flowsettings
from theflow.utils.modules import import_dotted_string
class Base(DeclarativeBase):
pass
class BaseLLMTable(Base):
"""Base table to store language model"""
__abstract__ = True
name = Column(String, primary_key=True, unique=True)
spec = Column(JSON, default={})
default = Column(Boolean, default=False)
_base_llm: Type[BaseLLMTable] = (
import_dotted_string(flowsettings.KH_TABLE_LLM, safe=False)
if hasattr(flowsettings, "KH_TABLE_LLM")
else BaseLLMTable
)
class LLMTable(_base_llm): # type: ignore
__tablename__ = "llm_table"
if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
LLMTable.metadata.create_all(engine)

View File

@@ -0,0 +1,191 @@
from typing import Optional, Type
from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize
from kotaemon.base import BaseComponent
from .db import LLMTable, engine
class LLMManager:
"""Represent a pool of models"""
def __init__(self):
self._models: dict[str, BaseComponent] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []
if hasattr(flowsettings, "KH_LLMS"):
for name, model in flowsettings.KH_LLMS.items():
with Session(engine) as session:
stmt = select(LLMTable).where(LLMTable.name == name)
result = session.execute(stmt)
if not result.first():
item = LLMTable(
name=name,
spec=model["spec"],
default=model.get("default", False),
)
session.add(item)
session.commit()
self.load()
self.load_vendors()
def load(self):
"""Load the model pool from database"""
self._models, self._info, self._defaut = {}, {}, ""
with Session(engine) as session:
stmt = select(LLMTable)
items = session.execute(stmt)
for (item,) in items:
self._models[item.name] = deserialize(item.spec, safe=False)
self._info[item.name] = {
"name": item.name,
"spec": item.spec,
"default": item.default,
}
if item.default:
self._default = item.name
def load_vendors(self):
from kotaemon.llms import (
AzureChatOpenAI,
ChatOpenAI,
EndpointChatLLM,
LlamaCppChat,
)
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
def __getitem__(self, key: str) -> BaseComponent:
"""Get model by name"""
return self._models[key]
def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models
def get(
self, key: str, default: Optional[BaseComponent] = None
) -> Optional[BaseComponent]:
"""Get model by name with default value"""
return self._models.get(key, default)
def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
"label": "LLM",
"choices": list(self._models.keys()),
"value": self.get_default_name(),
}
def options(self) -> dict:
"""Present a dict of models"""
return self._models
def get_random_name(self) -> str:
"""Get the name of random model
Returns:
str: random model name in the pool
"""
import random
if not self._models:
raise ValueError("No models in pool")
return random.choice(list(self._models.keys()))
def get_default_name(self) -> str:
"""Get the name of default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
str: model name
"""
if not self._models:
raise ValueError("No models in pool")
if not self._default:
return self.get_random_name()
return self._default
def get_random(self) -> BaseComponent:
"""Get random model"""
return self._models[self.get_random_name()]
def get_default(self) -> BaseComponent:
"""Get default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
BaseComponent: model
"""
return self._models[self.get_default_name()]
def info(self) -> dict:
"""List all models"""
return self._info
def add(self, name: str, spec: dict, default: bool):
"""Add a new model to the pool"""
try:
with Session(engine) as session:
item = LLMTable(name=name, spec=spec, default=default)
session.add(item)
session.commit()
except Exception as e:
raise ValueError(f"Failed to add model {name}: {e}")
self.load()
def delete(self, name: str):
"""Delete a model from the pool"""
try:
with Session(engine) as session:
item = session.query(LLMTable).filter_by(name=name).first()
session.delete(item)
session.commit()
except Exception as e:
raise ValueError(f"Failed to delete model {name}: {e}")
self.load()
def update(self, name: str, spec: dict, default: bool):
"""Update a model in the pool"""
try:
with Session(engine) as session:
if default:
# turn all models to non-default
session.query(LLMTable).update({"default": False})
session.commit()
item = session.query(LLMTable).filter_by(name=name).first()
if not item:
raise ValueError(f"Model {name} not found")
item.spec = spec
item.default = default
session.commit()
except Exception as e:
raise ValueError(f"Failed to update model {name}: {e}")
self.load()
def vendors(self) -> dict:
"""Return list of vendors"""
return {vendor.__qualname__: vendor for vendor in self._vendors}
llms = LLMManager()

318
libs/ktem/ktem/llms/ui.py Normal file
View File

@@ -0,0 +1,318 @@
from copy import deepcopy
import gradio as gr
import pandas as pd
import yaml
from ktem.app import BasePage
from .manager import llms
def format_description(cls):
params = cls.describe()["params"]
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
for key, value in params.items():
if isinstance(value["auto_callback"], str):
continue
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
class LLMManagement(BasePage):
def __init__(self, app):
self._app = app
self.spec_desc_default = (
"# Spec description\n\nSelect an LLM to view the spec description."
)
self.on_building_ui()
def on_building_ui(self):
with gr.Tab(label="View"):
self.llm_list = gr.DataFrame(
headers=["name", "vendor", "default"],
interactive=False,
)
with gr.Column(visible=False) as self._selected_panel:
self.selected_llm_name = gr.Textbox(value="", visible=False)
with gr.Row():
with gr.Column():
self.edit_default = gr.Checkbox(
label="Set default",
info=(
"Set this LLM as default. If no default is set, a "
"random LLM will be used."
),
)
self.edit_spec = gr.Textbox(
label="Specification",
info="Specification of the LLM in YAML format",
lines=10,
)
with gr.Row(visible=False) as self._selected_panel_btn:
with gr.Column():
self.btn_edit_save = gr.Button("Save", min_width=10)
with gr.Column():
self.btn_delete = gr.Button("Delete", min_width=10)
with gr.Row():
self.btn_delete_yes = gr.Button(
"Confirm delete",
variant="primary",
visible=False,
min_width=10,
)
self.btn_delete_no = gr.Button(
"Cancel", visible=False, min_width=10
)
with gr.Column():
self.btn_close = gr.Button("Close", min_width=10)
with gr.Column():
self.edit_spec_desc = gr.Markdown("# Spec description")
with gr.Tab(label="Add"):
with gr.Row():
with gr.Column(scale=2):
self.name = gr.Textbox(
label="LLM name",
info=(
"Must be unique. The name will be used to identify the LLM."
),
)
self.llm_choices = gr.Dropdown(
label="LLM vendors",
info=(
"Choose the vendor for the LLM. Each vendor has different "
"specification."
),
)
self.spec = gr.Textbox(
label="Specification",
info="Specification of the LLM in YAML format",
)
self.default = gr.Checkbox(
label="Set default",
info=(
"Set this LLM as default. This default LLM will be used "
"by default across the application."
),
)
self.btn_new = gr.Button("Create LLM")
with gr.Column(scale=3):
self.spec_desc = gr.Markdown(self.spec_desc_default)
def _on_app_created(self):
"""Called when the app is created"""
self._app.app.load(
self.list_llms,
inputs=None,
outputs=[self.llm_list],
)
self._app.app.load(
lambda: gr.update(choices=list(llms.vendors().keys())),
outputs=[self.llm_choices],
)
def on_llm_vendor_change(self, vendor):
vendor = llms.vendors()[vendor]
required: dict = {}
desc = vendor.describe()
for key, value in desc["params"].items():
if value.get("required", False):
required[key] = None
return yaml.dump(required), format_description(vendor)
def on_register_events(self):
self.llm_choices.select(
self.on_llm_vendor_change,
inputs=[self.llm_choices],
outputs=[self.spec, self.spec_desc],
)
self.btn_new.click(
self.create_llm,
inputs=[self.name, self.llm_choices, self.spec, self.default],
outputs=None,
).then(self.list_llms, inputs=None, outputs=[self.llm_list],).then(
lambda: ("", None, "", False, self.spec_desc_default),
outputs=[
self.name,
self.llm_choices,
self.spec,
self.default,
self.spec_desc,
],
)
self.llm_list.select(
self.select_llm,
inputs=self.llm_list,
outputs=[self.selected_llm_name],
show_progress="hidden",
)
self.selected_llm_name.change(
self.on_selected_llm_change,
inputs=[self.selected_llm_name],
outputs=[
self._selected_panel,
self._selected_panel_btn,
# delete section
self.btn_delete,
self.btn_delete_yes,
self.btn_delete_no,
# edit section
self.edit_spec,
self.edit_spec_desc,
self.edit_default,
],
show_progress="hidden",
)
self.btn_delete.click(
self.on_btn_delete_click,
inputs=None,
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
self.btn_delete_yes.click(
self.delete_llm,
inputs=[self.selected_llm_name],
outputs=[self.selected_llm_name],
show_progress="hidden",
).then(
self.list_llms,
inputs=None,
outputs=[self.llm_list],
)
self.btn_delete_no.click(
lambda: (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
),
inputs=None,
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
self.btn_edit_save.click(
self.save_llm,
inputs=[
self.selected_llm_name,
self.edit_default,
self.edit_spec,
],
show_progress="hidden",
).then(
self.list_llms,
inputs=None,
outputs=[self.llm_list],
)
self.btn_close.click(
lambda: "",
outputs=[self.selected_llm_name],
)
def create_llm(self, name, choices, spec, default):
try:
spec = yaml.safe_load(spec)
spec["__type__"] = (
llms.vendors()[choices].__module__
+ "."
+ llms.vendors()[choices].__qualname__
)
llms.add(name, spec=spec, default=default)
gr.Info(f"LLM {name} created successfully")
except Exception as e:
gr.Error(f"Failed to create LLM {name}: {e}")
def list_llms(self):
"""List the LLMs"""
items = []
for item in llms.info().values():
record = {}
record["name"] = item["name"]
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
record["default"] = item["default"]
items.append(record)
if items:
llm_list = pd.DataFrame.from_records(items)
else:
llm_list = pd.DataFrame.from_records(
[{"name": "-", "vendor": "-", "default": "-"}]
)
return llm_list
def select_llm(self, llm_list, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0:
gr.Info("No LLM is loaded. Please add LLM first")
return ""
if not ev.selected:
return ""
return llm_list["name"][ev.index[0]]
def on_selected_llm_change(self, selected_llm_name):
if selected_llm_name == "":
_selected_panel = gr.update(visible=False)
_selected_panel_btn = gr.update(visible=False)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
edit_spec = gr.update(value="")
edit_spec_desc = gr.update(value="")
edit_default = gr.update(value=False)
else:
_selected_panel = gr.update(visible=True)
_selected_panel_btn = gr.update(visible=True)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
info = deepcopy(llms.info()[selected_llm_name])
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
vendor = llms.vendors()[vendor_str]
edit_spec = yaml.dump(info["spec"])
edit_spec_desc = format_description(vendor)
edit_default = info["default"]
return (
_selected_panel,
_selected_panel_btn,
btn_delete,
btn_delete_yes,
btn_delete_no,
edit_spec,
edit_spec_desc,
edit_default,
)
def on_btn_delete_click(self):
btn_delete = gr.update(visible=False)
btn_delete_yes = gr.update(visible=True)
btn_delete_no = gr.update(visible=True)
return btn_delete, btn_delete_yes, btn_delete_no
def save_llm(self, selected_llm_name, default, spec):
try:
spec = yaml.safe_load(spec)
spec["__type__"] = llms.info()[selected_llm_name]["spec"]["__type__"]
llms.update(selected_llm_name, spec=spec, default=default)
gr.Info(f"LLM {selected_llm_name} saved successfully")
except Exception as e:
gr.Error(f"Failed to save LLM {selected_llm_name}: {e}")
def delete_llm(self, selected_llm_name):
try:
llms.delete(selected_llm_name)
except Exception as e:
gr.Error(f"Failed to delete LLM {selected_llm_name}: {e}")
return selected_llm_name
return ""

View File

@@ -1,6 +1,7 @@
import gradio as gr
from ktem.app import BasePage
from ktem.db.models import User, engine
from ktem.llms.ui import LLMManagement
from sqlmodel import Session, select
from .user import UserManagement
@@ -16,6 +17,9 @@ class AdminPage(BasePage):
with gr.Tab("User Management", visible=False) as self.user_management_tab:
self.user_management = UserManagement(self._app)
with gr.Tab("LLM Management") as self.llm_management_tab:
self.llm_management = LLMManagement(self._app)
def on_subscribe_public_events(self):
if self._app.f_user_management:
self._app.subscribe_event(

View File

@@ -9,6 +9,8 @@ from ktem.db.models import Conversation, engine
from sqlmodel import Session, select
from theflow.settings import settings as flowsettings
from kotaemon.base import Document
from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
from .common import STATE
@@ -189,6 +191,7 @@ class ChatPage(BasePage):
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
]
+ self._indices_input,
show_progress="hidden",
@@ -220,6 +223,7 @@ class ChatPage(BasePage):
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
]
+ self._indices_input,
show_progress="hidden",
@@ -392,7 +396,7 @@ class ChatPage(BasePage):
return pipeline, reasoning_state
async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
"""Chat function"""
chat_input = chat_history[-1][0]
chat_history = chat_history[:-1]
@@ -403,52 +407,43 @@ class ChatPage(BasePage):
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
pipeline.set_output_queue(queue)
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
text, refs = "", ""
len_ref = -1 # for logging purpose
msg_placeholder = getattr(
flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..."
)
print(msg_placeholder)
while True:
try:
response = queue.get_nowait()
except Exception:
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [
(chat_input, text or msg_placeholder)
], refs, state
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
len_ref = -1 # for logging purpose
for response in pipeline.stream(chat_input, conversation_id, chat_history):
if not isinstance(response, Document):
continue
if response is None:
queue.task_done()
print("Chat completed")
break
if response.channel is None:
continue
if "output" in response:
if response["output"] is None:
if response.channel == "chat":
if response.content is None:
text = ""
else:
text += response["output"]
text += response.content
if "evidence" in response:
if response["evidence"] is None:
if response.channel == "info":
if response.content is None:
refs = ""
else:
refs += response["evidence"]
refs += response.content
if len(refs) > len_ref:
print(f"Len refs: {len(refs)}")
len_ref = len(refs)
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [(chat_input, text)], refs, state
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
async def regen_fn(
self, conversation_id, chat_history, settings, state, *selecteds
):
def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds):
"""Regen function"""
if not chat_history:
gr.Warning("Empty chat")
@@ -456,12 +451,11 @@ class ChatPage(BasePage):
return
state["app"]["regen"] = True
async for chat, refs, state in self.chat_fn(
for chat, refs, state in self.chat_fn(
conversation_id, chat_history, settings, state, *selecteds
):
new_state = deepcopy(state)
new_state["app"]["regen"] = False
yield chat, refs, new_state
else:
state["app"]["regen"] = False
yield chat_history, "", state
state["app"]["regen"] = False

View File

@@ -4,10 +4,10 @@ import logging
import re
from collections import defaultdict
from functools import partial
from typing import Generator
import tiktoken
from ktem.components import llms
from theflow.settings import settings as flowsettings
from ktem.llms.manager import llms
from kotaemon.base import (
BaseComponent,
@@ -190,10 +190,10 @@ class AnswerWithContextPipeline(BaseComponent):
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
vlm_endpoint: str = ""
citation_pipeline: CitationPipeline = Node(
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
)
qa_template: str = DEFAULT_QA_TEXT_PROMPT
@@ -297,13 +297,95 @@ class AnswerWithContextPipeline(BaseComponent):
return answer
def stream( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Generator[Document, None, Document]:
"""Answer the question based on the evidence
def extract_evidence_images(self, evidence: str):
"""Util function to extract and isolate images from context/evidence"""
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
matches = re.findall(image_pattern, evidence)
context = re.sub(image_pattern, "", evidence)
return context, matches
In addition to the question and the evidence, this method also take into
account evidence_mode. The evidence_mode tells which kind of evidence is.
The kind of evidence affects:
1. How the evidence is represented.
2. The prompt to generate the answer.
By default, the evidence_mode is 0, which means the evidence is plain text with
no particular semantic representation. The evidence_mode can be:
1. "table": There will be HTML markup telling that there is a table
within the evidence.
2. "chatbot": There will be HTML markup telling that there is a chatbot.
This chatbot is a scenario, extracted from an Excel file, where each
row corresponds to an interaction.
Args:
question: the original question posed by user
evidence: the text that contain relevant information to answer the question
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
images = []
if evidence_mode == EVIDENCE_MODE_FIGURE:
# isolate image from evidence
evidence, images = self.extract_evidence_images(evidence)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
else:
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
output = ""
if evidence_mode == EVIDENCE_MODE_FIGURE:
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
output += text
yield Document(channel="chat", content=text)
else:
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
messages.append(HumanMessage(content=prompt))
try:
# try streaming first
print("Trying LLM streaming")
for text in self.llm.stream(messages):
output += text.text
yield Document(channel="chat", content=text.text)
except NotImplementedError:
print("Streaming is not supported, falling back to normal processing")
output = self.llm(messages).text
yield Document(channel="chat", content=output)
# retrieve the citation
citation = None
if evidence and self.enable_citation:
citation = self.citation_pipeline.invoke(
context=evidence, question=question
)
answer = Document(text=output, metadata={"citation": citation})
return answer
def extract_evidence_images(self, evidence: str):
"""Util function to extract and isolate images from context/evidence"""
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
matches = re.findall(image_pattern, evidence)
context = re.sub(image_pattern, "", evidence)
return context, matches
class RewriteQuestionPipeline(BaseComponent):
@@ -315,27 +397,19 @@ class RewriteQuestionPipeline(BaseComponent):
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost())
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
async def run(self, question: str) -> Document: # type: ignore
def run(self, question: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.rewrite_template)
prompt = prompt_template.populate(question=question, lang=self.lang)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
output = ""
for text in self.llm(messages):
if "content" in text:
output += text[1]
self.report_output({"chat_input": text[1]})
break
await asyncio.sleep(0)
return Document(text=output)
return self.llm(messages)
class FullQAPipeline(BaseReasoning):
@@ -351,7 +425,7 @@ class FullQAPipeline(BaseReasoning):
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
use_rewrite: bool = False
async def run( # type: ignore
async def ainvoke( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
import markdown
@@ -482,6 +556,132 @@ class FullQAPipeline(BaseReasoning):
self.report_output(None)
return answer
def stream( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Generator[Document, None, Document]:
import markdown
docs = []
doc_ids = []
if self.use_rewrite:
message = self.rewrite_pipeline(question=message).text
for retriever in self.retrievers:
for doc in retriever(text=message):
if doc.doc_id not in doc_ids:
docs.append(doc)
doc_ids.append(doc.doc_id)
for doc in docs:
# TODO: a better approach to show the information
text = markdown.markdown(
doc.text, extensions=["markdown.extensions.tables"]
)
yield Document(
content=(
"<details open>"
f"<summary>{doc.metadata['file_name']}</summary>"
f"{text}"
"</details><br>"
),
channel="info",
)
evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = yield from self.answering_pipeline.stream(
question=message,
history=history,
evidence=evidence,
evidence_mode=evidence_mode,
conv_id=conv_id,
**kwargs,
)
# prepare citation
spans = defaultdict(list)
if answer.metadata["citation"] is not None:
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx == -1:
continue
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{"start": start_idx, "end": end_idx}
)
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
spans[doc.doc_id].append(
{
"start": current_idx,
"end": current_idx + match_idx,
}
)
current_idx += match_idx + 2
if current_idx > end_idx:
break
break
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
not_detected = set(id2docs.keys()) - set(spans.keys())
yield Document(channel="info", content=None)
for id, ss in spans.items():
if not ss:
not_detected.add(id)
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += (
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
)
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
text_out = markdown.markdown(
text, extensions=["markdown.extensions.tables"]
)
yield Document(
content=(
"<details open>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
),
channel="info",
)
lack_evidence = False
if lack_evidence:
yield Document(channel="info", content="No evidence found.\n")
if not_detected:
yield Document(
channel="info",
content="Retrieved segments without matching evidence:\n",
)
for id in list(not_detected):
text_out = markdown.markdown(
id2docs[id].text, extensions=["markdown.extensions.tables"]
)
yield Document(
content=(
"<details>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
),
channel="info",
)
return answer
@classmethod
def get_pipeline(cls, settings, states, retrievers):
"""Get the reasoning pipeline
@@ -493,12 +693,9 @@ class FullQAPipeline(BaseReasoning):
_id = cls.get_info()["id"]
pipeline = FullQAPipeline(retrievers=retrievers)
pipeline.answering_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.main_llm"]
]
pipeline.answering_pipeline.citation_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.citation_llm"]
]
pipeline.answering_pipeline.llm = llms.get_default()
pipeline.answering_pipeline.citation_pipeline.llm = llms.get_default()
pipeline.answering_pipeline.enable_citation = settings[
f"reasoning.options.{_id}.highlight_citation"
]
@@ -512,7 +709,7 @@ class FullQAPipeline(BaseReasoning):
f"reasoning.options.{_id}.qa_prompt"
]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
pipeline.rewrite_pipeline.llm = llms.get_lowest_cost()
pipeline.rewrite_pipeline.llm = llms.get_default()
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
@@ -520,38 +717,12 @@ class FullQAPipeline(BaseReasoning):
@classmethod
def get_user_settings(cls) -> dict:
from ktem.components import llms
try:
citation_llm = llms.get_lowest_cost_name()
citation_llm_choices = list(llms.options().keys())
main_llm = llms.get_highest_accuracy_name()
main_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)
citation_llm = None
citation_llm_choices = []
main_llm = None
main_llm_choices = []
return {
"highlight_citation": {
"name": "Highlight Citation",
"value": False,
"component": "checkbox",
},
"citation_llm": {
"name": "LLM for citation",
"value": citation_llm,
"component": "dropdown",
"choices": citation_llm_choices,
},
"main_llm": {
"name": "LLM for main generation",
"value": main_llm,
"component": "dropdown",
"choices": main_llm_choices,
},
"system_prompt": {
"name": "System Prompt",
"value": "This is a question answering system",

View File

@@ -7,7 +7,7 @@ from index import ReaderIndexingPipeline
from openai.resources.embeddings import Embeddings
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.llms import AzureChatOpenAI
from kotaemon.llms import LCAzureChatOpenAI
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
openai_embedding = json.load(f)
@@ -61,7 +61,7 @@ def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path):
assert len(results) == 1
# create llm
llm = AzureChatOpenAI(
llm = LCAzureChatOpenAI(
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
openai_api_version="2023-03-15-preview",

View File

@@ -2,4 +2,4 @@ from ktem.main import App
app = App()
demo = app.make()
demo.queue().launch(favicon_path=app._favicon, inbrowser=True)
demo.queue().launch(favicon_path=app._favicon)