Allow file index to be private (#45)

* Fix breaking reranker

* Allow private file index

* Avoid setting default to 1 when user management is enabled
This commit is contained in:
Duc Nguyen (john) 2024-04-25 14:24:35 +07:00 committed by GitHub
parent 456f020caf
commit e29bec6275
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 145 additions and 36 deletions

View File

@ -109,11 +109,16 @@ class BaseIndex(abc.ABC):
return {} return {}
@abc.abstractmethod @abc.abstractmethod
def get_indexing_pipeline(self, settings: dict) -> "BaseComponent": def get_indexing_pipeline(
self, settings: dict, user_id: Optional[int]
) -> "BaseComponent":
"""Return the indexing pipeline that populates the entities into the index """Return the indexing pipeline that populates the entities into the index
Args: Args:
settings: the user settings of the index settings: the user settings of the index
user_id: the user id who is accessing the index
TODO: instead of having a user_id, should have an app_state
which might also contain the settings.
Returns: Returns:
BaseIndexing: the indexing pipeline BaseIndexing: the indexing pipeline

View File

@ -31,6 +31,26 @@ class FileIndex(BaseIndex):
def __init__(self, app, id: int, name: str, config: dict): def __init__(self, app, id: int, name: str, config: dict):
super().__init__(app, id, name, config) super().__init__(app, id, name, config)
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
self._selector_ui_cls: Type
self._selector_ui: Any = None
self._index_ui_cls: Type
self._index_ui: Any = None
self._default_settings: dict[str, dict] = {}
self._setting_mappings: dict[str, dict] = {}
def _setup_resources(self):
"""Setup resources for the file index
The resources include:
- Database table
- Vector store
- Document store
- File storage path
"""
Base = declarative_base() Base = declarative_base()
Source = type( Source = type(
"Source", "Source",
@ -50,6 +70,7 @@ class FileIndex(BaseIndex):
"date_created": Column( "date_created": Column(
DateTime(timezone=True), server_default=func.now() DateTime(timezone=True), server_default=func.now()
), ),
"user": Column(Integer, default=1),
}, },
) )
Index = type( Index = type(
@ -61,6 +82,7 @@ class FileIndex(BaseIndex):
"source_id": Column(String), "source_id": Column(String),
"target_id": Column(String), "target_id": Column(String),
"relation_type": Column(Integer), "relation_type": Column(Integer),
"user": Column(Integer, default=1),
}, },
) )
self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}") self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}")
@ -74,16 +96,6 @@ class FileIndex(BaseIndex):
"FileStoragePath": self._fs_path, "FileStoragePath": self._fs_path,
} }
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
self._selector_ui_cls: Type
self._selector_ui: Any = None
self._index_ui_cls: Type
self._index_ui: Any = None
self._default_settings: dict[str, dict] = {}
self._setting_mappings: dict[str, dict] = {}
def _setup_indexing_cls(self): def _setup_indexing_cls(self):
"""Retrieve the indexing class for the file index """Retrieve the indexing class for the file index
@ -247,6 +259,7 @@ class FileIndex(BaseIndex):
self.config = config self.config = config
# create the resources # create the resources
self._setup_resources()
self._resources["Source"].metadata.create_all(engine) # type: ignore self._resources["Source"].metadata.create_all(engine) # type: ignore
self._resources["Index"].metadata.create_all(engine) # type: ignore self._resources["Index"].metadata.create_all(engine) # type: ignore
self._fs_path.mkdir(parents=True, exist_ok=True) self._fs_path.mkdir(parents=True, exist_ok=True)
@ -255,6 +268,7 @@ class FileIndex(BaseIndex):
"""Clean up the index when the user delete it""" """Clean up the index when the user delete it"""
import shutil import shutil
self._setup_resources()
self._resources["Source"].__table__.drop(engine) # type: ignore self._resources["Source"].__table__.drop(engine) # type: ignore
self._resources["Index"].__table__.drop(engine) # type: ignore self._resources["Index"].__table__.drop(engine) # type: ignore
self._vs.drop() self._vs.drop()
@ -263,6 +277,7 @@ class FileIndex(BaseIndex):
def on_start(self): def on_start(self):
"""Setup the classes and hooks""" """Setup the classes and hooks"""
self._setup_resources()
self._setup_indexing_cls() self._setup_indexing_cls()
self._setup_retriever_cls() self._setup_retriever_cls()
self._setup_file_index_ui_cls() self._setup_file_index_ui_cls()
@ -326,9 +341,16 @@ class FileIndex(BaseIndex):
"Set 0 to disable." "Set 0 to disable."
), ),
}, },
"private": {
"name": "Make private",
"value": False,
"component": "radio",
"choices": [("Yes", True), ("No", False)],
"info": "If private, files will not be accessible across users.",
},
} }
def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing: def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
"""Define the interface of the indexing pipeline""" """Define the interface of the indexing pipeline"""
prefix = f"index.options.{self.id}." prefix = f"index.options.{self.id}."
@ -341,6 +363,7 @@ class FileIndex(BaseIndex):
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config) obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
obj.set_resources(resources=self._resources) obj.set_resources(resources=self._resources)
obj._user_id = user_id
return obj return obj

View File

@ -13,6 +13,7 @@ import gradio as gr
from ktem.components import filestorage_path from ktem.components import filestorage_path
from ktem.db.models import engine from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager from ktem.embeddings.manager import embedding_models_manager
from ktem.llms.manager import llms
from llama_index.vector_stores import ( from llama_index.vector_stores import (
FilterCondition, FilterCondition,
FilterOperator, FilterOperator,
@ -28,7 +29,7 @@ from theflow.utils.modules import import_dotted_string
from kotaemon.base import RetrievedDocument from kotaemon.base import RetrievedDocument
from kotaemon.indices import VectorIndexing, VectorRetrieval from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests import DocumentIngestor from kotaemon.indices.ingests import DocumentIngestor
from kotaemon.indices.rankings import BaseReranking from kotaemon.indices.rankings import BaseReranking, LLMReranking
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
@ -72,7 +73,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
""" """
vector_retrieval: VectorRetrieval = VectorRetrieval.withx() vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
reranker: BaseReranking reranker: BaseReranking = LLMReranking.withx()
get_extra_table: bool = False get_extra_table: bool = False
mmr: bool = False mmr: bool = False
top_k: int = 5 top_k: int = 5
@ -225,12 +226,15 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
""" """
retriever = cls( retriever = cls(
get_extra_table=user_settings["prioritize_table"], get_extra_table=user_settings["prioritize_table"],
reranker=user_settings["reranking_llm"],
top_k=user_settings["num_retrieval"], top_k=user_settings["num_retrieval"],
mmr=user_settings["mmr"], mmr=user_settings["mmr"],
) )
if not user_settings["use_reranking"]: if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore retriever.reranker = None # type: ignore
else:
retriever.reranker.llm = llms.get(
user_settings["reranking_llm"], llms.get_default()
)
retriever.vector_retrieval.embedding = embedding_models_manager[ retriever.vector_retrieval.embedding = embedding_models_manager[
index_settings.get("embedding", embedding_models_manager.get_default_name()) index_settings.get("embedding", embedding_models_manager.get_default_name())
@ -342,6 +346,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
name=Path(file_path).name, name=Path(file_path).name,
path=file_hash, path=file_hash,
size=Path(file_path).stat().st_size, size=Path(file_path).stat().st_size,
user=self._user_id, # type: ignore
) )
file_to_source[file_path] = source file_to_source[file_path] = source

View File

@ -168,6 +168,25 @@ class FileIndexPage(BasePage):
def on_subscribe_public_events(self): def on_subscribe_public_events(self):
"""Subscribe to the declared public event of the app""" """Subscribe to the declared public event of the app"""
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.list_file,
"inputs": [self._app.user_id],
"outputs": [self.file_list_state, self.file_list],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.list_file,
"inputs": [self._app.user_id],
"outputs": [self.file_list_state, self.file_list],
"show_progress": "hidden",
},
)
def file_selected(self, file_id): def file_selected(self, file_id):
if file_id is None: if file_id is None:
@ -257,7 +276,7 @@ class FileIndexPage(BasePage):
) )
.then( .then(
fn=self.list_file, fn=self.list_file,
inputs=None, inputs=[self._app.user_id],
outputs=[self.file_list_state, self.file_list], outputs=[self.file_list_state, self.file_list],
) )
) )
@ -294,12 +313,13 @@ class FileIndexPage(BasePage):
self.files, self.files,
self.reindex, self.reindex,
self._app.settings_state, self._app.settings_state,
self._app.user_id,
], ],
outputs=[self.file_output], outputs=[self.file_output],
concurrency_limit=20, concurrency_limit=20,
).then( ).then(
fn=self.list_file, fn=self.list_file,
inputs=None, inputs=[self._app.user_id],
outputs=[self.file_list_state, self.file_list], outputs=[self.file_list_state, self.file_list],
concurrency_limit=20, concurrency_limit=20,
) )
@ -317,11 +337,11 @@ class FileIndexPage(BasePage):
"""Called when the app is created""" """Called when the app is created"""
self._app.app.load( self._app.app.load(
self.list_file, self.list_file,
inputs=None, inputs=[self._app.user_id],
outputs=[self.file_list_state, self.file_list], outputs=[self.file_list_state, self.file_list],
) )
def index_fn(self, files, reindex: bool, settings): def index_fn(self, files, reindex: bool, settings, user_id):
"""Upload and index the files """Upload and index the files
Args: Args:
@ -342,7 +362,7 @@ class FileIndexPage(BasePage):
gr.Info(f"Start indexing {len(files)} files...") gr.Info(f"Start indexing {len(files)} files...")
# get the pipeline # get the pipeline
indexing_pipeline = self._index.get_indexing_pipeline(settings) indexing_pipeline = self._index.get_indexing_pipeline(settings, user_id)
result = indexing_pipeline(files, reindex=reindex) result = indexing_pipeline(files, reindex=reindex)
if result is None: if result is None:
@ -360,7 +380,7 @@ class FileIndexPage(BasePage):
return gr.update(value=file_path, visible=True) return gr.update(value=file_path, visible=True)
def index_files_from_dir(self, folder_path, reindex, settings): def index_files_from_dir(self, folder_path, reindex, settings, user_id):
"""This should be constructable by users """This should be constructable by users
It means that the users can build their own index. It means that the users can build their own index.
@ -428,12 +448,28 @@ class FileIndexPage(BasePage):
for p in exclude_patterns: for p in exclude_patterns:
files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)] files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)]
return self.index_fn(files, reindex, settings) return self.index_fn(files, reindex, settings, user_id)
def list_file(self, user_id):
if user_id is None:
# not signed in
return [], pd.DataFrame.from_records(
[
{
"id": "-",
"name": "-",
"size": "-",
"text_length": "-",
"date_created": "-",
}
]
)
def list_file(self):
Source = self._index._resources["Source"] Source = self._index._resources["Source"]
with Session(engine) as session: with Session(engine) as session:
statement = select(Source) statement = select(Source)
if self._index.config.get("private", False):
statement = statement.where(Source.user == user_id)
results = [ results = [
{ {
"id": each[0].id, "id": each[0].id,
@ -513,10 +549,12 @@ class FileSelector(BasePage):
self.on_building_ui() self.on_building_ui()
def default(self): def default(self):
return "disabled", [] if self._app.f_user_management:
return "disabled", [], -1
return "disabled", [], 1
def on_building_ui(self): def on_building_ui(self):
default_mode, default_selector = self.default() default_mode, default_selector, user_id = self.default()
self.mode = gr.Radio( self.mode = gr.Radio(
value=default_mode, value=default_mode,
@ -529,25 +567,30 @@ class FileSelector(BasePage):
) )
self.selector = gr.Dropdown( self.selector = gr.Dropdown(
label="Files", label="Files",
choices=default_selector, value=default_selector,
choices=[],
multiselect=True, multiselect=True,
container=False, container=False,
interactive=True, interactive=True,
visible=False, visible=False,
) )
self.selector_user_id = gr.State(value=user_id)
def on_register_events(self): def on_register_events(self):
self.mode.change( self.mode.change(
fn=lambda mode: gr.update(visible=mode == "select"), fn=lambda mode, user_id: (gr.update(visible=mode == "select"), user_id),
inputs=[self.mode], inputs=[self.mode, self._app.user_id],
outputs=[self.selector], outputs=[self.selector, self.selector_user_id],
) )
def as_gradio_component(self): def as_gradio_component(self):
return [self.mode, self.selector] return [self.mode, self.selector, self.selector_user_id]
def get_selected_ids(self, components): def get_selected_ids(self, components):
mode, selected = components[0], components[1] mode, selected, user_id = components[0], components[1], components[2]
if user_id is None:
return []
if mode == "disabled": if mode == "disabled":
return [] return []
elif mode == "select": elif mode == "select":
@ -556,17 +599,31 @@ class FileSelector(BasePage):
file_ids = [] file_ids = []
with Session(engine) as session: with Session(engine) as session:
statement = select(self._index._resources["Source"].id) statement = select(self._index._resources["Source"].id)
if self._index.config.get("private", False):
statement = statement.where(
self._index._resources["Source"].user == user_id
)
results = session.execute(statement).all() results = session.execute(statement).all()
for (id,) in results: for (id,) in results:
file_ids.append(id) file_ids.append(id)
return file_ids return file_ids
def load_files(self, selected_files): def load_files(self, selected_files, user_id):
options = [] options: list = []
available_ids = [] available_ids = []
if user_id is None:
# not signed in
return gr.update(value=selected_files, choices=options)
with Session(engine) as session: with Session(engine) as session:
statement = select(self._index._resources["Source"]) statement = select(self._index._resources["Source"])
if self._index.config.get("private", False):
statement = statement.where(
self._index._resources["Source"].user == user_id
)
results = session.execute(statement).all() results = session.execute(statement).all()
for result in results: for result in results:
available_ids.append(result[0].id) available_ids.append(result[0].id)
@ -583,7 +640,7 @@ class FileSelector(BasePage):
def _on_app_created(self): def _on_app_created(self):
self._app.app.load( self._app.app.load(
self.load_files, self.load_files,
inputs=self.selector, inputs=[self.selector, self._app.user_id],
outputs=[self.selector], outputs=[self.selector],
) )
@ -592,7 +649,26 @@ class FileSelector(BasePage):
name=f"onFileIndex{self._index.id}Changed", name=f"onFileIndex{self._index.id}Changed",
definition={ definition={
"fn": self.load_files, "fn": self.load_files,
"inputs": [self.selector], "inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},
)
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector], "outputs": [self.selector],
"show_progress": "hidden", "show_progress": "hidden",
}, },

View File

@ -365,7 +365,7 @@ class ChatPage(BasePage):
Args: Args:
settings: the settings of the app settings: the settings of the app
is_regen: whether the regen button is clicked state: the state of the app
selected: the list of file ids that will be served as context. If None, then selected: the list of file ids that will be served as context. If None, then
consider using all files consider using all files