diff --git a/.gitignore b/.gitignore
index 0114278..5c91c3e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -452,6 +452,7 @@ $RECYCLE.BIN/
.theflow/
# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm
+*.py[coid]
logs/
.gitsecret/keys/random_seed
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 21356ce..3f68b56 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -52,7 +52,12 @@ repos:
hooks:
- id: mypy
additional_dependencies:
- [types-PyYAML==6.0.12.11, "types-requests", "sqlmodel"]
+ [
+ types-PyYAML==6.0.12.11,
+ "types-requests",
+ "sqlmodel",
+ "types-Markdown",
+ ]
args: ["--check-untyped-defs", "--ignore-missing-imports"]
exclude: "^templates/"
- repo: https://github.com/codespell-project/codespell
diff --git a/libs/kotaemon/kotaemon/indices/qa/citation.py b/libs/kotaemon/kotaemon/indices/qa/citation.py
index f1a53c7..3192a07 100644
--- a/libs/kotaemon/kotaemon/indices/qa/citation.py
+++ b/libs/kotaemon/kotaemon/indices/qa/citation.py
@@ -104,18 +104,16 @@ class CitationPipeline(BaseComponent):
print("CitationPipeline: invoking LLM")
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
print("CitationPipeline: finish invoking LLM")
+ if not llm_output.messages:
+ return None
+ function_output = llm_output.messages[0].additional_kwargs["function_call"][
+ "arguments"
+ ]
+ output = QuestionAnswer.parse_raw(function_output)
except Exception as e:
print(e)
return None
- if not llm_output.messages:
- return None
-
- function_output = llm_output.messages[0].additional_kwargs["function_call"][
- "arguments"
- ]
- output = QuestionAnswer.parse_raw(function_output)
-
return output
async def ainvoke(self, context: str, question: str):
diff --git a/libs/kotaemon/kotaemon/loaders/__init__.py b/libs/kotaemon/kotaemon/loaders/__init__.py
index 28cb5f3..a59d713 100644
--- a/libs/kotaemon/kotaemon/loaders/__init__.py
+++ b/libs/kotaemon/kotaemon/loaders/__init__.py
@@ -5,7 +5,7 @@ from .docx_loader import DocxReader
from .excel_loader import PandasExcelReader
from .html_loader import HtmlReader
from .mathpix_loader import MathpixPDFReader
-from .ocr_loader import OCRReader
+from .ocr_loader import ImageReader, OCRReader
from .unstructured_loader import UnstructuredReader
__all__ = [
@@ -13,6 +13,7 @@ __all__ = [
"BaseReader",
"PandasExcelReader",
"MathpixPDFReader",
+ "ImageReader",
"OCRReader",
"DirectoryReader",
"UnstructuredReader",
diff --git a/libs/kotaemon/kotaemon/loaders/adobe_loader.py b/libs/kotaemon/kotaemon/loaders/adobe_loader.py
index dd8cbc9..09a802c 100644
--- a/libs/kotaemon/kotaemon/loaders/adobe_loader.py
+++ b/libs/kotaemon/kotaemon/loaders/adobe_loader.py
@@ -10,14 +10,6 @@ from llama_index.readers.base import BaseReader
from kotaemon.base import Document
-from .utils.adobe import (
- generate_figure_captions,
- load_json,
- parse_figure_paths,
- parse_table_paths,
- request_adobe_service,
-)
-
logger = logging.getLogger(__name__)
DEFAULT_VLM_ENDPOINT = (
@@ -74,6 +66,13 @@ class AdobeReader(BaseReader):
includes 3 types: text, table, and image
"""
+ from .utils.adobe import (
+ generate_figure_captions,
+ load_json,
+ parse_figure_paths,
+ parse_table_paths,
+ request_adobe_service,
+ )
filename = file.name
filepath = str(Path(file).resolve())
diff --git a/libs/kotaemon/kotaemon/loaders/ocr_loader.py b/libs/kotaemon/kotaemon/loaders/ocr_loader.py
index e689717..bb1ac5d 100644
--- a/libs/kotaemon/kotaemon/loaders/ocr_loader.py
+++ b/libs/kotaemon/kotaemon/loaders/ocr_loader.py
@@ -125,3 +125,70 @@ class OCRReader(BaseReader):
)
return documents
+
+
+class ImageReader(BaseReader):
+ """Read PDF using OCR, with high focus on table extraction
+
+ Example:
+ ```python
+ >> from knowledgehub.loaders import OCRReader
+ >> reader = OCRReader()
+ >> documents = reader.load_data("path/to/pdf")
+ ```
+
+ Args:
+ endpoint: URL to FullOCR endpoint. If not provided, will look for
+ environment variable `OCR_READER_ENDPOINT` or use the default
+ `knowledgehub.loaders.ocr_loader.DEFAULT_OCR_ENDPOINT`
+ (http://127.0.0.1:8000/v2/ai/infer/)
+ use_ocr: whether to use OCR to read text (e.g: from images, tables) in the PDF
+ If False, only the table and text within table cells will be extracted.
+ """
+
+ def __init__(self, endpoint: Optional[str] = None):
+ """Init the OCR reader with OCR endpoint (FullOCR pipeline)"""
+ super().__init__()
+ self.ocr_endpoint = endpoint or os.getenv(
+ "OCR_READER_ENDPOINT", DEFAULT_OCR_ENDPOINT
+ )
+
+ def load_data(
+ self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
+ ) -> List[Document]:
+ """Load data using OCR reader
+
+ Args:
+ file_path (Path): Path to PDF file
+ debug_path (Path): Path to store debug image output
+ artifact_path (Path): Path to OCR endpoints artifacts directory
+
+ Returns:
+ List[Document]: list of documents extracted from the PDF file
+ """
+ file_path = Path(file_path).resolve()
+
+ with file_path.open("rb") as content:
+ files = {"input": content}
+ data = {"job_id": uuid4(), "table_only": False}
+
+ # call the API from FullOCR endpoint
+ if "response_content" in kwargs:
+ # overriding response content if specified
+ ocr_results = kwargs["response_content"]
+ else:
+ # call original API
+ resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
+ ocr_results = resp.json()["result"]
+
+ extra_info = extra_info or {}
+ result = []
+ for ocr_result in ocr_results:
+ result.append(
+ Document(
+ content=ocr_result["csv_string"],
+ metadata=extra_info,
+ )
+ )
+
+ return result
diff --git a/libs/ktem/ktem/app.py b/libs/ktem/ktem/app.py
index 9bac904..64e8a9d 100644
--- a/libs/ktem/ktem/app.py
+++ b/libs/ktem/ktem/app.py
@@ -229,7 +229,9 @@ class BasePage:
def _on_app_created(self):
"""Called when the app is created"""
- def as_gradio_component(self) -> Optional[gr.components.Component]:
+ def as_gradio_component(
+ self,
+ ) -> Optional[gr.components.Component | list[gr.components.Component]]:
"""Return the gradio components responsible for events
Note: in ideal scenario, this method shouldn't be necessary.
diff --git a/libs/ktem/ktem/index/base.py b/libs/ktem/ktem/index/base.py
index 50bdd9e..5183762 100644
--- a/libs/ktem/ktem/index/base.py
+++ b/libs/ktem/ktem/index/base.py
@@ -1,6 +1,6 @@
import abc
import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from ktem.app import BasePage
@@ -57,7 +57,7 @@ class BaseIndex(abc.ABC):
self._app = app
self.id = id
self.name = name
- self._config = config # admin settings
+ self.config = config # admin settings
def on_create(self):
"""Create the index for the first time"""
@@ -121,7 +121,7 @@ class BaseIndex(abc.ABC):
...
def get_retriever_pipelines(
- self, settings: dict, selected: Optional[list]
+ self, settings: dict, selected: Any = None
) -> list["BaseComponent"]:
"""Return the retriever pipelines to retrieve the entity from the index"""
return []
diff --git a/libs/ktem/ktem/index/file/base.py b/libs/ktem/ktem/index/file/base.py
index 5f8e6f4..4f28f51 100644
--- a/libs/ktem/ktem/index/file/base.py
+++ b/libs/ktem/ktem/index/file/base.py
@@ -127,3 +127,11 @@ class BaseFileIndexIndexing(BaseComponent):
the absolute file storage path to the file
"""
raise NotImplementedError
+
+ def warning(self, msg):
+ """Log a warning message
+
+ Args:
+ msg: the message to log
+ """
+ print(msg)
diff --git a/libs/ktem/ktem/index/file/index.py b/libs/ktem/ktem/index/file/index.py
index ab1f35a..5fe3955 100644
--- a/libs/ktem/ktem/index/file/index.py
+++ b/libs/ktem/ktem/index/file/index.py
@@ -13,7 +13,6 @@ from theflow.utils.modules import import_dotted_string
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
-from .ui import FileIndexPage, FileSelector
class FileIndex(BaseIndex):
@@ -77,9 +76,15 @@ class FileIndex(BaseIndex):
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
+ self._selector_ui_cls: Type
+ self._selector_ui: Any = None
+ self._index_ui_cls: Type
+ self._index_ui: Any = None
self._setup_indexing_cls()
self._setup_retriever_cls()
+ self._setup_file_index_ui_cls()
+ self._setup_file_selector_ui_cls()
self._default_settings: dict[str, dict] = {}
self._setting_mappings: dict[str, dict] = {}
@@ -91,14 +96,14 @@ class FileIndex(BaseIndex):
The indexing class will is retrieved from the following order. Stop at the
first order found:
- - `FILE_INDEX_PIPELINE` in self._config
+ - `FILE_INDEX_PIPELINE` in self.config
- `FILE_INDEX_{id}_PIPELINE` in the flowsettings
- `FILE_INDEX_PIPELINE` in the flowsettings
- The default .pipelines.IndexDocumentPipeline
"""
- if "FILE_INDEX_PIPELINE" in self._config:
+ if "FILE_INDEX_PIPELINE" in self.config:
self._indexing_pipeline_cls = import_dotted_string(
- self._config["FILE_INDEX_PIPELINE"], safe=False
+ self.config["FILE_INDEX_PIPELINE"], safe=False
)
return
@@ -125,15 +130,15 @@ class FileIndex(BaseIndex):
The retriever classes will is retrieved from the following order. Stop at the
first order found:
- - `FILE_INDEX_RETRIEVER_PIPELINES` in self._config
+ - `FILE_INDEX_RETRIEVER_PIPELINES` in self.config
- `FILE_INDEX_{id}_RETRIEVER_PIPELINES` in the flowsettings
- `FILE_INDEX_RETRIEVER_PIPELINES` in the flowsettings
- The default .pipelines.DocumentRetrievalPipeline
"""
- if "FILE_INDEX_RETRIEVER_PIPELINES" in self._config:
+ if "FILE_INDEX_RETRIEVER_PIPELINES" in self.config:
self._retriever_pipeline_cls = [
import_dotted_string(each, safe=False)
- for each in self._config["FILE_INDEX_RETRIEVER_PIPELINES"]
+ for each in self.config["FILE_INDEX_RETRIEVER_PIPELINES"]
]
return
@@ -157,6 +162,76 @@ class FileIndex(BaseIndex):
self._retriever_pipeline_cls = [DocumentRetrievalPipeline]
+ def _setup_file_selector_ui_cls(self):
+ """Retrieve the file selector UI for the file index
+
+ There can be multiple retriever classes.
+
+ The retriever classes will is retrieved from the following order. Stop at the
+ first order found:
+ - `FILE_INDEX_SELECTOR_UI` in self.config
+ - `FILE_INDEX_{id}_SELECTOR_UI` in the flowsettings
+ - `FILE_INDEX_SELECTOR_UI` in the flowsettings
+ - The default .ui.FileSelector
+ """
+ if "FILE_INDEX_SELECTOR_UI" in self.config:
+ self._selector_ui_cls = import_dotted_string(
+ self.config["FILE_INDEX_SELECTOR_UI"], safe=False
+ )
+ return
+
+ if hasattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"):
+ self._selector_ui_cls = import_dotted_string(
+ getattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"),
+ safe=False,
+ )
+ return
+
+ if hasattr(flowsettings, "FILE_INDEX_SELECTOR_UI"):
+ self._selector_ui_cls = import_dotted_string(
+ getattr(flowsettings, "FILE_INDEX_SELECTOR_UI"), safe=False
+ )
+ return
+
+ from .ui import FileSelector
+
+ self._selector_ui_cls = FileSelector
+
+ def _setup_file_index_ui_cls(self):
+ """Retrieve the Index UI class
+
+ There can be multiple retriever classes.
+
+ The retriever classes will is retrieved from the following order. Stop at the
+ first order found:
+ - `FILE_INDEX_UI` in self.config
+ - `FILE_INDEX_{id}_UI` in the flowsettings
+ - `FILE_INDEX_UI` in the flowsettings
+ - The default .ui.FileIndexPage
+ """
+ if "FILE_INDEX_UI" in self.config:
+ self._index_ui_cls = import_dotted_string(
+ self.config["FILE_INDEX_UI"], safe=False
+ )
+ return
+
+ if hasattr(flowsettings, f"FILE_INDEX_{self.id}_UI"):
+ self._index_ui_cls = import_dotted_string(
+ getattr(flowsettings, f"FILE_INDEX_{self.id}_UI"),
+ safe=False,
+ )
+ return
+
+ if hasattr(flowsettings, "FILE_INDEX_UI"):
+ self._index_ui_cls = import_dotted_string(
+ getattr(flowsettings, "FILE_INDEX_UI"), safe=False
+ )
+ return
+
+ from .ui import FileIndexPage
+
+ self._index_ui_cls = FileIndexPage
+
def on_create(self):
"""Create the index for the first time
@@ -165,6 +240,13 @@ class FileIndex(BaseIndex):
2. Create the vectorstore
3. Create the docstore
"""
+ file_types_str = self.config.get(
+ "supported_file_types",
+ self.get_admin_settings()["supported_file_types"]["value"],
+ )
+ file_types = [each.strip() for each in file_types_str.split(",")]
+ self.config["supported_file_types"] = file_types
+
self._resources["Source"].metadata.create_all(engine) # type: ignore
self._resources["Index"].metadata.create_all(engine) # type: ignore
self._fs_path.mkdir(parents=True, exist_ok=True)
@@ -180,10 +262,14 @@ class FileIndex(BaseIndex):
shutil.rmtree(self._fs_path)
def get_selector_component_ui(self):
- return FileSelector(self._app, self)
+ if self._selector_ui is None:
+ self._selector_ui = self._selector_ui_cls(self._app, self)
+ return self._selector_ui
def get_index_page_ui(self):
- return FileIndexPage(self._app, self)
+ if self._index_ui is None:
+ self._index_ui = self._index_ui_cls(self._app, self)
+ return self._index_ui
def get_user_settings(self):
if self._default_settings:
@@ -210,7 +296,31 @@ class FileIndex(BaseIndex):
"value": embedding_default,
"component": "dropdown",
"choices": embedding_choices,
- }
+ },
+ "supported_file_types": {
+ "name": "Supported file types",
+ "value": (
+ "image, .pdf, .txt, .csv, .xlsx, .doc, .docx, .pptx, .html, .zip"
+ ),
+ "component": "text",
+ },
+ "max_file_size": {
+ "name": "Max file size (MB) - set 0 to disable",
+ "value": 1000,
+ "component": "number",
+ },
+ "max_number_of_files": {
+ "name": "Max number of files that can be indexed - set 0 to disable",
+ "value": 0,
+ "component": "number",
+ },
+ "max_number_of_text_length": {
+ "name": (
+ "Max amount of characters that can be indexed - set 0 to disable"
+ ),
+ "value": 0,
+ "component": "number",
+ },
}
def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing:
@@ -224,14 +334,15 @@ class FileIndex(BaseIndex):
else:
stripped_settings[key] = value
- obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self._config)
+ obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
obj.set_resources(resources=self._resources)
return obj
def get_retriever_pipelines(
- self, settings: dict, selected: Optional[list] = None
+ self, settings: dict, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
+ # retrieval settings
prefix = f"index.options.{self.id}."
stripped_settings = {}
for key, value in settings.items():
@@ -240,9 +351,12 @@ class FileIndex(BaseIndex):
else:
stripped_settings[key] = value
+ # transform selected id
+ selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected)
+
retrievers = []
for cls in self._retriever_pipeline_cls:
- obj = cls.get_pipeline(stripped_settings, self._config, selected)
+ obj = cls.get_pipeline(stripped_settings, self.config, selected_ids)
if obj is None:
continue
obj.set_resources(self._resources)
diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py
index 68b3a4d..b63d89c 100644
--- a/libs/ktem/ktem/index/file/pipelines.py
+++ b/libs/ktem/ktem/index/file/pipelines.py
@@ -9,6 +9,7 @@ from hashlib import sha256
from pathlib import Path
from typing import Optional
+import gradio as gr
from ktem.components import embeddings, filestorage_path
from ktem.db.models import engine
from llama_index.vector_stores import (
@@ -18,7 +19,7 @@ from llama_index.vector_stores import (
MetadataFilters,
)
from llama_index.vector_stores.types import VectorStoreQueryMode
-from sqlalchemy import select
+from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from theflow.settings import settings
from theflow.utils.modules import import_dotted_string
@@ -279,6 +280,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
to_index: list[str] = []
file_to_hash: dict[str, str] = {}
errors = []
+ to_update = []
for file_path in file_paths:
abs_path = str(Path(file_path).resolve())
@@ -291,16 +293,26 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
statement = select(Source).where(Source.name == Path(abs_path).name)
item = session.execute(statement).first()
- if item and not reindex:
- errors.append(Path(abs_path).name)
- continue
+ if item:
+ if not reindex:
+ errors.append(Path(abs_path).name)
+ continue
+ else:
+ to_update.append(Path(abs_path).name)
to_index.append(abs_path)
if errors:
+ error_files = ", ".join(errors)
+ if len(error_files) > 100:
+ error_files = error_files[:80] + "..."
print(
- "Files already exist. Please rename/remove them or enable reindex.\n"
- f"{errors}"
+ "Skip these files already exist. Please rename/remove them or "
+ f"enable reindex:\n{errors}"
+ )
+ self.warning(
+ "Skip these files already exist. Please rename/remove them or "
+ f"enable reindex:\n{error_files}"
)
if not to_index:
@@ -310,9 +322,19 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
for path in to_index:
shutil.copy(path, filestorage_path / file_to_hash[path])
- # prepare record info
+ # extract the file & prepare record info
file_to_source: dict = {}
+ extraction_errors = []
+ nodes = []
for file_path, file_hash in file_to_hash.items():
+ if str(Path(file_path).resolve()) not in to_index:
+ continue
+
+ extraction_result = self.file_ingestor(file_path)
+ if not extraction_result:
+ extraction_errors.append(Path(file_path).name)
+ continue
+ nodes.extend(extraction_result)
source = Source(
name=Path(file_path).name,
path=file_hash,
@@ -320,9 +342,23 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
)
file_to_source[file_path] = source
- # extract the files
- nodes = self.file_ingestor(to_index)
- print("Extracted", len(to_index), "files into", len(nodes), "nodes")
+ if extraction_errors:
+ msg = "Failed to extract these files: {}".format(
+ ", ".join(extraction_errors)
+ )
+ print(msg)
+ self.warning(msg)
+
+ if not nodes:
+ return [], []
+
+ print(
+ "Extracted",
+ len(to_index) - len(extraction_errors),
+ "files into",
+ len(nodes),
+ "nodes",
+ )
# index the files
print("Indexing the files into vector store")
@@ -332,7 +368,11 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
# persist to the index
print("Persisting the vector and the document into index")
file_ids = []
+ to_update = list(set(to_update))
with Session(engine) as session:
+ if to_update:
+ session.execute(delete(Source).where(Source.name.in_(to_update)))
+
for source in file_to_source.values():
session.add(source)
session.commit()
@@ -404,3 +444,6 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
super().set_resources(resources)
self.indexing_vector_pipeline.vector_store = self._VS
self.indexing_vector_pipeline.doc_store = self._DS
+
+ def warning(self, msg):
+ gr.Warning(msg)
diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py
index 9da2b4a..11d491f 100644
--- a/libs/ktem/ktem/index/file/ui.py
+++ b/libs/ktem/ktem/index/file/ui.py
@@ -1,29 +1,48 @@
import os
import tempfile
+from pathlib import Path
import gradio as gr
import pandas as pd
+from gradio.data_classes import FileData
+from gradio.utils import NamedString
from ktem.app import BasePage
from ktem.db.engine import engine
from sqlalchemy import select
from sqlalchemy.orm import Session
+class File(gr.File):
+ """Subclass from gr.File to maintain the original filename
+
+ The issue happens when user uploads file with name like: !@#$%%^&*().pdf
+ """
+
+ def _process_single_file(self, f: FileData) -> NamedString | bytes:
+ file_name = f.path
+ if self.type == "filepath":
+ if f.orig_name and Path(file_name).name != f.orig_name:
+ file_name = str(Path(file_name).parent / f.orig_name)
+ os.rename(f.path, file_name)
+ file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
+ file.name = file_name
+ return NamedString(file_name)
+ elif self.type == "binary":
+ with open(file_name, "rb") as file_data:
+ return file_data.read()
+ else:
+ raise ValueError(
+ "Unknown type: "
+ + str(type)
+ + ". Please choose from: 'filepath', 'binary'."
+ )
+
+
class DirectoryUpload(BasePage):
- def __init__(self, app):
- self._app = app
- self._supported_file_types = [
- "image",
- ".pdf",
- ".txt",
- ".csv",
- ".xlsx",
- ".doc",
- ".docx",
- ".pptx",
- ".html",
- ".zip",
- ]
+ def __init__(self, app, index):
+ super().__init__(app)
+ self._index = index
+ self._supported_file_types = self._index.config.get("supported_file_types", [])
self.on_building_ui()
def on_building_ui(self):
@@ -50,18 +69,7 @@ class FileIndexPage(BasePage):
def __init__(self, app, index):
super().__init__(app)
self._index = index
- self._supported_file_types = [
- "image",
- ".pdf",
- ".txt",
- ".csv",
- ".xlsx",
- ".doc",
- ".docx",
- ".pptx",
- ".html",
- ".zip",
- ]
+ self._supported_file_types = self._index.config.get("supported_file_types", [])
self.selected_panel_false = "Selected file: (please select above)"
self.selected_panel_true = "Selected file: {name}"
# TODO: on_building_ui is not correctly named if it's always called in
@@ -69,13 +77,32 @@ class FileIndexPage(BasePage):
self.public_events = [f"onFileIndex{index.id}Changed"]
self.on_building_ui()
+ def upload_instruction(self) -> str:
+ msgs = []
+ if self._supported_file_types:
+ msgs.append(
+ f"- Supported file types: {', '.join(self._supported_file_types)}"
+ )
+
+ if max_file_size := self._index.config.get("max_file_size", 0):
+ msgs.append(f"- Maximum file size: {max_file_size} MB")
+
+ if max_number_of_files := self._index.config.get("max_number_of_files", 0):
+ msgs.append(f"- The index can have maximum {max_number_of_files} files")
+
+ if msgs:
+ return "\n".join(msgs)
+
+ return ""
+
def on_building_ui(self):
"""Build the UI of the app"""
- with gr.Accordion(label="File upload", open=False):
- gr.Markdown(
- f"Supported file types: {', '.join(self._supported_file_types)}",
- )
- self.files = gr.File(
+ with gr.Accordion(label="File upload", open=True) as self.upload:
+ msg = self.upload_instruction()
+ if msg:
+ gr.Markdown(msg)
+
+ self.files = File(
file_types=self._supported_file_types,
file_count="multiple",
container=False,
@@ -98,18 +125,20 @@ class FileIndexPage(BasePage):
interactive=False,
)
- with gr.Row():
+ with gr.Row() as self.selection_info:
self.selected_file_id = gr.State(value=None)
self.selected_panel = gr.Markdown(self.selected_panel_false)
self.deselect_button = gr.Button("Deselect", visible=False)
- with gr.Row():
+ with gr.Row() as self.tools:
with gr.Column():
self.view_button = gr.Button("View Text (WIP)")
with gr.Column():
self.delete_button = gr.Button("Delete")
with gr.Row():
- self.delete_yes = gr.Button("Confirm Delete", visible=False)
+ self.delete_yes = gr.Button(
+ "Confirm Delete", variant="primary", visible=False
+ )
self.delete_no = gr.Button("Cancel", visible=False)
def on_subscribe_public_events(self):
@@ -242,10 +271,12 @@ class FileIndexPage(BasePage):
self._app.settings_state,
],
outputs=[self.file_output],
+ concurrency_limit=20,
).then(
fn=self.list_file,
inputs=None,
outputs=[self.file_list_state, self.file_list],
+ concurrency_limit=20,
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
onUploaded = onUploaded.then(**event)
@@ -274,6 +305,15 @@ class FileIndexPage(BasePage):
selected_files: the list of files already selected
settings: the settings of the app
"""
+ if not files:
+ gr.Info("No uploaded file")
+ return gr.update()
+
+ errors = self.validate(files)
+ if errors:
+ gr.Warning(", ".join(errors))
+ return gr.update()
+
gr.Info(f"Start indexing {len(files)} files...")
# get the pipeline
@@ -409,6 +449,35 @@ class FileIndexPage(BasePage):
name=list_files["name"][ev.index[0]]
)
+ def validate(self, files: list[str]):
+ """Validate if the files are valid"""
+ paths = [Path(file) for file in files]
+ errors = []
+ if max_file_size := self._index.config.get("max_file_size", 0):
+ errors_max_size = []
+ for path in paths:
+ if path.stat().st_size > max_file_size * 1e6:
+ errors_max_size.append(path.name)
+ if errors_max_size:
+ str_errors = ", ".join(errors_max_size)
+ if len(str_errors) > 60:
+ str_errors = str_errors[:55] + "..."
+ errors.append(
+ f"Maximum file size ({max_file_size} MB) exceeded: {str_errors}"
+ )
+
+ if max_number_of_files := self._index.config.get("max_number_of_files", 0):
+ with Session(engine) as session:
+ current_num_files = session.query(
+ self._index._db_tables["Source"].id
+ ).count()
+ if len(paths) + current_num_files > max_number_of_files:
+ errors.append(
+ f"Maximum number of files ({max_number_of_files}) will be exceeded"
+ )
+
+ return errors
+
class FileSelector(BasePage):
"""File selector UI in the Chat page"""
@@ -430,6 +499,9 @@ class FileSelector(BasePage):
def as_gradio_component(self):
return self.selector
+ def get_selected_ids(self, selected):
+ return selected
+
def load_files(self, selected_files):
options = []
available_ids = []
diff --git a/libs/ktem/ktem/index/manager.py b/libs/ktem/ktem/index/manager.py
index 72c4f99..af1c8d4 100644
--- a/libs/ktem/ktem/index/manager.py
+++ b/libs/ktem/ktem/index/manager.py
@@ -1,4 +1,4 @@
-from typing import Type
+from typing import Optional, Type
from ktem.db.models import engine
from sqlmodel import Session, select
@@ -49,15 +49,19 @@ class IndexManager:
Returns:
BaseIndex: the index object
"""
+ index_cls = import_dotted_string(index_type, safe=False)
+ index = index_cls(app=self._app, id=id, name=name, config=config)
+ index.on_create()
+
with Session(engine) as session:
- index_entry = Index(id=id, name=name, config=config, index_type=index_type)
+ index_entry = Index(
+ id=index.id, name=index.name, config=index.config, index_type=index_type
+ )
session.add(index_entry)
session.commit()
session.refresh(index_entry)
- index_cls = import_dotted_string(index_type, safe=False)
- index = index_cls(app=self._app, id=id, name=name, config=config)
- index.on_create()
+ index.id = index_entry.id
return index
@@ -77,7 +81,7 @@ class IndexManager:
self._indices.append(index)
return index
- def exists(self, id: int) -> bool:
+ def exists(self, id: Optional[int] = None, name: Optional[str] = None) -> bool:
"""Check if the index exists
Args:
@@ -86,9 +90,19 @@ class IndexManager:
Returns:
bool: True if the index exists, False otherwise
"""
- with Session(engine) as session:
- index = session.get(Index, id)
- return index is not None
+ if id:
+ with Session(engine) as session:
+ index = session.get(Index, id)
+ return index is not None
+
+ if name:
+ with Session(engine) as session:
+ index = session.exec(
+ select(Index).where(Index.name == name)
+ ).one_or_none()
+ return index is not None
+
+ return False
def on_application_startup(self):
"""This method is called by the base application when the application starts
diff --git a/libs/ktem/ktem/main.py b/libs/ktem/ktem/main.py
index c375ed7..1d76d04 100644
--- a/libs/ktem/ktem/main.py
+++ b/libs/ktem/ktem/main.py
@@ -27,7 +27,7 @@ class App(BaseApp):
if self.f_user_management:
from ktem.pages.login import LoginPage
- with gr.Tab("Login", elem_id="login-tab") as self._tabs["login-tab"]:
+ with gr.Tab("Welcome", elem_id="login-tab") as self._tabs["login-tab"]:
self.login_page = LoginPage(self)
with gr.Tab(
@@ -62,6 +62,9 @@ class App(BaseApp):
def on_subscribe_public_events(self):
if self.f_user_management:
+ from ktem.db.engine import engine
+ from ktem.db.models import User
+ from sqlmodel import Session, select
def signed_in_out(user_id):
if not user_id:
@@ -73,14 +76,31 @@ class App(BaseApp):
)
for k in self._tabs.keys()
)
- return list(
- (
- gr.update(visible=True)
- if k != "login-tab"
- else gr.update(visible=False)
- )
- for k in self._tabs.keys()
- )
+
+ with Session(engine) as session:
+ user = session.exec(select(User).where(User.id == user_id)).first()
+ if user is None:
+ return list(
+ (
+ gr.update(visible=True)
+ if k == "login-tab"
+ else gr.update(visible=False)
+ )
+ for k in self._tabs.keys()
+ )
+
+ is_admin = user.admin
+
+ tabs_update = []
+ for k in self._tabs.keys():
+ if k == "login-tab":
+ tabs_update.append(gr.update(visible=False))
+ elif k == "admin-tab":
+ tabs_update.append(gr.update(visible=is_admin))
+ else:
+ tabs_update.append(gr.update(visible=True))
+
+ return tabs_update
self.subscribe_event(
name="onSignIn",
diff --git a/libs/ktem/ktem/pages/admin/user.py b/libs/ktem/ktem/pages/admin/user.py
index 519fb0f..6411b30 100644
--- a/libs/ktem/ktem/pages/admin/user.py
+++ b/libs/ktem/ktem/pages/admin/user.py
@@ -40,7 +40,7 @@ def validate_username(usn):
if len(usn) > 32:
errors.append("Username must be at most 32 characters long")
- if not usn.strip("_").isalnum():
+ if not usn.replace("_", "").isalnum():
errors.append(
"Username must contain only alphanumeric characters and underscores"
)
@@ -97,8 +97,6 @@ def validate_password(pwd, pwd_cnf):
class UserManagement(BasePage):
def __init__(self, app):
self._app = app
- self.selected_panel_false = "Selected user: (please select above)"
- self.selected_panel_true = "Selected user: {name}"
self.on_building_ui()
if hasattr(flowsettings, "KH_FEATURE_USER_MANAGEMENT_ADMIN") and hasattr(
@@ -126,7 +124,38 @@ class UserManagement(BasePage):
gr.Info(f'User "{usn}" created successfully')
def on_building_ui(self):
- with gr.Accordion(label="Create user", open=False):
+ with gr.Tab(label="User list"):
+ self.state_user_list = gr.State(value=None)
+ self.user_list = gr.DataFrame(
+ headers=["id", "name", "admin"],
+ interactive=False,
+ )
+
+ with gr.Group(visible=False) as self._selected_panel:
+ self.selected_user_id = gr.Number(value=-1, visible=False)
+ self.usn_edit = gr.Textbox(label="Username")
+ with gr.Row():
+ self.pwd_edit = gr.Textbox(label="Change password", type="password")
+ self.pwd_cnf_edit = gr.Textbox(
+ label="Confirm change password",
+ type="password",
+ )
+ self.admin_edit = gr.Checkbox(label="Admin")
+
+ with gr.Row() as self._selected_panel_btn:
+ with gr.Column():
+ self.btn_edit_save = gr.Button("Save")
+ with gr.Column():
+ self.btn_delete = gr.Button("Delete")
+ with gr.Row():
+ self.btn_delete_yes = gr.Button(
+ "Confirm delete", variant="primary", visible=False
+ )
+ self.btn_delete_no = gr.Button("Cancel", visible=False)
+ with gr.Column():
+ self.btn_close = gr.Button("Close")
+
+ with gr.Tab(label="Create user"):
self.usn_new = gr.Textbox(label="Username", interactive=True)
self.pwd_new = gr.Textbox(
label="Password", type="password", interactive=True
@@ -139,52 +168,28 @@ class UserManagement(BasePage):
gr.Markdown(PASSWORD_RULE)
self.btn_new = gr.Button("Create user")
- gr.Markdown("## User list")
- self.btn_list_user = gr.Button("Refresh user list")
- self.state_user_list = gr.State(value=None)
- self.user_list = gr.DataFrame(
- headers=["id", "name", "admin"],
- interactive=False,
- )
-
- with gr.Row():
- self.selected_user_id = gr.State(value=None)
- self.selected_panel = gr.Markdown(self.selected_panel_false)
- self.deselect_button = gr.Button("Deselect", visible=False)
-
- with gr.Group():
- self.btn_delete = gr.Button("Delete user")
- with gr.Row():
- self.btn_delete_yes = gr.Button("Confirm", visible=False)
- self.btn_delete_no = gr.Button("Cancel", visible=False)
-
- gr.Markdown("## User details")
- self.usn_edit = gr.Textbox(label="Username")
- self.pwd_edit = gr.Textbox(label="Password", type="password")
- self.pwd_cnf_edit = gr.Textbox(label="Confirm password", type="password")
- self.admin_edit = gr.Checkbox(label="Admin")
- self.btn_edit_save = gr.Button("Save")
-
def on_register_events(self):
self.btn_new.click(
self.create_user,
inputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new],
- outputs=None,
- )
- self.btn_list_user.click(
- self.list_users, inputs=None, outputs=[self.state_user_list, self.user_list]
+ outputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new],
+ ).then(
+ self.list_users,
+ inputs=self._app.user_id,
+ outputs=[self.state_user_list, self.user_list],
)
self.user_list.select(
self.select_user,
inputs=self.user_list,
- outputs=[self.selected_user_id, self.selected_panel],
+ outputs=[self.selected_user_id],
show_progress="hidden",
)
- self.selected_panel.change(
+ self.selected_user_id.change(
self.on_selected_user_change,
inputs=[self.selected_user_id],
outputs=[
- self.deselect_button,
+ self._selected_panel,
+ self._selected_panel_btn,
# delete section
self.btn_delete,
self.btn_delete_yes,
@@ -197,12 +202,6 @@ class UserManagement(BasePage):
],
show_progress="hidden",
)
- self.deselect_button.click(
- lambda: (None, self.selected_panel_false),
- inputs=None,
- outputs=[self.selected_user_id, self.selected_panel],
- show_progress="hidden",
- )
self.btn_delete.click(
self.on_btn_delete_click,
inputs=[self.selected_user_id],
@@ -211,9 +210,13 @@ class UserManagement(BasePage):
)
self.btn_delete_yes.click(
self.delete_user,
- inputs=[self.selected_user_id],
- outputs=[self.selected_user_id, self.selected_panel],
+ inputs=[self._app.user_id, self.selected_user_id],
+ outputs=[self.selected_user_id],
show_progress="hidden",
+ ).then(
+ self.list_users,
+ inputs=self._app.user_id,
+ outputs=[self.state_user_list, self.user_list],
)
self.btn_delete_no.click(
lambda: (
@@ -234,21 +237,53 @@ class UserManagement(BasePage):
self.pwd_cnf_edit,
self.admin_edit,
],
- outputs=None,
+ outputs=[self.pwd_edit, self.pwd_cnf_edit],
show_progress="hidden",
+ ).then(
+ self.list_users,
+ inputs=self._app.user_id,
+ outputs=[self.state_user_list, self.user_list],
+ )
+ self.btn_close.click(
+ lambda: -1,
+ outputs=[self.selected_user_id],
+ )
+
+ def on_subscribe_public_events(self):
+ self._app.subscribe_event(
+ name="onSignIn",
+ definition={
+ "fn": self.list_users,
+ "inputs": [self._app.user_id],
+ "outputs": [self.state_user_list, self.user_list],
+ },
+ )
+ self._app.subscribe_event(
+ name="onSignOut",
+ definition={
+ "fn": lambda: ("", "", "", None, None, -1),
+ "outputs": [
+ self.usn_new,
+ self.pwd_new,
+ self.pwd_cnf_new,
+ self.state_user_list,
+ self.user_list,
+ self.selected_user_id,
+ ],
+ },
)
def create_user(self, usn, pwd, pwd_cnf):
errors = validate_username(usn)
if errors:
gr.Warning(errors)
- return
+ return usn, pwd, pwd_cnf
errors = validate_password(pwd, pwd_cnf)
print(errors)
if errors:
gr.Warning(errors)
- return
+ return usn, pwd, pwd_cnf
with Session(engine) as session:
statement = select(User).where(User.username_lower == usn.lower())
@@ -265,8 +300,22 @@ class UserManagement(BasePage):
session.commit()
gr.Info(f'User "{usn}" created successfully')
- def list_users(self):
+ return "", "", ""
+
+ def list_users(self, user_id):
+ if user_id is None:
+ return [], pd.DataFrame.from_records(
+ [{"id": "-", "username": "-", "admin": "-"}]
+ )
+
with Session(engine) as session:
+ statement = select(User).where(User.id == user_id)
+ user = session.exec(statement).one()
+ if not user.admin:
+ return [], pd.DataFrame.from_records(
+ [{"id": "-", "username": "-", "admin": "-"}]
+ )
+
statement = select(User)
results = [
{"id": user.id, "username": user.username, "admin": user.admin}
@@ -284,18 +333,17 @@ class UserManagement(BasePage):
def select_user(self, user_list, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0:
gr.Info("No user is loaded. Please refresh the user list")
- return None, self.selected_panel_false
+ return -1
if not ev.selected:
- return None, self.selected_panel_false
+ return -1
- return user_list["id"][ev.index[0]], self.selected_panel_true.format(
- name=user_list["username"][ev.index[0]]
- )
+ return user_list["id"][ev.index[0]]
def on_selected_user_change(self, selected_user_id):
- if selected_user_id is None:
- deselect_button = gr.update(visible=False)
+ if selected_user_id == -1:
+ _selected_panel = gr.update(visible=False)
+ _selected_panel_btn = gr.update(visible=False)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
@@ -304,7 +352,8 @@ class UserManagement(BasePage):
pwd_cnf_edit = gr.update(value="")
admin_edit = gr.update(value=False)
else:
- deselect_button = gr.update(visible=True)
+ _selected_panel = gr.update(visible=True)
+ _selected_panel_btn = gr.update(visible=True)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
@@ -319,7 +368,8 @@ class UserManagement(BasePage):
admin_edit = gr.update(value=user.admin)
return (
- deselect_button,
+ _selected_panel,
+ _selected_panel_btn,
btn_delete,
btn_delete_yes,
btn_delete_no,
@@ -344,17 +394,16 @@ class UserManagement(BasePage):
return btn_delete, btn_delete_yes, btn_delete_no
def save_user(self, selected_user_id, usn, pwd, pwd_cnf, admin):
- if usn:
- errors = validate_username(usn)
- if errors:
- gr.Warning(errors)
- return
+ errors = validate_username(usn)
+ if errors:
+ gr.Warning(errors)
+ return pwd, pwd_cnf
if pwd:
errors = validate_password(pwd, pwd_cnf)
if errors:
gr.Warning(errors)
- return
+ return pwd, pwd_cnf
with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id))
@@ -367,11 +416,17 @@ class UserManagement(BasePage):
session.commit()
gr.Info(f'User "{usn}" updated successfully')
- def delete_user(self, selected_user_id):
+ return "", ""
+
+ def delete_user(self, current_user, selected_user_id):
+ if current_user == selected_user_id:
+ gr.Warning("You cannot delete yourself")
+ return selected_user_id
+
with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id))
user = session.exec(statement).one()
session.delete(user)
session.commit()
gr.Info(f'User "{user.username}" deleted successfully')
- return None, self.selected_panel_false
+ return -1
diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py
index d2bba87..a83bd83 100644
--- a/libs/ktem/ktem/pages/chat/__init__.py
+++ b/libs/ktem/ktem/pages/chat/__init__.py
@@ -7,8 +7,10 @@ from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
from sqlmodel import Session, select
+from theflow.settings import settings as flowsettings
from .chat_panel import ChatPanel
+from .chat_suggestion import ChatSuggestion
from .common import STATE
from .control import ConversationControl
from .report import ReportIssue
@@ -26,24 +28,39 @@ class ChatPage(BasePage):
with gr.Column(scale=1):
self.chat_control = ConversationControl(self._app)
+ if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
+ self.chat_suggestion = ChatSuggestion(self._app)
+
for index in self._app.index_manager.indices:
- index.selector = -1
+ index.selector = None
index_ui = index.get_selector_component_ui()
if not index_ui:
+ # the index doesn't have a selector UI component
continue
- index_ui.unrender()
+ index_ui.unrender() # need to rerender later within Accordion
with gr.Accordion(label=f"{index.name} Index", open=False):
index_ui.render()
gr_index = index_ui.as_gradio_component()
if gr_index:
- index.selector = len(self._indices_input)
- self._indices_input.append(gr_index)
+ if isinstance(gr_index, list):
+ index.selector = tuple(
+ range(
+ len(self._indices_input),
+ len(self._indices_input) + len(gr_index),
+ )
+ )
+ self._indices_input.extend(gr_index)
+ else:
+ index.selector = len(self._indices_input)
+ self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui)
self.report_issue = ReportIssue(self._app)
+
with gr.Column(scale=6):
self.chat_panel = ChatPanel(self._app)
+
with gr.Column(scale=3):
with gr.Accordion(label="Information panel", open=True):
self.info_panel = gr.HTML(elem_id="chat-info-panel")
@@ -54,11 +71,24 @@ class ChatPage(BasePage):
self.chat_panel.text_input.submit,
self.chat_panel.submit_btn.click,
],
- fn=self.chat_panel.submit_msg,
- inputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
- outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
+ fn=self.submit_msg,
+ inputs=[
+ self.chat_panel.text_input,
+ self.chat_panel.chatbot,
+ self._app.user_id,
+ self.chat_control.conversation_id,
+ self.chat_control.conversation_rn,
+ ],
+ outputs=[
+ self.chat_panel.text_input,
+ self.chat_panel.chatbot,
+ self.chat_control.conversation_id,
+ self.chat_control.conversation,
+ self.chat_control.conversation_rn,
+ ],
+ concurrency_limit=20,
show_progress="hidden",
- ).then(
+ ).success(
fn=self.chat_fn,
inputs=[
self.chat_control.conversation_id,
@@ -72,6 +102,7 @@ class ChatPage(BasePage):
self.info_panel,
self.chat_state,
],
+ concurrency_limit=20,
show_progress="minimal",
).then(
fn=self.update_data_source,
@@ -82,6 +113,7 @@ class ChatPage(BasePage):
]
+ self._indices_input,
outputs=None,
+ concurrency_limit=20,
)
self.chat_panel.regen_btn.click(
@@ -98,6 +130,7 @@ class ChatPage(BasePage):
self.info_panel,
self.chat_state,
],
+ concurrency_limit=20,
show_progress="minimal",
).then(
fn=self.update_data_source,
@@ -108,6 +141,7 @@ class ChatPage(BasePage):
]
+ self._indices_input,
outputs=None,
+ concurrency_limit=20,
)
self.chat_panel.chatbot.like(
@@ -116,7 +150,12 @@ class ChatPage(BasePage):
outputs=None,
)
- self.chat_control.conversation.change(
+ self.chat_control.btn_new.click(
+ self.chat_control.new_conv,
+ inputs=self._app.user_id,
+ outputs=[self.chat_control.conversation_id, self.chat_control.conversation],
+ show_progress="hidden",
+ ).then(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
outputs=[
@@ -124,12 +163,71 @@ class ChatPage(BasePage):
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
+ self.info_panel,
self.chat_state,
]
+ self._indices_input,
show_progress="hidden",
)
+ self.chat_control.btn_del.click(
+ lambda id: self.toggle_delete(id),
+ inputs=[self.chat_control.conversation_id],
+ outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
+ )
+ self.chat_control.btn_del_conf.click(
+ self.chat_control.delete_conv,
+ inputs=[self.chat_control.conversation_id, self._app.user_id],
+ outputs=[self.chat_control.conversation_id, self.chat_control.conversation],
+ show_progress="hidden",
+ ).then(
+ self.chat_control.select_conv,
+ inputs=[self.chat_control.conversation],
+ outputs=[
+ self.chat_control.conversation_id,
+ self.chat_control.conversation,
+ self.chat_control.conversation_rn,
+ self.chat_panel.chatbot,
+ self.info_panel,
+ ]
+ + self._indices_input,
+ show_progress="hidden",
+ ).then(
+ lambda: self.toggle_delete(""),
+ outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
+ )
+ self.chat_control.btn_del_cnl.click(
+ lambda: self.toggle_delete(""),
+ outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
+ )
+ self.chat_control.conversation_rn_btn.click(
+ self.chat_control.rename_conv,
+ inputs=[
+ self.chat_control.conversation_id,
+ self.chat_control.conversation_rn,
+ self._app.user_id,
+ ],
+ outputs=[self.chat_control.conversation, self.chat_control.conversation],
+ show_progress="hidden",
+ )
+
+ self.chat_control.conversation.select(
+ self.chat_control.select_conv,
+ inputs=[self.chat_control.conversation],
+ outputs=[
+ self.chat_control.conversation_id,
+ self.chat_control.conversation,
+ self.chat_control.conversation_rn,
+ self.chat_panel.chatbot,
+ self.info_panel,
+ ]
+ + self._indices_input,
+ show_progress="hidden",
+ ).then(
+ lambda: self.toggle_delete(""),
+ outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
+ )
+
self.report_issue.report_btn.click(
self.report_issue.report,
inputs=[
@@ -140,11 +238,77 @@ class ChatPage(BasePage):
self.chat_panel.chatbot,
self._app.settings_state,
self._app.user_id,
+ self.info_panel,
self.chat_state,
]
+ self._indices_input,
outputs=None,
)
+ if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
+ self.chat_suggestion.example.select(
+ self.chat_suggestion.select_example,
+ outputs=[self.chat_panel.text_input],
+ show_progress="hidden",
+ )
+
+ def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name):
+ """Submit a message to the chatbot"""
+ if not chat_input:
+ raise ValueError("Input is empty")
+
+ if not conv_id:
+ id_, update = self.chat_control.new_conv(user_id)
+ with Session(engine) as session:
+ statement = select(Conversation).where(Conversation.id == id_)
+ name = session.exec(statement).one().name
+ new_conv_id = id_
+ conv_update = update
+ new_conv_name = name
+ else:
+ new_conv_id = conv_id
+ conv_update = gr.update()
+ new_conv_name = conv_name
+
+ return (
+ "",
+ chat_history + [(chat_input, None)],
+ new_conv_id,
+ conv_update,
+ new_conv_name,
+ )
+
+ def toggle_delete(self, conv_id):
+ if conv_id:
+ return gr.update(visible=False), gr.update(visible=True)
+ else:
+ return gr.update(visible=True), gr.update(visible=False)
+
+ def on_subscribe_public_events(self):
+ if self._app.f_user_management:
+ self._app.subscribe_event(
+ name="onSignIn",
+ definition={
+ "fn": self.chat_control.reload_conv,
+ "inputs": [self._app.user_id],
+ "outputs": [self.chat_control.conversation],
+ "show_progress": "hidden",
+ },
+ )
+
+ self._app.subscribe_event(
+ name="onSignOut",
+ definition={
+ "fn": lambda: self.chat_control.select_conv(""),
+ "outputs": [
+ self.chat_control.conversation_id,
+ self.chat_control.conversation,
+ self.chat_control.conversation_rn,
+ self.chat_panel.chatbot,
+ ]
+ + self._indices_input,
+ "show_progress": "hidden",
+ },
+ )
def update_data_source(self, convo_id, messages, state, *selecteds):
"""Update the data source"""
@@ -154,8 +318,12 @@ class ChatPage(BasePage):
selecteds_ = {}
for index in self._app.index_manager.indices:
- if index.selector != -1:
+ if index.selector is None:
+ continue
+ if isinstance(index.selector, int):
selecteds_[str(index.id)] = selecteds[index.selector]
+ else:
+ selecteds_[str(index.id)] = [selecteds[i] for i in index.selector]
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id)
@@ -205,8 +373,11 @@ class ChatPage(BasePage):
retrievers = []
for index in self._app.index_manager.indices:
index_selected = []
- if index.selector != -1:
+ if isinstance(index.selector, int):
index_selected = selecteds[index.selector]
+ if isinstance(index.selector, tuple):
+ for i in index.selector:
+ index_selected.append(selecteds[i])
iretrievers = index.get_retriever_pipelines(settings, index_selected)
retrievers += iretrievers
@@ -250,7 +421,10 @@ class ChatPage(BasePage):
break
if "output" in response:
- text += response["output"]
+ if response["output"] is None:
+ text = ""
+ else:
+ text += response["output"]
if "evidence" in response:
if response["evidence"] is None:
diff --git a/libs/ktem/ktem/pages/chat/chat_suggestion.py b/libs/ktem/ktem/pages/chat/chat_suggestion.py
new file mode 100644
index 0000000..23332c0
--- /dev/null
+++ b/libs/ktem/ktem/pages/chat/chat_suggestion.py
@@ -0,0 +1,26 @@
+import gradio as gr
+from ktem.app import BasePage
+from theflow.settings import settings as flowsettings
+
+
+class ChatSuggestion(BasePage):
+ def __init__(self, app):
+ self._app = app
+ self.on_building_ui()
+
+ def on_building_ui(self):
+ chat_samples = getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION_SAMPLES", [])
+ chat_samples = [[each] for each in chat_samples]
+ with gr.Accordion(label="Chat Suggestion", open=False) as self.accordion:
+ self.example = gr.DataFrame(
+ value=chat_samples,
+ headers=["Sample"],
+ interactive=False,
+ wrap=True,
+ )
+
+ def as_gradio_component(self):
+ return self.example
+
+ def select_example(self, ev: gr.SelectData):
+ return ev.value
diff --git a/libs/ktem/ktem/pages/chat/control.py b/libs/ktem/ktem/pages/chat/control.py
index e714112..f2ed99b 100644
--- a/libs/ktem/ktem/pages/chat/control.py
+++ b/libs/ktem/ktem/pages/chat/control.py
@@ -10,6 +10,17 @@ from .common import STATE
logger = logging.getLogger(__name__)
+def is_conv_name_valid(name):
+ """Check if the conversation name is valid"""
+ errors = []
+ if len(name) == 0:
+ errors.append("Name cannot be empty")
+ elif len(name) > 40:
+ errors.append("Name cannot be longer than 40 characters")
+
+ return "; ".join(errors)
+
+
class ConversationControl(BasePage):
"""Manage conversation"""
@@ -28,9 +39,17 @@ class ConversationControl(BasePage):
interactive=True,
)
- with gr.Row():
- self.conversation_new_btn = gr.Button(value="New", min_width=10)
- self.conversation_del_btn = gr.Button(value="Delete", min_width=10)
+ with gr.Row() as self._new_delete:
+ self.btn_new = gr.Button(value="New", min_width=10)
+ self.btn_del = gr.Button(value="Delete", min_width=10)
+
+ with gr.Row(visible=False) as self._delete_confirm:
+ self.btn_del_conf = gr.Button(
+ value="Delete",
+ variant="primary",
+ min_width=10,
+ )
+ self.btn_del_cnl = gr.Button(value="Cancel", min_width=10)
with gr.Row():
self.conversation_rn = gr.Text(
@@ -52,48 +71,6 @@ class ConversationControl(BasePage):
# outputs=[current_state],
# )
- def on_subscribe_public_events(self):
- if self._app.f_user_management:
- self._app.subscribe_event(
- name="onSignIn",
- definition={
- "fn": self.reload_conv,
- "inputs": [self._app.user_id],
- "outputs": [self.conversation],
- "show_progress": "hidden",
- },
- )
-
- self._app.subscribe_event(
- name="onSignOut",
- definition={
- "fn": self.reload_conv,
- "inputs": [self._app.user_id],
- "outputs": [self.conversation],
- "show_progress": "hidden",
- },
- )
-
- def on_register_events(self):
- self.conversation_new_btn.click(
- self.new_conv,
- inputs=self._app.user_id,
- outputs=[self.conversation_id, self.conversation],
- show_progress="hidden",
- )
- self.conversation_del_btn.click(
- self.delete_conv,
- inputs=[self.conversation_id, self._app.user_id],
- outputs=[self.conversation_id, self.conversation],
- show_progress="hidden",
- )
- self.conversation_rn_btn.click(
- self.rename_conv,
- inputs=[self.conversation_id, self.conversation_rn, self._app.user_id],
- outputs=[self.conversation, self.conversation],
- show_progress="hidden",
- )
-
def load_chat_history(self, user_id):
"""Reload chat history"""
options = []
@@ -112,7 +89,7 @@ class ConversationControl(BasePage):
def reload_conv(self, user_id):
conv_list = self.load_chat_history(user_id)
if conv_list:
- return gr.update(value=conv_list[0][1], choices=conv_list)
+ return gr.update(value=None, choices=conv_list)
else:
return gr.update(value=None, choices=[])
@@ -133,10 +110,15 @@ class ConversationControl(BasePage):
return id_, gr.update(value=id_, choices=history)
def delete_conv(self, conversation_id, user_id):
- """Create new chat"""
+ """Delete the selected conversation"""
+ if not conversation_id:
+ gr.Warning("No conversation selected.")
+ return None, gr.update()
+
if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)")
return None, gr.update()
+
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()
@@ -161,6 +143,7 @@ class ConversationControl(BasePage):
name = result.name
selected = result.data_source.get("selected", {})
chats = result.data_source.get("messages", [])
+ info_panel = ""
state = result.data_source.get("state", STATE)
except Exception as e:
logger.warning(e)
@@ -168,22 +151,36 @@ class ConversationControl(BasePage):
name = ""
selected = {}
chats = []
+ info_panel = ""
state = STATE
indices = []
for index in self._app.index_manager.indices:
# assume that the index has selector
- if index.selector == -1:
+ if index.selector is None:
continue
- indices.append(selected.get(str(index.id), []))
+ if isinstance(index.selector, int):
+ indices.append(selected.get(str(index.id), []))
+ if isinstance(index.selector, tuple):
+ indices.extend(selected.get(str(index.id), [[]] * len(index.selector)))
- return id_, id_, name, chats, state, *indices
+ return id_, id_, name, chats, info_panel, state, *indices
def rename_conv(self, conversation_id, new_name, user_id):
"""Rename the conversation"""
if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)")
return gr.update(), ""
+
+ if not conversation_id:
+ gr.Warning("No conversation selected.")
+ return gr.update(), ""
+
+ errors = is_conv_name_valid(new_name)
+ if errors:
+ gr.Warning(errors)
+ return gr.update(), conversation_id
+
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()
diff --git a/libs/ktem/ktem/pages/chat/report.py b/libs/ktem/ktem/pages/chat/report.py
index 25d83f8..dfe0301 100644
--- a/libs/ktem/ktem/pages/chat/report.py
+++ b/libs/ktem/ktem/pages/chat/report.py
@@ -48,13 +48,19 @@ class ReportIssue(BasePage):
chat_history: list,
settings: dict,
user_id: Optional[int],
+ info_panel: str,
chat_state: dict,
- *selecteds
+ *selecteds,
):
selecteds_ = {}
for index in self._app.index_manager.indices:
- if index.selector != -1:
- selecteds_[str(index.id)] = selecteds[index.selector]
+ if index.selector is not None:
+ if isinstance(index.selector, int):
+ selecteds_[str(index.id)] = selecteds[index.selector]
+ elif isinstance(index.selector, tuple):
+ selecteds_[str(index.id)] = [selecteds[_] for _ in index.selector]
+ else:
+ print(f"Unknown selector type: {index.selector}")
with Session(engine) as session:
issue = IssueReport(
@@ -66,6 +72,7 @@ class ReportIssue(BasePage):
chat={
"conv_id": conv_id,
"chat_history": chat_history,
+ "info_panel": info_panel,
"chat_state": chat_state,
"selecteds": selecteds_,
},
diff --git a/libs/ktem/ktem/pages/login.py b/libs/ktem/ktem/pages/login.py
index 6fe15d0..d5c57e5 100644
--- a/libs/ktem/ktem/pages/login.py
+++ b/libs/ktem/ktem/pages/login.py
@@ -31,11 +31,10 @@ class LoginPage(BasePage):
self.on_building_ui()
def on_building_ui(self):
- gr.Markdown("Welcome to Kotaemon")
- self.usn = gr.Textbox(label="Username")
- self.pwd = gr.Textbox(label="Password", type="password")
- self.btn_login = gr.Button("Login")
- self._dummy = gr.State()
+ gr.Markdown("# Welcome to Kotaemon")
+ self.usn = gr.Textbox(label="Username", visible=False)
+ self.pwd = gr.Textbox(label="Password", type="password", visible=False)
+ self.btn_login = gr.Button("Login", visible=False)
def on_register_events(self):
onSignIn = gr.on(
@@ -45,24 +44,56 @@ class LoginPage(BasePage):
outputs=[self._app.user_id, self.usn, self.pwd],
show_progress="hidden",
js=signin_js,
+ ).then(
+ self.toggle_login_visibility,
+ inputs=[self._app.user_id],
+ outputs=[self.usn, self.pwd, self.btn_login],
)
for event in self._app.get_event("onSignIn"):
onSignIn = onSignIn.success(**event)
+ def toggle_login_visibility(self, user_id):
+ return (
+ gr.update(visible=user_id is None),
+ gr.update(visible=user_id is None),
+ gr.update(visible=user_id is None),
+ )
+
def _on_app_created(self):
- self._app.app.load(
- None,
- inputs=None,
- outputs=[self.usn, self.pwd],
+ onSignIn = self._app.app.load(
+ self.login,
+ inputs=[self.usn, self.pwd],
+ outputs=[self._app.user_id, self.usn, self.pwd],
+ show_progress="hidden",
js=fetch_creds,
+ ).then(
+ self.toggle_login_visibility,
+ inputs=[self._app.user_id],
+ outputs=[self.usn, self.pwd, self.btn_login],
+ )
+ for event in self._app.get_event("onSignIn"):
+ onSignIn = onSignIn.success(**event)
+
+ def on_subscribe_public_events(self):
+ self._app.subscribe_event(
+ name="onSignOut",
+ definition={
+ "fn": self.toggle_login_visibility,
+ "inputs": [self._app.user_id],
+ "outputs": [self.usn, self.pwd, self.btn_login],
+ "show_progress": "hidden",
+ },
)
def login(self, usn, pwd):
+ if not usn or not pwd:
+ return None, usn, pwd
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
with Session(engine) as session:
stmt = select(User).where(
- User.username_lower == usn.lower(), User.password == hashed_password
+ User.username_lower == usn.lower().strip(),
+ User.password == hashed_password,
)
result = session.exec(stmt).all()
if result:
diff --git a/libs/ktem/ktem/pages/settings.py b/libs/ktem/ktem/pages/settings.py
index 0fce2e8..20912cb 100644
--- a/libs/ktem/ktem/pages/settings.py
+++ b/libs/ktem/ktem/pages/settings.py
@@ -164,9 +164,14 @@ class SettingsPage(BasePage):
show_progress="hidden",
)
onSignOutClick = self.signout.click(
- lambda: (None, "Current user: ___"),
+ lambda: (None, "Current user: ___", "", ""),
inputs=None,
- outputs=[self._user_id, self.current_name],
+ outputs=[
+ self._user_id,
+ self.current_name,
+ self.password_change,
+ self.password_change_confirm,
+ ],
show_progress="hidden",
js=signout_js,
).then(
@@ -192,8 +197,12 @@ class SettingsPage(BasePage):
self.password_change_btn = gr.Button("Change password", interactive=True)
def change_password(self, user_id, password, password_confirm):
- if password != password_confirm:
- gr.Warning("Password does not match")
+ from ktem.pages.admin.user import validate_password
+
+ errors = validate_password(password, password_confirm)
+ if errors:
+ print(errors)
+ gr.Warning(errors)
return password, password_confirm
with Session(engine) as session:
diff --git a/libs/ktem/ktem/reasoning/base.py b/libs/ktem/ktem/reasoning/base.py
index 80cf016..6d6e486 100644
--- a/libs/ktem/ktem/reasoning/base.py
+++ b/libs/ktem/ktem/reasoning/base.py
@@ -34,12 +34,16 @@ class BaseReasoning(BaseComponent):
@classmethod
def get_pipeline(
- cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None
+ cls,
+ user_settings: dict,
+ state: dict,
+ retrievers: Optional[list["BaseComponent"]] = None,
) -> "BaseReasoning":
"""Get the reasoning pipeline for the app to execute
Args:
user_setting: user settings
+ state: conversation state
retrievers (list): List of retrievers
"""
return cls()
diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py
index 23d8363..082c20f 100644
--- a/libs/ktem/ktem/reasoning/simple.py
+++ b/libs/ktem/ktem/reasoning/simple.py
@@ -22,6 +22,8 @@ from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate
from kotaemon.loaders.utils.gpt4v import stream_gpt4v
+from .base import BaseReasoning
+
logger = logging.getLogger(__name__)
EVIDENCE_MODE_TEXT = 0
@@ -204,7 +206,7 @@ class AnswerWithContextPipeline(BaseComponent):
lang: str = "English" # support English and Japanese
async def run( # type: ignore
- self, question: str, evidence: str, evidence_mode: int = 0
+ self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
"""Answer the question based on the evidence
@@ -336,7 +338,7 @@ class RewriteQuestionPipeline(BaseComponent):
return Document(text=output)
-class FullQAPipeline(BaseComponent):
+class FullQAPipeline(BaseReasoning):
"""Question answering pipeline. Handle from question to answer"""
class Config:
@@ -352,6 +354,8 @@ class FullQAPipeline(BaseComponent):
async def run( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
+ import markdown
+
docs = []
doc_ids = []
if self.use_rewrite:
@@ -364,12 +368,16 @@ class FullQAPipeline(BaseComponent):
docs.append(doc)
doc_ids.append(doc.doc_id)
for doc in docs:
+ # TODO: a better approach to show the information
+ text = markdown.markdown(
+ doc.text, extensions=["markdown.extensions.tables"]
+ )
self.report_output(
{
"evidence": (
""
f"{doc.metadata['file_name']}
"
- f"{doc.text}"
+ f"{text}"
"
"
)
}
@@ -378,7 +386,12 @@ class FullQAPipeline(BaseComponent):
evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = await self.answering_pipeline(
- question=message, evidence=evidence, evidence_mode=evidence_mode
+ question=message,
+ history=history,
+ evidence=evidence,
+ evidence_mode=evidence_mode,
+ conv_id=conv_id,
+ **kwargs,
)
# prepare citation
@@ -388,14 +401,29 @@ class FullQAPipeline(BaseComponent):
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
- if start_idx >= 0:
+ if start_idx == -1:
+ continue
+
+ end_idx = start_idx + len(quote)
+
+ current_idx = start_idx
+ if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
- {
- "start": start_idx,
- "end": start_idx + len(quote),
- }
+ {"start": start_idx, "end": end_idx}
)
- break
+ else:
+ while doc.text[current_idx:end_idx].find("|") != -1:
+ match_idx = doc.text[current_idx:end_idx].find("|")
+ spans[doc.doc_id].append(
+ {
+ "start": current_idx,
+ "end": current_idx + match_idx,
+ }
+ )
+ current_idx += match_idx + 2
+ if current_idx > end_idx:
+ break
+ break
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
@@ -414,12 +442,15 @@ class FullQAPipeline(BaseComponent):
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
+ text_out = markdown.markdown(
+ text, extensions=["markdown.extensions.tables"]
+ )
self.report_output(
{
"evidence": (
""
f"{id2docs[id].metadata['file_name']}
"
- f"{text}"
+ f"{text_out}"
"
"
)
}
@@ -434,12 +465,15 @@ class FullQAPipeline(BaseComponent):
{"evidence": "Retrieved segments without matching evidence:\n"}
)
for id in list(not_detected):
+ text_out = markdown.markdown(
+ id2docs[id].text, extensions=["markdown.extensions.tables"]
+ )
self.report_output(
{
"evidence": (
""
f"{id2docs[id].metadata['file_name']}
"
- f"{id2docs[id].text}"
+ f"{text_out}"
"
"
)
}