diff --git a/libs/ktem/ktem/index/base.py b/libs/ktem/ktem/index/base.py index 5183762..002b765 100644 --- a/libs/ktem/ktem/index/base.py +++ b/libs/ktem/ktem/index/base.py @@ -109,11 +109,16 @@ class BaseIndex(abc.ABC): return {} @abc.abstractmethod - def get_indexing_pipeline(self, settings: dict) -> "BaseComponent": + def get_indexing_pipeline( + self, settings: dict, user_id: Optional[int] + ) -> "BaseComponent": """Return the indexing pipeline that populates the entities into the index Args: settings: the user settings of the index + user_id: the user id who is accessing the index + TODO: instead of having a user_id, should have an app_state + which might also contain the settings. Returns: BaseIndexing: the indexing pipeline diff --git a/libs/ktem/ktem/index/file/index.py b/libs/ktem/ktem/index/file/index.py index 5e97983..0d6838d 100644 --- a/libs/ktem/ktem/index/file/index.py +++ b/libs/ktem/ktem/index/file/index.py @@ -31,6 +31,26 @@ class FileIndex(BaseIndex): def __init__(self, app, id: int, name: str, config: dict): super().__init__(app, id, name, config) + + self._indexing_pipeline_cls: Type[BaseFileIndexIndexing] + self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]] + self._selector_ui_cls: Type + self._selector_ui: Any = None + self._index_ui_cls: Type + self._index_ui: Any = None + + self._default_settings: dict[str, dict] = {} + self._setting_mappings: dict[str, dict] = {} + + def _setup_resources(self): + """Setup resources for the file index + + The resources include: + - Database table + - Vector store + - Document store + - File storage path + """ Base = declarative_base() Source = type( "Source", @@ -50,6 +70,7 @@ class FileIndex(BaseIndex): "date_created": Column( DateTime(timezone=True), server_default=func.now() ), + "user": Column(Integer, default=1), }, ) Index = type( @@ -61,6 +82,7 @@ class FileIndex(BaseIndex): "source_id": Column(String), "target_id": Column(String), "relation_type": Column(Integer), + "user": Column(Integer, default=1), }, ) self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}") @@ -74,16 +96,6 @@ class FileIndex(BaseIndex): "FileStoragePath": self._fs_path, } - self._indexing_pipeline_cls: Type[BaseFileIndexIndexing] - self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]] - self._selector_ui_cls: Type - self._selector_ui: Any = None - self._index_ui_cls: Type - self._index_ui: Any = None - - self._default_settings: dict[str, dict] = {} - self._setting_mappings: dict[str, dict] = {} - def _setup_indexing_cls(self): """Retrieve the indexing class for the file index @@ -247,6 +259,7 @@ class FileIndex(BaseIndex): self.config = config # create the resources + self._setup_resources() self._resources["Source"].metadata.create_all(engine) # type: ignore self._resources["Index"].metadata.create_all(engine) # type: ignore self._fs_path.mkdir(parents=True, exist_ok=True) @@ -255,6 +268,7 @@ class FileIndex(BaseIndex): """Clean up the index when the user delete it""" import shutil + self._setup_resources() self._resources["Source"].__table__.drop(engine) # type: ignore self._resources["Index"].__table__.drop(engine) # type: ignore self._vs.drop() @@ -263,6 +277,7 @@ class FileIndex(BaseIndex): def on_start(self): """Setup the classes and hooks""" + self._setup_resources() self._setup_indexing_cls() self._setup_retriever_cls() self._setup_file_index_ui_cls() @@ -326,9 +341,16 @@ class FileIndex(BaseIndex): "Set 0 to disable." ), }, + "private": { + "name": "Make private", + "value": False, + "component": "radio", + "choices": [("Yes", True), ("No", False)], + "info": "If private, files will not be accessible across users.", + }, } - def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing: + def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing: """Define the interface of the indexing pipeline""" prefix = f"index.options.{self.id}." @@ -341,6 +363,7 @@ class FileIndex(BaseIndex): obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config) obj.set_resources(resources=self._resources) + obj._user_id = user_id return obj diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index 375f3dd..558d2a9 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -13,6 +13,7 @@ import gradio as gr from ktem.components import filestorage_path from ktem.db.models import engine from ktem.embeddings.manager import embedding_models_manager +from ktem.llms.manager import llms from llama_index.vector_stores import ( FilterCondition, FilterOperator, @@ -28,7 +29,7 @@ from theflow.utils.modules import import_dotted_string from kotaemon.base import RetrievedDocument from kotaemon.indices import VectorIndexing, VectorRetrieval from kotaemon.indices.ingests import DocumentIngestor -from kotaemon.indices.rankings import BaseReranking +from kotaemon.indices.rankings import BaseReranking, LLMReranking from .base import BaseFileIndexIndexing, BaseFileIndexRetriever @@ -72,7 +73,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): """ vector_retrieval: VectorRetrieval = VectorRetrieval.withx() - reranker: BaseReranking + reranker: BaseReranking = LLMReranking.withx() get_extra_table: bool = False mmr: bool = False top_k: int = 5 @@ -225,12 +226,15 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): """ retriever = cls( get_extra_table=user_settings["prioritize_table"], - reranker=user_settings["reranking_llm"], top_k=user_settings["num_retrieval"], mmr=user_settings["mmr"], ) if not user_settings["use_reranking"]: retriever.reranker = None # type: ignore + else: + retriever.reranker.llm = llms.get( + user_settings["reranking_llm"], llms.get_default() + ) retriever.vector_retrieval.embedding = embedding_models_manager[ index_settings.get("embedding", embedding_models_manager.get_default_name()) @@ -342,6 +346,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing): name=Path(file_path).name, path=file_hash, size=Path(file_path).stat().st_size, + user=self._user_id, # type: ignore ) file_to_source[file_path] = source diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index ca00894..b02137d 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -168,6 +168,25 @@ class FileIndexPage(BasePage): def on_subscribe_public_events(self): """Subscribe to the declared public event of the app""" + if self._app.f_user_management: + self._app.subscribe_event( + name="onSignIn", + definition={ + "fn": self.list_file, + "inputs": [self._app.user_id], + "outputs": [self.file_list_state, self.file_list], + "show_progress": "hidden", + }, + ) + self._app.subscribe_event( + name="onSignOut", + definition={ + "fn": self.list_file, + "inputs": [self._app.user_id], + "outputs": [self.file_list_state, self.file_list], + "show_progress": "hidden", + }, + ) def file_selected(self, file_id): if file_id is None: @@ -257,7 +276,7 @@ class FileIndexPage(BasePage): ) .then( fn=self.list_file, - inputs=None, + inputs=[self._app.user_id], outputs=[self.file_list_state, self.file_list], ) ) @@ -294,12 +313,13 @@ class FileIndexPage(BasePage): self.files, self.reindex, self._app.settings_state, + self._app.user_id, ], outputs=[self.file_output], concurrency_limit=20, ).then( fn=self.list_file, - inputs=None, + inputs=[self._app.user_id], outputs=[self.file_list_state, self.file_list], concurrency_limit=20, ) @@ -317,11 +337,11 @@ class FileIndexPage(BasePage): """Called when the app is created""" self._app.app.load( self.list_file, - inputs=None, + inputs=[self._app.user_id], outputs=[self.file_list_state, self.file_list], ) - def index_fn(self, files, reindex: bool, settings): + def index_fn(self, files, reindex: bool, settings, user_id): """Upload and index the files Args: @@ -342,7 +362,7 @@ class FileIndexPage(BasePage): gr.Info(f"Start indexing {len(files)} files...") # get the pipeline - indexing_pipeline = self._index.get_indexing_pipeline(settings) + indexing_pipeline = self._index.get_indexing_pipeline(settings, user_id) result = indexing_pipeline(files, reindex=reindex) if result is None: @@ -360,7 +380,7 @@ class FileIndexPage(BasePage): return gr.update(value=file_path, visible=True) - def index_files_from_dir(self, folder_path, reindex, settings): + def index_files_from_dir(self, folder_path, reindex, settings, user_id): """This should be constructable by users It means that the users can build their own index. @@ -428,12 +448,28 @@ class FileIndexPage(BasePage): for p in exclude_patterns: files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)] - return self.index_fn(files, reindex, settings) + return self.index_fn(files, reindex, settings, user_id) + + def list_file(self, user_id): + if user_id is None: + # not signed in + return [], pd.DataFrame.from_records( + [ + { + "id": "-", + "name": "-", + "size": "-", + "text_length": "-", + "date_created": "-", + } + ] + ) - def list_file(self): Source = self._index._resources["Source"] with Session(engine) as session: statement = select(Source) + if self._index.config.get("private", False): + statement = statement.where(Source.user == user_id) results = [ { "id": each[0].id, @@ -513,10 +549,12 @@ class FileSelector(BasePage): self.on_building_ui() def default(self): - return "disabled", [] + if self._app.f_user_management: + return "disabled", [], -1 + return "disabled", [], 1 def on_building_ui(self): - default_mode, default_selector = self.default() + default_mode, default_selector, user_id = self.default() self.mode = gr.Radio( value=default_mode, @@ -529,25 +567,30 @@ class FileSelector(BasePage): ) self.selector = gr.Dropdown( label="Files", - choices=default_selector, + value=default_selector, + choices=[], multiselect=True, container=False, interactive=True, visible=False, ) + self.selector_user_id = gr.State(value=user_id) def on_register_events(self): self.mode.change( - fn=lambda mode: gr.update(visible=mode == "select"), - inputs=[self.mode], - outputs=[self.selector], + fn=lambda mode, user_id: (gr.update(visible=mode == "select"), user_id), + inputs=[self.mode, self._app.user_id], + outputs=[self.selector, self.selector_user_id], ) def as_gradio_component(self): - return [self.mode, self.selector] + return [self.mode, self.selector, self.selector_user_id] def get_selected_ids(self, components): - mode, selected = components[0], components[1] + mode, selected, user_id = components[0], components[1], components[2] + if user_id is None: + return [] + if mode == "disabled": return [] elif mode == "select": @@ -556,17 +599,31 @@ class FileSelector(BasePage): file_ids = [] with Session(engine) as session: statement = select(self._index._resources["Source"].id) + if self._index.config.get("private", False): + statement = statement.where( + self._index._resources["Source"].user == user_id + ) results = session.execute(statement).all() for (id,) in results: file_ids.append(id) return file_ids - def load_files(self, selected_files): - options = [] + def load_files(self, selected_files, user_id): + options: list = [] available_ids = [] + if user_id is None: + # not signed in + return gr.update(value=selected_files, choices=options) + with Session(engine) as session: statement = select(self._index._resources["Source"]) + if self._index.config.get("private", False): + + statement = statement.where( + self._index._resources["Source"].user == user_id + ) + results = session.execute(statement).all() for result in results: available_ids.append(result[0].id) @@ -583,7 +640,7 @@ class FileSelector(BasePage): def _on_app_created(self): self._app.app.load( self.load_files, - inputs=self.selector, + inputs=[self.selector, self._app.user_id], outputs=[self.selector], ) @@ -592,8 +649,27 @@ class FileSelector(BasePage): name=f"onFileIndex{self._index.id}Changed", definition={ "fn": self.load_files, - "inputs": [self.selector], + "inputs": [self.selector, self._app.user_id], "outputs": [self.selector], "show_progress": "hidden", }, ) + if self._app.f_user_management: + self._app.subscribe_event( + name="onSignIn", + definition={ + "fn": self.load_files, + "inputs": [self.selector, self._app.user_id], + "outputs": [self.selector], + "show_progress": "hidden", + }, + ) + self._app.subscribe_event( + name="onSignOut", + definition={ + "fn": self.load_files, + "inputs": [self.selector, self._app.user_id], + "outputs": [self.selector], + "show_progress": "hidden", + }, + ) diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index b8be90a..fe541d7 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -365,7 +365,7 @@ class ChatPage(BasePage): Args: settings: the settings of the app - is_regen: whether the regen button is clicked + state: the state of the app selected: the list of file ids that will be served as context. If None, then consider using all files