feat: tweak the 'Chat suggestion' feature to tie it to conversations (#341) #none
Signed-off-by: Kennywu <jdlow@live.cn>
This commit is contained in:
parent
96f58a445a
commit
49a083fd9f
|
@ -63,6 +63,7 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface")
|
|||
KH_DOC_DIR = this_dir / "docs"
|
||||
|
||||
KH_MODE = "dev"
|
||||
KH_FEATURE_CHAT_SUGGESTION = config("KH_FEATURE_CHAT_SUGGESTION", default=False)
|
||||
KH_FEATURE_USER_MANAGEMENT = config(
|
||||
"KH_FEATURE_USER_MANAGEMENT", default=True, cast=bool
|
||||
)
|
||||
|
|
|
@ -34,7 +34,7 @@ class BaseConversation(SQLModel):
|
|||
|
||||
is_public: bool = Field(default=False)
|
||||
|
||||
# contains messages + current files
|
||||
# contains messages + current files + chat_suggestions
|
||||
data_source: dict = Field(default={}, sa_column=Column(JSON))
|
||||
|
||||
date_created: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
import ast
|
||||
import asyncio
|
||||
import csv
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
from filelock import FileLock
|
||||
from ktem.app import BasePage
|
||||
from ktem.components import reasonings
|
||||
from ktem.db.models import Conversation, engine
|
||||
|
@ -10,6 +16,9 @@ from ktem.index.file.ui import File
|
|||
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
|
||||
SuggestConvNamePipeline,
|
||||
)
|
||||
from ktem.reasoning.prompt_optimization.suggest_followup_chat import (
|
||||
SuggestFollowupQuesPipeline,
|
||||
)
|
||||
from plotly.io import from_json
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
@ -17,8 +26,8 @@ from theflow.settings import settings as flowsettings
|
|||
from kotaemon.base import Document
|
||||
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
||||
|
||||
from ...utils import SUPPORTED_LANGUAGE_MAP
|
||||
from .chat_panel import ChatPanel
|
||||
from .chat_suggestion import ChatSuggestion
|
||||
from .common import STATE
|
||||
from .control import ConversationControl
|
||||
from .report import ReportIssue
|
||||
|
@ -50,6 +59,7 @@ class ChatPage(BasePage):
|
|||
self._reasoning_type = gr.State(value=None)
|
||||
self._llm_type = gr.State(value=None)
|
||||
self._conversation_renamed = gr.State(value=False)
|
||||
self._suggestion_updated = gr.State(value=False)
|
||||
self._info_panel_expanded = gr.State(value=True)
|
||||
|
||||
def on_building_ui(self):
|
||||
|
@ -58,13 +68,11 @@ class ChatPage(BasePage):
|
|||
self.state_retrieval_history = gr.State([])
|
||||
self.state_plot_history = gr.State([])
|
||||
self.state_plot_panel = gr.State(None)
|
||||
self.state_follow_up = gr.State(None)
|
||||
|
||||
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
|
||||
self.chat_control = ConversationControl(self._app)
|
||||
|
||||
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||
self.chat_suggestion = ChatSuggestion(self._app)
|
||||
|
||||
for index_id, index in enumerate(self._app.index_manager.indices):
|
||||
index.selector = None
|
||||
index_ui = index.get_selector_component_ui()
|
||||
|
@ -156,6 +164,11 @@ class ChatPage(BasePage):
|
|||
return plot
|
||||
|
||||
def on_register_events(self):
|
||||
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||
self.state_follow_up = self.chat_control.chat_suggestion.example
|
||||
else:
|
||||
self.state_follow_up = self.chat_control.followup_suggestions
|
||||
|
||||
gr.on(
|
||||
triggers=[
|
||||
self.chat_panel.text_input.submit,
|
||||
|
@ -168,6 +181,7 @@ class ChatPage(BasePage):
|
|||
self._app.user_id,
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation_rn,
|
||||
self.state_follow_up,
|
||||
],
|
||||
outputs=[
|
||||
self.chat_panel.text_input,
|
||||
|
@ -175,6 +189,7 @@ class ChatPage(BasePage):
|
|||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.state_follow_up,
|
||||
],
|
||||
concurrency_limit=20,
|
||||
show_progress="hidden",
|
||||
|
@ -225,6 +240,30 @@ class ChatPage(BasePage):
|
|||
self.chat_control.conversation_rn,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
fn=self.suggest_chat_conv,
|
||||
inputs=[
|
||||
self._app.settings_state,
|
||||
self.chat_panel.chatbot,
|
||||
],
|
||||
outputs=[
|
||||
self.state_follow_up,
|
||||
self._suggestion_updated,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).success(
|
||||
self.chat_control.update_chat_suggestions,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.state_follow_up,
|
||||
self._suggestion_updated,
|
||||
self._app.user_id,
|
||||
],
|
||||
outputs=[
|
||||
self.chat_control.conversation,
|
||||
self.chat_control.conversation,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
fn=self.persist_data_source,
|
||||
inputs=[
|
||||
|
@ -292,6 +331,30 @@ class ChatPage(BasePage):
|
|||
self.chat_control.conversation_rn,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
fn=self.suggest_chat_conv,
|
||||
inputs=[
|
||||
self._app.settings_state,
|
||||
self.chat_panel.chatbot,
|
||||
],
|
||||
outputs=[
|
||||
self.state_follow_up,
|
||||
self._suggestion_updated,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).success(
|
||||
self.chat_control.update_chat_suggestions,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.state_follow_up,
|
||||
self._suggestion_updated,
|
||||
self._app.user_id,
|
||||
],
|
||||
outputs=[
|
||||
self.chat_control.conversation,
|
||||
self.chat_control.conversation,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
fn=self.persist_data_source,
|
||||
inputs=[
|
||||
|
@ -339,6 +402,7 @@ class ChatPage(BasePage):
|
|||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.chat_panel.chatbot,
|
||||
self.state_follow_up,
|
||||
self.info_panel,
|
||||
self.state_plot_panel,
|
||||
self.state_retrieval_history,
|
||||
|
@ -372,6 +436,7 @@ class ChatPage(BasePage):
|
|||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.chat_panel.chatbot,
|
||||
self.state_follow_up,
|
||||
self.info_panel,
|
||||
self.state_plot_panel,
|
||||
self.state_retrieval_history,
|
||||
|
@ -423,6 +488,7 @@ class ChatPage(BasePage):
|
|||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.chat_panel.chatbot,
|
||||
self.state_follow_up,
|
||||
self.info_panel,
|
||||
self.state_plot_panel,
|
||||
self.state_retrieval_history,
|
||||
|
@ -501,13 +567,15 @@ class ChatPage(BasePage):
|
|||
)
|
||||
|
||||
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||
self.chat_suggestion.example.select(
|
||||
self.chat_suggestion.select_example,
|
||||
self.state_follow_up.select(
|
||||
self.chat_control.chat_suggestion.select_example,
|
||||
outputs=[self.chat_panel.text_input],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name):
|
||||
def submit_msg(
|
||||
self, chat_input, chat_history, user_id, conv_id, conv_name, chat_suggest
|
||||
):
|
||||
"""Submit a message to the chatbot"""
|
||||
if not chat_input:
|
||||
raise ValueError("Input is empty")
|
||||
|
@ -517,13 +585,20 @@ class ChatPage(BasePage):
|
|||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == id_)
|
||||
name = session.exec(statement).one().name
|
||||
suggestion = (
|
||||
session.exec(statement)
|
||||
.one()
|
||||
.data_source.get("chat_suggestions", [])
|
||||
)
|
||||
new_conv_id = id_
|
||||
conv_update = update
|
||||
new_conv_name = name
|
||||
new_chat_suggestion = suggestion
|
||||
else:
|
||||
new_conv_id = conv_id
|
||||
conv_update = gr.update()
|
||||
new_conv_name = conv_name
|
||||
new_chat_suggestion = chat_suggest
|
||||
|
||||
return (
|
||||
"",
|
||||
|
@ -531,6 +606,7 @@ class ChatPage(BasePage):
|
|||
new_conv_id,
|
||||
conv_update,
|
||||
new_conv_name,
|
||||
new_chat_suggestion,
|
||||
)
|
||||
|
||||
def toggle_delete(self, conv_id):
|
||||
|
@ -872,3 +948,118 @@ class ChatPage(BasePage):
|
|||
renamed = True
|
||||
|
||||
return new_name, renamed
|
||||
|
||||
def suggest_chat_conv(self, settings, chat_history):
|
||||
suggest_pipeline = SuggestFollowupQuesPipeline()
|
||||
suggest_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
|
||||
settings["reasoning.lang"], "English"
|
||||
)
|
||||
|
||||
updated = False
|
||||
|
||||
suggested_ques = []
|
||||
if len(chat_history) >= 1:
|
||||
suggested_resp = suggest_pipeline(chat_history).text
|
||||
if ques_res := re.search(r"\[(.*?)\]", re.sub("\n", "", suggested_resp)):
|
||||
ques_res_str = ques_res.group()
|
||||
try:
|
||||
suggested_ques = ast.literal_eval(ques_res_str)
|
||||
suggested_ques = [[x] for x in suggested_ques]
|
||||
updated = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return suggested_ques, updated
|
||||
|
||||
def backup_original_info(
|
||||
self, chat_history, settings, info_pannel, original_chat_history
|
||||
):
|
||||
original_chat_history.append(chat_history[-1])
|
||||
return original_chat_history, settings, info_pannel
|
||||
|
||||
def save_log(
|
||||
self,
|
||||
conversation_id,
|
||||
chat_history,
|
||||
settings,
|
||||
info_panel,
|
||||
original_chat_history,
|
||||
original_settings,
|
||||
original_info_panel,
|
||||
log_dir,
|
||||
):
|
||||
if not Path(log_dir).exists():
|
||||
Path(log_dir).mkdir(parents=True)
|
||||
|
||||
lock = FileLock(Path(log_dir) / ".lock")
|
||||
# get current date
|
||||
today = datetime.now()
|
||||
formatted_date = today.strftime("%d%m%Y_%H")
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||
result = session.exec(statement).one()
|
||||
|
||||
data_source = deepcopy(result.data_source)
|
||||
likes = data_source.get("likes", [])
|
||||
if not likes:
|
||||
return
|
||||
|
||||
feedback = likes[-1][-1]
|
||||
message_index = likes[-1][0]
|
||||
|
||||
current_message = chat_history[message_index[0]]
|
||||
original_message = original_chat_history[message_index[0]]
|
||||
is_original = all(
|
||||
[
|
||||
current_item == original_item
|
||||
for current_item, original_item in zip(
|
||||
current_message, original_message
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
dataframe = [
|
||||
[
|
||||
conversation_id,
|
||||
message_index,
|
||||
current_message[0],
|
||||
current_message[1],
|
||||
chat_history,
|
||||
settings,
|
||||
info_panel,
|
||||
feedback,
|
||||
is_original,
|
||||
original_message[1],
|
||||
original_chat_history,
|
||||
original_settings,
|
||||
original_info_panel,
|
||||
]
|
||||
]
|
||||
|
||||
with lock:
|
||||
log_file = Path(log_dir) / f"{formatted_date}_log.csv"
|
||||
is_log_file_exist = log_file.is_file()
|
||||
with open(log_file, "a") as f:
|
||||
writer = csv.writer(f)
|
||||
# write headers
|
||||
if not is_log_file_exist:
|
||||
writer.writerow(
|
||||
[
|
||||
"Conversation ID",
|
||||
"Message ID",
|
||||
"Question",
|
||||
"Answer",
|
||||
"Chat History",
|
||||
"Settings",
|
||||
"Evidences",
|
||||
"Feedback",
|
||||
"Original/ Rewritten",
|
||||
"Original Answer",
|
||||
"Original Chat History",
|
||||
"Original Settings",
|
||||
"Original Evidences",
|
||||
]
|
||||
)
|
||||
|
||||
writer.writerows(dataframe)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
|
@ -9,6 +10,7 @@ from sqlmodel import Session, or_, select
|
|||
import flowsettings
|
||||
|
||||
from ...utils.conversation import sync_retrieval_n_message
|
||||
from .chat_suggestion import ChatSuggestion
|
||||
from .common import STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -103,6 +105,10 @@ class ConversationControl(BasePage):
|
|||
visible=False,
|
||||
)
|
||||
|
||||
self.followup_suggestions = gr.State([])
|
||||
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||
self.chat_suggestion = ChatSuggestion(self._app)
|
||||
|
||||
def load_chat_history(self, user_id):
|
||||
"""Reload chat history"""
|
||||
|
||||
|
@ -220,6 +226,8 @@ class ConversationControl(BasePage):
|
|||
|
||||
chats = result.data_source.get("messages", [])
|
||||
|
||||
chat_suggestions = result.data_source.get("chat_suggestions", [])
|
||||
|
||||
retrieval_history: list[str] = result.data_source.get(
|
||||
"retrieval_messages", []
|
||||
)
|
||||
|
@ -243,6 +251,7 @@ class ConversationControl(BasePage):
|
|||
name = ""
|
||||
selected = {}
|
||||
chats = []
|
||||
chat_suggestions = []
|
||||
retrieval_history = []
|
||||
plot_history = []
|
||||
info_panel = ""
|
||||
|
@ -265,6 +274,7 @@ class ConversationControl(BasePage):
|
|||
id_,
|
||||
name,
|
||||
chats,
|
||||
chat_suggestions,
|
||||
info_panel,
|
||||
plot_data,
|
||||
retrieval_history,
|
||||
|
@ -311,6 +321,46 @@ class ConversationControl(BasePage):
|
|||
gr.update(visible=False),
|
||||
)
|
||||
|
||||
def update_chat_suggestions(
|
||||
self, conversation_id, new_suggestions, is_updated, user_id
|
||||
):
|
||||
"""Update the conversation's chat suggestions"""
|
||||
if not is_updated:
|
||||
return (
|
||||
gr.update(),
|
||||
conversation_id,
|
||||
gr.update(visible=False),
|
||||
)
|
||||
|
||||
if user_id is None:
|
||||
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||
return gr.update(), ""
|
||||
|
||||
if not conversation_id:
|
||||
gr.Warning("No conversation selected.")
|
||||
return gr.update(), ""
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||
result = session.exec(statement).one()
|
||||
|
||||
data_source = deepcopy(result.data_source)
|
||||
data_source["chat_suggestions"] = [
|
||||
[x] for x in new_suggestions.iloc[:, 0].tolist()
|
||||
]
|
||||
|
||||
result.data_source = data_source
|
||||
session.add(result)
|
||||
session.commit()
|
||||
|
||||
history = self.load_chat_history(user_id)
|
||||
gr.Info("Chat suggestions updated.")
|
||||
return (
|
||||
gr.update(choices=history),
|
||||
conversation_id,
|
||||
gr.update(visible=False),
|
||||
)
|
||||
|
||||
def _on_app_created(self):
|
||||
"""Reload the conversation once the app is created"""
|
||||
self._app.app.load(
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import logging
|
||||
|
||||
from ktem.llms.manager import llms
|
||||
|
||||
from kotaemon.base import AIMessage, BaseComponent, Document, HumanMessage, Node
|
||||
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuggestFollowupQuesPipeline(BaseComponent):
|
||||
"""Suggest a list of follow-up questions based on the chat history."""
|
||||
|
||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||
SUGGEST_QUESTIONS_PROMPT_TEMPLATE = (
|
||||
"Based on the chat history above. "
|
||||
"your task is to generate 3 to 5 relevant follow-up questions. "
|
||||
"These questions should be simple, clear, "
|
||||
"and designed to guide the conversation further. "
|
||||
"Ensure that the questions are open-ended to encourage detailed responses. "
|
||||
"Respond in JSON format with 'questions' key. "
|
||||
"Answer using the language {lang} same as the question. "
|
||||
"If the question uses Chinese, the answer should be in Chinese.\n"
|
||||
)
|
||||
prompt_template: str = SUGGEST_QUESTIONS_PROMPT_TEMPLATE
|
||||
extra_prompt: str = """Example of valid response:
|
||||
```json
|
||||
{
|
||||
"questions": ["the weather is good", "what's your favorite city"]
|
||||
}
|
||||
```"""
|
||||
lang: str = "English"
|
||||
|
||||
def run(self, chat_history: list[tuple[str, str]]) -> Document:
|
||||
prompt_template = PromptTemplate(self.prompt_template)
|
||||
prompt = prompt_template.populate(lang=self.lang) + self.extra_prompt
|
||||
|
||||
messages = []
|
||||
for human, ai in chat_history[-3:]:
|
||||
messages.append(HumanMessage(content=human))
|
||||
messages.append(AIMessage(content=ai))
|
||||
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
|
||||
return self.llm(messages)
|
Loading…
Reference in New Issue
Block a user