Fix UI bugs (#8)

* Auto create conversation when the user starts

* Add conversation rename rule check

* Fix empty name during save

* Confirm deleting conversation

* Show warning if users don't select file when upload files in the File Index

* Feedback when user uploads duplicated file

* Limit the file types

* Fix valid username

* Allow login when username with leading and trailing whitespaces

* Improve the user

* Disable admin panel for non-admnin user

* Refresh user lists after creating/deleting users

* Auto logging in

* Clear admin information upon signing out

* Fix unable to receive uploaded filename that include special characters, like !@#$%^&*().pdf

* Set upload validation for FileIndex

* Improve user management UI/UIX

* Show extraction error when indexing file

* Return selected user -1 when signing out

* Fix default supported file types in file index

* Validate changing password

* Allow the selector to contain mulitple gradio components

* A more tolerable placeholder screen

* Allow chat suggestion box

* Increase concurrency limit

* Make adobe loader optional

* Use BaseReasoning

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin 2024-04-03 16:33:54 +07:00 committed by GitHub
parent 43a18ba070
commit ecf09b275f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 936 additions and 255 deletions

1
.gitignore vendored
View File

@ -452,6 +452,7 @@ $RECYCLE.BIN/
.theflow/
# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm
*.py[coid]
logs/
.gitsecret/keys/random_seed

View File

@ -52,7 +52,12 @@ repos:
hooks:
- id: mypy
additional_dependencies:
[types-PyYAML==6.0.12.11, "types-requests", "sqlmodel"]
[
types-PyYAML==6.0.12.11,
"types-requests",
"sqlmodel",
"types-Markdown",
]
args: ["--check-untyped-defs", "--ignore-missing-imports"]
exclude: "^templates/"
- repo: https://github.com/codespell-project/codespell

View File

@ -104,18 +104,16 @@ class CitationPipeline(BaseComponent):
print("CitationPipeline: invoking LLM")
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
print("CitationPipeline: finish invoking LLM")
if not llm_output.messages:
return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
except Exception as e:
print(e)
return None
if not llm_output.messages:
return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
return output
async def ainvoke(self, context: str, question: str):

View File

@ -5,7 +5,7 @@ from .docx_loader import DocxReader
from .excel_loader import PandasExcelReader
from .html_loader import HtmlReader
from .mathpix_loader import MathpixPDFReader
from .ocr_loader import OCRReader
from .ocr_loader import ImageReader, OCRReader
from .unstructured_loader import UnstructuredReader
__all__ = [
@ -13,6 +13,7 @@ __all__ = [
"BaseReader",
"PandasExcelReader",
"MathpixPDFReader",
"ImageReader",
"OCRReader",
"DirectoryReader",
"UnstructuredReader",

View File

@ -10,14 +10,6 @@ from llama_index.readers.base import BaseReader
from kotaemon.base import Document
from .utils.adobe import (
generate_figure_captions,
load_json,
parse_figure_paths,
parse_table_paths,
request_adobe_service,
)
logger = logging.getLogger(__name__)
DEFAULT_VLM_ENDPOINT = (
@ -74,6 +66,13 @@ class AdobeReader(BaseReader):
includes 3 types: text, table, and image
"""
from .utils.adobe import (
generate_figure_captions,
load_json,
parse_figure_paths,
parse_table_paths,
request_adobe_service,
)
filename = file.name
filepath = str(Path(file).resolve())

View File

@ -125,3 +125,70 @@ class OCRReader(BaseReader):
)
return documents
class ImageReader(BaseReader):
"""Read PDF using OCR, with high focus on table extraction
Example:
```python
>> from knowledgehub.loaders import OCRReader
>> reader = OCRReader()
>> documents = reader.load_data("path/to/pdf")
```
Args:
endpoint: URL to FullOCR endpoint. If not provided, will look for
environment variable `OCR_READER_ENDPOINT` or use the default
`knowledgehub.loaders.ocr_loader.DEFAULT_OCR_ENDPOINT`
(http://127.0.0.1:8000/v2/ai/infer/)
use_ocr: whether to use OCR to read text (e.g: from images, tables) in the PDF
If False, only the table and text within table cells will be extracted.
"""
def __init__(self, endpoint: Optional[str] = None):
"""Init the OCR reader with OCR endpoint (FullOCR pipeline)"""
super().__init__()
self.ocr_endpoint = endpoint or os.getenv(
"OCR_READER_ENDPOINT", DEFAULT_OCR_ENDPOINT
)
def load_data(
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
) -> List[Document]:
"""Load data using OCR reader
Args:
file_path (Path): Path to PDF file
debug_path (Path): Path to store debug image output
artifact_path (Path): Path to OCR endpoints artifacts directory
Returns:
List[Document]: list of documents extracted from the PDF file
"""
file_path = Path(file_path).resolve()
with file_path.open("rb") as content:
files = {"input": content}
data = {"job_id": uuid4(), "table_only": False}
# call the API from FullOCR endpoint
if "response_content" in kwargs:
# overriding response content if specified
ocr_results = kwargs["response_content"]
else:
# call original API
resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
ocr_results = resp.json()["result"]
extra_info = extra_info or {}
result = []
for ocr_result in ocr_results:
result.append(
Document(
content=ocr_result["csv_string"],
metadata=extra_info,
)
)
return result

View File

@ -229,7 +229,9 @@ class BasePage:
def _on_app_created(self):
"""Called when the app is created"""
def as_gradio_component(self) -> Optional[gr.components.Component]:
def as_gradio_component(
self,
) -> Optional[gr.components.Component | list[gr.components.Component]]:
"""Return the gradio components responsible for events
Note: in ideal scenario, this method shouldn't be necessary.

View File

@ -1,6 +1,6 @@
import abc
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from ktem.app import BasePage
@ -57,7 +57,7 @@ class BaseIndex(abc.ABC):
self._app = app
self.id = id
self.name = name
self._config = config # admin settings
self.config = config # admin settings
def on_create(self):
"""Create the index for the first time"""
@ -121,7 +121,7 @@ class BaseIndex(abc.ABC):
...
def get_retriever_pipelines(
self, settings: dict, selected: Optional[list]
self, settings: dict, selected: Any = None
) -> list["BaseComponent"]:
"""Return the retriever pipelines to retrieve the entity from the index"""
return []

View File

@ -127,3 +127,11 @@ class BaseFileIndexIndexing(BaseComponent):
the absolute file storage path to the file
"""
raise NotImplementedError
def warning(self, msg):
"""Log a warning message
Args:
msg: the message to log
"""
print(msg)

View File

@ -13,7 +13,6 @@ from theflow.utils.modules import import_dotted_string
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .ui import FileIndexPage, FileSelector
class FileIndex(BaseIndex):
@ -77,9 +76,15 @@ class FileIndex(BaseIndex):
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
self._selector_ui_cls: Type
self._selector_ui: Any = None
self._index_ui_cls: Type
self._index_ui: Any = None
self._setup_indexing_cls()
self._setup_retriever_cls()
self._setup_file_index_ui_cls()
self._setup_file_selector_ui_cls()
self._default_settings: dict[str, dict] = {}
self._setting_mappings: dict[str, dict] = {}
@ -91,14 +96,14 @@ class FileIndex(BaseIndex):
The indexing class will is retrieved from the following order. Stop at the
first order found:
- `FILE_INDEX_PIPELINE` in self._config
- `FILE_INDEX_PIPELINE` in self.config
- `FILE_INDEX_{id}_PIPELINE` in the flowsettings
- `FILE_INDEX_PIPELINE` in the flowsettings
- The default .pipelines.IndexDocumentPipeline
"""
if "FILE_INDEX_PIPELINE" in self._config:
if "FILE_INDEX_PIPELINE" in self.config:
self._indexing_pipeline_cls = import_dotted_string(
self._config["FILE_INDEX_PIPELINE"], safe=False
self.config["FILE_INDEX_PIPELINE"], safe=False
)
return
@ -125,15 +130,15 @@ class FileIndex(BaseIndex):
The retriever classes will is retrieved from the following order. Stop at the
first order found:
- `FILE_INDEX_RETRIEVER_PIPELINES` in self._config
- `FILE_INDEX_RETRIEVER_PIPELINES` in self.config
- `FILE_INDEX_{id}_RETRIEVER_PIPELINES` in the flowsettings
- `FILE_INDEX_RETRIEVER_PIPELINES` in the flowsettings
- The default .pipelines.DocumentRetrievalPipeline
"""
if "FILE_INDEX_RETRIEVER_PIPELINES" in self._config:
if "FILE_INDEX_RETRIEVER_PIPELINES" in self.config:
self._retriever_pipeline_cls = [
import_dotted_string(each, safe=False)
for each in self._config["FILE_INDEX_RETRIEVER_PIPELINES"]
for each in self.config["FILE_INDEX_RETRIEVER_PIPELINES"]
]
return
@ -157,6 +162,76 @@ class FileIndex(BaseIndex):
self._retriever_pipeline_cls = [DocumentRetrievalPipeline]
def _setup_file_selector_ui_cls(self):
"""Retrieve the file selector UI for the file index
There can be multiple retriever classes.
The retriever classes will is retrieved from the following order. Stop at the
first order found:
- `FILE_INDEX_SELECTOR_UI` in self.config
- `FILE_INDEX_{id}_SELECTOR_UI` in the flowsettings
- `FILE_INDEX_SELECTOR_UI` in the flowsettings
- The default .ui.FileSelector
"""
if "FILE_INDEX_SELECTOR_UI" in self.config:
self._selector_ui_cls = import_dotted_string(
self.config["FILE_INDEX_SELECTOR_UI"], safe=False
)
return
if hasattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"):
self._selector_ui_cls = import_dotted_string(
getattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"),
safe=False,
)
return
if hasattr(flowsettings, "FILE_INDEX_SELECTOR_UI"):
self._selector_ui_cls = import_dotted_string(
getattr(flowsettings, "FILE_INDEX_SELECTOR_UI"), safe=False
)
return
from .ui import FileSelector
self._selector_ui_cls = FileSelector
def _setup_file_index_ui_cls(self):
"""Retrieve the Index UI class
There can be multiple retriever classes.
The retriever classes will is retrieved from the following order. Stop at the
first order found:
- `FILE_INDEX_UI` in self.config
- `FILE_INDEX_{id}_UI` in the flowsettings
- `FILE_INDEX_UI` in the flowsettings
- The default .ui.FileIndexPage
"""
if "FILE_INDEX_UI" in self.config:
self._index_ui_cls = import_dotted_string(
self.config["FILE_INDEX_UI"], safe=False
)
return
if hasattr(flowsettings, f"FILE_INDEX_{self.id}_UI"):
self._index_ui_cls = import_dotted_string(
getattr(flowsettings, f"FILE_INDEX_{self.id}_UI"),
safe=False,
)
return
if hasattr(flowsettings, "FILE_INDEX_UI"):
self._index_ui_cls = import_dotted_string(
getattr(flowsettings, "FILE_INDEX_UI"), safe=False
)
return
from .ui import FileIndexPage
self._index_ui_cls = FileIndexPage
def on_create(self):
"""Create the index for the first time
@ -165,6 +240,13 @@ class FileIndex(BaseIndex):
2. Create the vectorstore
3. Create the docstore
"""
file_types_str = self.config.get(
"supported_file_types",
self.get_admin_settings()["supported_file_types"]["value"],
)
file_types = [each.strip() for each in file_types_str.split(",")]
self.config["supported_file_types"] = file_types
self._resources["Source"].metadata.create_all(engine) # type: ignore
self._resources["Index"].metadata.create_all(engine) # type: ignore
self._fs_path.mkdir(parents=True, exist_ok=True)
@ -180,10 +262,14 @@ class FileIndex(BaseIndex):
shutil.rmtree(self._fs_path)
def get_selector_component_ui(self):
return FileSelector(self._app, self)
if self._selector_ui is None:
self._selector_ui = self._selector_ui_cls(self._app, self)
return self._selector_ui
def get_index_page_ui(self):
return FileIndexPage(self._app, self)
if self._index_ui is None:
self._index_ui = self._index_ui_cls(self._app, self)
return self._index_ui
def get_user_settings(self):
if self._default_settings:
@ -210,7 +296,31 @@ class FileIndex(BaseIndex):
"value": embedding_default,
"component": "dropdown",
"choices": embedding_choices,
}
},
"supported_file_types": {
"name": "Supported file types",
"value": (
"image, .pdf, .txt, .csv, .xlsx, .doc, .docx, .pptx, .html, .zip"
),
"component": "text",
},
"max_file_size": {
"name": "Max file size (MB) - set 0 to disable",
"value": 1000,
"component": "number",
},
"max_number_of_files": {
"name": "Max number of files that can be indexed - set 0 to disable",
"value": 0,
"component": "number",
},
"max_number_of_text_length": {
"name": (
"Max amount of characters that can be indexed - set 0 to disable"
),
"value": 0,
"component": "number",
},
}
def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing:
@ -224,14 +334,15 @@ class FileIndex(BaseIndex):
else:
stripped_settings[key] = value
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self._config)
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
obj.set_resources(resources=self._resources)
return obj
def get_retriever_pipelines(
self, settings: dict, selected: Optional[list] = None
self, settings: dict, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
# retrieval settings
prefix = f"index.options.{self.id}."
stripped_settings = {}
for key, value in settings.items():
@ -240,9 +351,12 @@ class FileIndex(BaseIndex):
else:
stripped_settings[key] = value
# transform selected id
selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected)
retrievers = []
for cls in self._retriever_pipeline_cls:
obj = cls.get_pipeline(stripped_settings, self._config, selected)
obj = cls.get_pipeline(stripped_settings, self.config, selected_ids)
if obj is None:
continue
obj.set_resources(self._resources)

View File

@ -9,6 +9,7 @@ from hashlib import sha256
from pathlib import Path
from typing import Optional
import gradio as gr
from ktem.components import embeddings, filestorage_path
from ktem.db.models import engine
from llama_index.vector_stores import (
@ -18,7 +19,7 @@ from llama_index.vector_stores import (
MetadataFilters,
)
from llama_index.vector_stores.types import VectorStoreQueryMode
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from theflow.settings import settings
from theflow.utils.modules import import_dotted_string
@ -279,6 +280,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
to_index: list[str] = []
file_to_hash: dict[str, str] = {}
errors = []
to_update = []
for file_path in file_paths:
abs_path = str(Path(file_path).resolve())
@ -291,16 +293,26 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
statement = select(Source).where(Source.name == Path(abs_path).name)
item = session.execute(statement).first()
if item and not reindex:
errors.append(Path(abs_path).name)
continue
if item:
if not reindex:
errors.append(Path(abs_path).name)
continue
else:
to_update.append(Path(abs_path).name)
to_index.append(abs_path)
if errors:
error_files = ", ".join(errors)
if len(error_files) > 100:
error_files = error_files[:80] + "..."
print(
"Files already exist. Please rename/remove them or enable reindex.\n"
f"{errors}"
"Skip these files already exist. Please rename/remove them or "
f"enable reindex:\n{errors}"
)
self.warning(
"Skip these files already exist. Please rename/remove them or "
f"enable reindex:\n{error_files}"
)
if not to_index:
@ -310,9 +322,19 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
for path in to_index:
shutil.copy(path, filestorage_path / file_to_hash[path])
# prepare record info
# extract the file & prepare record info
file_to_source: dict = {}
extraction_errors = []
nodes = []
for file_path, file_hash in file_to_hash.items():
if str(Path(file_path).resolve()) not in to_index:
continue
extraction_result = self.file_ingestor(file_path)
if not extraction_result:
extraction_errors.append(Path(file_path).name)
continue
nodes.extend(extraction_result)
source = Source(
name=Path(file_path).name,
path=file_hash,
@ -320,9 +342,23 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
)
file_to_source[file_path] = source
# extract the files
nodes = self.file_ingestor(to_index)
print("Extracted", len(to_index), "files into", len(nodes), "nodes")
if extraction_errors:
msg = "Failed to extract these files: {}".format(
", ".join(extraction_errors)
)
print(msg)
self.warning(msg)
if not nodes:
return [], []
print(
"Extracted",
len(to_index) - len(extraction_errors),
"files into",
len(nodes),
"nodes",
)
# index the files
print("Indexing the files into vector store")
@ -332,7 +368,11 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
# persist to the index
print("Persisting the vector and the document into index")
file_ids = []
to_update = list(set(to_update))
with Session(engine) as session:
if to_update:
session.execute(delete(Source).where(Source.name.in_(to_update)))
for source in file_to_source.values():
session.add(source)
session.commit()
@ -404,3 +444,6 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
super().set_resources(resources)
self.indexing_vector_pipeline.vector_store = self._VS
self.indexing_vector_pipeline.doc_store = self._DS
def warning(self, msg):
gr.Warning(msg)

View File

@ -1,29 +1,48 @@
import os
import tempfile
from pathlib import Path
import gradio as gr
import pandas as pd
from gradio.data_classes import FileData
from gradio.utils import NamedString
from ktem.app import BasePage
from ktem.db.engine import engine
from sqlalchemy import select
from sqlalchemy.orm import Session
class File(gr.File):
"""Subclass from gr.File to maintain the original filename
The issue happens when user uploads file with name like: !@#$%%^&*().pdf
"""
def _process_single_file(self, f: FileData) -> NamedString | bytes:
file_name = f.path
if self.type == "filepath":
if f.orig_name and Path(file_name).name != f.orig_name:
file_name = str(Path(file_name).parent / f.orig_name)
os.rename(f.path, file_name)
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
file.name = file_name
return NamedString(file_name)
elif self.type == "binary":
with open(file_name, "rb") as file_data:
return file_data.read()
else:
raise ValueError(
"Unknown type: "
+ str(type)
+ ". Please choose from: 'filepath', 'binary'."
)
class DirectoryUpload(BasePage):
def __init__(self, app):
self._app = app
self._supported_file_types = [
"image",
".pdf",
".txt",
".csv",
".xlsx",
".doc",
".docx",
".pptx",
".html",
".zip",
]
def __init__(self, app, index):
super().__init__(app)
self._index = index
self._supported_file_types = self._index.config.get("supported_file_types", [])
self.on_building_ui()
def on_building_ui(self):
@ -50,18 +69,7 @@ class FileIndexPage(BasePage):
def __init__(self, app, index):
super().__init__(app)
self._index = index
self._supported_file_types = [
"image",
".pdf",
".txt",
".csv",
".xlsx",
".doc",
".docx",
".pptx",
".html",
".zip",
]
self._supported_file_types = self._index.config.get("supported_file_types", [])
self.selected_panel_false = "Selected file: (please select above)"
self.selected_panel_true = "Selected file: {name}"
# TODO: on_building_ui is not correctly named if it's always called in
@ -69,13 +77,32 @@ class FileIndexPage(BasePage):
self.public_events = [f"onFileIndex{index.id}Changed"]
self.on_building_ui()
def upload_instruction(self) -> str:
msgs = []
if self._supported_file_types:
msgs.append(
f"- Supported file types: {', '.join(self._supported_file_types)}"
)
if max_file_size := self._index.config.get("max_file_size", 0):
msgs.append(f"- Maximum file size: {max_file_size} MB")
if max_number_of_files := self._index.config.get("max_number_of_files", 0):
msgs.append(f"- The index can have maximum {max_number_of_files} files")
if msgs:
return "\n".join(msgs)
return ""
def on_building_ui(self):
"""Build the UI of the app"""
with gr.Accordion(label="File upload", open=False):
gr.Markdown(
f"Supported file types: {', '.join(self._supported_file_types)}",
)
self.files = gr.File(
with gr.Accordion(label="File upload", open=True) as self.upload:
msg = self.upload_instruction()
if msg:
gr.Markdown(msg)
self.files = File(
file_types=self._supported_file_types,
file_count="multiple",
container=False,
@ -98,18 +125,20 @@ class FileIndexPage(BasePage):
interactive=False,
)
with gr.Row():
with gr.Row() as self.selection_info:
self.selected_file_id = gr.State(value=None)
self.selected_panel = gr.Markdown(self.selected_panel_false)
self.deselect_button = gr.Button("Deselect", visible=False)
with gr.Row():
with gr.Row() as self.tools:
with gr.Column():
self.view_button = gr.Button("View Text (WIP)")
with gr.Column():
self.delete_button = gr.Button("Delete")
with gr.Row():
self.delete_yes = gr.Button("Confirm Delete", visible=False)
self.delete_yes = gr.Button(
"Confirm Delete", variant="primary", visible=False
)
self.delete_no = gr.Button("Cancel", visible=False)
def on_subscribe_public_events(self):
@ -242,10 +271,12 @@ class FileIndexPage(BasePage):
self._app.settings_state,
],
outputs=[self.file_output],
concurrency_limit=20,
).then(
fn=self.list_file,
inputs=None,
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
onUploaded = onUploaded.then(**event)
@ -274,6 +305,15 @@ class FileIndexPage(BasePage):
selected_files: the list of files already selected
settings: the settings of the app
"""
if not files:
gr.Info("No uploaded file")
return gr.update()
errors = self.validate(files)
if errors:
gr.Warning(", ".join(errors))
return gr.update()
gr.Info(f"Start indexing {len(files)} files...")
# get the pipeline
@ -409,6 +449,35 @@ class FileIndexPage(BasePage):
name=list_files["name"][ev.index[0]]
)
def validate(self, files: list[str]):
"""Validate if the files are valid"""
paths = [Path(file) for file in files]
errors = []
if max_file_size := self._index.config.get("max_file_size", 0):
errors_max_size = []
for path in paths:
if path.stat().st_size > max_file_size * 1e6:
errors_max_size.append(path.name)
if errors_max_size:
str_errors = ", ".join(errors_max_size)
if len(str_errors) > 60:
str_errors = str_errors[:55] + "..."
errors.append(
f"Maximum file size ({max_file_size} MB) exceeded: {str_errors}"
)
if max_number_of_files := self._index.config.get("max_number_of_files", 0):
with Session(engine) as session:
current_num_files = session.query(
self._index._db_tables["Source"].id
).count()
if len(paths) + current_num_files > max_number_of_files:
errors.append(
f"Maximum number of files ({max_number_of_files}) will be exceeded"
)
return errors
class FileSelector(BasePage):
"""File selector UI in the Chat page"""
@ -430,6 +499,9 @@ class FileSelector(BasePage):
def as_gradio_component(self):
return self.selector
def get_selected_ids(self, selected):
return selected
def load_files(self, selected_files):
options = []
available_ids = []

View File

@ -1,4 +1,4 @@
from typing import Type
from typing import Optional, Type
from ktem.db.models import engine
from sqlmodel import Session, select
@ -49,15 +49,19 @@ class IndexManager:
Returns:
BaseIndex: the index object
"""
index_cls = import_dotted_string(index_type, safe=False)
index = index_cls(app=self._app, id=id, name=name, config=config)
index.on_create()
with Session(engine) as session:
index_entry = Index(id=id, name=name, config=config, index_type=index_type)
index_entry = Index(
id=index.id, name=index.name, config=index.config, index_type=index_type
)
session.add(index_entry)
session.commit()
session.refresh(index_entry)
index_cls = import_dotted_string(index_type, safe=False)
index = index_cls(app=self._app, id=id, name=name, config=config)
index.on_create()
index.id = index_entry.id
return index
@ -77,7 +81,7 @@ class IndexManager:
self._indices.append(index)
return index
def exists(self, id: int) -> bool:
def exists(self, id: Optional[int] = None, name: Optional[str] = None) -> bool:
"""Check if the index exists
Args:
@ -86,9 +90,19 @@ class IndexManager:
Returns:
bool: True if the index exists, False otherwise
"""
with Session(engine) as session:
index = session.get(Index, id)
return index is not None
if id:
with Session(engine) as session:
index = session.get(Index, id)
return index is not None
if name:
with Session(engine) as session:
index = session.exec(
select(Index).where(Index.name == name)
).one_or_none()
return index is not None
return False
def on_application_startup(self):
"""This method is called by the base application when the application starts

View File

@ -27,7 +27,7 @@ class App(BaseApp):
if self.f_user_management:
from ktem.pages.login import LoginPage
with gr.Tab("Login", elem_id="login-tab") as self._tabs["login-tab"]:
with gr.Tab("Welcome", elem_id="login-tab") as self._tabs["login-tab"]:
self.login_page = LoginPage(self)
with gr.Tab(
@ -62,6 +62,9 @@ class App(BaseApp):
def on_subscribe_public_events(self):
if self.f_user_management:
from ktem.db.engine import engine
from ktem.db.models import User
from sqlmodel import Session, select
def signed_in_out(user_id):
if not user_id:
@ -73,14 +76,31 @@ class App(BaseApp):
)
for k in self._tabs.keys()
)
return list(
(
gr.update(visible=True)
if k != "login-tab"
else gr.update(visible=False)
)
for k in self._tabs.keys()
)
with Session(engine) as session:
user = session.exec(select(User).where(User.id == user_id)).first()
if user is None:
return list(
(
gr.update(visible=True)
if k == "login-tab"
else gr.update(visible=False)
)
for k in self._tabs.keys()
)
is_admin = user.admin
tabs_update = []
for k in self._tabs.keys():
if k == "login-tab":
tabs_update.append(gr.update(visible=False))
elif k == "admin-tab":
tabs_update.append(gr.update(visible=is_admin))
else:
tabs_update.append(gr.update(visible=True))
return tabs_update
self.subscribe_event(
name="onSignIn",

View File

@ -40,7 +40,7 @@ def validate_username(usn):
if len(usn) > 32:
errors.append("Username must be at most 32 characters long")
if not usn.strip("_").isalnum():
if not usn.replace("_", "").isalnum():
errors.append(
"Username must contain only alphanumeric characters and underscores"
)
@ -97,8 +97,6 @@ def validate_password(pwd, pwd_cnf):
class UserManagement(BasePage):
def __init__(self, app):
self._app = app
self.selected_panel_false = "Selected user: (please select above)"
self.selected_panel_true = "Selected user: {name}"
self.on_building_ui()
if hasattr(flowsettings, "KH_FEATURE_USER_MANAGEMENT_ADMIN") and hasattr(
@ -126,7 +124,38 @@ class UserManagement(BasePage):
gr.Info(f'User "{usn}" created successfully')
def on_building_ui(self):
with gr.Accordion(label="Create user", open=False):
with gr.Tab(label="User list"):
self.state_user_list = gr.State(value=None)
self.user_list = gr.DataFrame(
headers=["id", "name", "admin"],
interactive=False,
)
with gr.Group(visible=False) as self._selected_panel:
self.selected_user_id = gr.Number(value=-1, visible=False)
self.usn_edit = gr.Textbox(label="Username")
with gr.Row():
self.pwd_edit = gr.Textbox(label="Change password", type="password")
self.pwd_cnf_edit = gr.Textbox(
label="Confirm change password",
type="password",
)
self.admin_edit = gr.Checkbox(label="Admin")
with gr.Row() as self._selected_panel_btn:
with gr.Column():
self.btn_edit_save = gr.Button("Save")
with gr.Column():
self.btn_delete = gr.Button("Delete")
with gr.Row():
self.btn_delete_yes = gr.Button(
"Confirm delete", variant="primary", visible=False
)
self.btn_delete_no = gr.Button("Cancel", visible=False)
with gr.Column():
self.btn_close = gr.Button("Close")
with gr.Tab(label="Create user"):
self.usn_new = gr.Textbox(label="Username", interactive=True)
self.pwd_new = gr.Textbox(
label="Password", type="password", interactive=True
@ -139,52 +168,28 @@ class UserManagement(BasePage):
gr.Markdown(PASSWORD_RULE)
self.btn_new = gr.Button("Create user")
gr.Markdown("## User list")
self.btn_list_user = gr.Button("Refresh user list")
self.state_user_list = gr.State(value=None)
self.user_list = gr.DataFrame(
headers=["id", "name", "admin"],
interactive=False,
)
with gr.Row():
self.selected_user_id = gr.State(value=None)
self.selected_panel = gr.Markdown(self.selected_panel_false)
self.deselect_button = gr.Button("Deselect", visible=False)
with gr.Group():
self.btn_delete = gr.Button("Delete user")
with gr.Row():
self.btn_delete_yes = gr.Button("Confirm", visible=False)
self.btn_delete_no = gr.Button("Cancel", visible=False)
gr.Markdown("## User details")
self.usn_edit = gr.Textbox(label="Username")
self.pwd_edit = gr.Textbox(label="Password", type="password")
self.pwd_cnf_edit = gr.Textbox(label="Confirm password", type="password")
self.admin_edit = gr.Checkbox(label="Admin")
self.btn_edit_save = gr.Button("Save")
def on_register_events(self):
self.btn_new.click(
self.create_user,
inputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new],
outputs=None,
)
self.btn_list_user.click(
self.list_users, inputs=None, outputs=[self.state_user_list, self.user_list]
outputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new],
).then(
self.list_users,
inputs=self._app.user_id,
outputs=[self.state_user_list, self.user_list],
)
self.user_list.select(
self.select_user,
inputs=self.user_list,
outputs=[self.selected_user_id, self.selected_panel],
outputs=[self.selected_user_id],
show_progress="hidden",
)
self.selected_panel.change(
self.selected_user_id.change(
self.on_selected_user_change,
inputs=[self.selected_user_id],
outputs=[
self.deselect_button,
self._selected_panel,
self._selected_panel_btn,
# delete section
self.btn_delete,
self.btn_delete_yes,
@ -197,12 +202,6 @@ class UserManagement(BasePage):
],
show_progress="hidden",
)
self.deselect_button.click(
lambda: (None, self.selected_panel_false),
inputs=None,
outputs=[self.selected_user_id, self.selected_panel],
show_progress="hidden",
)
self.btn_delete.click(
self.on_btn_delete_click,
inputs=[self.selected_user_id],
@ -211,9 +210,13 @@ class UserManagement(BasePage):
)
self.btn_delete_yes.click(
self.delete_user,
inputs=[self.selected_user_id],
outputs=[self.selected_user_id, self.selected_panel],
inputs=[self._app.user_id, self.selected_user_id],
outputs=[self.selected_user_id],
show_progress="hidden",
).then(
self.list_users,
inputs=self._app.user_id,
outputs=[self.state_user_list, self.user_list],
)
self.btn_delete_no.click(
lambda: (
@ -234,21 +237,53 @@ class UserManagement(BasePage):
self.pwd_cnf_edit,
self.admin_edit,
],
outputs=None,
outputs=[self.pwd_edit, self.pwd_cnf_edit],
show_progress="hidden",
).then(
self.list_users,
inputs=self._app.user_id,
outputs=[self.state_user_list, self.user_list],
)
self.btn_close.click(
lambda: -1,
outputs=[self.selected_user_id],
)
def on_subscribe_public_events(self):
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.list_users,
"inputs": [self._app.user_id],
"outputs": [self.state_user_list, self.user_list],
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": lambda: ("", "", "", None, None, -1),
"outputs": [
self.usn_new,
self.pwd_new,
self.pwd_cnf_new,
self.state_user_list,
self.user_list,
self.selected_user_id,
],
},
)
def create_user(self, usn, pwd, pwd_cnf):
errors = validate_username(usn)
if errors:
gr.Warning(errors)
return
return usn, pwd, pwd_cnf
errors = validate_password(pwd, pwd_cnf)
print(errors)
if errors:
gr.Warning(errors)
return
return usn, pwd, pwd_cnf
with Session(engine) as session:
statement = select(User).where(User.username_lower == usn.lower())
@ -265,8 +300,22 @@ class UserManagement(BasePage):
session.commit()
gr.Info(f'User "{usn}" created successfully')
def list_users(self):
return "", "", ""
def list_users(self, user_id):
if user_id is None:
return [], pd.DataFrame.from_records(
[{"id": "-", "username": "-", "admin": "-"}]
)
with Session(engine) as session:
statement = select(User).where(User.id == user_id)
user = session.exec(statement).one()
if not user.admin:
return [], pd.DataFrame.from_records(
[{"id": "-", "username": "-", "admin": "-"}]
)
statement = select(User)
results = [
{"id": user.id, "username": user.username, "admin": user.admin}
@ -284,18 +333,17 @@ class UserManagement(BasePage):
def select_user(self, user_list, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0:
gr.Info("No user is loaded. Please refresh the user list")
return None, self.selected_panel_false
return -1
if not ev.selected:
return None, self.selected_panel_false
return -1
return user_list["id"][ev.index[0]], self.selected_panel_true.format(
name=user_list["username"][ev.index[0]]
)
return user_list["id"][ev.index[0]]
def on_selected_user_change(self, selected_user_id):
if selected_user_id is None:
deselect_button = gr.update(visible=False)
if selected_user_id == -1:
_selected_panel = gr.update(visible=False)
_selected_panel_btn = gr.update(visible=False)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
@ -304,7 +352,8 @@ class UserManagement(BasePage):
pwd_cnf_edit = gr.update(value="")
admin_edit = gr.update(value=False)
else:
deselect_button = gr.update(visible=True)
_selected_panel = gr.update(visible=True)
_selected_panel_btn = gr.update(visible=True)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
@ -319,7 +368,8 @@ class UserManagement(BasePage):
admin_edit = gr.update(value=user.admin)
return (
deselect_button,
_selected_panel,
_selected_panel_btn,
btn_delete,
btn_delete_yes,
btn_delete_no,
@ -344,17 +394,16 @@ class UserManagement(BasePage):
return btn_delete, btn_delete_yes, btn_delete_no
def save_user(self, selected_user_id, usn, pwd, pwd_cnf, admin):
if usn:
errors = validate_username(usn)
if errors:
gr.Warning(errors)
return
errors = validate_username(usn)
if errors:
gr.Warning(errors)
return pwd, pwd_cnf
if pwd:
errors = validate_password(pwd, pwd_cnf)
if errors:
gr.Warning(errors)
return
return pwd, pwd_cnf
with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id))
@ -367,11 +416,17 @@ class UserManagement(BasePage):
session.commit()
gr.Info(f'User "{usn}" updated successfully')
def delete_user(self, selected_user_id):
return "", ""
def delete_user(self, current_user, selected_user_id):
if current_user == selected_user_id:
gr.Warning("You cannot delete yourself")
return selected_user_id
with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id))
user = session.exec(statement).one()
session.delete(user)
session.commit()
gr.Info(f'User "{user.username}" deleted successfully')
return None, self.selected_panel_false
return -1

View File

@ -7,8 +7,10 @@ from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
from sqlmodel import Session, select
from theflow.settings import settings as flowsettings
from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
from .common import STATE
from .control import ConversationControl
from .report import ReportIssue
@ -26,24 +28,39 @@ class ChatPage(BasePage):
with gr.Column(scale=1):
self.chat_control = ConversationControl(self._app)
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion = ChatSuggestion(self._app)
for index in self._app.index_manager.indices:
index.selector = -1
index.selector = None
index_ui = index.get_selector_component_ui()
if not index_ui:
# the index doesn't have a selector UI component
continue
index_ui.unrender()
index_ui.unrender() # need to rerender later within Accordion
with gr.Accordion(label=f"{index.name} Index", open=False):
index_ui.render()
gr_index = index_ui.as_gradio_component()
if gr_index:
index.selector = len(self._indices_input)
self._indices_input.append(gr_index)
if isinstance(gr_index, list):
index.selector = tuple(
range(
len(self._indices_input),
len(self._indices_input) + len(gr_index),
)
)
self._indices_input.extend(gr_index)
else:
index.selector = len(self._indices_input)
self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui)
self.report_issue = ReportIssue(self._app)
with gr.Column(scale=6):
self.chat_panel = ChatPanel(self._app)
with gr.Column(scale=3):
with gr.Accordion(label="Information panel", open=True):
self.info_panel = gr.HTML(elem_id="chat-info-panel")
@ -54,11 +71,24 @@ class ChatPage(BasePage):
self.chat_panel.text_input.submit,
self.chat_panel.submit_btn.click,
],
fn=self.chat_panel.submit_msg,
inputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
fn=self.submit_msg,
inputs=[
self.chat_panel.text_input,
self.chat_panel.chatbot,
self._app.user_id,
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
],
outputs=[
self.chat_panel.text_input,
self.chat_panel.chatbot,
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
],
concurrency_limit=20,
show_progress="hidden",
).then(
).success(
fn=self.chat_fn,
inputs=[
self.chat_control.conversation_id,
@ -72,6 +102,7 @@ class ChatPage(BasePage):
self.info_panel,
self.chat_state,
],
concurrency_limit=20,
show_progress="minimal",
).then(
fn=self.update_data_source,
@ -82,6 +113,7 @@ class ChatPage(BasePage):
]
+ self._indices_input,
outputs=None,
concurrency_limit=20,
)
self.chat_panel.regen_btn.click(
@ -98,6 +130,7 @@ class ChatPage(BasePage):
self.info_panel,
self.chat_state,
],
concurrency_limit=20,
show_progress="minimal",
).then(
fn=self.update_data_source,
@ -108,6 +141,7 @@ class ChatPage(BasePage):
]
+ self._indices_input,
outputs=None,
concurrency_limit=20,
)
self.chat_panel.chatbot.like(
@ -116,7 +150,12 @@ class ChatPage(BasePage):
outputs=None,
)
self.chat_control.conversation.change(
self.chat_control.btn_new.click(
self.chat_control.new_conv,
inputs=self._app.user_id,
outputs=[self.chat_control.conversation_id, self.chat_control.conversation],
show_progress="hidden",
).then(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
outputs=[
@ -124,12 +163,71 @@ class ChatPage(BasePage):
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
]
+ self._indices_input,
show_progress="hidden",
)
self.chat_control.btn_del.click(
lambda id: self.toggle_delete(id),
inputs=[self.chat_control.conversation_id],
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
)
self.chat_control.btn_del_conf.click(
self.chat_control.delete_conv,
inputs=[self.chat_control.conversation_id, self._app.user_id],
outputs=[self.chat_control.conversation_id, self.chat_control.conversation],
show_progress="hidden",
).then(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
]
+ self._indices_input,
show_progress="hidden",
).then(
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
)
self.chat_control.btn_del_cnl.click(
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
)
self.chat_control.conversation_rn_btn.click(
self.chat_control.rename_conv,
inputs=[
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
self._app.user_id,
],
outputs=[self.chat_control.conversation, self.chat_control.conversation],
show_progress="hidden",
)
self.chat_control.conversation.select(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
]
+ self._indices_input,
show_progress="hidden",
).then(
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
)
self.report_issue.report_btn.click(
self.report_issue.report,
inputs=[
@ -140,11 +238,77 @@ class ChatPage(BasePage):
self.chat_panel.chatbot,
self._app.settings_state,
self._app.user_id,
self.info_panel,
self.chat_state,
]
+ self._indices_input,
outputs=None,
)
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion.example.select(
self.chat_suggestion.select_example,
outputs=[self.chat_panel.text_input],
show_progress="hidden",
)
def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name):
"""Submit a message to the chatbot"""
if not chat_input:
raise ValueError("Input is empty")
if not conv_id:
id_, update = self.chat_control.new_conv(user_id)
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == id_)
name = session.exec(statement).one().name
new_conv_id = id_
conv_update = update
new_conv_name = name
else:
new_conv_id = conv_id
conv_update = gr.update()
new_conv_name = conv_name
return (
"",
chat_history + [(chat_input, None)],
new_conv_id,
conv_update,
new_conv_name,
)
def toggle_delete(self, conv_id):
if conv_id:
return gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=True), gr.update(visible=False)
def on_subscribe_public_events(self):
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.chat_control.reload_conv,
"inputs": [self._app.user_id],
"outputs": [self.chat_control.conversation],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": lambda: self.chat_control.select_conv(""),
"outputs": [
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
]
+ self._indices_input,
"show_progress": "hidden",
},
)
def update_data_source(self, convo_id, messages, state, *selecteds):
"""Update the data source"""
@ -154,8 +318,12 @@ class ChatPage(BasePage):
selecteds_ = {}
for index in self._app.index_manager.indices:
if index.selector != -1:
if index.selector is None:
continue
if isinstance(index.selector, int):
selecteds_[str(index.id)] = selecteds[index.selector]
else:
selecteds_[str(index.id)] = [selecteds[i] for i in index.selector]
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id)
@ -205,8 +373,11 @@ class ChatPage(BasePage):
retrievers = []
for index in self._app.index_manager.indices:
index_selected = []
if index.selector != -1:
if isinstance(index.selector, int):
index_selected = selecteds[index.selector]
if isinstance(index.selector, tuple):
for i in index.selector:
index_selected.append(selecteds[i])
iretrievers = index.get_retriever_pipelines(settings, index_selected)
retrievers += iretrievers
@ -250,7 +421,10 @@ class ChatPage(BasePage):
break
if "output" in response:
text += response["output"]
if response["output"] is None:
text = ""
else:
text += response["output"]
if "evidence" in response:
if response["evidence"] is None:

View File

@ -0,0 +1,26 @@
import gradio as gr
from ktem.app import BasePage
from theflow.settings import settings as flowsettings
class ChatSuggestion(BasePage):
def __init__(self, app):
self._app = app
self.on_building_ui()
def on_building_ui(self):
chat_samples = getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION_SAMPLES", [])
chat_samples = [[each] for each in chat_samples]
with gr.Accordion(label="Chat Suggestion", open=False) as self.accordion:
self.example = gr.DataFrame(
value=chat_samples,
headers=["Sample"],
interactive=False,
wrap=True,
)
def as_gradio_component(self):
return self.example
def select_example(self, ev: gr.SelectData):
return ev.value

View File

@ -10,6 +10,17 @@ from .common import STATE
logger = logging.getLogger(__name__)
def is_conv_name_valid(name):
"""Check if the conversation name is valid"""
errors = []
if len(name) == 0:
errors.append("Name cannot be empty")
elif len(name) > 40:
errors.append("Name cannot be longer than 40 characters")
return "; ".join(errors)
class ConversationControl(BasePage):
"""Manage conversation"""
@ -28,9 +39,17 @@ class ConversationControl(BasePage):
interactive=True,
)
with gr.Row():
self.conversation_new_btn = gr.Button(value="New", min_width=10)
self.conversation_del_btn = gr.Button(value="Delete", min_width=10)
with gr.Row() as self._new_delete:
self.btn_new = gr.Button(value="New", min_width=10)
self.btn_del = gr.Button(value="Delete", min_width=10)
with gr.Row(visible=False) as self._delete_confirm:
self.btn_del_conf = gr.Button(
value="Delete",
variant="primary",
min_width=10,
)
self.btn_del_cnl = gr.Button(value="Cancel", min_width=10)
with gr.Row():
self.conversation_rn = gr.Text(
@ -52,48 +71,6 @@ class ConversationControl(BasePage):
# outputs=[current_state],
# )
def on_subscribe_public_events(self):
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.reload_conv,
"inputs": [self._app.user_id],
"outputs": [self.conversation],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.reload_conv,
"inputs": [self._app.user_id],
"outputs": [self.conversation],
"show_progress": "hidden",
},
)
def on_register_events(self):
self.conversation_new_btn.click(
self.new_conv,
inputs=self._app.user_id,
outputs=[self.conversation_id, self.conversation],
show_progress="hidden",
)
self.conversation_del_btn.click(
self.delete_conv,
inputs=[self.conversation_id, self._app.user_id],
outputs=[self.conversation_id, self.conversation],
show_progress="hidden",
)
self.conversation_rn_btn.click(
self.rename_conv,
inputs=[self.conversation_id, self.conversation_rn, self._app.user_id],
outputs=[self.conversation, self.conversation],
show_progress="hidden",
)
def load_chat_history(self, user_id):
"""Reload chat history"""
options = []
@ -112,7 +89,7 @@ class ConversationControl(BasePage):
def reload_conv(self, user_id):
conv_list = self.load_chat_history(user_id)
if conv_list:
return gr.update(value=conv_list[0][1], choices=conv_list)
return gr.update(value=None, choices=conv_list)
else:
return gr.update(value=None, choices=[])
@ -133,10 +110,15 @@ class ConversationControl(BasePage):
return id_, gr.update(value=id_, choices=history)
def delete_conv(self, conversation_id, user_id):
"""Create new chat"""
"""Delete the selected conversation"""
if not conversation_id:
gr.Warning("No conversation selected.")
return None, gr.update()
if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)")
return None, gr.update()
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()
@ -161,6 +143,7 @@ class ConversationControl(BasePage):
name = result.name
selected = result.data_source.get("selected", {})
chats = result.data_source.get("messages", [])
info_panel = ""
state = result.data_source.get("state", STATE)
except Exception as e:
logger.warning(e)
@ -168,22 +151,36 @@ class ConversationControl(BasePage):
name = ""
selected = {}
chats = []
info_panel = ""
state = STATE
indices = []
for index in self._app.index_manager.indices:
# assume that the index has selector
if index.selector == -1:
if index.selector is None:
continue
indices.append(selected.get(str(index.id), []))
if isinstance(index.selector, int):
indices.append(selected.get(str(index.id), []))
if isinstance(index.selector, tuple):
indices.extend(selected.get(str(index.id), [[]] * len(index.selector)))
return id_, id_, name, chats, state, *indices
return id_, id_, name, chats, info_panel, state, *indices
def rename_conv(self, conversation_id, new_name, user_id):
"""Rename the conversation"""
if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)")
return gr.update(), ""
if not conversation_id:
gr.Warning("No conversation selected.")
return gr.update(), ""
errors = is_conv_name_valid(new_name)
if errors:
gr.Warning(errors)
return gr.update(), conversation_id
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()

View File

@ -48,13 +48,19 @@ class ReportIssue(BasePage):
chat_history: list,
settings: dict,
user_id: Optional[int],
info_panel: str,
chat_state: dict,
*selecteds
*selecteds,
):
selecteds_ = {}
for index in self._app.index_manager.indices:
if index.selector != -1:
selecteds_[str(index.id)] = selecteds[index.selector]
if index.selector is not None:
if isinstance(index.selector, int):
selecteds_[str(index.id)] = selecteds[index.selector]
elif isinstance(index.selector, tuple):
selecteds_[str(index.id)] = [selecteds[_] for _ in index.selector]
else:
print(f"Unknown selector type: {index.selector}")
with Session(engine) as session:
issue = IssueReport(
@ -66,6 +72,7 @@ class ReportIssue(BasePage):
chat={
"conv_id": conv_id,
"chat_history": chat_history,
"info_panel": info_panel,
"chat_state": chat_state,
"selecteds": selecteds_,
},

View File

@ -31,11 +31,10 @@ class LoginPage(BasePage):
self.on_building_ui()
def on_building_ui(self):
gr.Markdown("Welcome to Kotaemon")
self.usn = gr.Textbox(label="Username")
self.pwd = gr.Textbox(label="Password", type="password")
self.btn_login = gr.Button("Login")
self._dummy = gr.State()
gr.Markdown("# Welcome to Kotaemon")
self.usn = gr.Textbox(label="Username", visible=False)
self.pwd = gr.Textbox(label="Password", type="password", visible=False)
self.btn_login = gr.Button("Login", visible=False)
def on_register_events(self):
onSignIn = gr.on(
@ -45,24 +44,56 @@ class LoginPage(BasePage):
outputs=[self._app.user_id, self.usn, self.pwd],
show_progress="hidden",
js=signin_js,
).then(
self.toggle_login_visibility,
inputs=[self._app.user_id],
outputs=[self.usn, self.pwd, self.btn_login],
)
for event in self._app.get_event("onSignIn"):
onSignIn = onSignIn.success(**event)
def toggle_login_visibility(self, user_id):
return (
gr.update(visible=user_id is None),
gr.update(visible=user_id is None),
gr.update(visible=user_id is None),
)
def _on_app_created(self):
self._app.app.load(
None,
inputs=None,
outputs=[self.usn, self.pwd],
onSignIn = self._app.app.load(
self.login,
inputs=[self.usn, self.pwd],
outputs=[self._app.user_id, self.usn, self.pwd],
show_progress="hidden",
js=fetch_creds,
).then(
self.toggle_login_visibility,
inputs=[self._app.user_id],
outputs=[self.usn, self.pwd, self.btn_login],
)
for event in self._app.get_event("onSignIn"):
onSignIn = onSignIn.success(**event)
def on_subscribe_public_events(self):
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.toggle_login_visibility,
"inputs": [self._app.user_id],
"outputs": [self.usn, self.pwd, self.btn_login],
"show_progress": "hidden",
},
)
def login(self, usn, pwd):
if not usn or not pwd:
return None, usn, pwd
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
with Session(engine) as session:
stmt = select(User).where(
User.username_lower == usn.lower(), User.password == hashed_password
User.username_lower == usn.lower().strip(),
User.password == hashed_password,
)
result = session.exec(stmt).all()
if result:

View File

@ -164,9 +164,14 @@ class SettingsPage(BasePage):
show_progress="hidden",
)
onSignOutClick = self.signout.click(
lambda: (None, "Current user: ___"),
lambda: (None, "Current user: ___", "", ""),
inputs=None,
outputs=[self._user_id, self.current_name],
outputs=[
self._user_id,
self.current_name,
self.password_change,
self.password_change_confirm,
],
show_progress="hidden",
js=signout_js,
).then(
@ -192,8 +197,12 @@ class SettingsPage(BasePage):
self.password_change_btn = gr.Button("Change password", interactive=True)
def change_password(self, user_id, password, password_confirm):
if password != password_confirm:
gr.Warning("Password does not match")
from ktem.pages.admin.user import validate_password
errors = validate_password(password, password_confirm)
if errors:
print(errors)
gr.Warning(errors)
return password, password_confirm
with Session(engine) as session:

View File

@ -34,12 +34,16 @@ class BaseReasoning(BaseComponent):
@classmethod
def get_pipeline(
cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None
cls,
user_settings: dict,
state: dict,
retrievers: Optional[list["BaseComponent"]] = None,
) -> "BaseReasoning":
"""Get the reasoning pipeline for the app to execute
Args:
user_setting: user settings
state: conversation state
retrievers (list): List of retrievers
"""
return cls()

View File

@ -22,6 +22,8 @@ from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate
from kotaemon.loaders.utils.gpt4v import stream_gpt4v
from .base import BaseReasoning
logger = logging.getLogger(__name__)
EVIDENCE_MODE_TEXT = 0
@ -204,7 +206,7 @@ class AnswerWithContextPipeline(BaseComponent):
lang: str = "English" # support English and Japanese
async def run( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
"""Answer the question based on the evidence
@ -336,7 +338,7 @@ class RewriteQuestionPipeline(BaseComponent):
return Document(text=output)
class FullQAPipeline(BaseComponent):
class FullQAPipeline(BaseReasoning):
"""Question answering pipeline. Handle from question to answer"""
class Config:
@ -352,6 +354,8 @@ class FullQAPipeline(BaseComponent):
async def run( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
import markdown
docs = []
doc_ids = []
if self.use_rewrite:
@ -364,12 +368,16 @@ class FullQAPipeline(BaseComponent):
docs.append(doc)
doc_ids.append(doc.doc_id)
for doc in docs:
# TODO: a better approach to show the information
text = markdown.markdown(
doc.text, extensions=["markdown.extensions.tables"]
)
self.report_output(
{
"evidence": (
"<details open>"
f"<summary>{doc.metadata['file_name']}</summary>"
f"{doc.text}"
f"{text}"
"</details><br>"
)
}
@ -378,7 +386,12 @@ class FullQAPipeline(BaseComponent):
evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = await self.answering_pipeline(
question=message, evidence=evidence, evidence_mode=evidence_mode
question=message,
history=history,
evidence=evidence,
evidence_mode=evidence_mode,
conv_id=conv_id,
**kwargs,
)
# prepare citation
@ -388,14 +401,29 @@ class FullQAPipeline(BaseComponent):
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx >= 0:
if start_idx == -1:
continue
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{
"start": start_idx,
"end": start_idx + len(quote),
}
{"start": start_idx, "end": end_idx}
)
break
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
spans[doc.doc_id].append(
{
"start": current_idx,
"end": current_idx + match_idx,
}
)
current_idx += match_idx + 2
if current_idx > end_idx:
break
break
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
@ -414,12 +442,15 @@ class FullQAPipeline(BaseComponent):
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
text_out = markdown.markdown(
text, extensions=["markdown.extensions.tables"]
)
self.report_output(
{
"evidence": (
"<details open>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text}"
f"{text_out}"
"</details><br>"
)
}
@ -434,12 +465,15 @@ class FullQAPipeline(BaseComponent):
{"evidence": "Retrieved segments without matching evidence:\n"}
)
for id in list(not_detected):
text_out = markdown.markdown(
id2docs[id].text, extensions=["markdown.extensions.tables"]
)
self.report_output(
{
"evidence": (
"<details>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{id2docs[id].text}"
f"{text_out}"
"</details><br>"
)
}