Allow file selector to be disabled (#36)

* Allow file selector to be disabled

* Update docs and variable names
This commit is contained in:
Duc Nguyen (john) 2024-04-16 18:43:56 +07:00 committed by GitHub
parent e19893a509
commit 1b2082a140
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 91 additions and 45 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 138 KiB

After

Width:  |  Height:  |  Size: 73 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

After

Width:  |  Height:  |  Size: 40 KiB

View File

@ -128,9 +128,12 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions:
1. Conversation Settings Panel 1. Conversation Settings Panel
- Here you can select, create, rename, and delete conversations. - Here you can select, create, rename, and delete conversations.
- By default, a new conversation is created automatically if no conversation is selected. - By default, a new conversation is created automatically if no conversation is selected.
- Below that you have the file index, where you can select which files to retrieve references from. - Below that you have the file index, where you can choose whether to disable, select all files, or select which files to retrieve references from.
- These are the files you have uploaded to the application from the `File Index` tab. - If you choose "Disabled", no files will be considered as context during chat.
- If no file is selected, all files will be used. - If you choose "Search All", all files will be considered during chat.
- If you choose "Select", a dropdown will appear for you to select the
files to be considered during chat. If no files are selected, then no
files will be considered during chat.
2. Chat Panel 2. Chat Panel
- This is where you can chat with the chatbot. - This is where you can chat with the chatbot.
3. Information panel 3. Information panel

View File

@ -128,9 +128,12 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions:
1. Conversation Settings Panel 1. Conversation Settings Panel
- Here you can select, create, rename, and delete conversations. - Here you can select, create, rename, and delete conversations.
- By default, a new conversation is created automatically if no conversation is selected. - By default, a new conversation is created automatically if no conversation is selected.
- Below that you have the file index, where you can select which files to retrieve references from. - Below that you have the file index, where you can choose whether to disable, select all files, or select which files to retrieve references from.
- These are the files you have uploaded to the application from the `File Index` tab. - If you choose "Disabled", no files will be considered as context during chat.
- If no file is selected, all files will be used. - If you choose "Search All", all files will be considered during chat.
- If you choose "Select", a dropdown will appear for you to select the
files to be considered during chat. If no files are selected, then no
files will be considered during chat.
2. Chat Panel 2. Chat Panel
- This is where you can chat with the chatbot. - This is where you can chat with the chatbot.
3. Information panel 3. Information panel

View File

@ -67,31 +67,36 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
documents documents
get_extra_table: if True, for each retrieved document, the pipeline will look get_extra_table: if True, for each retrieved document, the pipeline will look
for surrounding tables (e.g. within the page) for surrounding tables (e.g. within the page)
top_k: number of documents to retrieve
mmr: whether to use mmr to re-rank the documents
""" """
vector_retrieval: VectorRetrieval = VectorRetrieval.withx() vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
reranker: BaseReranking reranker: BaseReranking
get_extra_table: bool = False get_extra_table: bool = False
mmr: bool = False
top_k: int = 5
def run( def run(
self, self,
text: str, text: str,
top_k: int = 5,
mmr: bool = False,
doc_ids: Optional[list[str]] = None, doc_ids: Optional[list[str]] = None,
*args,
**kwargs,
) -> list[RetrievedDocument]: ) -> list[RetrievedDocument]:
"""Retrieve document excerpts similar to the text """Retrieve document excerpts similar to the text
Args: Args:
text: the text to retrieve similar documents text: the text to retrieve similar documents
top_k: number of documents to retrieve
mmr: whether to use mmr to re-rank the documents
doc_ids: list of document ids to constraint the retrieval doc_ids: list of document ids to constraint the retrieval
""" """
if not doc_ids:
logger.info(f"Skip retrieval because of no selected files: {self}")
return []
Index = self._Index Index = self._Index
kwargs = {} retrieval_kwargs = {}
if doc_ids:
with Session(engine) as session: with Session(engine) as session:
stmt = select(Index).where( stmt = select(Index).where(
Index.relation_type == "vector", Index.relation_type == "vector",
@ -100,7 +105,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
results = session.execute(stmt) results = session.execute(stmt)
vs_ids = [r[0].target_id for r in results.all()] vs_ids = [r[0].target_id for r in results.all()]
kwargs["filters"] = MetadataFilters( retrieval_kwargs["filters"] = MetadataFilters(
filters=[ filters=[
MetadataFilter( MetadataFilter(
key="doc_id", key="doc_id",
@ -112,13 +117,13 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
condition=FilterCondition.OR, condition=FilterCondition.OR,
) )
if mmr: if self.mmr:
# TODO: double check that llama-index MMR works correctly # TODO: double check that llama-index MMR works correctly
kwargs["mode"] = VectorStoreQueryMode.MMR retrieval_kwargs["mode"] = VectorStoreQueryMode.MMR
kwargs["mmr_threshold"] = 0.5 retrieval_kwargs["mmr_threshold"] = 0.5
# rerank # rerank
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs) docs = self.vector_retrieval(text=text, top_k=self.top_k, **retrieval_kwargs)
if docs and self.get_from_path("reranker"): if docs and self.get_from_path("reranker"):
docs = self.reranker(docs, query=text) docs = self.reranker(docs, query=text)
@ -221,6 +226,8 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
retriever = cls( retriever = cls(
get_extra_table=user_settings["prioritize_table"], get_extra_table=user_settings["prioritize_table"],
reranker=user_settings["reranking_llm"], reranker=user_settings["reranking_llm"],
top_k=user_settings["num_retrieval"],
mmr=user_settings["mmr"],
) )
if not user_settings["use_reranking"]: if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore retriever.reranker = None # type: ignore
@ -228,11 +235,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
retriever.vector_retrieval.embedding = embedding_models_manager[ retriever.vector_retrieval.embedding = embedding_models_manager[
index_settings.get("embedding", embedding_models_manager.get_default_name()) index_settings.get("embedding", embedding_models_manager.get_default_name())
] ]
kwargs = { kwargs = {".doc_ids": selected}
".top_k": int(user_settings["num_retrieval"]),
".mmr": user_settings["mmr"],
".doc_ids": selected,
}
retriever.set_run(kwargs, temp=True) retriever.set_run(kwargs, temp=True)
return retriever return retriever

View File

@ -512,21 +512,56 @@ class FileSelector(BasePage):
self._index = index self._index = index
self.on_building_ui() self.on_building_ui()
def default(self):
return "disabled", []
def on_building_ui(self): def on_building_ui(self):
default_mode, default_selector = self.default()
self.mode = gr.Radio(
value=default_mode,
choices=[
("Disabled", "disabled"),
("Search All", "all"),
("Select", "select"),
],
container=False,
)
self.selector = gr.Dropdown( self.selector = gr.Dropdown(
label="Files", label="Files",
choices=[], choices=default_selector,
multiselect=True, multiselect=True,
container=False, container=False,
interactive=True, interactive=True,
visible=False,
)
def on_register_events(self):
self.mode.change(
fn=lambda mode: gr.update(visible=mode == "select"),
inputs=[self.mode],
outputs=[self.selector],
) )
def as_gradio_component(self): def as_gradio_component(self):
return self.selector return [self.mode, self.selector]
def get_selected_ids(self, selected): def get_selected_ids(self, components):
mode, selected = components[0], components[1]
if mode == "disabled":
return []
elif mode == "select":
return selected return selected
file_ids = []
with Session(engine) as session:
statement = select(self._index._resources["Source"].id)
results = session.execute(statement).all()
for (id,) in results:
file_ids.append(id)
return file_ids
def load_files(self, selected_files): def load_files(self, selected_files):
options = [] options = []
available_ids = [] available_ids = []

View File

@ -52,9 +52,11 @@ class ChatPage(BasePage):
len(self._indices_input) + len(gr_index), len(self._indices_input) + len(gr_index),
) )
) )
index.default_selector = index_ui.default()
self._indices_input.extend(gr_index) self._indices_input.extend(gr_index)
else: else:
index.selector = len(self._indices_input) index.selector = len(self._indices_input)
index.default_selector = index_ui.default()
self._indices_input.append(gr_index) self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui) setattr(self, f"_index_{index.id}", index_ui)

View File

@ -156,9 +156,9 @@ class ConversationControl(BasePage):
if index.selector is None: if index.selector is None:
continue continue
if isinstance(index.selector, int): if isinstance(index.selector, int):
indices.append(selected.get(str(index.id), [])) indices.append(selected.get(str(index.id), index.default_selector))
if isinstance(index.selector, tuple): if isinstance(index.selector, tuple):
indices.extend(selected.get(str(index.id), [[]] * len(index.selector))) indices.extend(selected.get(str(index.id), index.default_selector))
return id_, id_, name, chats, info_panel, state, *indices return id_, id_, name, chats, info_panel, state, *indices