feat: add quick file selection upon tagging on Chat input (#533) bump:patch
* fix: improve inline citation logics without rag * fix: improve explanation for citation options * feat: add quick file selection on Chat input
This commit is contained in:
parent
f15abdbb23
commit
ab6b3fc529
|
@ -152,6 +152,20 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
|
|||
def replace_citation_with_link(self, answer: str):
|
||||
# Define the regex pattern to match 【number】
|
||||
pattern = r"【\d+】"
|
||||
|
||||
# Regular expression to match merged citations
|
||||
multi_pattern = r"【([\d,\s]+)】"
|
||||
|
||||
# Function to replace merged citations with independent ones
|
||||
def split_citations(match):
|
||||
# Extract the numbers, split by comma, and create individual citations
|
||||
numbers = match.group(1).split(",")
|
||||
return "".join(f"【{num.strip()}】" for num in numbers)
|
||||
|
||||
# Replace merged citations in the text
|
||||
answer = re.sub(multi_pattern, split_citations, answer)
|
||||
|
||||
# Find all citations in the answer
|
||||
matches = re.finditer(pattern, answer)
|
||||
|
||||
matched_citations = set()
|
||||
|
@ -240,25 +254,30 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
|
|||
# try streaming first
|
||||
print("Trying LLM streaming")
|
||||
for out_msg in self.llm.stream(messages):
|
||||
if START_ANSWER in output:
|
||||
if not final_answer:
|
||||
try:
|
||||
left_over_answer = output.split(START_ANSWER)[1].lstrip()
|
||||
except IndexError:
|
||||
left_over_answer = ""
|
||||
if left_over_answer:
|
||||
out_msg.text = left_over_answer + out_msg.text
|
||||
if evidence:
|
||||
if START_ANSWER in output:
|
||||
if not final_answer:
|
||||
try:
|
||||
left_over_answer = output.split(START_ANSWER)[
|
||||
1
|
||||
].lstrip()
|
||||
except IndexError:
|
||||
left_over_answer = ""
|
||||
if left_over_answer:
|
||||
out_msg.text = left_over_answer + out_msg.text
|
||||
|
||||
final_answer += (
|
||||
out_msg.text.lstrip() if not final_answer else out_msg.text
|
||||
)
|
||||
final_answer += (
|
||||
out_msg.text.lstrip() if not final_answer else out_msg.text
|
||||
)
|
||||
yield Document(channel="chat", content=out_msg.text)
|
||||
|
||||
# check for the edge case of citation list is repeated
|
||||
# with smaller LLMs
|
||||
if START_CITATION in out_msg.text:
|
||||
break
|
||||
else:
|
||||
yield Document(channel="chat", content=out_msg.text)
|
||||
|
||||
# check for the edge case of citation list is repeated
|
||||
# with smaller LLMs
|
||||
if START_CITATION in out_msg.text:
|
||||
break
|
||||
|
||||
output += out_msg.text
|
||||
logprobs += out_msg.logprobs
|
||||
except NotImplementedError:
|
||||
|
@ -289,8 +308,10 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
|
|||
|
||||
# yield the final answer
|
||||
final_answer = self.replace_citation_with_link(final_answer)
|
||||
yield Document(channel="chat", content=None)
|
||||
yield Document(channel="chat", content=final_answer)
|
||||
|
||||
if final_answer:
|
||||
yield Document(channel="chat", content=None)
|
||||
yield Document(channel="chat", content=final_answer)
|
||||
|
||||
return answer
|
||||
|
||||
|
|
|
@ -26,6 +26,9 @@ def find_start_end_phrase(
|
|||
matches = []
|
||||
matched_length = 0
|
||||
for sentence in [start_phrase, end_phrase]:
|
||||
if sentence is None:
|
||||
continue
|
||||
|
||||
match = SequenceMatcher(
|
||||
None, sentence, context, autojunk=False
|
||||
).find_longest_match()
|
||||
|
|
|
@ -177,6 +177,10 @@ class BaseApp:
|
|||
"<script>"
|
||||
f"{self._svg_js}"
|
||||
"</script>"
|
||||
"<script type='module' "
|
||||
"src='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.min.js'>" # noqa
|
||||
"</script>"
|
||||
"<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.css'/>" # noqa
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
|
|
|
@ -365,3 +365,20 @@ details.evidence {
|
|||
color: #10b981;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
/* pop-up for file tag in chat input*/
|
||||
.tribute-container ul {
|
||||
background-color: var(--background-fill-primary) !important;
|
||||
color: var(--body-text-color) !important;
|
||||
font-family: var(--font);
|
||||
font-size: var(--text-md);
|
||||
}
|
||||
|
||||
.tribute-container li.highlight {
|
||||
background-color: var(--border-color-primary) !important;
|
||||
}
|
||||
|
||||
/* a fix for flickering background in Gradio DataFrame */
|
||||
tbody:not(.row_odd) {
|
||||
background: var(--table-even-background-fill);
|
||||
}
|
||||
|
|
|
@ -29,6 +29,25 @@ function() {
|
|||
}
|
||||
"""
|
||||
|
||||
update_file_list_js = """
|
||||
function(file_list) {
|
||||
var values = [];
|
||||
for (var i = 0; i < file_list.length; i++) {
|
||||
values.push({
|
||||
key: file_list[i][0],
|
||||
value: '"' + file_list[i][0] + '"',
|
||||
});
|
||||
}
|
||||
var tribute = new Tribute({
|
||||
values: values,
|
||||
noMatchTemplate: "",
|
||||
allowSpaces: true,
|
||||
})
|
||||
input_box = document.querySelector('#chat-input textarea');
|
||||
tribute.attach(input_box);
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class File(gr.File):
|
||||
"""Subclass from gr.File to maintain the original filename
|
||||
|
@ -1429,6 +1448,10 @@ class FileSelector(BasePage):
|
|||
visible=False,
|
||||
)
|
||||
self.selector_user_id = gr.State(value=user_id)
|
||||
self.selector_choices = gr.JSON(
|
||||
value=[],
|
||||
visible=False,
|
||||
)
|
||||
|
||||
def on_register_events(self):
|
||||
self.mode.change(
|
||||
|
@ -1436,6 +1459,14 @@ class FileSelector(BasePage):
|
|||
inputs=[self.mode, self._app.user_id],
|
||||
outputs=[self.selector, self.selector_user_id],
|
||||
)
|
||||
# attach special event for the first index
|
||||
if self._index.id == 1:
|
||||
self.selector_choices.change(
|
||||
fn=None,
|
||||
inputs=[self.selector_choices],
|
||||
js=update_file_list_js,
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
def as_gradio_component(self):
|
||||
return [self.mode, self.selector, self.selector_user_id]
|
||||
|
@ -1468,7 +1499,7 @@ class FileSelector(BasePage):
|
|||
available_ids = []
|
||||
if user_id is None:
|
||||
# not signed in
|
||||
return gr.update(value=selected_files, choices=options)
|
||||
return gr.update(value=selected_files, choices=options), options
|
||||
|
||||
with Session(engine) as session:
|
||||
# get file list from Source table
|
||||
|
@ -1501,13 +1532,13 @@ class FileSelector(BasePage):
|
|||
each for each in selected_files if each in available_ids_set
|
||||
]
|
||||
|
||||
return gr.update(value=selected_files, choices=options)
|
||||
return gr.update(value=selected_files, choices=options), options
|
||||
|
||||
def _on_app_created(self):
|
||||
self._app.app.load(
|
||||
self.load_files,
|
||||
inputs=[self.selector, self._app.user_id],
|
||||
outputs=[self.selector],
|
||||
outputs=[self.selector, self.selector_choices],
|
||||
)
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
|
@ -1516,26 +1547,18 @@ class FileSelector(BasePage):
|
|||
definition={
|
||||
"fn": self.load_files,
|
||||
"inputs": [self.selector, self._app.user_id],
|
||||
"outputs": [self.selector],
|
||||
"outputs": [self.selector, self.selector_choices],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
if self._app.f_user_management:
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": self.load_files,
|
||||
"inputs": [self.selector, self._app.user_id],
|
||||
"outputs": [self.selector],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
self._app.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
"fn": self.load_files,
|
||||
"inputs": [self.selector, self._app.user_id],
|
||||
"outputs": [self.selector],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
for event_name in ["onSignIn", "onSignOut"]:
|
||||
self._app.subscribe_event(
|
||||
name=event_name,
|
||||
definition={
|
||||
"fn": self.load_files,
|
||||
"inputs": [self.selector, self._app.user_id],
|
||||
"outputs": [self.selector, self.selector_choices],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
|
|
@ -8,7 +8,7 @@ import gradio as gr
|
|||
from ktem.app import BasePage
|
||||
from ktem.components import reasonings
|
||||
from ktem.db.models import Conversation, engine
|
||||
from ktem.index.file.ui import File
|
||||
from ktem.index.file.ui import File, chat_input_focus_js
|
||||
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
|
||||
SuggestConvNamePipeline,
|
||||
)
|
||||
|
@ -22,7 +22,7 @@ from theflow.settings import settings as flowsettings
|
|||
from kotaemon.base import Document
|
||||
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
||||
|
||||
from ...utils import SUPPORTED_LANGUAGE_MAP
|
||||
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex
|
||||
from .chat_panel import ChatPanel
|
||||
from .common import STATE
|
||||
from .control import ConversationControl
|
||||
|
@ -113,6 +113,7 @@ class ChatPage(BasePage):
|
|||
self.state_plot_history = gr.State([])
|
||||
self.state_plot_panel = gr.State(None)
|
||||
self.state_follow_up = gr.State(None)
|
||||
self.first_selector_choices = gr.State(None)
|
||||
|
||||
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
|
||||
self.chat_control = ConversationControl(self._app)
|
||||
|
@ -130,6 +131,11 @@ class ChatPage(BasePage):
|
|||
):
|
||||
index_ui.render()
|
||||
gr_index = index_ui.as_gradio_component()
|
||||
|
||||
# get the file selector choices for the first index
|
||||
if index_id == 0:
|
||||
self.first_selector_choices = index_ui.selector_choices
|
||||
|
||||
if gr_index:
|
||||
if isinstance(gr_index, list):
|
||||
index.selector = tuple(
|
||||
|
@ -272,6 +278,7 @@ class ChatPage(BasePage):
|
|||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation_rn,
|
||||
self.state_follow_up,
|
||||
self.first_selector_choices,
|
||||
],
|
||||
outputs=[
|
||||
self.chat_panel.text_input,
|
||||
|
@ -280,6 +287,9 @@ class ChatPage(BasePage):
|
|||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.state_follow_up,
|
||||
# file selector from the first index
|
||||
self._indices_input[0],
|
||||
self._indices_input[1],
|
||||
],
|
||||
concurrency_limit=20,
|
||||
show_progress="hidden",
|
||||
|
@ -426,6 +436,10 @@ class ChatPage(BasePage):
|
|||
fn=self._json_to_plot,
|
||||
inputs=self.state_plot_panel,
|
||||
outputs=self.plot_panel,
|
||||
).then(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
js=chat_input_focus_js,
|
||||
)
|
||||
|
||||
self.chat_control.btn_del.click(
|
||||
|
@ -516,7 +530,12 @@ class ChatPage(BasePage):
|
|||
lambda: self.toggle_delete(""),
|
||||
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
||||
).then(
|
||||
fn=None, inputs=None, outputs=None, js=pdfview_js
|
||||
fn=lambda: True,
|
||||
inputs=None,
|
||||
outputs=[self._preview_links],
|
||||
js=pdfview_js,
|
||||
).then(
|
||||
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
|
||||
)
|
||||
|
||||
# evidence display on message selection
|
||||
|
@ -535,7 +554,12 @@ class ChatPage(BasePage):
|
|||
inputs=self.state_plot_panel,
|
||||
outputs=self.plot_panel,
|
||||
).then(
|
||||
fn=None, inputs=None, outputs=None, js=pdfview_js
|
||||
fn=lambda: True,
|
||||
inputs=None,
|
||||
outputs=[self._preview_links],
|
||||
js=pdfview_js,
|
||||
).then(
|
||||
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
|
||||
)
|
||||
|
||||
self.chat_control.cb_is_public.change(
|
||||
|
@ -585,7 +609,14 @@ class ChatPage(BasePage):
|
|||
)
|
||||
|
||||
def submit_msg(
|
||||
self, chat_input, chat_history, user_id, conv_id, conv_name, chat_suggest
|
||||
self,
|
||||
chat_input,
|
||||
chat_history,
|
||||
user_id,
|
||||
conv_id,
|
||||
conv_name,
|
||||
chat_suggest,
|
||||
first_selector_choices,
|
||||
):
|
||||
"""Submit a message to the chatbot"""
|
||||
if not chat_input:
|
||||
|
@ -593,6 +624,24 @@ class ChatPage(BasePage):
|
|||
|
||||
chat_input_text = chat_input.get("text", "")
|
||||
|
||||
# get all file names with pattern @"filename" in input_str
|
||||
file_names, chat_input_text = get_file_names_regex(chat_input_text)
|
||||
first_selector_choices_map = {
|
||||
item[0]: item[1] for item in first_selector_choices
|
||||
}
|
||||
file_ids = []
|
||||
|
||||
if file_names:
|
||||
for file_name in file_names:
|
||||
file_id = first_selector_choices_map.get(file_name)
|
||||
if file_id:
|
||||
file_ids.append(file_id)
|
||||
|
||||
if file_ids:
|
||||
selector_output = ["select", file_ids]
|
||||
else:
|
||||
selector_output = [gr.update(), gr.update()]
|
||||
|
||||
# check if regen mode is active
|
||||
if chat_input_text:
|
||||
chat_history = chat_history + [(chat_input_text, None)]
|
||||
|
@ -620,14 +669,14 @@ class ChatPage(BasePage):
|
|||
new_conv_name = conv_name
|
||||
new_chat_suggestion = chat_suggest
|
||||
|
||||
return (
|
||||
return [
|
||||
{},
|
||||
chat_history,
|
||||
new_conv_id,
|
||||
conv_update,
|
||||
new_conv_name,
|
||||
new_chat_suggestion,
|
||||
)
|
||||
] + selector_output
|
||||
|
||||
def toggle_delete(self, conv_id):
|
||||
if conv_id:
|
||||
|
|
|
@ -25,7 +25,7 @@ class ChatPanel(BasePage):
|
|||
interactive=True,
|
||||
scale=20,
|
||||
file_count="multiple",
|
||||
placeholder="Chat input",
|
||||
placeholder="Type a message (or tag a file with @filename)",
|
||||
container=False,
|
||||
show_label=False,
|
||||
elem_id="chat-input",
|
||||
|
|
|
@ -410,7 +410,11 @@ class FullQAPipeline(BaseReasoning):
|
|||
"name": "Citation style",
|
||||
"value": "highlight",
|
||||
"component": "radio",
|
||||
"choices": ["highlight", "inline", "off"],
|
||||
"choices": [
|
||||
("highlight (long answer)", "highlight"),
|
||||
("inline (precise answer)", "inline"),
|
||||
("off", "off"),
|
||||
],
|
||||
},
|
||||
"create_mindmap": {
|
||||
"name": "Create Mindmap",
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from .conversation import get_file_names_regex
|
||||
from .lang import SUPPORTED_LANGUAGE_MAP
|
||||
|
||||
__all__ = ["SUPPORTED_LANGUAGE_MAP"]
|
||||
__all__ = ["SUPPORTED_LANGUAGE_MAP", "get_file_names_regex"]
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import re
|
||||
|
||||
|
||||
def sync_retrieval_n_message(
|
||||
messages: list[list[str]],
|
||||
retrievals: list[str],
|
||||
|
@ -16,5 +19,15 @@ def sync_retrieval_n_message(
|
|||
return retrievals
|
||||
|
||||
|
||||
def get_file_names_regex(input_str: str) -> tuple[list[str], str]:
|
||||
# get all file names with pattern @"filename" in input_str
|
||||
# also remove these file names from input_str
|
||||
pattern = r'@"([^"]*)"'
|
||||
matches = re.findall(pattern, input_str)
|
||||
input_str = re.sub(pattern, "", input_str).strip()
|
||||
|
||||
return matches, input_str
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(sync_retrieval_n_message([[""], [""], [""]], []))
|
||||
|
|
Loading…
Reference in New Issue
Block a user