fix: vastly improve chat UI responsiveness by reordering Gradio events (#360) bump:patch

This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-10-04 17:15:49 +07:00 committed by GitHub
parent b01fc217b2
commit dfd00fe752
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,12 +1,8 @@
import asyncio import asyncio
import csv
from copy import deepcopy from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Optional from typing import Optional
import gradio as gr import gradio as gr
from filelock import FileLock
from ktem.app import BasePage from ktem.app import BasePage
from ktem.components import reasonings from ktem.components import reasonings
from ktem.db.models import Conversation, engine from ktem.db.models import Conversation, engine
@ -38,6 +34,7 @@ function() {
for (var i = 0; i < links.length; i++) { for (var i = 0; i < links.length; i++) {
links[i].onclick = openModal; links[i].onclick = openModal;
} }
return [links.length]
} }
""" """
@ -48,19 +45,18 @@ class ChatPage(BasePage):
self._indices_input = [] self._indices_input = []
self.on_building_ui() self.on_building_ui()
self._preview_links = gr.State(value=None)
self._reasoning_type = gr.State(value=None) self._reasoning_type = gr.State(value=None)
self._llm_type = gr.State(value=None) self._llm_type = gr.State(value=None)
self._conversation_renamed = gr.State(value=False) self._conversation_renamed = gr.State(value=False)
self.info_panel_expanded = gr.State(value=True) self._info_panel_expanded = gr.State(value=True)
def on_building_ui(self): def on_building_ui(self):
with gr.Row(): with gr.Row():
self.state_chat = gr.State(STATE) self.state_chat = gr.State(STATE)
self.state_retrieval_history = gr.State([]) self.state_retrieval_history = gr.State([])
self.state_chat_history = gr.State([])
self.state_plot_history = gr.State([]) self.state_plot_history = gr.State([])
self.state_settings = gr.State({})
self.state_info_panel = gr.State("")
self.state_plot_panel = gr.State(None) self.state_plot_panel = gr.State(None)
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column: with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
@ -203,37 +199,11 @@ class ChatPage(BasePage):
], ],
concurrency_limit=20, concurrency_limit=20,
show_progress="minimal", show_progress="minimal",
).success(
fn=self.backup_original_info,
inputs=[
self.chat_panel.chatbot,
self._app.settings_state,
self.info_panel,
self.state_chat_history,
],
outputs=[
self.state_chat_history,
self.state_settings,
self.state_info_panel,
],
).then( ).then(
fn=self.persist_data_source, fn=lambda: True,
inputs=[ inputs=None,
self.chat_control.conversation_id, outputs=[self._preview_links],
self._app.user_id, js=pdfview_js,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
).success( ).success(
fn=self.check_and_suggest_name_conv, fn=self.check_and_suggest_name_conv,
inputs=self.chat_panel.chatbot, inputs=self.chat_panel.chatbot,
@ -256,7 +226,23 @@ class ChatPage(BasePage):
], ],
show_progress="hidden", show_progress="hidden",
).then( ).then(
fn=None, inputs=None, outputs=None, js=pdfview_js fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
) )
self.chat_panel.regen_btn.click( self.chat_panel.regen_btn.click(
@ -281,23 +267,10 @@ class ChatPage(BasePage):
concurrency_limit=20, concurrency_limit=20,
show_progress="minimal", show_progress="minimal",
).then( ).then(
fn=self.persist_data_source, fn=lambda: True,
inputs=[ inputs=None,
self.chat_control.conversation_id, outputs=[self._preview_links],
self._app.user_id, js=pdfview_js,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
).success( ).success(
fn=self.check_and_suggest_name_conv, fn=self.check_and_suggest_name_conv,
inputs=self.chat_panel.chatbot, inputs=self.chat_panel.chatbot,
@ -320,7 +293,23 @@ class ChatPage(BasePage):
], ],
show_progress="hidden", show_progress="hidden",
).then( ).then(
fn=None, inputs=None, outputs=None, js=pdfview_js fn=self.persist_data_source,
inputs=[
self.chat_control.conversation_id,
self._app.user_id,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_panel.chatbot,
self.state_chat,
]
+ self._indices_input,
outputs=[
self.state_retrieval_history,
self.state_plot_history,
],
concurrency_limit=20,
) )
self.chat_control.btn_info_expand.click( self.chat_control.btn_info_expand.click(
@ -328,29 +317,15 @@ class ChatPage(BasePage):
gr.update(scale=INFO_PANEL_SCALES[is_expanded]), gr.update(scale=INFO_PANEL_SCALES[is_expanded]),
not is_expanded, not is_expanded,
), ),
inputs=self.info_panel_expanded, inputs=self._info_panel_expanded,
outputs=[self.info_column, self.info_panel_expanded], outputs=[self.info_column, self._info_panel_expanded],
) )
self.chat_panel.chatbot.like( self.chat_panel.chatbot.like(
fn=self.is_liked, fn=self.is_liked,
inputs=[self.chat_control.conversation_id], inputs=[self.chat_control.conversation_id],
outputs=None, outputs=None,
).success(
self.save_log,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.info_panel,
self.state_chat_history,
self.state_settings,
self.state_info_panel,
gr.State(getattr(flowsettings, "KH_APP_DATA_DIR", "logs")),
],
outputs=None,
) )
self.chat_control.btn_new.click( self.chat_control.btn_new.click(
self.chat_control.new_conv, self.chat_control.new_conv,
inputs=self._app.user_id, inputs=self._app.user_id,
@ -701,7 +676,15 @@ class ChatPage(BasePage):
def message_selected(self, retrieval_history, plot_history, msg: gr.SelectData): def message_selected(self, retrieval_history, plot_history, msg: gr.SelectData):
index = msg.index[0] index = msg.index[0]
return retrieval_history[index], plot_history[index] try:
retrieval_content, plot_content = (
retrieval_history[index],
plot_history[index],
)
except IndexError:
retrieval_content, plot_content = gr.update(), None
return retrieval_content, plot_content
def create_pipeline( def create_pipeline(
self, self,
@ -889,96 +872,3 @@ class ChatPage(BasePage):
renamed = True renamed = True
return new_name, renamed return new_name, renamed
def backup_original_info(
self, chat_history, settings, info_pannel, original_chat_history
):
original_chat_history.append(chat_history[-1])
return original_chat_history, settings, info_pannel
def save_log(
self,
conversation_id,
chat_history,
settings,
info_panel,
original_chat_history,
original_settings,
original_info_panel,
log_dir,
):
if not Path(log_dir).exists():
Path(log_dir).mkdir(parents=True)
lock = FileLock(Path(log_dir) / ".lock")
# get current date
today = datetime.now()
formatted_date = today.strftime("%d%m%Y_%H")
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()
data_source = deepcopy(result.data_source)
likes = data_source.get("likes", [])
if not likes:
return
feedback = likes[-1][-1]
message_index = likes[-1][0]
current_message = chat_history[message_index[0]]
original_message = original_chat_history[message_index[0]]
is_original = all(
[
current_item == original_item
for current_item, original_item in zip(
current_message, original_message
)
]
)
dataframe = [
[
conversation_id,
message_index,
current_message[0],
current_message[1],
chat_history,
settings,
info_panel,
feedback,
is_original,
original_message[1],
original_chat_history,
original_settings,
original_info_panel,
]
]
with lock:
log_file = Path(log_dir) / f"{formatted_date}_log.csv"
is_log_file_exist = log_file.is_file()
with open(log_file, "a") as f:
writer = csv.writer(f)
# write headers
if not is_log_file_exist:
writer.writerow(
[
"Conversation ID",
"Message ID",
"Question",
"Answer",
"Chat History",
"Settings",
"Evidences",
"Feedback",
"Original/ Rewritten",
"Original Answer",
"Original Chat History",
"Original Settings",
"Original Evidences",
]
)
writer.writerows(dataframe)