Allow file index to be private (#45)

* Fix breaking reranker

* Allow private file index

* Avoid setting default to 1 when user management is enabled
This commit is contained in:
Duc Nguyen (john) 2024-04-25 14:24:35 +07:00 committed by GitHub
parent 456f020caf
commit e29bec6275
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 145 additions and 36 deletions

View File

@ -109,11 +109,16 @@ class BaseIndex(abc.ABC):
return {}
@abc.abstractmethod
def get_indexing_pipeline(self, settings: dict) -> "BaseComponent":
def get_indexing_pipeline(
self, settings: dict, user_id: Optional[int]
) -> "BaseComponent":
"""Return the indexing pipeline that populates the entities into the index
Args:
settings: the user settings of the index
user_id: the user id who is accessing the index
TODO: instead of having a user_id, should have an app_state
which might also contain the settings.
Returns:
BaseIndexing: the indexing pipeline

View File

@ -31,6 +31,26 @@ class FileIndex(BaseIndex):
def __init__(self, app, id: int, name: str, config: dict):
super().__init__(app, id, name, config)
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
self._selector_ui_cls: Type
self._selector_ui: Any = None
self._index_ui_cls: Type
self._index_ui: Any = None
self._default_settings: dict[str, dict] = {}
self._setting_mappings: dict[str, dict] = {}
def _setup_resources(self):
"""Setup resources for the file index
The resources include:
- Database table
- Vector store
- Document store
- File storage path
"""
Base = declarative_base()
Source = type(
"Source",
@ -50,6 +70,7 @@ class FileIndex(BaseIndex):
"date_created": Column(
DateTime(timezone=True), server_default=func.now()
),
"user": Column(Integer, default=1),
},
)
Index = type(
@ -61,6 +82,7 @@ class FileIndex(BaseIndex):
"source_id": Column(String),
"target_id": Column(String),
"relation_type": Column(Integer),
"user": Column(Integer, default=1),
},
)
self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}")
@ -74,16 +96,6 @@ class FileIndex(BaseIndex):
"FileStoragePath": self._fs_path,
}
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
self._selector_ui_cls: Type
self._selector_ui: Any = None
self._index_ui_cls: Type
self._index_ui: Any = None
self._default_settings: dict[str, dict] = {}
self._setting_mappings: dict[str, dict] = {}
def _setup_indexing_cls(self):
"""Retrieve the indexing class for the file index
@ -247,6 +259,7 @@ class FileIndex(BaseIndex):
self.config = config
# create the resources
self._setup_resources()
self._resources["Source"].metadata.create_all(engine) # type: ignore
self._resources["Index"].metadata.create_all(engine) # type: ignore
self._fs_path.mkdir(parents=True, exist_ok=True)
@ -255,6 +268,7 @@ class FileIndex(BaseIndex):
"""Clean up the index when the user delete it"""
import shutil
self._setup_resources()
self._resources["Source"].__table__.drop(engine) # type: ignore
self._resources["Index"].__table__.drop(engine) # type: ignore
self._vs.drop()
@ -263,6 +277,7 @@ class FileIndex(BaseIndex):
def on_start(self):
"""Setup the classes and hooks"""
self._setup_resources()
self._setup_indexing_cls()
self._setup_retriever_cls()
self._setup_file_index_ui_cls()
@ -326,9 +341,16 @@ class FileIndex(BaseIndex):
"Set 0 to disable."
),
},
"private": {
"name": "Make private",
"value": False,
"component": "radio",
"choices": [("Yes", True), ("No", False)],
"info": "If private, files will not be accessible across users.",
},
}
def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing:
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
"""Define the interface of the indexing pipeline"""
prefix = f"index.options.{self.id}."
@ -341,6 +363,7 @@ class FileIndex(BaseIndex):
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
obj.set_resources(resources=self._resources)
obj._user_id = user_id
return obj

View File

@ -13,6 +13,7 @@ import gradio as gr
from ktem.components import filestorage_path
from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager
from ktem.llms.manager import llms
from llama_index.vector_stores import (
FilterCondition,
FilterOperator,
@ -28,7 +29,7 @@ from theflow.utils.modules import import_dotted_string
from kotaemon.base import RetrievedDocument
from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests import DocumentIngestor
from kotaemon.indices.rankings import BaseReranking
from kotaemon.indices.rankings import BaseReranking, LLMReranking
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
@ -72,7 +73,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
"""
vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
reranker: BaseReranking
reranker: BaseReranking = LLMReranking.withx()
get_extra_table: bool = False
mmr: bool = False
top_k: int = 5
@ -225,12 +226,15 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
"""
retriever = cls(
get_extra_table=user_settings["prioritize_table"],
reranker=user_settings["reranking_llm"],
top_k=user_settings["num_retrieval"],
mmr=user_settings["mmr"],
)
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore
else:
retriever.reranker.llm = llms.get(
user_settings["reranking_llm"], llms.get_default()
)
retriever.vector_retrieval.embedding = embedding_models_manager[
index_settings.get("embedding", embedding_models_manager.get_default_name())
@ -342,6 +346,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
name=Path(file_path).name,
path=file_hash,
size=Path(file_path).stat().st_size,
user=self._user_id, # type: ignore
)
file_to_source[file_path] = source

View File

@ -168,6 +168,25 @@ class FileIndexPage(BasePage):
def on_subscribe_public_events(self):
"""Subscribe to the declared public event of the app"""
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.list_file,
"inputs": [self._app.user_id],
"outputs": [self.file_list_state, self.file_list],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.list_file,
"inputs": [self._app.user_id],
"outputs": [self.file_list_state, self.file_list],
"show_progress": "hidden",
},
)
def file_selected(self, file_id):
if file_id is None:
@ -257,7 +276,7 @@ class FileIndexPage(BasePage):
)
.then(
fn=self.list_file,
inputs=None,
inputs=[self._app.user_id],
outputs=[self.file_list_state, self.file_list],
)
)
@ -294,12 +313,13 @@ class FileIndexPage(BasePage):
self.files,
self.reindex,
self._app.settings_state,
self._app.user_id,
],
outputs=[self.file_output],
concurrency_limit=20,
).then(
fn=self.list_file,
inputs=None,
inputs=[self._app.user_id],
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
)
@ -317,11 +337,11 @@ class FileIndexPage(BasePage):
"""Called when the app is created"""
self._app.app.load(
self.list_file,
inputs=None,
inputs=[self._app.user_id],
outputs=[self.file_list_state, self.file_list],
)
def index_fn(self, files, reindex: bool, settings):
def index_fn(self, files, reindex: bool, settings, user_id):
"""Upload and index the files
Args:
@ -342,7 +362,7 @@ class FileIndexPage(BasePage):
gr.Info(f"Start indexing {len(files)} files...")
# get the pipeline
indexing_pipeline = self._index.get_indexing_pipeline(settings)
indexing_pipeline = self._index.get_indexing_pipeline(settings, user_id)
result = indexing_pipeline(files, reindex=reindex)
if result is None:
@ -360,7 +380,7 @@ class FileIndexPage(BasePage):
return gr.update(value=file_path, visible=True)
def index_files_from_dir(self, folder_path, reindex, settings):
def index_files_from_dir(self, folder_path, reindex, settings, user_id):
"""This should be constructable by users
It means that the users can build their own index.
@ -428,12 +448,28 @@ class FileIndexPage(BasePage):
for p in exclude_patterns:
files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)]
return self.index_fn(files, reindex, settings)
return self.index_fn(files, reindex, settings, user_id)
def list_file(self, user_id):
if user_id is None:
# not signed in
return [], pd.DataFrame.from_records(
[
{
"id": "-",
"name": "-",
"size": "-",
"text_length": "-",
"date_created": "-",
}
]
)
def list_file(self):
Source = self._index._resources["Source"]
with Session(engine) as session:
statement = select(Source)
if self._index.config.get("private", False):
statement = statement.where(Source.user == user_id)
results = [
{
"id": each[0].id,
@ -513,10 +549,12 @@ class FileSelector(BasePage):
self.on_building_ui()
def default(self):
return "disabled", []
if self._app.f_user_management:
return "disabled", [], -1
return "disabled", [], 1
def on_building_ui(self):
default_mode, default_selector = self.default()
default_mode, default_selector, user_id = self.default()
self.mode = gr.Radio(
value=default_mode,
@ -529,25 +567,30 @@ class FileSelector(BasePage):
)
self.selector = gr.Dropdown(
label="Files",
choices=default_selector,
value=default_selector,
choices=[],
multiselect=True,
container=False,
interactive=True,
visible=False,
)
self.selector_user_id = gr.State(value=user_id)
def on_register_events(self):
self.mode.change(
fn=lambda mode: gr.update(visible=mode == "select"),
inputs=[self.mode],
outputs=[self.selector],
fn=lambda mode, user_id: (gr.update(visible=mode == "select"), user_id),
inputs=[self.mode, self._app.user_id],
outputs=[self.selector, self.selector_user_id],
)
def as_gradio_component(self):
return [self.mode, self.selector]
return [self.mode, self.selector, self.selector_user_id]
def get_selected_ids(self, components):
mode, selected = components[0], components[1]
mode, selected, user_id = components[0], components[1], components[2]
if user_id is None:
return []
if mode == "disabled":
return []
elif mode == "select":
@ -556,17 +599,31 @@ class FileSelector(BasePage):
file_ids = []
with Session(engine) as session:
statement = select(self._index._resources["Source"].id)
if self._index.config.get("private", False):
statement = statement.where(
self._index._resources["Source"].user == user_id
)
results = session.execute(statement).all()
for (id,) in results:
file_ids.append(id)
return file_ids
def load_files(self, selected_files):
options = []
def load_files(self, selected_files, user_id):
options: list = []
available_ids = []
if user_id is None:
# not signed in
return gr.update(value=selected_files, choices=options)
with Session(engine) as session:
statement = select(self._index._resources["Source"])
if self._index.config.get("private", False):
statement = statement.where(
self._index._resources["Source"].user == user_id
)
results = session.execute(statement).all()
for result in results:
available_ids.append(result[0].id)
@ -583,7 +640,7 @@ class FileSelector(BasePage):
def _on_app_created(self):
self._app.app.load(
self.load_files,
inputs=self.selector,
inputs=[self.selector, self._app.user_id],
outputs=[self.selector],
)
@ -592,7 +649,26 @@ class FileSelector(BasePage):
name=f"onFileIndex{self._index.id}Changed",
definition={
"fn": self.load_files,
"inputs": [self.selector],
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},
)
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},

View File

@ -365,7 +365,7 @@ class ChatPage(BasePage):
Args:
settings: the settings of the app
is_regen: whether the regen button is clicked
state: the state of the app
selected: the list of file ids that will be served as context. If None, then
consider using all files