Allow file selector to be disabled (#36)
* Allow file selector to be disabled * Update docs and variable names
This commit is contained in:
parent
e19893a509
commit
1b2082a140
Binary file not shown.
Before Width: | Height: | Size: 138 KiB After Width: | Height: | Size: 73 KiB |
Binary file not shown.
Before Width: | Height: | Size: 66 KiB After Width: | Height: | Size: 40 KiB |
|
@ -128,9 +128,12 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions:
|
||||||
1. Conversation Settings Panel
|
1. Conversation Settings Panel
|
||||||
- Here you can select, create, rename, and delete conversations.
|
- Here you can select, create, rename, and delete conversations.
|
||||||
- By default, a new conversation is created automatically if no conversation is selected.
|
- By default, a new conversation is created automatically if no conversation is selected.
|
||||||
- Below that you have the file index, where you can select which files to retrieve references from.
|
- Below that you have the file index, where you can choose whether to disable, select all files, or select which files to retrieve references from.
|
||||||
- These are the files you have uploaded to the application from the `File Index` tab.
|
- If you choose "Disabled", no files will be considered as context during chat.
|
||||||
- If no file is selected, all files will be used.
|
- If you choose "Search All", all files will be considered during chat.
|
||||||
|
- If you choose "Select", a dropdown will appear for you to select the
|
||||||
|
files to be considered during chat. If no files are selected, then no
|
||||||
|
files will be considered during chat.
|
||||||
2. Chat Panel
|
2. Chat Panel
|
||||||
- This is where you can chat with the chatbot.
|
- This is where you can chat with the chatbot.
|
||||||
3. Information panel
|
3. Information panel
|
||||||
|
|
|
@ -128,9 +128,12 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions:
|
||||||
1. Conversation Settings Panel
|
1. Conversation Settings Panel
|
||||||
- Here you can select, create, rename, and delete conversations.
|
- Here you can select, create, rename, and delete conversations.
|
||||||
- By default, a new conversation is created automatically if no conversation is selected.
|
- By default, a new conversation is created automatically if no conversation is selected.
|
||||||
- Below that you have the file index, where you can select which files to retrieve references from.
|
- Below that you have the file index, where you can choose whether to disable, select all files, or select which files to retrieve references from.
|
||||||
- These are the files you have uploaded to the application from the `File Index` tab.
|
- If you choose "Disabled", no files will be considered as context during chat.
|
||||||
- If no file is selected, all files will be used.
|
- If you choose "Search All", all files will be considered during chat.
|
||||||
|
- If you choose "Select", a dropdown will appear for you to select the
|
||||||
|
files to be considered during chat. If no files are selected, then no
|
||||||
|
files will be considered during chat.
|
||||||
2. Chat Panel
|
2. Chat Panel
|
||||||
- This is where you can chat with the chatbot.
|
- This is where you can chat with the chatbot.
|
||||||
3. Information panel
|
3. Information panel
|
||||||
|
|
|
@ -67,58 +67,63 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
documents
|
documents
|
||||||
get_extra_table: if True, for each retrieved document, the pipeline will look
|
get_extra_table: if True, for each retrieved document, the pipeline will look
|
||||||
for surrounding tables (e.g. within the page)
|
for surrounding tables (e.g. within the page)
|
||||||
|
top_k: number of documents to retrieve
|
||||||
|
mmr: whether to use mmr to re-rank the documents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
|
vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
|
||||||
reranker: BaseReranking
|
reranker: BaseReranking
|
||||||
get_extra_table: bool = False
|
get_extra_table: bool = False
|
||||||
|
mmr: bool = False
|
||||||
|
top_k: int = 5
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
top_k: int = 5,
|
|
||||||
mmr: bool = False,
|
|
||||||
doc_ids: Optional[list[str]] = None,
|
doc_ids: Optional[list[str]] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> list[RetrievedDocument]:
|
) -> list[RetrievedDocument]:
|
||||||
"""Retrieve document excerpts similar to the text
|
"""Retrieve document excerpts similar to the text
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: the text to retrieve similar documents
|
text: the text to retrieve similar documents
|
||||||
top_k: number of documents to retrieve
|
|
||||||
mmr: whether to use mmr to re-rank the documents
|
|
||||||
doc_ids: list of document ids to constraint the retrieval
|
doc_ids: list of document ids to constraint the retrieval
|
||||||
"""
|
"""
|
||||||
|
if not doc_ids:
|
||||||
|
logger.info(f"Skip retrieval because of no selected files: {self}")
|
||||||
|
return []
|
||||||
|
|
||||||
Index = self._Index
|
Index = self._Index
|
||||||
|
|
||||||
kwargs = {}
|
retrieval_kwargs = {}
|
||||||
if doc_ids:
|
with Session(engine) as session:
|
||||||
with Session(engine) as session:
|
stmt = select(Index).where(
|
||||||
stmt = select(Index).where(
|
Index.relation_type == "vector",
|
||||||
Index.relation_type == "vector",
|
Index.source_id.in_(doc_ids), # type: ignore
|
||||||
Index.source_id.in_(doc_ids), # type: ignore
|
|
||||||
)
|
|
||||||
results = session.execute(stmt)
|
|
||||||
vs_ids = [r[0].target_id for r in results.all()]
|
|
||||||
|
|
||||||
kwargs["filters"] = MetadataFilters(
|
|
||||||
filters=[
|
|
||||||
MetadataFilter(
|
|
||||||
key="doc_id",
|
|
||||||
value=vs_id,
|
|
||||||
operator=FilterOperator.EQ,
|
|
||||||
)
|
|
||||||
for vs_id in vs_ids
|
|
||||||
],
|
|
||||||
condition=FilterCondition.OR,
|
|
||||||
)
|
)
|
||||||
|
results = session.execute(stmt)
|
||||||
|
vs_ids = [r[0].target_id for r in results.all()]
|
||||||
|
|
||||||
if mmr:
|
retrieval_kwargs["filters"] = MetadataFilters(
|
||||||
|
filters=[
|
||||||
|
MetadataFilter(
|
||||||
|
key="doc_id",
|
||||||
|
value=vs_id,
|
||||||
|
operator=FilterOperator.EQ,
|
||||||
|
)
|
||||||
|
for vs_id in vs_ids
|
||||||
|
],
|
||||||
|
condition=FilterCondition.OR,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.mmr:
|
||||||
# TODO: double check that llama-index MMR works correctly
|
# TODO: double check that llama-index MMR works correctly
|
||||||
kwargs["mode"] = VectorStoreQueryMode.MMR
|
retrieval_kwargs["mode"] = VectorStoreQueryMode.MMR
|
||||||
kwargs["mmr_threshold"] = 0.5
|
retrieval_kwargs["mmr_threshold"] = 0.5
|
||||||
|
|
||||||
# rerank
|
# rerank
|
||||||
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
|
docs = self.vector_retrieval(text=text, top_k=self.top_k, **retrieval_kwargs)
|
||||||
if docs and self.get_from_path("reranker"):
|
if docs and self.get_from_path("reranker"):
|
||||||
docs = self.reranker(docs, query=text)
|
docs = self.reranker(docs, query=text)
|
||||||
|
|
||||||
|
@ -221,6 +226,8 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
retriever = cls(
|
retriever = cls(
|
||||||
get_extra_table=user_settings["prioritize_table"],
|
get_extra_table=user_settings["prioritize_table"],
|
||||||
reranker=user_settings["reranking_llm"],
|
reranker=user_settings["reranking_llm"],
|
||||||
|
top_k=user_settings["num_retrieval"],
|
||||||
|
mmr=user_settings["mmr"],
|
||||||
)
|
)
|
||||||
if not user_settings["use_reranking"]:
|
if not user_settings["use_reranking"]:
|
||||||
retriever.reranker = None # type: ignore
|
retriever.reranker = None # type: ignore
|
||||||
|
@ -228,11 +235,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
retriever.vector_retrieval.embedding = embedding_models_manager[
|
retriever.vector_retrieval.embedding = embedding_models_manager[
|
||||||
index_settings.get("embedding", embedding_models_manager.get_default_name())
|
index_settings.get("embedding", embedding_models_manager.get_default_name())
|
||||||
]
|
]
|
||||||
kwargs = {
|
kwargs = {".doc_ids": selected}
|
||||||
".top_k": int(user_settings["num_retrieval"]),
|
|
||||||
".mmr": user_settings["mmr"],
|
|
||||||
".doc_ids": selected,
|
|
||||||
}
|
|
||||||
retriever.set_run(kwargs, temp=True)
|
retriever.set_run(kwargs, temp=True)
|
||||||
return retriever
|
return retriever
|
||||||
|
|
||||||
|
|
|
@ -512,20 +512,55 @@ class FileSelector(BasePage):
|
||||||
self._index = index
|
self._index = index
|
||||||
self.on_building_ui()
|
self.on_building_ui()
|
||||||
|
|
||||||
|
def default(self):
|
||||||
|
return "disabled", []
|
||||||
|
|
||||||
def on_building_ui(self):
|
def on_building_ui(self):
|
||||||
|
default_mode, default_selector = self.default()
|
||||||
|
|
||||||
|
self.mode = gr.Radio(
|
||||||
|
value=default_mode,
|
||||||
|
choices=[
|
||||||
|
("Disabled", "disabled"),
|
||||||
|
("Search All", "all"),
|
||||||
|
("Select", "select"),
|
||||||
|
],
|
||||||
|
container=False,
|
||||||
|
)
|
||||||
self.selector = gr.Dropdown(
|
self.selector = gr.Dropdown(
|
||||||
label="Files",
|
label="Files",
|
||||||
choices=[],
|
choices=default_selector,
|
||||||
multiselect=True,
|
multiselect=True,
|
||||||
container=False,
|
container=False,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
|
visible=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_register_events(self):
|
||||||
|
self.mode.change(
|
||||||
|
fn=lambda mode: gr.update(visible=mode == "select"),
|
||||||
|
inputs=[self.mode],
|
||||||
|
outputs=[self.selector],
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_gradio_component(self):
|
def as_gradio_component(self):
|
||||||
return self.selector
|
return [self.mode, self.selector]
|
||||||
|
|
||||||
def get_selected_ids(self, selected):
|
def get_selected_ids(self, components):
|
||||||
return selected
|
mode, selected = components[0], components[1]
|
||||||
|
if mode == "disabled":
|
||||||
|
return []
|
||||||
|
elif mode == "select":
|
||||||
|
return selected
|
||||||
|
|
||||||
|
file_ids = []
|
||||||
|
with Session(engine) as session:
|
||||||
|
statement = select(self._index._resources["Source"].id)
|
||||||
|
results = session.execute(statement).all()
|
||||||
|
for (id,) in results:
|
||||||
|
file_ids.append(id)
|
||||||
|
|
||||||
|
return file_ids
|
||||||
|
|
||||||
def load_files(self, selected_files):
|
def load_files(self, selected_files):
|
||||||
options = []
|
options = []
|
||||||
|
|
|
@ -52,9 +52,11 @@ class ChatPage(BasePage):
|
||||||
len(self._indices_input) + len(gr_index),
|
len(self._indices_input) + len(gr_index),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
index.default_selector = index_ui.default()
|
||||||
self._indices_input.extend(gr_index)
|
self._indices_input.extend(gr_index)
|
||||||
else:
|
else:
|
||||||
index.selector = len(self._indices_input)
|
index.selector = len(self._indices_input)
|
||||||
|
index.default_selector = index_ui.default()
|
||||||
self._indices_input.append(gr_index)
|
self._indices_input.append(gr_index)
|
||||||
setattr(self, f"_index_{index.id}", index_ui)
|
setattr(self, f"_index_{index.id}", index_ui)
|
||||||
|
|
||||||
|
|
|
@ -156,9 +156,9 @@ class ConversationControl(BasePage):
|
||||||
if index.selector is None:
|
if index.selector is None:
|
||||||
continue
|
continue
|
||||||
if isinstance(index.selector, int):
|
if isinstance(index.selector, int):
|
||||||
indices.append(selected.get(str(index.id), []))
|
indices.append(selected.get(str(index.id), index.default_selector))
|
||||||
if isinstance(index.selector, tuple):
|
if isinstance(index.selector, tuple):
|
||||||
indices.extend(selected.get(str(index.id), [[]] * len(index.selector)))
|
indices.extend(selected.get(str(index.id), index.default_selector))
|
||||||
|
|
||||||
return id_, id_, name, chats, info_panel, state, *indices
|
return id_, id_, name, chats, info_panel, state, *indices
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user